|
import copy |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch._C._distributed_c10d import ( |
|
ProcessGroup, |
|
) |
|
import torch.distributed._shard.sharding_spec as shard_spec |
|
from torch.distributed._shard.sharding_spec._internals import ( |
|
get_split_size, |
|
get_chunked_dim_size, |
|
) |
|
from torch.distributed.nn.functional import ( |
|
all_to_all, |
|
all_to_all_single, |
|
) |
|
from torch.distributed._shard.metadata import ShardMetadata |
|
|
|
from .shard import Shard |
|
|
|
|
|
def get_idx_from_placements(placements, current_rank) -> int: |
|
""" |
|
Return the position of the current rank in the given placements. |
|
|
|
Args: |
|
placements(List[Union[_remote_device, str]]): |
|
Specifies the placement of each shard of the Tensor. The size of |
|
the list represents the number of shards to be created. This could |
|
be a list of |
|
:class:`torch.distributed._remote_device`'s. This list |
|
could also contain a string which represents remote |
|
device as accepted by |
|
:class:`torch.distributed._remote_device` |
|
current_rank (int): number of current device. |
|
|
|
Returns: |
|
A int which contains the position of current device in the placement list. |
|
""" |
|
for idx, placement in enumerate(placements): |
|
if current_rank == placement.rank(): |
|
return idx |
|
raise RuntimeError('current_rank not in the placement.') |
|
|
|
|
|
def build_reshard_metadata( |
|
st_size: torch.Size, |
|
sharding_spec: shard_spec.ShardingSpec, |
|
world_size: int, |
|
) -> Tuple[List[ShardMetadata], List[int]]: |
|
""" |
|
Based the given sharding spec, we calculate the offset and local shard size. |
|
We then build a ShardMetadata on top of the calculation result. |
|
|
|
Args: |
|
st_size (torch.Size): The size of the sharded tensor. |
|
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The |
|
specification describing how the tensor is sharded. |
|
world_size (int): number of ranks. |
|
|
|
Returns: |
|
A Tuple of the followings: |
|
A List[`ShardMetadata`] which contains the metadata for the shard, including |
|
offsets, lengths and device placement. |
|
A List[int] which contains the ranks in the order of placement. |
|
""" |
|
shard_dim = int(sharding_spec.dim) |
|
shards_metadata = [None] * world_size |
|
ranks = [] |
|
offsets = [0] * len(st_size) |
|
split_size = get_split_size(st_size[shard_dim], world_size) |
|
for idx, placement in enumerate(sharding_spec.placements): |
|
ranks.append(placement.rank()) |
|
sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) |
|
local_tensor_size = list(st_size) |
|
local_tensor_size[shard_dim] = sharded_dim_size |
|
shards_metadata[placement.rank()] = ShardMetadata( |
|
shard_offsets=copy.deepcopy(offsets), |
|
shard_sizes=local_tensor_size, |
|
placement=placement, |
|
) |
|
offsets[shard_dim] += sharded_dim_size |
|
return shards_metadata, ranks |
|
|
|
|
|
def reshuffle_local_shard( |
|
local_shard: torch.Tensor, |
|
st_size: torch.Size, |
|
sharding_spec: shard_spec.ShardingSpec, |
|
resharding_spec: shard_spec.ShardingSpec, |
|
pg: ProcessGroup, |
|
) -> Tuple[List[Shard], List[ShardMetadata]]: |
|
""" |
|
Reshuffle the local shard directly when the reshard dim is same as the original |
|
sharding dim. Logically we do this in two step: |
|
1. To collect all shards based on original sharding spec. |
|
2. Reshard the tensor based on the given resharding spec. |
|
|
|
In reality, we consolidate the two steps into one by sending the local tensor to |
|
the new shard directly based on the resharding spec. |
|
|
|
Args: |
|
local_tensor (Tensor): Local tensor stored in the current rank. |
|
st_size (torch.Size): The size of the sharded tensor. |
|
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The |
|
specification describing how the tensor is sharded originally. |
|
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The |
|
specification describing how the tensor will be resharded. |
|
pg (ProcessGroup): The process group to aggregate on. |
|
|
|
Returns: |
|
A Tuple of the followings: |
|
A List[`Shard`] which contains the local tensor and its metadata. |
|
A List[`ShardMetadata`] which contains the metadata for the shard, including |
|
offsets, lengths and device placement. |
|
""" |
|
current_rank = dist.get_rank(pg) |
|
world_size = dist.get_world_size(pg) |
|
|
|
shards_metadata, ranks = build_reshard_metadata( |
|
st_size, resharding_spec, world_size |
|
) |
|
|
|
reshard_dim = int(resharding_spec.dim) |
|
split_size = get_split_size(st_size[reshard_dim], world_size) |
|
input_split_sizes = [0] * world_size |
|
idx = get_idx_from_placements(sharding_spec.placements, current_rank) |
|
new_rank = resharding_spec.placements[idx].rank() |
|
input_split_sizes[new_rank] = local_shard.size(reshard_dim) |
|
|
|
output_split_sizes = [0] * world_size |
|
new_idx = ranks.index(current_rank) |
|
sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) |
|
output_split_sizes[new_rank] = sharded_dim_size |
|
|
|
local_shard = local_shard.transpose(0, reshard_dim).contiguous() |
|
gathered_input_size = list(local_shard.size()) |
|
gathered_input_size[0] = sharded_dim_size |
|
gathered_input = torch.empty(gathered_input_size, device=local_shard.device, dtype=local_shard.dtype) |
|
|
|
local_shard = all_to_all_single( |
|
gathered_input, |
|
local_shard, |
|
input_split_sizes=input_split_sizes, |
|
output_split_sizes=output_split_sizes, |
|
group=pg, |
|
) |
|
local_tensor = local_shard.transpose(0, reshard_dim).contiguous() |
|
local_shards = [Shard(local_tensor, shards_metadata[current_rank])] |
|
return local_shards, shards_metadata |
|
|
|
|
|
def reshard_local_shard( |
|
local_tensor: torch.Tensor, |
|
st_size: torch.Size, |
|
sharding_spec: shard_spec.ShardingSpec, |
|
resharding_spec: shard_spec.ShardingSpec, |
|
pg: ProcessGroup, |
|
) -> Tuple[List[Shard], List[ShardMetadata]]: |
|
""" |
|
Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is |
|
different from the original sharding dim, we need to do two steps logically: |
|
1. To collect all shards based on original sharding spec. |
|
2. Reshard the tensor based on the given resharding spec. |
|
|
|
In reality, we consolidate the two steps into one by sending each rank the new |
|
shard based on the resharding spec. |
|
|
|
Args: |
|
local_tensor (Tensor): Local tensor stored in the current rank. |
|
st_size (torch.Size): The size of the sharded tensor. |
|
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The |
|
specification describing how the tensor is sharded originally. |
|
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The |
|
specification describing how the tensor will be resharded. |
|
pg (ProcessGroup): The process group to aggregate on. |
|
|
|
Returns: |
|
A Tuple of the followings: |
|
A List[`Shard`] which contains the local tensor and its metadata. |
|
A List[`ShardMetadata`] which contains the metadata for the shard, including |
|
offsets, lengths and device placement. |
|
""" |
|
current_rank = dist.get_rank(pg) |
|
world_size = dist.get_world_size(pg) |
|
current_sharding_dim = int(sharding_spec.dim) |
|
reshard_dim = int(resharding_spec.dim) |
|
|
|
|
|
shards_metadata, ranks = build_reshard_metadata( |
|
st_size, resharding_spec, world_size |
|
) |
|
|
|
|
|
input_split_sizes = [] |
|
for metadata in shards_metadata: |
|
input_split_sizes.append(metadata.shard_sizes[reshard_dim]) |
|
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) |
|
|
|
if rearrange_input: |
|
|
|
indices: List[int] = [] |
|
for metadata in shards_metadata: |
|
offset_start_idx = metadata.shard_offsets[reshard_dim] |
|
split_size = metadata.shard_sizes[reshard_dim] |
|
indices += range(offset_start_idx, offset_start_idx + split_size) |
|
local_tensor = local_tensor.index_select( |
|
reshard_dim, torch.tensor(indices, device=local_tensor.device) |
|
) |
|
|
|
|
|
|
|
output_tensor_list = [torch.tensor(1)] * world_size |
|
split_size = get_split_size(st_size[current_sharding_dim], world_size) |
|
rearrange_output_list = False |
|
indices = [] |
|
for idx, placement in enumerate(sharding_spec.placements): |
|
sharded_dim_size = get_chunked_dim_size( |
|
st_size[current_sharding_dim], split_size, idx |
|
) |
|
output_tensor_size = list(st_size) |
|
output_tensor_size[current_sharding_dim] = sharded_dim_size |
|
output_tensor_size[reshard_dim] = input_split_sizes[current_rank] |
|
output_tensor_list[ |
|
placement.rank() |
|
] = torch.empty( |
|
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype |
|
) |
|
indices.append(placement.rank()) |
|
if idx != placement.rank(): |
|
rearrange_output_list = True |
|
|
|
|
|
input_tensor_list = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) |
|
input_tensor_list = [tensor.contiguous() for tensor in input_tensor_list] |
|
output_tensor_list = all_to_all( |
|
output_tensor_list, |
|
input_tensor_list, |
|
group=pg, |
|
) |
|
|
|
if rearrange_output_list: |
|
|
|
output_tensor_list = [output_tensor_list[idx] for idx in indices] |
|
local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) |
|
local_shards = [Shard(local_tensor, shards_metadata[current_rank])] |
|
return local_shards, shards_metadata |
|
|