|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Big vision sharding utilities.""" |
|
|
|
from absl import logging |
|
|
|
from big_vision.pp.registry import Registry |
|
import big_vision.utils as u |
|
import flax.linen as nn |
|
import jax |
|
import numpy as np |
|
|
|
|
|
NamedSharding = jax.sharding.NamedSharding |
|
P = jax.sharding.PartitionSpec |
|
|
|
|
|
def _replicated(mesh): |
|
return NamedSharding(mesh, P()) |
|
|
|
|
|
def _shard_along_axis(mesh, i, axis_name): |
|
return NamedSharding(mesh, P(*((None,) * i + (axis_name,)))) |
|
|
|
|
|
def infer_sharding(params, strategy, mesh): |
|
"""Infers `params` sharding based on strategy. |
|
|
|
Args: |
|
params: a pytree of arrays. |
|
strategy: sharding strategy. |
|
mesh: jax device mesh. |
|
|
|
Returns: |
|
A pytree with shardings, that has the same shape as the `tree` argument. |
|
""" |
|
patterns, tactics = zip(*strategy) |
|
|
|
x_with_names, tree_def = u.tree_flatten_with_names(params) |
|
names = tree_def.unflatten(list(zip(*x_with_names))[0]) |
|
|
|
|
|
|
|
mask_trees = u.make_mask_trees(params, patterns) |
|
|
|
specs = jax.tree.map(lambda x: (None,) * x.ndim, params) |
|
|
|
for mask_tree, tactic in zip(mask_trees, tactics): |
|
for op_str in tactic.split("|"): |
|
op = Registry.lookup(f"shardings.{op_str}")() |
|
specs = jax.tree.map( |
|
lambda x, n, match, spec, op=op: op(spec, mesh, n, x) |
|
if match else spec, |
|
params, names, mask_tree, specs, |
|
is_leaf=lambda v: isinstance(v, nn.Partitioned)) |
|
|
|
|
|
specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs) |
|
return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Registry.register("shardings.replicate") |
|
def replicate(): |
|
"""Full replication sharding rule. |
|
|
|
Note full replication is deafult, so this can be skipped and useful to |
|
explicitly state in the config that certrain parameters are replicated. |
|
TODO: can be generalized to support replication over a sub-mesh. |
|
|
|
Returns: |
|
A function that updates the sharding spec. |
|
""" |
|
def _update_spec(cur_spec, mesh, name, x): |
|
del x, mesh |
|
if not all(axis is None for axis in cur_spec): |
|
raise ValueError(f"Inconsistent sharding instructions: " |
|
f"parameter {name} has spec {cur_spec}, " |
|
f"so it can't be fully replicated.") |
|
return cur_spec |
|
return _update_spec |
|
|
|
|
|
@Registry.register("shardings.fsdp") |
|
def fsdp(axis, min_size_to_shard_mb=4): |
|
"""FSDP sharding rule. |
|
|
|
Shards the largest dimension that is not sharded already and is divisible |
|
by the total device count. |
|
|
|
Args: |
|
axis: mesh axis name for FSDP, or a collection of names. |
|
min_size_to_shard_mb: minimal tensor size to bother with sharding. |
|
|
|
Returns: |
|
A function that updates the sharding spec. |
|
""" |
|
axis = axis if isinstance(axis, str) else tuple(axis) |
|
axis_tuple = axis if isinstance(axis, tuple) else (axis,) |
|
def _update_spec(cur_spec, mesh, name, x): |
|
shape = x.shape |
|
axis_size = np.prod([mesh.shape[a] for a in axis_tuple]) |
|
|
|
if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20): |
|
return cur_spec |
|
|
|
|
|
idx = np.argsort(shape)[::-1] |
|
for i in idx: |
|
if shape[i] % axis_size == 0: |
|
if cur_spec[i] is None: |
|
return cur_spec[:i] + (axis,) + cur_spec[i+1:] |
|
|
|
logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all " |
|
"its dimensions are not divisible by the requested axis: " |
|
"%s:%i, or already occupied by other sharding rules: %s", |
|
name, shape, axis, axis_size, cur_spec) |
|
return cur_spec |
|
return _update_spec |
|
|
|
|
|
@Registry.register("shardings.logical_partitioning") |
|
def logical_partitioning(): |
|
"""Manual sharding based on Flax's logical partitioning annotations. |
|
|
|
Uses logical sharding annotations added in model code with |
|
`nn.with_logical_partitioning`. Respects logical to mesh name mapping rules |
|
(typically defined in the dynamic context using |
|
`with nn.logical_axis_rules(rules): ...`). |
|
|
|
Returns: |
|
A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed |
|
specs. |
|
""" |
|
def _update_spec(cur_spec, mesh, name, x): |
|
del x, name, mesh |
|
if isinstance(cur_spec, nn.LogicallyPartitioned): |
|
return nn.logical_to_mesh_axes(cur_spec.names) |
|
return cur_spec |
|
return _update_spec |
|
|
|
|
|
@Registry.register("shardings.shard_dim") |
|
def shard_dim(axis, dim, ignore_ndim_error=False): |
|
"""Shards the given dimension along the given axis. |
|
|
|
Args: |
|
axis: mesh axis name for sharding. |
|
dim: dimension to shard (can be negative). |
|
ignore_ndim_error: if True, a warning error is logged instead of raising an |
|
exception when the given dimension is not compatible with the number of |
|
dimensions of the array. |
|
|
|
Returns: |
|
A function that updates the sharding spec. |
|
""" |
|
def _update_spec(cur_spec, mesh, name, x): |
|
del mesh, x |
|
if np.abs(dim) >= len(cur_spec): |
|
msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}" |
|
if ignore_ndim_error: |
|
logging.warning(msg) |
|
return cur_spec |
|
else: |
|
raise ValueError(msg) |
|
pos_dim = dim |
|
if pos_dim < 0: |
|
pos_dim += len(cur_spec) |
|
if cur_spec[pos_dim] is not None: |
|
raise ValueError( |
|
f"Already sharded: shard_dim({axis}, {dim}):" |
|
f" name={name} cur_spec={cur_spec}" |
|
) |
|
new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :] |
|
return new_spec |
|
|
|
return _update_spec |
|
|