Spaces:
Runtime error
Runtime error
import logging | |
import warnings | |
import weakref | |
import torch | |
import torch.distributed as dist | |
import torch.distributed.distributed_c10d as c10d | |
from typing import List, Optional, cast | |
""" | |
Moved eager kernel implementations to a separate file partly for readability and partly as it is currently | |
easier in dynamo to set tracing policy on a file-by-file level. | |
Do not put code in this file that Dynamo is expected to trace into, as dynamo may disallow this whole file. | |
DEBUG/TESTING HELPERS: | |
This module includes some helpers that are quite useful when debugging or testing functional collectives: | |
_tensor_needs_wait | |
_outstanding_wait_count | |
_wait_all | |
""" | |
logger = logging.getLogger(__name__) | |
data_ptr_to_work = dict() | |
work_version = 0 | |
class _WaitRegistration: | |
def __init__(self, work): | |
global work_version | |
self.work = work | |
self.version = work_version | |
self.ptrs = set() | |
self.ptr_alias_count = {} | |
self.cleanup_count = 0 | |
work_version += 1 | |
def _register_tensor_ptr(self, data_ptr): | |
global data_ptr_to_work | |
data_ptr_to_work[data_ptr] = self | |
self.ptrs.add(data_ptr) | |
def _record_wrapper(self, ptr): | |
self._register_tensor_ptr(ptr) | |
self.ptr_alias_count.setdefault(ptr, 0) | |
self.ptr_alias_count[ptr] += 1 | |
self.cleanup_count += 1 | |
def wait(self): | |
if self.work is not None: | |
self.work.wait() | |
self.work = None | |
self.cleanup() | |
def decrement_live_tensor(self, ptr): | |
self.cleanup_count -= 1 | |
if self.cleanup_count == 0: | |
self.wait() | |
else: | |
self.ptr_alias_count[ptr] -= 1 | |
if self.ptr_alias_count[ptr] < 1 and data_ptr_to_work.get(ptr, None) == self: | |
del data_ptr_to_work[ptr] | |
def cleanup(self): | |
for ptr in self.ptrs: | |
if data_ptr_to_work.get(ptr, None) == self: | |
del data_ptr_to_work[ptr] | |
def _register_tensor_work(tensor_or_list, work_or_list): | |
if not isinstance(tensor_or_list, list): | |
tensor_or_list = [tensor_or_list] | |
if not isinstance(work_or_list, list): | |
reg = _WaitRegistration(work_or_list) | |
for tensor in tensor_or_list: | |
reg._register_tensor_ptr(tensor.data_ptr()) | |
else: | |
for tensor, work in zip(tensor_or_list, work_or_list): | |
reg = _WaitRegistration(work) | |
reg._register_tensor_ptr(tensor.data_ptr()) | |
def _wait_reg_dec(ptr, wait_reg): | |
wait_reg.decrement_live_tensor(ptr) | |
def _register_tensor_wrapper(tensor) -> None: | |
global data_ptr_to_work | |
data_ptr = tensor.elem.data_ptr() | |
# Note: we should NEVER try to trace this, bc it registers runtime stuff during trace. | |
# Instead, backends must call this themselves when implementing traced collectives. | |
wait_reg = data_ptr_to_work.get(data_ptr, None) | |
if wait_reg is None: | |
warnings.warn( | |
"Trying to register finalizer to AsyncCollectiveTensor but the inner tensor is already gone" | |
) | |
else: | |
# We force the collective to be waited in the case this tensor goes away to reduce the change of deadlocks. | |
# NOTE: we register the callback to the ACT wrapper class, for the following reasons: | |
# 1. The inner tensor is referenced by the associated Work object, so it's uncollective until we release the | |
# associated work object | |
# 2. There's a n-to-1 relationship between wrappers and inner tensor due to non-waitable ops like view() | |
wait_reg._record_wrapper(data_ptr) | |
weakref.finalize(tensor, _wait_reg_dec, data_ptr, wait_reg) | |
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: | |
global data_ptr_to_work | |
data_ptr = tensor.data_ptr() | |
wait_reg = data_ptr_to_work.get(data_ptr) | |
if wait_reg is not None: | |
wait_reg.wait() | |
return tensor | |
def _tensor_needs_wait(tensor: torch.Tensor) -> bool: | |
"""Returns true if ```tensor``` needs to be waited. Works with ACS and inner tensors.""" | |
if hasattr(tensor, "_get_acs_underlying_tensor"): | |
tensor = tensor._get_acs_underlying_tensor() | |
data_ptr = tensor.data_ptr() | |
wait_reg = data_ptr_to_work.get(data_ptr) | |
return wait_reg is not None and wait_reg.work is not None | |
def _outstanding_wait_count() -> int: | |
""" Returns the number of outstanding work objects waiting to be waited (sic). """ | |
return len(data_ptr_to_work) | |
def _wait_all() -> None: | |
""" Wait for all outstanding collectives. """ | |
for work_reg in list(data_ptr_to_work.values()): | |
work_reg.wait() | |
def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp: | |
reduceOp = reduceOp.upper() | |
op = dist.ReduceOp.RedOpType.__members__.get(reduceOp) | |
if op is None: | |
raise ValueError(f"Invalid reduce operation {reduceOp}") | |
return cast(dist.ReduceOp, op) | |
""" | |
Kernel implementations (for eager runtime only) - should never be traced by torch.compile | |
These functions should all be bound to dispatcher ops. During tracing, the op itself should be | |
captured in the graph and the backend should implement the op however it prefers. | |
""" | |
def _broadcast(self, src, tag, ranks, group_size): | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
inplace_tensor = self.clone(memory_format=torch.contiguous_format) | |
work = dist.broadcast(inplace_tensor, src, group=group, async_op=True) | |
_register_tensor_work(inplace_tensor, work) | |
return inplace_tensor | |
# TODO assert if ranks has duplicated entries | |
def _all_reduce(self, reduceOp, tag, ranks, group_size): | |
op = _str_to_reduce_op(reduceOp) | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
inplace_tensor = self.clone(memory_format=torch.contiguous_format) | |
work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True) | |
_register_tensor_work(inplace_tensor, work) | |
return inplace_tensor | |
def _all_reduce_coalesced(self, reduceOp, tag, ranks, group_size): | |
op = _str_to_reduce_op(reduceOp) | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
inplace_tensor_list = [t.clone(memory_format=torch.contiguous_format) for t in self] | |
work = dist.all_reduce_coalesced(inplace_tensor_list, op=op, group=group, async_op=True) | |
_register_tensor_work(inplace_tensor_list, work) | |
return inplace_tensor_list | |
def _all_gather_into_tensor(shard, tag, ranks, group_size): | |
# TODO add dim support? | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
out_size = list(shard.size()) | |
out_size[0] *= group_size | |
out_tensor = shard.new_empty(out_size) | |
assert out_tensor.is_contiguous() | |
# FIXME gloo doesn't support _allgather_base | |
if dist.get_backend(group) == dist.Backend.GLOO or shard.is_cpu: | |
tensor_list = list(torch.chunk(out_tensor, group_size)) | |
work = dist.all_gather(tensor_list, shard, group=group, async_op=True) | |
else: | |
work = dist.all_gather_into_tensor(out_tensor, shard, group=group, async_op=True) | |
_register_tensor_work(out_tensor, work) | |
return out_tensor | |
def _all_gather_into_tensor_coalesced(self, tag, rankset, group_size): | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, rankset, group_size) | |
assert group is not None | |
def mk_out_tensor(shard): | |
out_size = list(shard.size()) | |
out_size[0] *= group_size | |
out_tensor = shard.new_empty(out_size) | |
assert out_tensor.is_contiguous() | |
return out_tensor | |
out_tensors = [mk_out_tensor(t) for t in self] | |
work_list = _all_gather_into_tensor_coalesced_fallback( | |
output_tensors=out_tensors, | |
input_tensors=self, | |
group=group, | |
async_op=True) | |
_register_tensor_work(out_tensors, work_list) | |
return out_tensors | |
def _reduce_scatter_tensor( | |
input: torch.Tensor, | |
reduceOp: str, | |
tag: str, | |
ranks: List[int], | |
group_size: int, | |
): | |
# TODO add dim support? | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
op = _str_to_reduce_op(reduceOp) | |
if dist.get_backend(group) == dist.Backend.GLOO or input.is_cpu: | |
# cpu::gloo backend does not have reduce_scatter we fallback to do all_reduce | |
# + local chunk | |
logger.warning( | |
"ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!" | |
) | |
reduction_input = input.clone() | |
group_rank = dist.get_rank(group) | |
work = dist.all_reduce(reduction_input, op=op, group=group, async_op=True) | |
out_tensor = reduction_input.chunk(group_size, dim=0)[group_rank] | |
_register_tensor_work(out_tensor, work) | |
else: | |
out_size = list(input.size()) | |
out_size[0] //= group_size | |
out_tensor = input.new_empty(out_size) | |
work = dist.reduce_scatter_tensor( | |
out_tensor, input, op=op, group=group, async_op=True | |
) | |
_register_tensor_work(out_tensor, work) | |
return out_tensor | |
def _reduce_scatter_tensor_coalesced( | |
inputs: List[torch.Tensor], | |
reduce_op: str, | |
tag: str, | |
ranks: List[int], | |
group_size: int, | |
): | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
assert group is not None | |
op = _str_to_reduce_op(reduce_op) | |
def mk_out_tensor(shard): | |
out_size = list(shard.size()) | |
out_size[0] //= group_size | |
out_tensor = shard.new_empty(out_size) | |
assert out_tensor.is_contiguous() | |
return out_tensor | |
out_tensors = [mk_out_tensor(t) for t in inputs] | |
work_list = _reduce_scatter_tensor_coalesced_fallback( | |
output_tensors=out_tensors, | |
input_tensors=inputs, | |
op=op, | |
group=group, | |
async_op=False) | |
_register_tensor_work(out_tensors, work_list) | |
return out_tensors | |
def _all_gather_into_tensor_coalesced_fallback(output_tensors, input_tensors, group, async_op=False): | |
# all_gather_coalesced is useless, it doesn't work under NCCL and does lots of copies under Gloo | |
# all_gather is useless too because it's single tensor | |
# NCCL's PG::all_gather with multiple tensors is broken, it only works for the multi-device setting | |
# and fails if you mix same-size with different-size tensor lists. | |
# _coalescing_manager crashed NCCL when used with all_gather_into_tensor. | |
if input_tensors[0].is_cpu or not async_op: | |
work_list = [] | |
out_tensors_sliced = [ | |
list(torch.chunk(out_tensor, dist.get_world_size(group))) | |
for out_tensor in output_tensors | |
] | |
for shard, out_tensor in zip(input_tensors, out_tensors_sliced): | |
work = c10d.all_gather(out_tensor, shard, group=group, async_op=async_op) | |
work_list.append(work) | |
return work_list | |
else: | |
with c10d._coalescing_manager(group=group, async_ops=True) as cm: | |
for in_t, out_t in zip(input_tensors, output_tensors): | |
dist.all_gather_into_tensor(out_t, in_t, group=group, async_op=True) | |
return cm | |
def _reduce_scatter_tensor_coalesced_fallback(output_tensors, input_tensors, op, group, async_op=False): | |
# All the same reasons as the all_gather fallback | |
work_list = [] | |
for shard, out_tensor in zip(input_tensors, output_tensors): | |
work = c10d.reduce_scatter_tensor(out_tensor, shard, op=op, group=group, async_op=async_op) | |
work_list.append(work) | |
return work_list | |
def _all_to_all_single( | |
input: torch.Tensor, | |
output_split_sizes: Optional[List[int]], | |
input_split_sizes: Optional[List[int]], | |
tag: str, | |
ranks: List[int], | |
group_size: int, | |
): | |
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) | |
if output_split_sizes is not None: | |
torch._check(input.dim() >= 1, lambda: f"Expected input to have at least 1 dim but got {input.dim()} dim") | |
out_size = list(input.size()) | |
out_size[0] = sum(output_split_sizes) | |
out_tensor = input.new_empty(out_size) | |
else: | |
out_tensor = input.new_empty(input.size()) | |
work = c10d.all_to_all_single( | |
out_tensor, input, output_split_sizes=output_split_sizes, | |
input_split_sizes=input_split_sizes, group=group, async_op=True | |
) | |
_register_tensor_work(out_tensor, work) | |
return out_tensor | |