Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Specialized mapping functions.""" | |
import functools | |
from typing import Any, Callable, Optional, Sequence, Union | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
PYTREE = Any | |
PYTREE_JAX_ARRAY = Any | |
partial = functools.partial | |
PROXY = object() | |
def _maybe_slice(array, i, slice_size, axis): | |
if axis is PROXY: | |
return array | |
else: | |
return jax.lax.dynamic_slice_in_dim( | |
array, i, slice_size=slice_size, axis=axis) | |
def _maybe_get_size(array, axis): | |
if axis == PROXY: | |
return -1 | |
else: | |
return array.shape[axis] | |
def _expand_axes(axes, values, name='sharded_apply'): | |
values_tree_def = jax.tree_flatten(values)[1] | |
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) | |
# Replace None's with PROXY | |
flat_axes = [PROXY if x is None else x for x in flat_axes] | |
return jax.tree_unflatten(values_tree_def, flat_axes) | |
def sharded_map( | |
fun: Callable[..., PYTREE_JAX_ARRAY], | |
shard_size: Union[int, None] = 1, | |
in_axes: Union[int, PYTREE] = 0, | |
out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: | |
"""Sharded vmap. | |
Maps `fun` over axes, in a way similar to vmap, but does so in shards of | |
`shard_size`. This allows a smooth trade-off between memory usage | |
(as in a plain map) vs higher throughput (as in a vmap). | |
Args: | |
fun: Function to apply smap transform to. | |
shard_size: Integer denoting shard size. | |
in_axes: Either integer or pytree describing which axis to map over for each | |
input to `fun`, None denotes broadcasting. | |
out_axes: integer or pytree denoting to what axis in the output the mapped | |
over axis maps. | |
Returns: | |
function with smap applied. | |
""" | |
vmapped_fun = hk.vmap(fun, in_axes, out_axes) | |
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) | |
def sharded_apply( | |
fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic | |
shard_size: Union[int, None] = 1, | |
in_axes: Union[int, PYTREE] = 0, | |
out_axes: Union[int, PYTREE] = 0, | |
new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: | |
"""Sharded apply. | |
Applies `fun` over shards to axes, in a way similar to vmap, | |
but does so in shards of `shard_size`. Shards are stacked after. | |
This allows a smooth trade-off between | |
memory usage (as in a plain map) vs higher throughput (as in a vmap). | |
Args: | |
fun: Function to apply smap transform to. | |
shard_size: Integer denoting shard size. | |
in_axes: Either integer or pytree describing which axis to map over for each | |
input to `fun`, None denotes broadcasting. | |
out_axes: integer or pytree denoting to what axis in the output the mapped | |
over axis maps. | |
new_out_axes: whether to stack outputs on new axes. This assumes that the | |
output sizes for each shard (including the possible remainder shard) are | |
the same. | |
Returns: | |
function with smap applied. | |
""" | |
docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' | |
'but with additional array axes over which {fun} is mapped.') | |
if new_out_axes: | |
raise NotImplementedError('New output axes not yet implemented.') | |
# shard size None denotes no sharding | |
if shard_size is None: | |
return fun | |
def mapped_fn(*args): | |
# Expand in axes and Determine Loop range | |
in_axes_ = _expand_axes(in_axes, args) | |
in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_) | |
flat_sizes = jax.tree_flatten(in_sizes)[0] | |
in_size = max(flat_sizes) | |
assert all(i in {in_size, -1} for i in flat_sizes) | |
num_extra_shards = (in_size - 1) // shard_size | |
# Fix Up if necessary | |
last_shard_size = in_size % shard_size | |
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size | |
def apply_fun_to_slice(slice_start, slice_size): | |
input_slice = jax.tree_util.tree_map( | |
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis | |
), args, in_axes_) | |
return fun(*input_slice) | |
remainder_shape_dtype = hk.eval_shape( | |
partial(apply_fun_to_slice, 0, last_shard_size)) | |
out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) | |
out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) | |
out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) | |
if num_extra_shards > 0: | |
regular_shard_shape_dtype = hk.eval_shape( | |
partial(apply_fun_to_slice, 0, shard_size)) | |
shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) | |
def make_output_shape(axis, shard_shape, remainder_shape): | |
return shard_shape[:axis] + ( | |
shard_shape[axis] * num_extra_shards + | |
remainder_shape[axis],) + shard_shape[axis + 1:] | |
out_shapes = jax.tree_util.tree_map(make_output_shape, out_axes_, shard_shapes, | |
out_shapes) | |
# Calls dynamic Update slice with different argument order | |
# This is here since tree_multimap only works with positional arguments | |
def dynamic_update_slice_in_dim(full_array, update, axis, i): | |
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) | |
def compute_shard(outputs, slice_start, slice_size): | |
slice_out = apply_fun_to_slice(slice_start, slice_size) | |
update_slice = partial( | |
dynamic_update_slice_in_dim, i=slice_start) | |
return jax.tree_util.tree_map(update_slice, outputs, slice_out, out_axes_) | |
def scan_iteration(outputs, i): | |
new_outputs = compute_shard(outputs, i, shard_size) | |
return new_outputs, () | |
slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) | |
def allocate_buffer(dtype, shape): | |
return jnp.zeros(shape, dtype=dtype) | |
outputs = jax.tree_util.tree_map(allocate_buffer, out_dtypes, out_shapes) | |
if slice_starts.shape[0] > 0: | |
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) | |
if last_shard_size != shard_size: | |
remainder_start = in_size - last_shard_size | |
outputs = compute_shard(outputs, remainder_start, last_shard_size) | |
return outputs | |
return mapped_fn | |
def inference_subbatch( | |
module: Callable[..., PYTREE_JAX_ARRAY], | |
subbatch_size: int, | |
batched_args: Sequence[PYTREE_JAX_ARRAY], | |
nonbatched_args: Sequence[PYTREE_JAX_ARRAY], | |
low_memory: bool = True, | |
input_subbatch_dim: int = 0, | |
output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: | |
"""Run through subbatches (like batch apply but with split and concat).""" | |
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test | |
if not low_memory: | |
args = list(batched_args) + list(nonbatched_args) | |
return module(*args) | |
if output_subbatch_dim is None: | |
output_subbatch_dim = input_subbatch_dim | |
def run_module(*batched_args): | |
args = list(batched_args) + list(nonbatched_args) | |
return module(*args) | |
sharded_module = sharded_apply(run_module, | |
shard_size=subbatch_size, | |
in_axes=input_subbatch_dim, | |
out_axes=output_subbatch_dim) | |
return sharded_module(*batched_args) | |