File size: 9,134 Bytes
85bd48b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Function to stack repeats of a layer function without shared parameters."""

import collections
import contextlib
import functools
import inspect
from typing import Any, Callable, Optional, Tuple, Union

import haiku as hk
import jax
import jax.numpy as jnp

LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng'])
LayerStackScanned = collections.namedtuple('LayerStackScanned',
                                           ['i', 'args_ys'])

# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray = Any
WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]


def _check_no_varargs(f):
  if list(inspect.signature(
      f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:
    raise ValueError(
        'The function `f` should not have any `varargs` (that is *args) '
        'argument. Instead, it should only use explicit positional'
        'arguments.')


@contextlib.contextmanager
def nullcontext():
  yield


def maybe_with_rng(key):
  if key is not None:
    return hk.with_rng(key)
  else:
    return nullcontext()


def maybe_fold_in(key, data):
  if key is not None:
    return jax.random.fold_in(key, data)
  else:
    return None


class _LayerStack(hk.Module):
  """Module to compose parameterized functions, implemented as a scan."""

  def __init__(self,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    """Iterate a function `f` `count` times, with non-shared parameters."""
    super().__init__(name=name)
    self._count = count
    self._unroll = unroll

  def __call__(self, x, *args_ys):
    count = self._count
    if hk.running_init():
      # At initialization time, we run just one layer but add an extra first
      # dimension to every initialized tensor, making sure to use different
      # random keys for different slices.
      def creator(next_creator, shape, dtype, init, context):
        del context

        def multi_init(shape, dtype):
          assert shape[0] == count
          key = hk.maybe_next_rng_key()

          def rng_context_init(slice_idx):
            slice_key = maybe_fold_in(key, slice_idx)
            with maybe_with_rng(slice_key):
              return init(shape[1:], dtype)

          return jax.vmap(rng_context_init)(jnp.arange(count))

        return next_creator((count,) + tuple(shape), dtype, multi_init)

      def getter(next_getter, value, context):
        trailing_dims = len(context.original_shape) + 1
        sliced_value = jax.lax.index_in_dim(
            value, index=0, axis=value.ndim - trailing_dims, keepdims=False)
        return next_getter(sliced_value)

      with hk.experimental.custom_creator(
          creator), hk.experimental.custom_getter(getter):
        if len(args_ys) == 1 and args_ys[0] is None:
          args0 = (None,)
        else:
          args0 = [
              jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
              for ys in args_ys
          ]
        x, z = self._call_wrapped(x, *args0)
        if z is None:
          return x, z

        # Broadcast state to hold each layer state.
        def broadcast_state(layer_state):
          return jnp.broadcast_to(
              layer_state, [count,] + list(layer_state.shape))
        zs = jax.tree_util.tree_map(broadcast_state, z)
        return x, zs
    else:
      # Use scan during apply, threading through random seed so that it's
      # unique for each layer.
      def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
        rng = carry.rng

        def getter(next_getter, value, context):
          # Getter slices the full param at the current loop index.
          trailing_dims = len(context.original_shape) + 1
          assert value.shape[value.ndim - trailing_dims] == count, (
              f'Attempting to use a parameter stack of size '
              f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
              f'size {count}.')

          sliced_value = jax.lax.dynamic_index_in_dim(
              value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False)
          return next_getter(sliced_value)

        with hk.experimental.custom_getter(getter):
          if rng is None:
            out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
          else:
            rng, rng_ = jax.random.split(rng)
            with hk.with_rng(rng_):
              out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
        return LayerStackCarry(x=out_x, rng=rng), z

      carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
      scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
                                  args_ys=args_ys)

      carry, zs = hk.scan(
          layer, carry, scanned, length=count, unroll=self._unroll)
      return carry.x, zs

  def _call_wrapped(self,
                    x: jnp.ndarray,
                    *args,
                    ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    raise NotImplementedError()


class _LayerStackNoState(_LayerStack):
  """_LayerStack impl with no per-layer state provided to the function."""

  def __init__(self,
               f: WrappedFn,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    super().__init__(count=count, unroll=unroll, name=name)
    _check_no_varargs(f)
    self._f = f

  @hk.transparent
  def _call_wrapped(self, args, y):
    del y
    ret = self._f(*args)
    if len(args) == 1:
      # If the function takes a single argument, the wrapped function receives
      # a tuple of length 1, and therefore it must return a tuple of length 1.
      ret = (ret,)
    return ret, None


class _LayerStackWithState(_LayerStack):
  """_LayerStack impl with per-layer state provided to the function."""

  def __init__(self,
               f: WrappedFn,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    super().__init__(count=count, unroll=unroll, name=name)
    self._f = f

  @hk.transparent
  def _call_wrapped(self, x, *args):
    return self._f(x, *args)


def layer_stack(num_layers: int,
                with_state=False,
                unroll: int = 1,
                name: Optional[str] = None):
  """Utility to wrap a Haiku function and recursively apply it to an input.

  A function is valid if it uses only explicit position parameters, and
  its return type matches its input type. The position parameters can be
  arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
  that kwargs are not supported, neither are functions with variable number
  of parameters (specified by `*args`).

  If `with_state=False` then the new, wrapped function can be understood as
  performing the following:
  ```
  for i in range(num_layers):
    x = f(x)
  return x
  ```

  And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
  ```
  for i in range(num_layers):
    x, zs[i] = f(x, ys_0[i], ys_1[i])
  return x, zs
  ```
  The code using `layer_stack` for the above function would be:
  ```
  def f(x, y_0, y_1):
    ...
    return new_x, z
  x, zs = layer_stack.layer_stack(num_layers,
                                  with_state=True)(f)(x, ys_0, ys_1)
  ```

  Crucially, any parameters created inside `f` will not be shared across
  iterations.

  Args:
    num_layers: The number of times to iterate the wrapped function.
    with_state: Whether or not to pass per-layer state to the wrapped function.
    unroll: the unroll used by `scan`.
    name: Name of the Haiku context.

  Returns:
    Callable that will produce a layer stack when called with a valid function.
  """
  def iterate(f):
    if with_state:
      @functools.wraps(f)
      def wrapped(x, *args):
        for ys in args:
          assert ys.shape[0] == num_layers
        return _LayerStackWithState(
            f, num_layers, unroll=unroll, name=name)(x, *args)
    else:
      _check_no_varargs(f)
      @functools.wraps(f)
      def wrapped(*args):
        ret = _LayerStackNoState(
            f, num_layers, unroll=unroll, name=name)(args, None)[0]
        if len(args) == 1:
          # If the function takes a single argument, we must also return a
          # single value, and not a tuple of length 1.
          ret = ret[0]
        return ret

    return wrapped
  return iterate