# -*- coding: utf-8 -*- import importlib from omegaconf import OmegaConf, DictConfig, ListConfig import torch import torch.distributed as dist from typing import Union def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]: config_file = OmegaConf.load(config_file) if 'base_config' in config_file.keys(): if config_file['base_config'] == "default_base": base_config = OmegaConf.create() # base_config = get_default_config() elif config_file['base_config'].endswith(".yaml"): base_config = get_config_from_file(config_file['base_config']) else: raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.") config_file = {key: value for key, value in config_file if key != "base_config"} return OmegaConf.merge(base_config, config_file) return config_file def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def get_obj_from_config(config): if "target" not in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"]) def instantiate_from_config(config, **kwargs): if "target" not in config: raise KeyError("Expected key `target` to instantiate.") cls = get_obj_from_str(config["target"]) params = config.get("params", dict()) # params.update(kwargs) # instance = cls(**params) kwargs.update(params) instance = cls(**kwargs) return instance def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def all_gather_batch(tensors): """ Performs all_gather operation on the provided tensors. """ # Queue the gathered tensors world_size = get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: return tensors tensor_list = [] output_tensor = [] for tensor in tensors: tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] dist.all_gather( tensor_all, tensor, async_op=False # performance opt ) tensor_list.append(tensor_all) for tensor_all in tensor_list: output_tensor.append(torch.cat(tensor_all, dim=0)) return output_tensor