File size: 13,118 Bytes
a560c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Miscellaneous modules."""

from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from flax import linen as nn
import jax
import jax.numpy as jnp

from invariant_slot_attention.lib import utils

Shape = Tuple[int]

DType = Any
Array = Any  # jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]]  # pytype: disable=not-supported-yet
ProcessorState = ArrayTree
PRNGKey = Array
NestedDict = Dict[str, Any]


class Identity(nn.Module):
  """Module that applies the identity function, ignoring any additional args."""

  @nn.compact
  def __call__(self, inputs, **args):
    return inputs


class Readout(nn.Module):
  """Module for reading out multiple targets from an embedding."""

  keys: Sequence[str]
  readout_modules: Sequence[Callable[[], nn.Module]]
  stop_gradient: Optional[Sequence[bool]] = None

  @nn.compact
  def __call__(self, inputs, train = False):
    num_targets = len(self.keys)
    assert num_targets >= 1, "Need to have at least one target."
    assert len(self.readout_modules) == num_targets, (
        "len(modules) and len(keys) must match.")
    if self.stop_gradient is not None:
      assert len(self.stop_gradient) == num_targets, (
          "len(stop_gradient) and len(keys) must match.")
    outputs = {}
    for i in range(num_targets):
      if self.stop_gradient is not None and self.stop_gradient[i]:
        x = jax.lax.stop_gradient(inputs)
      else:
        x = inputs
      outputs[self.keys[i]] = self.readout_modules[i]()(x, train)  # pytype: disable=not-callable
    return outputs


class MLP(nn.Module):
  """Simple MLP with one hidden layer and optional pre-/post-layernorm."""

  hidden_size: int
  output_size: Optional[int] = None
  num_hidden_layers: int = 1
  activation_fn: Callable[[Array], Array] = nn.relu
  layernorm: Optional[str] = None
  activate_output: bool = False
  residual: bool = False

  @nn.compact
  def __call__(self, inputs, train = False):
    del train  # Unused.

    output_size = self.output_size or inputs.shape[-1]

    x = inputs

    if self.layernorm == "pre":
      x = nn.LayerNorm()(x)

    for i in range(self.num_hidden_layers):
      x = nn.Dense(self.hidden_size, name=f"dense_mlp_{i}")(x)
      x = self.activation_fn(x)
    x = nn.Dense(output_size, name=f"dense_mlp_{self.num_hidden_layers}")(x)

    if self.activate_output:
      x = self.activation_fn(x)

    if self.residual:
      x = x + inputs

    if self.layernorm == "post":
      x = nn.LayerNorm()(x)

    return x


class GRU(nn.Module):
  """GRU cell as nn.Module."""

  @nn.compact
  def __call__(self, carry, inputs,
               train = False):
    del train  # Unused.
    carry, _ = nn.GRUCell()(carry, inputs)
    return carry


class Dense(nn.Module):
  """Dense layer as nn.Module accepting "train" flag."""

  features: int
  use_bias: bool = True

  @nn.compact
  def __call__(self, inputs, train = False):
    del train  # Unused.
    return nn.Dense(features=self.features, use_bias=self.use_bias)(inputs)


class PositionEmbedding(nn.Module):
  """A module for applying N-dimensional position embedding.

  Attr:
    embedding_type: A string defining the type of position embedding to use. One
      of ["linear", "discrete_1d", "fourier", "gaussian_fourier"].
    update_type: A string defining how the input is updated with the position
      embedding. One of ["proj_add", "concat"].
    num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
      "fourier", the embedding dimensionality is 2 x number of position
      dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
      the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
      == "linear", this parameter is ignored.
    gaussian_sigma: Standard deviation of sampled Gaussians.
    pos_transform: Optional transform for the embedding.
    output_transform: Optional transform for the combined input and embedding.
    trainable_pos_embedding: Boolean flag for allowing gradients to flow into
      the position embedding, so that the optimizer can update it.
  """

  embedding_type: str
  update_type: str
  num_fourier_bases: int = 0
  gaussian_sigma: float = 1.0
  pos_transform: Callable[[], nn.Module] = Identity
  output_transform: Callable[[], nn.Module] = Identity
  trainable_pos_embedding: bool = False

  def _make_pos_embedding_tensor(self, rng, input_shape):
    if self.embedding_type == "discrete_1d":
      # An integer tensor in [0, input_shape[-2]-1] reflecting
      # 1D discrete position encoding (encode the second-to-last axis).
      pos_embedding = jnp.broadcast_to(
          jnp.arange(input_shape[-2]), input_shape[1:-1])
    else:
      # A tensor grid in [-1, +1] for each input dimension.
      pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])

    if self.embedding_type == "linear":
      pass
    elif self.embedding_type == "discrete_1d":
      pos_embedding = jax.nn.one_hot(pos_embedding, input_shape[-2])
    elif self.embedding_type == "fourier":
      # NeRF-style Fourier/sinusoidal position encoding.
      pos_embedding = utils.convert_to_fourier_features(
          pos_embedding * jnp.pi, basis_degree=self.num_fourier_bases)
    elif self.embedding_type == "gaussian_fourier":
      # Gaussian Fourier features. Reference: https://arxiv.org/abs/2006.10739
      num_dims = pos_embedding.shape[-1]
      projection = jax.random.normal(
          rng, [num_dims, self.num_fourier_bases]) * self.gaussian_sigma
      pos_embedding = jnp.pi * pos_embedding.dot(projection)
      # A slightly faster implementation of sin and cos.
      pos_embedding = jnp.sin(
          jnp.concatenate([pos_embedding, pos_embedding + 0.5 * jnp.pi],
                          axis=-1))
    else:
      raise ValueError("Invalid embedding type provided.")

    # Add batch dimension.
    pos_embedding = jnp.expand_dims(pos_embedding, axis=0)

    return pos_embedding

  @nn.compact
  def __call__(self, inputs):

    # Compute the position embedding only in the initial call use the same rng
    # as is used for initializing learnable parameters.
    pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
                               inputs.shape)

    if not self.trainable_pos_embedding:
      pos_embedding = jax.lax.stop_gradient(pos_embedding)

    # Apply optional transformation on the position embedding.
    pos_embedding = self.pos_transform()(pos_embedding)  # pytype: disable=not-callable

    # Apply position encoding to inputs.
    if self.update_type == "project_add":
      # Here, we project the position encodings to the same dimensionality as
      # the inputs and add them to the inputs (broadcast along batch dimension).
      # This is roughly equivalent to concatenation of position encodings to the
      # inputs (if followed by a Dense layer), but is slightly more efficient.
      n_features = inputs.shape[-1]
      x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
    elif self.update_type == "concat":
      # Repeat the position embedding along the first (batch) dimension.
      pos_embedding = jnp.broadcast_to(
          pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
      # concatenate along the channel dimension.
      x = jnp.concatenate((inputs, pos_embedding), axis=-1)
    else:
      raise ValueError("Invalid update type provided.")

    # Apply optional output transformation.
    x = self.output_transform()(x)  # pytype: disable=not-callable
    return x


class RelativePositionEmbedding(nn.Module):
  """A module for applying embedding of input position relative to slots.

  Attr
    update_type: A string defining how the input is updated with the position
      embedding. One of ["proj_add", "concat"].
    embedding_type: A string defining the type of position embedding to use.
      Currently only "linear" is supported.
    num_fourier_bases: The number of Fourier bases to use. For embedding_type ==
      "fourier", the embedding dimensionality is 2 x number of position
      dimensions x num_fourier_bases. For embedding_type == "gaussian_fourier",
      the embedding dimensionality is 2 x num_fourier_bases. For embedding_type
      == "linear", this parameter is ignored.
    gaussian_sigma: Standard deviation of sampled Gaussians.
    pos_transform: Optional transform for the embedding.
    output_transform: Optional transform for the combined input and embedding.
    trainable_pos_embedding: Boolean flag for allowing gradients to flow into
      the position embedding, so that the optimizer can update it.
  """

  update_type: str
  embedding_type: str = "linear"
  num_fourier_bases: int = 0
  gaussian_sigma: float = 1.0
  pos_transform: Callable[[], nn.Module] = Identity
  output_transform: Callable[[], nn.Module] = Identity
  trainable_pos_embedding: bool = False
  scales_factor: float = 1.0

  def _make_pos_embedding_tensor(self, rng, input_shape):

    # A tensor grid in [-1, +1] for each input dimension.
    pos_embedding = utils.create_gradient_grid(input_shape[1:-1], [-1.0, 1.0])

    # Add batch dimension.
    pos_embedding = jnp.expand_dims(pos_embedding, axis=0)

    return pos_embedding

  @nn.compact
  def __call__(self, inputs, slot_positions,
               slot_scales = None,
               slot_rotm = None):

    # Compute the position embedding only in the initial call use the same rng
    # as is used for initializing learnable parameters.
    pos_embedding = self.param("pos_embedding", self._make_pos_embedding_tensor,
                               inputs.shape)

    if not self.trainable_pos_embedding:
      pos_embedding = jax.lax.stop_gradient(pos_embedding)

    # Relativize pos_embedding with respect to slot positions
    # and optionally slot scales.
    slot_positions = jnp.expand_dims(
        jnp.expand_dims(slot_positions, axis=-2), axis=-2)
    if slot_scales is not None:
      slot_scales = jnp.expand_dims(
          jnp.expand_dims(slot_scales, axis=-2), axis=-2)

    if self.embedding_type == "linear":
      pos_embedding = pos_embedding - slot_positions
      if slot_rotm is not None:
        pos_embedding = self.transform(slot_rotm, pos_embedding)
      if slot_scales is not None:
        # Scales are usually small so the grid might get too large.
        pos_embedding = pos_embedding / self.scales_factor
        pos_embedding = pos_embedding / slot_scales
    else:
      raise ValueError("Invalid embedding type provided.")

    # Apply optional transformation on the position embedding.
    pos_embedding = self.pos_transform()(pos_embedding)  # pytype: disable=not-callable

    # Define intermediate for logging.
    pos_embedding = Identity(name="pos_emb")(pos_embedding)

    # Apply position encoding to inputs.
    if self.update_type == "project_add":
      # Here, we project the position encodings to the same dimensionality as
      # the inputs and add them to the inputs (broadcast along batch dimension).
      # This is roughly equivalent to concatenation of position encodings to the
      # inputs (if followed by a Dense layer), but is slightly more efficient.
      n_features = inputs.shape[-1]
      x = inputs + nn.Dense(n_features, name="dense_pe_0")(pos_embedding)
    elif self.update_type == "concat":
      # Repeat the position embedding along the first (batch) dimension.
      pos_embedding = jnp.broadcast_to(
          pos_embedding, shape=inputs.shape[:-1] + pos_embedding.shape[-1:])
      # concatenate along the channel dimension.
      x = jnp.concatenate((inputs, pos_embedding), axis=-1)
    else:
      raise ValueError("Invalid update type provided.")

    # Apply optional output transformation.
    x = self.output_transform()(x)  # pytype: disable=not-callable
    return x

  @classmethod
  def transform(cls, rot, coords):
    # The coordinate grid coords is in the (y, x) format, so we need to swap
    # the coordinates on the input and output.
    coords = jnp.stack([coords[Ellipsis, 1], coords[Ellipsis, 0]], axis=-1)
    # Equivalent to inv(R) * coords^T = R^T * coords^T = (coords * R)^T.
    # We are multiplying by the inverse of the rotation matrix because
    # we are rotating the coordinate grid *against* the rotation of the object.
    new_coords = jnp.einsum("...hij,...jk->...hik", coords, rot)
    # Swap coordinates again.
    return jnp.stack([new_coords[Ellipsis, 1], new_coords[Ellipsis, 0]], axis=-1)