ondrejbiza commited on
Commit
90e5776
1 Parent(s): 1ccf223

V2 config, revert requirements.

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  /venv
2
  /flagged
3
  /clevr_isa_ts
4
- *.pyc
 
 
1
  /venv
2
  /flagged
3
  /clevr_isa_ts
4
+ *.pyc
5
+ *.DS_Store
app.py CHANGED
@@ -10,7 +10,7 @@ import jax.numpy as jnp
10
  import numpy as np
11
  from PIL import Image
12
 
13
- from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale import get_config
14
  from invariant_slot_attention.lib import utils
15
 
16
 
@@ -61,8 +61,8 @@ def load_image(name):
61
  return img
62
 
63
 
64
- download_path = snapshot_download(repo_id="ondrejbiza/isa", allow_patterns="clevr_isa_ts*")
65
- checkpoint_dir = os.path.join(download_path, "clevr_isa_ts")
66
 
67
  model, state, rng = load_model(get_config(), checkpoint_dir)
68
 
 
10
  import numpy as np
11
  from PIL import Image
12
 
13
+ from invariant_slot_attention.configs.clevr_with_masks.equiv_transl_scale_v2 import get_config
14
  from invariant_slot_attention.lib import utils
15
 
16
 
 
61
  return img
62
 
63
 
64
+ download_path = snapshot_download(repo_id="ondrejbiza/isa", allow_patterns="clevr_isa_ts_v2*")
65
+ checkpoint_dir = os.path.join(download_path, "clevr_isa_ts_v2")
66
 
67
  model, state, rng = load_model(get_config(), checkpoint_dir)
68
 
invariant_slot_attention/configs/clevr_with_masks/equiv_transl_scale_v2.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""Config for unsupervised training on CLEVR."""
17
+
18
+ import ml_collections
19
+
20
+
21
+ def get_config():
22
+ """Get the default hyperparameter configuration."""
23
+ config = ml_collections.ConfigDict()
24
+
25
+ config.seed = 42
26
+ config.seed_data = True
27
+
28
+ config.batch_size = 64
29
+ config.num_train_steps = 500000 # from the original Slot Attention
30
+ config.init_checkpoint = ml_collections.ConfigDict()
31
+ config.init_checkpoint.xid = 0 # Disabled by default.
32
+ config.init_checkpoint.wid = 1
33
+
34
+ config.optimizer_configs = ml_collections.ConfigDict()
35
+ config.optimizer_configs.optimizer = "adam"
36
+
37
+ config.optimizer_configs.grad_clip = ml_collections.ConfigDict()
38
+ config.optimizer_configs.grad_clip.clip_method = "clip_by_global_norm"
39
+ config.optimizer_configs.grad_clip.clip_value = 0.05
40
+
41
+ config.lr_configs = ml_collections.ConfigDict()
42
+ config.lr_configs.learning_rate_schedule = "compound"
43
+ config.lr_configs.factors = "constant * cosine_decay * linear_warmup"
44
+ config.lr_configs.warmup_steps = 10000 # from the original Slot Attention
45
+ config.lr_configs.steps_per_cycle = config.get_ref("num_train_steps")
46
+ # from the original Slot Attention
47
+ config.lr_configs.base_learning_rate = 4e-4
48
+
49
+ config.eval_pad_last_batch = False # True
50
+ config.log_loss_every_steps = 50
51
+ config.eval_every_steps = 5000
52
+ config.checkpoint_every_steps = 5000
53
+
54
+ config.train_metrics_spec = {
55
+ "loss": "loss",
56
+ "ari": "ari",
57
+ "ari_nobg": "ari_nobg",
58
+ }
59
+ config.eval_metrics_spec = {
60
+ "eval_loss": "loss",
61
+ "eval_ari": "ari",
62
+ "eval_ari_nobg": "ari_nobg",
63
+ }
64
+
65
+ config.data = ml_collections.ConfigDict({
66
+ "dataset_name": "clevr_with_masks",
67
+ "shuffle_buffer_size": config.batch_size * 8,
68
+ "resolution": (128, 128)
69
+ })
70
+
71
+ config.max_instances = 11
72
+ config.num_slots = config.max_instances # Only used for metrics.
73
+ config.logging_min_n_colors = config.max_instances
74
+
75
+ config.preproc_train = [
76
+ "tfds_image_to_tfds_video",
77
+ "video_from_tfds",
78
+ "top_left_crop(top=29, left=64, height=192)",
79
+ "resize_small({size})".format(size=min(*config.data.resolution))
80
+ ]
81
+
82
+ config.preproc_eval = [
83
+ "tfds_image_to_tfds_video",
84
+ "video_from_tfds",
85
+ "top_left_crop(top=29, left=64, height=192)",
86
+ "resize_small({size})".format(size=min(*config.data.resolution))
87
+ ]
88
+
89
+ config.eval_slice_size = 1
90
+ config.eval_slice_keys = ["video", "segmentations_video"]
91
+
92
+ # Dictionary of targets and corresponding channels. Losses need to match.
93
+ targets = {"video": 3}
94
+ config.losses = {"recon": {"targets": list(targets)}}
95
+ config.losses = ml_collections.ConfigDict({
96
+ f"recon_{target}": {"loss_type": "recon", "key": target}
97
+ for target in targets})
98
+
99
+ config.model = ml_collections.ConfigDict({
100
+ "module": "invariant_slot_attention.modules.SAVi",
101
+
102
+ # Encoder.
103
+ "encoder": ml_collections.ConfigDict({
104
+ "module": "invariant_slot_attention.modules.FrameEncoder",
105
+ "reduction": "spatial_flatten",
106
+ "backbone": ml_collections.ConfigDict({
107
+ "module": "invariant_slot_attention.modules.SimpleCNN",
108
+ "features": [64, 64, 64, 64],
109
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
110
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1)]
111
+ }),
112
+ "pos_emb": ml_collections.ConfigDict({
113
+ "module": "invariant_slot_attention.modules.PositionEmbedding",
114
+ "embedding_type": "linear",
115
+ "update_type": "concat"
116
+ }),
117
+ }),
118
+
119
+ # Corrector.
120
+ "corrector": ml_collections.ConfigDict({
121
+ "module": "invariant_slot_attention.modules.SlotAttentionTranslScaleEquiv", # pylint: disable=line-too-long
122
+ "num_iterations": 3,
123
+ "qkv_size": 64,
124
+ "mlp_size": 128,
125
+ "grid_encoder": ml_collections.ConfigDict({
126
+ "module": "invariant_slot_attention.modules.MLP",
127
+ "hidden_size": 128,
128
+ "layernorm": "pre"
129
+ }),
130
+ "add_rel_pos_to_values": False, # V2
131
+ "zero_position_init": False, # Random positions.
132
+ "init_with_fixed_scale": None, # Random scales.
133
+ "scales_factor": 5.0,
134
+ }),
135
+
136
+ # Predictor.
137
+ # Removed since we are running a single frame.
138
+ "predictor": ml_collections.ConfigDict({
139
+ "module": "invariant_slot_attention.modules.Identity"
140
+ }),
141
+
142
+ # Initializer.
143
+ "initializer": ml_collections.ConfigDict({
144
+ "module": "invariant_slot_attention.modules.ParamStateInitRandomPositionsScales", # pylint: disable=line-too-long
145
+ "shape": (11, 64), # (num_slots, slot_size)
146
+ }),
147
+
148
+ # Decoder.
149
+ "decoder": ml_collections.ConfigDict({
150
+ "module":
151
+ "invariant_slot_attention.modules.SiameseSpatialBroadcastDecoder",
152
+ "resolution": (16, 16), # Update if data resolution or strides change
153
+ "backbone": ml_collections.ConfigDict({
154
+ "module": "invariant_slot_attention.modules.CNN",
155
+ "features": [64, 64, 64, 64, 64],
156
+ "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5), (5, 5)],
157
+ "strides": [(2, 2), (2, 2), (2, 2), (1, 1), (1, 1)],
158
+ "max_pool_strides": [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)],
159
+ "layer_transpose": [True, True, True, False, False]
160
+ }),
161
+ "target_readout": ml_collections.ConfigDict({
162
+ "module": "invariant_slot_attention.modules.Readout",
163
+ "keys": list(targets),
164
+ "readout_modules": [ml_collections.ConfigDict({ # pylint: disable=g-complex-comprehension
165
+ "module": "invariant_slot_attention.modules.MLP",
166
+ "num_hidden_layers": 0,
167
+ "hidden_size": 0,
168
+ "output_size": targets[k]}) for k in targets],
169
+ }),
170
+ "relative_positions_and_scales": True,
171
+ "pos_emb": ml_collections.ConfigDict({
172
+ "module":
173
+ "invariant_slot_attention.modules.RelativePositionEmbedding",
174
+ "embedding_type":
175
+ "linear",
176
+ "update_type":
177
+ "project_add",
178
+ "scales_factor":
179
+ 5.0,
180
+ }),
181
+ }),
182
+ "decode_corrected": True,
183
+ "decode_predicted": False,
184
+ })
185
+
186
+ # Which video-shaped variables to visualize.
187
+ config.debug_var_video_paths = {
188
+ "recon_masks": "decoder/alphas_softmaxed/__call__/0", # pylint: disable=line-too-long
189
+ }
190
+
191
+ # Define which attention matrices to log/visualize.
192
+ config.debug_var_attn_paths = {
193
+ "corrector_attn": "corrector/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn" # pylint: disable=line-too-long
194
+ }
195
+
196
+ # Widths of attention matrices (for reshaping to image grid).
197
+ config.debug_var_attn_widths = {
198
+ "corrector_attn": 16,
199
+ }
200
+
201
+ return config
202
+
requirements.txt CHANGED
@@ -4,7 +4,7 @@ tensorflow-cpu>=2.12.0
4
  tensorflow-datasets>=4.4.0
5
  matplotlib>=3.5.0
6
  clu>=0.0.3
7
- flax>=0.3.5
8
  chex>=0.0.7
9
  optax>=0.1.0
10
  ml-collections>=0.1.0
 
4
  tensorflow-datasets>=4.4.0
5
  matplotlib>=3.5.0
6
  clu>=0.0.3
7
+ flax==0.3.5
8
  chex>=0.0.7
9
  optax>=0.1.0
10
  ml-collections>=0.1.0