File size: 2,456 Bytes
d382778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import torch
from . import training_stats
import socket
#----------------------------------------------------------------------------
def is_port_in_use(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
except OSError:
return True
else:
return False
def init():
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ:
for port in range(29500, 29500 + 1000):
if not is_port_in_use(port):
os.environ['MASTER_PORT'] = str(port)
break
if 'RANK' not in os.environ:
os.environ['RANK'] = '0'
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = '1'
backend = 'gloo' if os.name == 'nt' else 'nccl'
torch.distributed.init_process_group(backend=backend, init_method='env://')
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
sync_device = torch.device('cuda') if get_world_size() > 1 else None
training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
#----------------------------------------------------------------------------
def get_rank():
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
#----------------------------------------------------------------------------
def get_world_size():
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
#----------------------------------------------------------------------------
def should_stop():
return False
#----------------------------------------------------------------------------
def update_progress(cur, total):
_ = cur, total
#----------------------------------------------------------------------------
def print0(*args, **kwargs):
if get_rank() == 0:
print(*args, **kwargs)
#----------------------------------------------------------------------------
|