Image-Generator / torch /distributed /_functional_collectives_impl.py
Adi-69s's picture
Upload 5061 files
b2659ad verified
import logging
import warnings
import weakref
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
from typing import List, Optional, cast
"""
Moved eager kernel implementations to a separate file partly for readability and partly as it is currently
easier in dynamo to set tracing policy on a file-by-file level.
Do not put code in this file that Dynamo is expected to trace into, as dynamo may disallow this whole file.
DEBUG/TESTING HELPERS:
This module includes some helpers that are quite useful when debugging or testing functional collectives:
_tensor_needs_wait
_outstanding_wait_count
_wait_all
"""
logger = logging.getLogger(__name__)
data_ptr_to_work = dict()
work_version = 0
class _WaitRegistration:
def __init__(self, work):
global work_version
self.work = work
self.version = work_version
self.ptrs = set()
self.ptr_alias_count = {}
self.cleanup_count = 0
work_version += 1
def _register_tensor_ptr(self, data_ptr):
global data_ptr_to_work
data_ptr_to_work[data_ptr] = self
self.ptrs.add(data_ptr)
def _record_wrapper(self, ptr):
self._register_tensor_ptr(ptr)
self.ptr_alias_count.setdefault(ptr, 0)
self.ptr_alias_count[ptr] += 1
self.cleanup_count += 1
def wait(self):
if self.work is not None:
self.work.wait()
self.work = None
self.cleanup()
def decrement_live_tensor(self, ptr):
self.cleanup_count -= 1
if self.cleanup_count == 0:
self.wait()
else:
self.ptr_alias_count[ptr] -= 1
if self.ptr_alias_count[ptr] < 1 and data_ptr_to_work.get(ptr, None) == self:
del data_ptr_to_work[ptr]
def cleanup(self):
for ptr in self.ptrs:
if data_ptr_to_work.get(ptr, None) == self:
del data_ptr_to_work[ptr]
def _register_tensor_work(tensor_or_list, work_or_list):
if not isinstance(tensor_or_list, list):
tensor_or_list = [tensor_or_list]
if not isinstance(work_or_list, list):
reg = _WaitRegistration(work_or_list)
for tensor in tensor_or_list:
reg._register_tensor_ptr(tensor.data_ptr())
else:
for tensor, work in zip(tensor_or_list, work_or_list):
reg = _WaitRegistration(work)
reg._register_tensor_ptr(tensor.data_ptr())
def _wait_reg_dec(ptr, wait_reg):
wait_reg.decrement_live_tensor(ptr)
def _register_tensor_wrapper(tensor) -> None:
global data_ptr_to_work
data_ptr = tensor.elem.data_ptr()
# Note: we should NEVER try to trace this, bc it registers runtime stuff during trace.
# Instead, backends must call this themselves when implementing traced collectives.
wait_reg = data_ptr_to_work.get(data_ptr, None)
if wait_reg is None:
warnings.warn(
"Trying to register finalizer to AsyncCollectiveTensor but the inner tensor is already gone"
)
else:
# We force the collective to be waited in the case this tensor goes away to reduce the change of deadlocks.
# NOTE: we register the callback to the ACT wrapper class, for the following reasons:
# 1. The inner tensor is referenced by the associated Work object, so it's uncollective until we release the
# associated work object
# 2. There's a n-to-1 relationship between wrappers and inner tensor due to non-waitable ops like view()
wait_reg._record_wrapper(data_ptr)
weakref.finalize(tensor, _wait_reg_dec, data_ptr, wait_reg)
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
global data_ptr_to_work
data_ptr = tensor.data_ptr()
wait_reg = data_ptr_to_work.get(data_ptr)
if wait_reg is not None:
wait_reg.wait()
return tensor
def _tensor_needs_wait(tensor: torch.Tensor) -> bool:
"""Returns true if ```tensor``` needs to be waited. Works with ACS and inner tensors."""
if hasattr(tensor, "_get_acs_underlying_tensor"):
tensor = tensor._get_acs_underlying_tensor()
data_ptr = tensor.data_ptr()
wait_reg = data_ptr_to_work.get(data_ptr)
return wait_reg is not None and wait_reg.work is not None
def _outstanding_wait_count() -> int:
""" Returns the number of outstanding work objects waiting to be waited (sic). """
return len(data_ptr_to_work)
def _wait_all() -> None:
""" Wait for all outstanding collectives. """
for work_reg in list(data_ptr_to_work.values()):
work_reg.wait()
def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
reduceOp = reduceOp.upper()
op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
if op is None:
raise ValueError(f"Invalid reduce operation {reduceOp}")
return cast(dist.ReduceOp, op)
"""
Kernel implementations (for eager runtime only) - should never be traced by torch.compile
These functions should all be bound to dispatcher ops. During tracing, the op itself should be
captured in the graph and the backend should implement the op however it prefers.
"""
def _broadcast(self, src, tag, ranks, group_size):
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
inplace_tensor = self.clone(memory_format=torch.contiguous_format)
work = dist.broadcast(inplace_tensor, src, group=group, async_op=True)
_register_tensor_work(inplace_tensor, work)
return inplace_tensor
# TODO assert if ranks has duplicated entries
def _all_reduce(self, reduceOp, tag, ranks, group_size):
op = _str_to_reduce_op(reduceOp)
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
inplace_tensor = self.clone(memory_format=torch.contiguous_format)
work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
_register_tensor_work(inplace_tensor, work)
return inplace_tensor
def _all_reduce_coalesced(self, reduceOp, tag, ranks, group_size):
op = _str_to_reduce_op(reduceOp)
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
inplace_tensor_list = [t.clone(memory_format=torch.contiguous_format) for t in self]
work = dist.all_reduce_coalesced(inplace_tensor_list, op=op, group=group, async_op=True)
_register_tensor_work(inplace_tensor_list, work)
return inplace_tensor_list
def _all_gather_into_tensor(shard, tag, ranks, group_size):
# TODO add dim support?
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
out_size = list(shard.size())
out_size[0] *= group_size
out_tensor = shard.new_empty(out_size)
assert out_tensor.is_contiguous()
# FIXME gloo doesn't support _allgather_base
if dist.get_backend(group) == dist.Backend.GLOO or shard.is_cpu:
tensor_list = list(torch.chunk(out_tensor, group_size))
work = dist.all_gather(tensor_list, shard, group=group, async_op=True)
else:
work = dist.all_gather_into_tensor(out_tensor, shard, group=group, async_op=True)
_register_tensor_work(out_tensor, work)
return out_tensor
def _all_gather_into_tensor_coalesced(self, tag, rankset, group_size):
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, rankset, group_size)
assert group is not None
def mk_out_tensor(shard):
out_size = list(shard.size())
out_size[0] *= group_size
out_tensor = shard.new_empty(out_size)
assert out_tensor.is_contiguous()
return out_tensor
out_tensors = [mk_out_tensor(t) for t in self]
work_list = _all_gather_into_tensor_coalesced_fallback(
output_tensors=out_tensors,
input_tensors=self,
group=group,
async_op=True)
_register_tensor_work(out_tensors, work_list)
return out_tensors
def _reduce_scatter_tensor(
input: torch.Tensor,
reduceOp: str,
tag: str,
ranks: List[int],
group_size: int,
):
# TODO add dim support?
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
op = _str_to_reduce_op(reduceOp)
if dist.get_backend(group) == dist.Backend.GLOO or input.is_cpu:
# cpu::gloo backend does not have reduce_scatter we fallback to do all_reduce
# + local chunk
logger.warning(
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
)
reduction_input = input.clone()
group_rank = dist.get_rank(group)
work = dist.all_reduce(reduction_input, op=op, group=group, async_op=True)
out_tensor = reduction_input.chunk(group_size, dim=0)[group_rank]
_register_tensor_work(out_tensor, work)
else:
out_size = list(input.size())
out_size[0] //= group_size
out_tensor = input.new_empty(out_size)
work = dist.reduce_scatter_tensor(
out_tensor, input, op=op, group=group, async_op=True
)
_register_tensor_work(out_tensor, work)
return out_tensor
def _reduce_scatter_tensor_coalesced(
inputs: List[torch.Tensor],
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
assert group is not None
op = _str_to_reduce_op(reduce_op)
def mk_out_tensor(shard):
out_size = list(shard.size())
out_size[0] //= group_size
out_tensor = shard.new_empty(out_size)
assert out_tensor.is_contiguous()
return out_tensor
out_tensors = [mk_out_tensor(t) for t in inputs]
work_list = _reduce_scatter_tensor_coalesced_fallback(
output_tensors=out_tensors,
input_tensors=inputs,
op=op,
group=group,
async_op=False)
_register_tensor_work(out_tensors, work_list)
return out_tensors
def _all_gather_into_tensor_coalesced_fallback(output_tensors, input_tensors, group, async_op=False):
# all_gather_coalesced is useless, it doesn't work under NCCL and does lots of copies under Gloo
# all_gather is useless too because it's single tensor
# NCCL's PG::all_gather with multiple tensors is broken, it only works for the multi-device setting
# and fails if you mix same-size with different-size tensor lists.
# _coalescing_manager crashed NCCL when used with all_gather_into_tensor.
if input_tensors[0].is_cpu or not async_op:
work_list = []
out_tensors_sliced = [
list(torch.chunk(out_tensor, dist.get_world_size(group)))
for out_tensor in output_tensors
]
for shard, out_tensor in zip(input_tensors, out_tensors_sliced):
work = c10d.all_gather(out_tensor, shard, group=group, async_op=async_op)
work_list.append(work)
return work_list
else:
with c10d._coalescing_manager(group=group, async_ops=True) as cm:
for in_t, out_t in zip(input_tensors, output_tensors):
dist.all_gather_into_tensor(out_t, in_t, group=group, async_op=True)
return cm
def _reduce_scatter_tensor_coalesced_fallback(output_tensors, input_tensors, op, group, async_op=False):
# All the same reasons as the all_gather fallback
work_list = []
for shard, out_tensor in zip(input_tensors, output_tensors):
work = c10d.reduce_scatter_tensor(out_tensor, shard, op=op, group=group, async_op=async_op)
work_list.append(work)
return work_list
def _all_to_all_single(
input: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
tag: str,
ranks: List[int],
group_size: int,
):
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
if output_split_sizes is not None:
torch._check(input.dim() >= 1, lambda: f"Expected input to have at least 1 dim but got {input.dim()} dim")
out_size = list(input.size())
out_size[0] = sum(output_split_sizes)
out_tensor = input.new_empty(out_size)
else:
out_tensor = input.new_empty(input.size())
work = c10d.all_to_all_single(
out_tensor, input, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, group=group, async_op=True
)
_register_tensor_work(out_tensor, work)
return out_tensor