|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
import functools |
|
from typing import Callable, Dict, List, TYPE_CHECKING |
|
|
|
import torch |
|
|
|
from ._internals import ( |
|
check_tensor, |
|
get_chunked_dim_size, |
|
get_split_size, |
|
validate_non_overlapping_shards_metadata |
|
) |
|
from torch.distributed._shard.metadata import ShardMetadata |
|
|
|
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta |
|
from torch.distributed._shard.op_registry_utils import _decorator_func |
|
|
|
if TYPE_CHECKING: |
|
|
|
|
|
from torch.distributed._shard.sharded_tensor import ShardedTensor |
|
|
|
class PlacementSpec(ABC): |
|
""" |
|
Base class representing the placement of an entity. Subclasses of this |
|
class can be used to specify customized placements which might not be |
|
covered by existing APIs. |
|
""" |
|
pass |
|
|
|
|
|
@dataclass |
|
class DevicePlacementSpec(PlacementSpec): |
|
""" |
|
Associates placement of an entity with a single device. |
|
|
|
Args: |
|
device(:class:`torch.distributed._remote_device`): The device to place the entity on. |
|
""" |
|
|
|
device: torch.distributed._remote_device |
|
|
|
def __post_init__(self): |
|
if not isinstance(self.device, torch.distributed._remote_device): |
|
self.device = torch.distributed._remote_device(self.device) |
|
|
|
class ShardingSpec(ABC): |
|
""" |
|
Base class representing sharding specifications. |
|
""" |
|
@abstractmethod |
|
def build_metadata(self, |
|
tensor_sizes: torch.Size, |
|
tensor_properties: sharded_tensor_meta.TensorProperties, |
|
) -> sharded_tensor_meta.ShardedTensorMetadata: |
|
""" |
|
Given a global tensor size, define how to shard a tensor like this shape |
|
across ranks, return ShardedTensorMetadata |
|
Args: |
|
tensor_sizes (:class:`torch.Size`): |
|
The tensor shape to shard on, a `torch.Size` object that represents the |
|
tensor shape to be sharded according to the ShardingSpec. |
|
tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties): |
|
Tensor properties used to create a ShardedTensor. |
|
Returns: |
|
A :class:`ShardedTensorMetadata` object that encodes the information about |
|
the layout of the ShardedTensor and its properties. |
|
""" |
|
|
|
@abstractmethod |
|
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": |
|
""" |
|
Given a global tensor on src_rank, shard this tensor |
|
across ranks within the process group, return a ShardedTensor. |
|
Args: |
|
tensor (:class:`torch.Tensor`): Tensor needs to be sharded. |
|
Keyword args: |
|
src_rank (int, optional): The source rank which is used as the ground truth of |
|
the data for the parameter that would be sharded and scattered |
|
across the rest of the ranks. |
|
Default: 0. |
|
process_group (ProcessGroup, optional): The process group to work on. If None, |
|
the default process group will be used. |
|
Returns: |
|
A :class:`ShardedTensor` sharded from the given tensor. |
|
""" |
|
|
|
|
|
_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} |
|
|
|
def _has_custom_op(sharding_spec, op): |
|
""" |
|
Returns whether or not the ShardingSpec has a custom op implementation. |
|
""" |
|
class_name = type(sharding_spec).__qualname__ |
|
return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] |
|
|
|
def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group): |
|
""" |
|
Calls the custom op for this ShardingSpec if it exists. |
|
""" |
|
class_name = type(sharding_spec).__qualname__ |
|
if not _has_custom_op(sharding_spec, op): |
|
raise RuntimeError(f'Custom op: {op} not registered for {class_name}') |
|
func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] |
|
return func(types, args, kwargs, process_group) |
|
|
|
def custom_sharding_spec_op(sharding_spec_class, func): |
|
""" |
|
Decorator to allow custom registration of ops. |
|
Args: |
|
sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. |
|
func(Callable): The op to override (ex: torch.bmm) |
|
""" |
|
class_name = sharding_spec_class.__qualname__ |
|
if class_name not in _CUSTOM_SHARDING_SPEC_OPS: |
|
_CUSTOM_SHARDING_SPEC_OPS[class_name] = {} |
|
return functools.partial( |
|
_decorator_func, |
|
op=func, |
|
op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] |
|
) |
|
|
|
|
|
@dataclass |
|
class EnumerableShardingSpec(ShardingSpec): |
|
""" |
|
This is a type of PlacementSpec that allows users to specify a generic |
|
sharding scheme by enumerating exactly how each shard is laid out. |
|
|
|
Args: |
|
shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing |
|
each shard. Note that none of the shards should overlap. |
|
""" |
|
|
|
shards: List[ShardMetadata] |
|
|
|
def __post_init__(self): |
|
if len(self.shards) == 0: |
|
raise ValueError(f'Empty shard list provided: {self.shards}') |
|
|
|
|
|
rank = -1 |
|
for shard in self.shards: |
|
if rank != -1 and rank != len(shard.shard_offsets): |
|
raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}') |
|
rank = len(shard.shard_offsets) |
|
|
|
validate_non_overlapping_shards_metadata(self.shards) |
|
|
|
def build_metadata(self, |
|
tensor_sizes: torch.Size, |
|
tensor_properties: sharded_tensor_meta.TensorProperties, |
|
) -> sharded_tensor_meta.ShardedTensorMetadata: |
|
|
|
check_tensor(self.shards, tensor_sizes) |
|
return sharded_tensor_meta.ShardedTensorMetadata( |
|
self.shards, |
|
tensor_sizes, |
|
tensor_properties |
|
) |
|
|
|
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": |
|
|
|
raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") |
|
|
|
|
|
def _infer_sharding_spec_from_shards_metadata(shards_metadata): |
|
""" |
|
Infer the sharding spec from the metadata of each shard of a ShardedTensor. |
|
If the tensor is sharded only on one dimension, we can then verify whether it's |
|
a ChunkShardingSpec or not. The way to verify it is to first get the total length |
|
and perform a chunk sharding with the given placements to see if we can have the |
|
same chunk size as the given shards_metadata. If not, we assume it's enum sharded. |
|
|
|
Args: |
|
shards_metadata (List[ShardMetadata]): List of Metadata of local shards. |
|
|
|
Returns: |
|
A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding |
|
spec for one sharded tensor. |
|
""" |
|
placements = [] |
|
chunk_sharding_dim = None |
|
chunk_offset_list = [] |
|
shard_size_list = [] |
|
|
|
for shard_metadata in shards_metadata: |
|
placements.append(shard_metadata.placement) |
|
local_offsets = shard_metadata.shard_offsets |
|
chunk_offset_list.append(sum(local_offsets)) |
|
shard_size_list.append(shard_metadata.shard_sizes) |
|
shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] |
|
|
|
|
|
if len(shard_dims) == 0: |
|
continue |
|
|
|
|
|
if len(shard_dims) != 1: |
|
chunk_sharding_dim = None |
|
break |
|
|
|
|
|
if not chunk_sharding_dim: |
|
chunk_sharding_dim = shard_dims[0] |
|
elif chunk_sharding_dim != shard_dims[0]: |
|
chunk_sharding_dim = None |
|
break |
|
|
|
if chunk_sharding_dim is not None: |
|
|
|
placements = [ |
|
x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0]) |
|
] |
|
|
|
from .chunk_sharding_spec import ChunkShardingSpec |
|
chunk_spec = ChunkShardingSpec( |
|
dim=chunk_sharding_dim, |
|
placements=placements, |
|
) |
|
shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) |
|
shard_total_length = sum(shard_sizes) |
|
chunks = len(placements) |
|
split_size = get_split_size(shard_total_length, chunks) |
|
chunk_shard_sizes = sorted( |
|
[ |
|
get_chunked_dim_size(shard_total_length, split_size, idx) |
|
for idx in range(len(placements)) |
|
] |
|
) |
|
if shard_sizes == chunk_shard_sizes: |
|
return chunk_spec |
|
return EnumerableShardingSpec(shards_metadata) |
|
|