# python3.7 """Misc utility functions.""" import os import sys import subprocess from importlib import import_module import argparse from easydict import EasyDict import torch import torch.distributed as dist import torch.multiprocessing as mp __all__ = [ 'init_dist', 'bool_parser', 'DictAction', 'parse_config', 'update_config' ] def init_dist(launcher, backend='nccl', **kwargs): """Initializes distributed environment.""" if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) elif launcher == 'slurm': proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) addr = subprocess.getoutput( f'scontrol show hostname {node_list} | head -n1') port = os.environ.get('PORT', 29500) os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) dist.init_process_group(backend=backend) else: raise NotImplementedError(f'Not implemented launcher type: ' f'`{launcher}`!') def bool_parser(arg): """Parses an argument to boolean.""" if isinstance(arg, bool): return arg if arg.lower() in ['1', 'true', 't', 'yes', 'y']: return True if arg.lower() in ['0', 'false', 'f', 'no', 'n']: return False raise argparse.ArgumentTypeError(f'`{arg}` cannot be converted to boolean!') class DictAction(argparse.Action): """Argparse action to split an argument into key-value. NOTE: This class is borrowed from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py """ @staticmethod def _parse_int_float_bool(val): try: return int(val) except ValueError: pass try: return float(val) except ValueError: pass if val.lower() in ['true', 'false']: return val.lower() == 'true' return val def __call__(self, parser, namespace, values, option_string=None): options = {} for kv in values: key, val = kv.split('=', maxsplit=1) val = [self._parse_int_float_bool(v) for v in val.split(',')] if len(val) == 1: val = val[0] options[key] = val setattr(namespace, self.dest, options) def parse_config(config_file): """Parses configuration from python file.""" assert os.path.isfile(config_file) directory = os.path.dirname(config_file) filename = os.path.basename(config_file) module_name, extension = os.path.splitext(filename) assert extension == '.py' sys.path.insert(0, directory) module = import_module(module_name) sys.path.pop(0) config = EasyDict() for key, value in module.__dict__.items(): if key.startswith('__'): continue config[key] = value del sys.modules[module_name] return config def update_config(config, new_config): """Updates configuration in a hierarchical level. For key-value pair {'a.b.c.d': v} in `new_config`, the `config` will be updated by config['a']['b']['c']['d'] = v """ if new_config is None: return config assert isinstance(config, dict) assert isinstance(new_config, dict) for key, val in new_config.items(): hierarchical_keys = key.split('.') temp = config for sub_key in hierarchical_keys[:-1]: temp = temp[sub_key] temp[hierarchical_keys[-1]] = val return config