Spaces:
Running
Running
File size: 11,622 Bytes
a560c26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 |
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Initializers module library for equivariant slot attention."""
import functools
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
from flax import linen as nn
import jax
import jax.numpy as jnp
from invariant_slot_attention.lib import utils
Shape = Tuple[int]
DType = Any
Array = Any # jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]] # pytype: disable=not-supported-yet
ProcessorState = ArrayTree
PRNGKey = Array
NestedDict = Dict[str, Any]
def get_uniform_initializer(vmin, vmax):
"""Get an uniform initializer with an arbitrary range."""
init = nn.initializers.uniform(scale=vmax - vmin)
def fn(*args, **kwargs):
return init(*args, **kwargs) + vmin
return fn
def get_normal_initializer(mean, sd):
"""Get a normal initializer with an arbitrary mean."""
init = nn.initializers.normal(stddev=sd)
def fn(*args, **kwargs):
return init(*args, **kwargs) + mean
return fn
class ParamStateInitRandomPositions(nn.Module):
"""Fixed, learnable state initalization with random positions.
Random slot positions sampled from U[-1, 1] are concatenated
as the last two dimensions.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
param = self.param("state_init", init_fn(), self.shape)
out = utils.broadcast_across_batch(param, batch_size=batch_size)
shape = out.shape[:-1]
rng = self.make_rng("state_init")
slot_positions = jax.random.uniform(
rng, shape=[*shape, 2], minval=self.slot_positions_min,
maxval=self.slot_positions_max)
out = jnp.concatenate((out, slot_positions), axis=-1)
return out
class ParamStateInitLearnablePositions(nn.Module):
"""Fixed, learnable state initalization with learnable positions.
Learnable initial positions are concatenated at the end of slots.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn_state = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
init_fn_state = init_fn_state()
init_fn_pos = get_uniform_initializer(
self.slot_positions_min, self.slot_positions_max)
param_state = self.param("state_init", init_fn_state, self.shape)
param_pos = self.param(
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
param = jnp.concatenate((param_state, param_pos), axis=-1)
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
class ParamStateInitRandomPositionsScales(nn.Module):
"""Fixed, learnable state initalization with random positions and scales.
Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
are concatenated as the last four dimensions.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
slot_scales_mean: float = 0.1
slot_scales_sd: float = 0.1
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
param = self.param("state_init", init_fn(), self.shape)
out = utils.broadcast_across_batch(param, batch_size=batch_size)
shape = out.shape[:-1]
rng = self.make_rng("state_init")
slot_positions = jax.random.uniform(
rng, shape=[*shape, 2], minval=self.slot_positions_min,
maxval=self.slot_positions_max)
slot_scales = jax.random.normal(rng, shape=[*shape, 2])
slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
out = jnp.concatenate((out, slot_positions, slot_scales), axis=-1)
return out
class ParamStateInitLearnablePositionsScales(nn.Module):
"""Fixed, learnable state initalization with random positions and scales.
Lernable initial positions and scales are concatenated at the end of slots.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
slot_scales_mean: float = 0.1
slot_scales_sd: float = 0.01
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn_state = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
init_fn_state = init_fn_state()
init_fn_pos = get_uniform_initializer(
self.slot_positions_min, self.slot_positions_max)
init_fn_scales = get_normal_initializer(
self.slot_scales_mean, self.slot_scales_sd)
param_state = self.param("state_init", init_fn_state, self.shape)
param_pos = self.param(
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
param_scales = self.param(
"state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
param = jnp.concatenate((param_state, param_pos, param_scales), axis=-1)
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
class ParamStateInitLearnablePositionsRotationsScales(nn.Module):
"""Fixed, learnable state initalization.
Learnable initial positions, rotations and scales are concatenated
at the end of slots. The rotation matrix is flattened.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
slot_scales_mean: float = 0.1
slot_scales_sd: float = 0.01
slot_angles_mean: float = 0.
slot_angles_sd: float = 0.1
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn_state = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn_state = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
init_fn_state = init_fn_state()
init_fn_pos = get_uniform_initializer(
self.slot_positions_min, self.slot_positions_max)
init_fn_scales = get_normal_initializer(
self.slot_scales_mean, self.slot_scales_sd)
init_fn_angles = get_normal_initializer(
self.slot_angles_mean, self.slot_angles_sd)
param_state = self.param("state_init", init_fn_state, self.shape)
param_pos = self.param(
"state_init_position", init_fn_pos, (*self.shape[:-1], 2))
param_scales = self.param(
"state_init_scale", init_fn_scales, (*self.shape[:-1], 2))
param_angles = self.param(
"state_init_angles", init_fn_angles, (*self.shape[:-1], 1))
# Initial angles in the range of (-pi / 4, pi / 4) <=> (-45, 45) degrees.
angles = jnp.tanh(param_angles) * (jnp.pi / 4)
rotm = jnp.concatenate(
[jnp.cos(angles), jnp.sin(angles),
-jnp.sin(angles), jnp.cos(angles)], axis=-1)
param = jnp.concatenate(
(param_state, param_pos, param_scales, rotm), axis=-1)
return utils.broadcast_across_batch(param, batch_size=batch_size) # pytype: disable=bad-return-type # jax-ndarray
class ParamStateInitRandomPositionsRotationsScales(nn.Module):
"""Fixed, learnable state initialization with random pos., rot. and scales.
Random slot positions and scales sampled from U[-1, 1] and N(0.1, 0.1)
are concatenated as the last four dimensions. Rotations are sampled
from +- 45 degrees.
Note: This module ignores any conditional input (by design).
"""
shape: Sequence[int]
init_fn: str = "normal" # Default init with unit variance.
conditioning_key: Optional[str] = None
slot_positions_min: float = -1.
slot_positions_max: float = 1.
slot_scales_mean: float = 0.1
slot_scales_sd: float = 0.1
slot_angles_min: float = -jnp.pi / 4.
slot_angles_max: float = jnp.pi / 4.
@nn.compact
def __call__(self, inputs, batch_size,
train = False):
del inputs, train # Unused.
if self.init_fn == "normal":
init_fn = functools.partial(nn.initializers.normal, stddev=1.)
elif self.init_fn == "zeros":
init_fn = lambda: nn.initializers.zeros
else:
raise ValueError("Unknown init_fn: {}.".format(self.init_fn))
param = self.param("state_init", init_fn(), self.shape)
out = utils.broadcast_across_batch(param, batch_size=batch_size)
shape = out.shape[:-1]
rng = self.make_rng("state_init")
slot_positions = jax.random.uniform(
rng, shape=[*shape, 2], minval=self.slot_positions_min,
maxval=self.slot_positions_max)
rng = self.make_rng("state_init")
slot_scales = jax.random.normal(rng, shape=[*shape, 2])
slot_scales = self.slot_scales_mean + self.slot_scales_sd * slot_scales
rng = self.make_rng("state_init")
slot_angles = jax.random.uniform(rng, shape=[*shape, 1])
slot_angles = (slot_angles * (self.slot_angles_max - self.slot_angles_min)
) + self.slot_angles_min
slot_rotm = jnp.concatenate(
[jnp.cos(slot_angles), jnp.sin(slot_angles),
-jnp.sin(slot_angles), jnp.cos(slot_angles)], axis=-1)
out = jnp.concatenate(
(out, slot_positions, slot_scales, slot_rotm), axis=-1)
return out
|