import collections.abc import copy from typing import Optional, List, Sequence import torch from torch.distributed import distributed_c10d from torch.distributed import rpc from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) from torch.distributed._shard.metadata import ShardMetadata from .metadata import TensorProperties, ShardedTensorMetadata from .shard import Shard def _parse_and_validate_remote_device(pg, remote_device): if remote_device is None: raise ValueError("remote device is None") worker_name = remote_device.worker_name() rank = remote_device.rank() device = remote_device.device() # Validate rank, skip validation if rank is not part of process group. if not distributed_c10d._rank_not_in_group(pg): if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)): raise ValueError(f'Invalid rank: {rank}') if worker_name is not None: if not rpc._is_current_rpc_agent_set(): raise RuntimeError(f'RPC framework needs to be initialized for using worker names: {worker_name}') workers = rpc._get_current_rpc_agent().get_worker_infos() for worker in workers: if worker.name == worker_name: return worker.id, device raise ValueError(f'Invalid worker name: {worker_name}') return rank, device def _validate_output_tensor_for_gather( my_rank: int, dst_rank: int, size: torch.Size, dst_tensor: Optional[torch.Tensor], ) -> None: if dst_rank == my_rank: if dst_tensor is None: raise ValueError( f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" ) if tuple(size) != (dst_tensor.size()): raise ValueError( f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," f"but should be {tuple(size)}" ) elif dst_tensor: raise ValueError( "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks." ) def _flatten_tensor_size(size) -> torch.Size: """ Checks if tensor size is valid, then flatten/return a torch.Size object. """ if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): dims = list(*size) else: dims = list(size) for dim in dims: if not isinstance(dim, int): raise TypeError(f'size has to be a sequence of ints, found: {dims}') return torch.Size(dims) def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): if is_local: assert isinstance(ranks, int) if expected != actual: raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " f"Found one local shard tensor {prop_name}={expected}, " f"the other local shard tensor {prop_name}={actual}.") else: # compare failure check across ranks, ranks list should have two rank assert len(ranks) == 2 if expected != actual: raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! " f"Found {prop_name}={expected} on rank:{ranks[0]}, " f"and {prop_name}={actual} on rank:{ranks[1]}.") def build_metadata_from_local_shards( local_shards: List[Shard], global_size: torch.Size, current_rank: int, pg: distributed_c10d.ProcessGroup ) -> ShardedTensorMetadata: assert len(local_shards) > 0, "must have local shards!" local_shard_metadatas: List[ShardMetadata] = [] first_shard_dtype = local_shards[0].tensor.dtype first_shard_layout = local_shards[0].tensor.layout first_shard_requires_grad = local_shards[0].tensor.requires_grad first_shard_is_pinned = local_shards[0].tensor.is_pinned() # 1). Validate local tensors and associated metadatas for i, local_shard in enumerate(local_shards): local_shard_tensor = local_shard.tensor local_shard_meta = local_shard.metadata local_shard_metadatas.append(local_shard_meta) rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement) if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout: raise ValueError( f'Only torch.strided layout is currently supported, but found ' f'{local_shard_tensor.layout} on rank:{current_rank}!' ) if not local_shard_tensor.is_contiguous(): raise ValueError('Only torch.contiguous_format memory_format is currently supported!') if rank != current_rank: raise ValueError( f"Local shard metadata's rank does not match with the rank in its process group! " f'Found current rank in the process group: {current_rank}, ' f"local ShardMetadata placement's rank: {rank}" ) if local_shard_tensor.device != local_device: raise ValueError( f"Local shard tensor device does not match with local Shard's placement! " f"Found local shard tensor device: {local_shard_tensor.device}, " f"local shard metadata placement device: {local_device}" ) _raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) _raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank) _raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank) _raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank) # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then # do all_gather to collect local_sharded_tensor_metadata from all ranks local_tensor_properties = TensorProperties( dtype=first_shard_dtype, layout=first_shard_layout, requires_grad=first_shard_requires_grad, memory_format=torch.contiguous_format, pin_memory=first_shard_is_pinned ) local_sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=local_shard_metadatas, size=global_size, tensor_properties=local_tensor_properties) return local_sharded_tensor_metadata def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): global_sharded_tensor_metadata = None global_metadata_rank = 0 for rank, rank_metadata in enumerate(gathered_metadatas): if rank_metadata is None: continue if global_sharded_tensor_metadata is None: global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) global_metadata_rank = rank else: _raise_if_mismatch(global_sharded_tensor_metadata.size, rank_metadata.size, "global_size", [global_metadata_rank, rank], is_local=False) # don't need to check layout and memory format as we already checked in local shards validation stage _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype, rank_metadata.tensor_properties.dtype, "dtype", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad, rank_metadata.tensor_properties.requires_grad, "requires_grad", [global_metadata_rank, rank], is_local=False) _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory, rank_metadata.tensor_properties.pin_memory, "pin_memory", [global_metadata_rank, rank], is_local=False) # pass all validations, extend shards metadata global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata) # check if the shards_metadata is compatible with global size of the sharded tensor. check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) else: raise ValueError("ShardedTensor have no local shards on all ranks!") return global_sharded_tensor_metadata