|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import datetime
|
|
import functools
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import timm.models.hub as timm_hub
|
|
|
|
|
|
def setup_for_distributed(is_master):
|
|
"""
|
|
This function disables printing when not in master process
|
|
"""
|
|
import builtins as __builtin__
|
|
|
|
builtin_print = __builtin__.print
|
|
|
|
def print(*args, **kwargs):
|
|
force = kwargs.pop("force", False)
|
|
if is_master or force:
|
|
builtin_print(*args, **kwargs)
|
|
|
|
__builtin__.print = print
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
if not dist.is_available():
|
|
return False
|
|
if not dist.is_initialized():
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_world_size():
|
|
if not is_dist_avail_and_initialized():
|
|
return 1
|
|
return dist.get_world_size()
|
|
|
|
|
|
def get_rank():
|
|
if not is_dist_avail_and_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
|
|
def init_distributed_mode(args):
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
args.rank = int(os.environ["RANK"])
|
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
|
elif "SLURM_PROCID" in os.environ:
|
|
args.rank = int(os.environ["SLURM_PROCID"])
|
|
args.gpu = args.rank % torch.cuda.device_count()
|
|
else:
|
|
print("Not using distributed mode")
|
|
args.distributed = False
|
|
return
|
|
|
|
args.distributed = True
|
|
|
|
torch.cuda.set_device(args.gpu)
|
|
args.dist_backend = "nccl"
|
|
print(
|
|
"| distributed init (rank {}, world {}): {}".format(
|
|
args.rank, args.world_size, args.dist_url
|
|
),
|
|
flush=True,
|
|
)
|
|
torch.distributed.init_process_group(
|
|
backend=args.dist_backend,
|
|
init_method=args.dist_url,
|
|
world_size=args.world_size,
|
|
rank=args.rank,
|
|
timeout=datetime.timedelta(
|
|
days=365
|
|
),
|
|
)
|
|
torch.distributed.barrier()
|
|
setup_for_distributed(args.rank == 0)
|
|
|
|
|
|
def get_dist_info():
|
|
if torch.__version__ < "1.0":
|
|
initialized = dist._initialized
|
|
else:
|
|
initialized = dist.is_initialized()
|
|
if initialized:
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
return rank, world_size
|
|
|
|
|
|
def main_process(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
rank, _ = get_dist_info()
|
|
if rank == 0:
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def download_cached_file(url, check_hash=True, progress=False):
|
|
"""
|
|
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
|
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
|
"""
|
|
|
|
def get_cached_file_path():
|
|
|
|
parts = torch.hub.urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
|
|
|
return cached_file
|
|
|
|
if is_main_process():
|
|
timm_hub.download_cached_file(url, check_hash, progress)
|
|
|
|
if is_dist_avail_and_initialized():
|
|
dist.barrier()
|
|
|
|
return get_cached_file_path()
|
|
|