Spaces:
Running
on
T4
Running
on
T4
File size: 9,134 Bytes
85bd48b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function to stack repeats of a layer function without shared parameters."""
import collections
import contextlib
import functools
import inspect
from typing import Any, Callable, Optional, Tuple, Union
import haiku as hk
import jax
import jax.numpy as jnp
LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng'])
LayerStackScanned = collections.namedtuple('LayerStackScanned',
['i', 'args_ys'])
# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray = Any
WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]
def _check_no_varargs(f):
if list(inspect.signature(
f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:
raise ValueError(
'The function `f` should not have any `varargs` (that is *args) '
'argument. Instead, it should only use explicit positional'
'arguments.')
@contextlib.contextmanager
def nullcontext():
yield
def maybe_with_rng(key):
if key is not None:
return hk.with_rng(key)
else:
return nullcontext()
def maybe_fold_in(key, data):
if key is not None:
return jax.random.fold_in(key, data)
else:
return None
class _LayerStack(hk.Module):
"""Module to compose parameterized functions, implemented as a scan."""
def __init__(self,
count: int,
unroll: int,
name: Optional[str] = None):
"""Iterate a function `f` `count` times, with non-shared parameters."""
super().__init__(name=name)
self._count = count
self._unroll = unroll
def __call__(self, x, *args_ys):
count = self._count
if hk.running_init():
# At initialization time, we run just one layer but add an extra first
# dimension to every initialized tensor, making sure to use different
# random keys for different slices.
def creator(next_creator, shape, dtype, init, context):
del context
def multi_init(shape, dtype):
assert shape[0] == count
key = hk.maybe_next_rng_key()
def rng_context_init(slice_idx):
slice_key = maybe_fold_in(key, slice_idx)
with maybe_with_rng(slice_key):
return init(shape[1:], dtype)
return jax.vmap(rng_context_init)(jnp.arange(count))
return next_creator((count,) + tuple(shape), dtype, multi_init)
def getter(next_getter, value, context):
trailing_dims = len(context.original_shape) + 1
sliced_value = jax.lax.index_in_dim(
value, index=0, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_creator(
creator), hk.experimental.custom_getter(getter):
if len(args_ys) == 1 and args_ys[0] is None:
args0 = (None,)
else:
args0 = [
jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
for ys in args_ys
]
x, z = self._call_wrapped(x, *args0)
if z is None:
return x, z
# Broadcast state to hold each layer state.
def broadcast_state(layer_state):
return jnp.broadcast_to(
layer_state, [count,] + list(layer_state.shape))
zs = jax.tree_util.tree_map(broadcast_state, z)
return x, zs
else:
# Use scan during apply, threading through random seed so that it's
# unique for each layer.
def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
rng = carry.rng
def getter(next_getter, value, context):
# Getter slices the full param at the current loop index.
trailing_dims = len(context.original_shape) + 1
assert value.shape[value.ndim - trailing_dims] == count, (
f'Attempting to use a parameter stack of size '
f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
f'size {count}.')
sliced_value = jax.lax.dynamic_index_in_dim(
value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False)
return next_getter(sliced_value)
with hk.experimental.custom_getter(getter):
if rng is None:
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
else:
rng, rng_ = jax.random.split(rng)
with hk.with_rng(rng_):
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
return LayerStackCarry(x=out_x, rng=rng), z
carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
args_ys=args_ys)
carry, zs = hk.scan(
layer, carry, scanned, length=count, unroll=self._unroll)
return carry.x, zs
def _call_wrapped(self,
x: jnp.ndarray,
*args,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
raise NotImplementedError()
class _LayerStackNoState(_LayerStack):
"""_LayerStack impl with no per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
_check_no_varargs(f)
self._f = f
@hk.transparent
def _call_wrapped(self, args, y):
del y
ret = self._f(*args)
if len(args) == 1:
# If the function takes a single argument, the wrapped function receives
# a tuple of length 1, and therefore it must return a tuple of length 1.
ret = (ret,)
return ret, None
class _LayerStackWithState(_LayerStack):
"""_LayerStack impl with per-layer state provided to the function."""
def __init__(self,
f: WrappedFn,
count: int,
unroll: int,
name: Optional[str] = None):
super().__init__(count=count, unroll=unroll, name=name)
self._f = f
@hk.transparent
def _call_wrapped(self, x, *args):
return self._f(x, *args)
def layer_stack(num_layers: int,
with_state=False,
unroll: int = 1,
name: Optional[str] = None):
"""Utility to wrap a Haiku function and recursively apply it to an input.
A function is valid if it uses only explicit position parameters, and
its return type matches its input type. The position parameters can be
arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
that kwargs are not supported, neither are functions with variable number
of parameters (specified by `*args`).
If `with_state=False` then the new, wrapped function can be understood as
performing the following:
```
for i in range(num_layers):
x = f(x)
return x
```
And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
```
for i in range(num_layers):
x, zs[i] = f(x, ys_0[i], ys_1[i])
return x, zs
```
The code using `layer_stack` for the above function would be:
```
def f(x, y_0, y_1):
...
return new_x, z
x, zs = layer_stack.layer_stack(num_layers,
with_state=True)(f)(x, ys_0, ys_1)
```
Crucially, any parameters created inside `f` will not be shared across
iterations.
Args:
num_layers: The number of times to iterate the wrapped function.
with_state: Whether or not to pass per-layer state to the wrapped function.
unroll: the unroll used by `scan`.
name: Name of the Haiku context.
Returns:
Callable that will produce a layer stack when called with a valid function.
"""
def iterate(f):
if with_state:
@functools.wraps(f)
def wrapped(x, *args):
for ys in args:
assert ys.shape[0] == num_layers
return _LayerStackWithState(
f, num_layers, unroll=unroll, name=name)(x, *args)
else:
_check_no_varargs(f)
@functools.wraps(f)
def wrapped(*args):
ret = _LayerStackNoState(
f, num_layers, unroll=unroll, name=name)(args, None)[0]
if len(args) == 1:
# If the function takes a single argument, we must also return a
# single value, and not a tuple of length 1.
ret = ret[0]
return ret
return wrapped
return iterate
|