Spaces:
Runtime error
Runtime error
File size: 1,393 Bytes
1f418ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import time
import torch
from mmengine.dist import (broadcast, get_dist_info, init_dist, is_distributed, get_local_rank)
from mmengine.utils.dl_utils import (set_multi_processing)
def setup_env(env_cfg, distributed, launcher):
"""Setup environment.
An example of ``env_cfg``::
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl', timeout=1800),
resource_limit=4096
)
Args:
env_cfg (dict): Config for setting environment.
"""
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
mp_cfg: dict = env_cfg.get('mp_cfg', {})
set_multi_processing(**mp_cfg, distributed=distributed)
# init distributed env first, since logger depends on the dist info.
if distributed and not is_distributed():
dist_cfg: dict = env_cfg.get('dist_cfg', {})
init_dist(launcher, **dist_cfg)
_rank, _world_size = get_dist_info()
# _local_rank = get_local_rank()
timestamp = torch.tensor(time.time(), dtype=torch.float64)
# broadcast timestamp from 0 process to other processes
broadcast(timestamp)
_timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(timestamp.item()))
return _rank, _world_size, _timestamp
|