|
from typing import List |
|
|
|
from torch.distributed._shard.metadata import ShardMetadata |
|
|
|
def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): |
|
""" |
|
Checks if two shards overlap. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
ndims = len(shard1.shard_offsets) |
|
for i in range(ndims): |
|
if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: |
|
return False |
|
if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: |
|
return False |
|
|
|
return True |
|
|
|
def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): |
|
""" |
|
Ensures none of the shards overlap with each other. |
|
|
|
Args: |
|
shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing |
|
each shard. |
|
Raises: |
|
``ValueError`` if there's overlap in any two shards. |
|
""" |
|
|
|
for i in range(len(shards)): |
|
for j in range(i + 1, len(shards)): |
|
if _check_shard_metadata_pair_overlap(shards[i], shards[j]): |
|
raise ValueError(f'Shards {shards[i]} and {shards[j]} overlap') |
|
|
|
|
|
def check_tensor(shards_metadata, tensor_dims) -> None: |
|
""" |
|
Checks if the shards_metadata is compatible with the provided tensor dims. |
|
|
|
Args: |
|
shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` |
|
objects representing each shard of the tensor. |
|
tensor_dims(Sequence of int): Dimensions of tensor to verify |
|
Raises: |
|
``ValueError`` if not compatible. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
tensor_rank = len(tensor_dims) |
|
shards_rank = len(shards_metadata[0].shard_offsets) |
|
if tensor_rank != shards_rank: |
|
raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}') |
|
|
|
total_shard_volume = 0 |
|
for shard in shards_metadata: |
|
shard_volume = 1 |
|
for i, shard_length in enumerate(shard.shard_sizes): |
|
shard_volume *= shard_length |
|
if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: |
|
raise ValueError( |
|
f'Shard offset {shard.shard_offsets[i]} and length ' |
|
f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}') |
|
total_shard_volume += shard_volume |
|
|
|
tensor_volume = 1 |
|
for size in tensor_dims: |
|
tensor_volume *= size |
|
|
|
if total_shard_volume != tensor_volume: |
|
|
|
raise ValueError( |
|
f'Total volume of shards: {total_shard_volume} ' |
|
f'does not match tensor volume: {tensor_volume}, in other words ' |
|
f'all the individual shards do not cover the entire tensor') |
|
|
|
def get_split_size(dim_size, chunks): |
|
""" |
|
Computes the split size inline with ``torch.chunk`` |
|
|
|
Args: |
|
dim_size(int): Size of the dimension being chunked. |
|
chunks(int): Number of chunks to create for ``dim_size``. |
|
|
|
Returns: |
|
An int indicating the split size to use. |
|
""" |
|
return (dim_size + chunks - 1) // chunks |
|
|
|
def get_chunked_dim_size(dim_size, split_size, idx): |
|
""" |
|
Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` |
|
and ``split_size``. |
|
|
|
Args: |
|
dim_size(int): Size of the dimension being chunked. |
|
split_size(int): The chunk size for each chunk of ``dim_size``. |
|
idx(int): The index of chunk whose dim size is being requested. |
|
|
|
Returns: |
|
An int indicating the dim size of the chunk. |
|
""" |
|
return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) |
|
|
|
def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): |
|
""" |
|
Generate the start pos and offset length for the current rank for |
|
chunk sharding. |
|
|
|
Args: |
|
sharding_dim_size(int): The dimension length which we shard on. |
|
world_size(int): number of ranks. |
|
spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): |
|
sharding spec. |
|
rank(int): # of cuda process. |
|
|
|
Returns: |
|
start_pos(int): start position of sharded tensor on the given rank. |
|
chunk_size(int): chunk size of sharded tensor on the given rank. |
|
""" |
|
split_size = get_split_size(sharding_dim_size, world_size) |
|
current_offsets = 0 |
|
start_pos = current_offsets |
|
for idx, placement in enumerate(spec.placements): |
|
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) |
|
if rank == placement.rank(): |
|
start_pos = current_offsets |
|
break |
|
current_offsets += chunk_size |
|
return start_pos, chunk_size |
|
|