victan's picture
Upload seamless_communication/cli/m4t/finetune/dist_utils.py with huggingface_hub
c24ca9f
raw
history blame
No virus
1.93 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import logging
import os
from datetime import timedelta
from typing import List
import torch
import torch.distributed as dist
import torch.multiprocessing
logger = logging.getLogger(__name__)
def is_dist_initialized() -> bool:
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank() -> int:
if not is_dist_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
if not is_dist_initialized():
return 0
return int(os.environ["LOCAL_RANK"])
def get_world_size() -> int:
if not is_dist_initialized():
return 1
return dist.get_world_size()
def is_main_process() -> bool:
return get_rank() == 0
def init_distributed(loggers: List[logging.Logger]) -> None:
"""Initializes the distributed backend"""
torch.multiprocessing.set_start_method("spawn")
if "RANK" not in os.environ:
logger.error(
"Cannot init disributed context, as environment varaibles are not set."
)
return
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
)
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=180),
)
logger.info(f"Setting cuda:{local_rank} as main device")
if not is_main_process():
for to_mute in loggers:
to_mute.setLevel(logging.ERROR)
torch.cuda.set_device(local_rank)
dist.barrier()