zcxu-eric's picture
add app
8aa9c9a
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import os
import socket
import warnings
import torch
from torch import distributed as dist
def distributed_init(args):
if dist.is_initialized():
warnings.warn("Distributed is already initialized, cannot initialize twice!")
args.rank = dist.get_rank()
else:
print(
f"Distributed Init (Rank {args.rank}): "
f"{args.init_method}"
)
dist.init_process_group(
backend='nccl',
init_method=args.init_method,
world_size=args.world_size,
rank=args.rank,
)
print(
f"Initialized Host {socket.gethostname()} as Rank "
f"{args.rank}"
)
if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ:
# Set for onboxdataloader support
split = args.init_method.split("//")
assert len(split) == 2, (
"host url for distributed should be split by '//' "
+ "into exactly two elements"
)
split = split[1].split(":")
assert (
len(split) == 2
), "host url should be of the form <host_url>:<host_port>"
os.environ["MASTER_ADDR"] = split[0]
os.environ["MASTER_PORT"] = split[1]
# perform a dummy all-reduce to initialize the NCCL communicator
dist.all_reduce(torch.zeros(1).cuda())
suppress_output(is_master())
args.rank = dist.get_rank()
return args.rank
def get_rank():
if not dist.is_available():
return 0
if not dist.is_nccl_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_master():
return get_rank() == 0
def synchronize():
if dist.is_initialized():
dist.barrier()
def suppress_output(is_master):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
import warnings
builtin_warn = warnings.warn
def warn(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_warn(*args, **kwargs)
# Log warnings only once
warnings.warn = warn
warnings.simplefilter("once", UserWarning)