from typing import List from torch.distributed._shard.metadata import ShardMetadata def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): """ Checks if two shards overlap. """ # For each dim of each shard, check if one shard resides on the other # end of second shard with respect to that dim. As an example for a 2D # shard, we would check if one shard is above or on the left of the # other shard. ndims = len(shard1.shard_offsets) for i in range(ndims): if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: return False if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: return False return True def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): """ Ensures none of the shards overlap with each other. Args: shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing each shard. Raises: ``ValueError`` if there's overlap in any two shards. """ # TODO: evaluate optimizing this if needed. for i in range(len(shards)): for j in range(i + 1, len(shards)): if _check_shard_metadata_pair_overlap(shards[i], shards[j]): raise ValueError(f'Shards {shards[i]} and {shards[j]} overlap') def check_tensor(shards_metadata, tensor_dims) -> None: """ Checks if the shards_metadata is compatible with the provided tensor dims. Args: shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing each shard of the tensor. tensor_dims(Sequence of int): Dimensions of tensor to verify Raises: ``ValueError`` if not compatible. """ # If the tensor's volume matches the total volume of all shards and # all shard boundaries are within tensor dims, we have a compatible # sharding spec for this tensor. Note that we have already verified # we don't have overlapping shards. tensor_rank = len(tensor_dims) shards_rank = len(shards_metadata[0].shard_offsets) if tensor_rank != shards_rank: raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}') total_shard_volume = 0 for shard in shards_metadata: shard_volume = 1 for i, shard_length in enumerate(shard.shard_sizes): shard_volume *= shard_length if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: raise ValueError( f'Shard offset {shard.shard_offsets[i]} and length ' f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}') total_shard_volume += shard_volume tensor_volume = 1 for size in tensor_dims: tensor_volume *= size if total_shard_volume != tensor_volume: # TODO: Can we improve this error message to point out the gaps? raise ValueError( f'Total volume of shards: {total_shard_volume} ' f'does not match tensor volume: {tensor_volume}, in other words ' f'all the individual shards do not cover the entire tensor') def get_split_size(dim_size, chunks): """ Computes the split size inline with ``torch.chunk`` Args: dim_size(int): Size of the dimension being chunked. chunks(int): Number of chunks to create for ``dim_size``. Returns: An int indicating the split size to use. """ return (dim_size + chunks - 1) // chunks def get_chunked_dim_size(dim_size, split_size, idx): """ Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` and ``split_size``. Args: dim_size(int): Size of the dimension being chunked. split_size(int): The chunk size for each chunk of ``dim_size``. idx(int): The index of chunk whose dim size is being requested. Returns: An int indicating the dim size of the chunk. """ return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): """ Generate the start pos and offset length for the current rank for chunk sharding. Args: sharding_dim_size(int): The dimension length which we shard on. world_size(int): number of ranks. spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): sharding spec. rank(int): # of cuda process. Returns: start_pos(int): start position of sharded tensor on the given rank. chunk_size(int): chunk size of sharded tensor on the given rank. """ split_size = get_split_size(sharding_dim_size, world_size) current_offsets = 0 start_pos = current_offsets for idx, placement in enumerate(spec.placements): chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) if rank == placement.rank(): start_pos = current_offsets break current_offsets += chunk_size return start_pos, chunk_size