File size: 1,393 Bytes
1f418ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import time
import torch

from mmengine.dist import (broadcast, get_dist_info, init_dist, is_distributed, get_local_rank)
from mmengine.utils.dl_utils import (set_multi_processing)
def setup_env(env_cfg, distributed, launcher):
    """Setup environment.

    An example of ``env_cfg``::

        env_cfg = dict(
            cudnn_benchmark=True,
            mp_cfg=dict(
                mp_start_method='fork',
                opencv_num_threads=0
            ),
            dist_cfg=dict(backend='nccl', timeout=1800),
            resource_limit=4096
        )

    Args:
        env_cfg (dict): Config for setting environment.
    """
    if env_cfg.get('cudnn_benchmark'):
        torch.backends.cudnn.benchmark = True

    mp_cfg: dict = env_cfg.get('mp_cfg', {})
    set_multi_processing(**mp_cfg, distributed=distributed)

    # init distributed env first, since logger depends on the dist info.
    if distributed and not is_distributed():
        dist_cfg: dict = env_cfg.get('dist_cfg', {})
        init_dist(launcher, **dist_cfg)

    _rank, _world_size = get_dist_info()
    # _local_rank = get_local_rank()

    timestamp = torch.tensor(time.time(), dtype=torch.float64)
    # broadcast timestamp from 0 process to other processes
    broadcast(timestamp)
    _timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(timestamp.item()))
    return _rank, _world_size, _timestamp