File size: 7,763 Bytes
85bd48b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialized mapping functions."""

import functools

from typing import Any, Callable, Optional, Sequence, Union

import haiku as hk
import jax
import jax.numpy as jnp


PYTREE = Any
PYTREE_JAX_ARRAY = Any

partial = functools.partial
PROXY = object()


def _maybe_slice(array, i, slice_size, axis):
  if axis is PROXY:
    return array
  else:
    return jax.lax.dynamic_slice_in_dim(
        array, i, slice_size=slice_size, axis=axis)


def _maybe_get_size(array, axis):
  if axis == PROXY:
    return -1
  else:
    return array.shape[axis]


def _expand_axes(axes, values, name='sharded_apply'):
  values_tree_def = jax.tree_flatten(values)[1]
  flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
  # Replace None's with PROXY
  flat_axes = [PROXY if x is None else x for x in flat_axes]
  return jax.tree_unflatten(values_tree_def, flat_axes)


def sharded_map(
    fun: Callable[..., PYTREE_JAX_ARRAY],
    shard_size: Union[int, None] = 1,
    in_axes: Union[int, PYTREE] = 0,
    out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
  """Sharded vmap.

  Maps `fun` over axes, in a way similar to vmap, but does so in shards of
  `shard_size`. This allows a smooth trade-off between memory usage
  (as in a plain map) vs higher throughput (as in a vmap).

  Args:
    fun: Function to apply smap transform to.
    shard_size: Integer denoting shard size.
    in_axes: Either integer or pytree describing which axis to map over for each
      input to `fun`, None denotes broadcasting.
    out_axes: integer or pytree denoting to what axis in the output the mapped
      over axis maps.

  Returns:
    function with smap applied.
  """
  vmapped_fun = hk.vmap(fun, in_axes, out_axes)
  return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)


def sharded_apply(
    fun: Callable[..., PYTREE_JAX_ARRAY],  # pylint: disable=g-bare-generic
    shard_size: Union[int, None] = 1,
    in_axes: Union[int, PYTREE] = 0,
    out_axes: Union[int, PYTREE] = 0,
    new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
  """Sharded apply.

  Applies `fun` over shards to axes, in a way similar to vmap,
  but does so in shards of `shard_size`. Shards are stacked after.
  This allows a smooth trade-off between
  memory usage (as in a plain map) vs higher throughput (as in a vmap).

  Args:
    fun: Function to apply smap transform to.
    shard_size: Integer denoting shard size.
    in_axes: Either integer or pytree describing which axis to map over for each
      input to `fun`, None denotes broadcasting.
    out_axes: integer or pytree denoting to what axis in the output the mapped
      over axis maps.
    new_out_axes: whether to stack outputs on new axes. This assumes that the
      output sizes for each shard (including the possible remainder shard) are
      the same.

  Returns:
    function with smap applied.
  """
  docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
            'but with additional array axes over which {fun} is mapped.')
  if new_out_axes:
    raise NotImplementedError('New output axes not yet implemented.')

  # shard size None denotes no sharding
  if shard_size is None:
    return fun

  @jax.util.wraps(fun, docstr=docstr)
  def mapped_fn(*args):
    # Expand in axes and Determine Loop range
    in_axes_ = _expand_axes(in_axes, args)

    in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_)
    flat_sizes = jax.tree_flatten(in_sizes)[0]
    in_size = max(flat_sizes)
    assert all(i in {in_size, -1} for i in flat_sizes)

    num_extra_shards = (in_size - 1) // shard_size

    # Fix Up if necessary
    last_shard_size = in_size % shard_size
    last_shard_size = shard_size if last_shard_size == 0 else last_shard_size

    def apply_fun_to_slice(slice_start, slice_size):
      input_slice = jax.tree_util.tree_map(
          lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
                                          ), args, in_axes_)
      return fun(*input_slice)

    remainder_shape_dtype = hk.eval_shape(
        partial(apply_fun_to_slice, 0, last_shard_size))
    out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
    out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
    out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)

    if num_extra_shards > 0:
      regular_shard_shape_dtype = hk.eval_shape(
          partial(apply_fun_to_slice, 0, shard_size))
      shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)

      def make_output_shape(axis, shard_shape, remainder_shape):
        return shard_shape[:axis] + (
            shard_shape[axis] * num_extra_shards +
            remainder_shape[axis],) + shard_shape[axis + 1:]

      out_shapes = jax.tree_util.tree_map(make_output_shape, out_axes_, shard_shapes,
                                     out_shapes)

    # Calls dynamic Update slice with different argument order
    # This is here since tree_multimap only works with positional arguments
    def dynamic_update_slice_in_dim(full_array, update, axis, i):
      return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)

    def compute_shard(outputs, slice_start, slice_size):
      slice_out = apply_fun_to_slice(slice_start, slice_size)
      update_slice = partial(
          dynamic_update_slice_in_dim, i=slice_start)
      return jax.tree_util.tree_map(update_slice, outputs, slice_out, out_axes_)

    def scan_iteration(outputs, i):
      new_outputs = compute_shard(outputs, i, shard_size)
      return new_outputs, ()

    slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)

    def allocate_buffer(dtype, shape):
      return jnp.zeros(shape, dtype=dtype)

    outputs = jax.tree_util.tree_map(allocate_buffer, out_dtypes, out_shapes)

    if slice_starts.shape[0] > 0:
      outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)

    if last_shard_size != shard_size:
      remainder_start = in_size - last_shard_size
      outputs = compute_shard(outputs, remainder_start, last_shard_size)

    return outputs

  return mapped_fn


def inference_subbatch(
    module: Callable[..., PYTREE_JAX_ARRAY],
    subbatch_size: int,
    batched_args: Sequence[PYTREE_JAX_ARRAY],
    nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
    low_memory: bool = True,
    input_subbatch_dim: int = 0,
    output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
  """Run through subbatches (like batch apply but with split and concat)."""
  assert len(batched_args) > 0  # pylint: disable=g-explicit-length-test

  if not low_memory:
    args = list(batched_args) + list(nonbatched_args)
    return module(*args)

  if output_subbatch_dim is None:
    output_subbatch_dim = input_subbatch_dim

  def run_module(*batched_args):
    args = list(batched_args) + list(nonbatched_args)
    return module(*args)
  sharded_module = sharded_apply(run_module,
                                 shard_size=subbatch_size,
                                 in_axes=input_subbatch_dim,
                                 out_axes=output_subbatch_dim)
  return sharded_module(*batched_args)