UMMJ's picture
Upload 5875 files
9dd3461
from abc import ABC, abstractmethod
from dataclasses import dataclass
import functools
from typing import Callable, Dict, List, TYPE_CHECKING
import torch
from ._internals import (
check_tensor,
get_chunked_dim_size,
get_split_size,
validate_non_overlapping_shards_metadata
)
from torch.distributed._shard.metadata import ShardMetadata
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
from torch.distributed._shard.op_registry_utils import _decorator_func
if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
class PlacementSpec(ABC):
"""
Base class representing the placement of an entity. Subclasses of this
class can be used to specify customized placements which might not be
covered by existing APIs.
"""
pass
@dataclass
class DevicePlacementSpec(PlacementSpec):
"""
Associates placement of an entity with a single device.
Args:
device(:class:`torch.distributed._remote_device`): The device to place the entity on.
"""
device: torch.distributed._remote_device
def __post_init__(self):
if not isinstance(self.device, torch.distributed._remote_device):
self.device = torch.distributed._remote_device(self.device)
class ShardingSpec(ABC):
"""
Base class representing sharding specifications.
"""
@abstractmethod
def build_metadata(self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
"""
Given a global tensor size, define how to shard a tensor like this shape
across ranks, return ShardedTensorMetadata
Args:
tensor_sizes (:class:`torch.Size`):
The tensor shape to shard on, a `torch.Size` object that represents the
tensor shape to be sharded according to the ShardingSpec.
tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties):
Tensor properties used to create a ShardedTensor.
Returns:
A :class:`ShardedTensorMetadata` object that encodes the information about
the layout of the ShardedTensor and its properties.
"""
@abstractmethod
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
"""
Given a global tensor on src_rank, shard this tensor
across ranks within the process group, return a ShardedTensor.
Args:
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
A :class:`ShardedTensor` sharded from the given tensor.
"""
# Ops customized for a particular ShardingSpec.
_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
def _has_custom_op(sharding_spec, op):
"""
Returns whether or not the ShardingSpec has a custom op implementation.
"""
class_name = type(sharding_spec).__qualname__
return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name]
def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group):
"""
Calls the custom op for this ShardingSpec if it exists.
"""
class_name = type(sharding_spec).__qualname__
if not _has_custom_op(sharding_spec, op):
raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op]
return func(types, args, kwargs, process_group)
def custom_sharding_spec_op(sharding_spec_class, func):
"""
Decorator to allow custom registration of ops.
Args:
sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
func(Callable): The op to override (ex: torch.bmm)
"""
class_name = sharding_spec_class.__qualname__
if class_name not in _CUSTOM_SHARDING_SPEC_OPS:
_CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
return functools.partial(
_decorator_func,
op=func,
op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name]
)
@dataclass
class EnumerableShardingSpec(ShardingSpec):
"""
This is a type of PlacementSpec that allows users to specify a generic
sharding scheme by enumerating exactly how each shard is laid out.
Args:
shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
each shard. Note that none of the shards should overlap.
"""
shards: List[ShardMetadata]
def __post_init__(self):
if len(self.shards) == 0:
raise ValueError(f'Empty shard list provided: {self.shards}')
# Validate each shard has same rank.
rank = -1
for shard in self.shards:
if rank != -1 and rank != len(shard.shard_offsets):
raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
rank = len(shard.shard_offsets)
validate_non_overlapping_shards_metadata(self.shards)
def build_metadata(self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
# check if shards form a valid tensor
check_tensor(self.shards, tensor_sizes)
return sharded_tensor_meta.ShardedTensorMetadata(
self.shards,
tensor_sizes,
tensor_properties
)
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
# TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec
raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!")
def _infer_sharding_spec_from_shards_metadata(shards_metadata):
"""
Infer the sharding spec from the metadata of each shard of a ShardedTensor.
If the tensor is sharded only on one dimension, we can then verify whether it's
a ChunkShardingSpec or not. The way to verify it is to first get the total length
and perform a chunk sharding with the given placements to see if we can have the
same chunk size as the given shards_metadata. If not, we assume it's enum sharded.
Args:
shards_metadata (List[ShardMetadata]): List of Metadata of local shards.
Returns:
A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding
spec for one sharded tensor.
"""
placements = []
chunk_sharding_dim = None
chunk_offset_list = []
shard_size_list = []
# collect local shard metadatas from the global sharded_tensor_metadata
for shard_metadata in shards_metadata: # type: ignore[attr-defined]
placements.append(shard_metadata.placement)
local_offsets = shard_metadata.shard_offsets
chunk_offset_list.append(sum(local_offsets))
shard_size_list.append(shard_metadata.shard_sizes)
shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
# If the offset is [0, 0, ..., 0] (all zeros),
# we cannot decide whether how the tensor is sharded.
if len(shard_dims) == 0:
continue
# If the offset is [0, N, .,0, M, 0, .., 0],
# we are sure it's sharded by more than one dimension.
if len(shard_dims) != 1:
chunk_sharding_dim = None
break
# If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just
# one dimension, we need to make sure all ranks share the same dimension.
if not chunk_sharding_dim:
chunk_sharding_dim = shard_dims[0]
elif chunk_sharding_dim != shard_dims[0]:
chunk_sharding_dim = None
break
if chunk_sharding_dim is not None:
# Ensure we infer the correct placement order from offsets
placements = [
x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
]
from .chunk_sharding_spec import ChunkShardingSpec
chunk_spec = ChunkShardingSpec(
dim=chunk_sharding_dim,
placements=placements,
)
shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
shard_total_length = sum(shard_sizes)
chunks = len(placements)
split_size = get_split_size(shard_total_length, chunks)
chunk_shard_sizes = sorted(
[
get_chunked_dim_size(shard_total_length, split_size, idx)
for idx in range(len(placements))
]
)
if shard_sizes == chunk_shard_sizes:
return chunk_spec
return EnumerableShardingSpec(shards_metadata)