import gc import os import re import torch import torch.distributed as dist from packaging import version from craftsman.utils.config import config_to_primitive from craftsman.utils.typing import * def parse_version(ver: str): return version.parse(ver) def get_rank(): # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: return int(rank) return 0 def get_world_size(): world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") for key in world_size_keys: world_size = os.environ.get(key) if world_size is not None: return int(world_size) return 1 def get_device(): return torch.device(f"cuda:{get_rank()}") def load_module_weights( path, module_name=None, ignore_modules=None, map_location=None ) -> Tuple[dict, int, int]: if module_name is not None and ignore_modules is not None: raise ValueError("module_name and ignore_modules cannot be both set") if map_location is None: map_location = get_device() ckpt = torch.load(path, map_location=map_location) state_dict = ckpt["state_dict"] state_dict_to_load = state_dict if ignore_modules is not None: state_dict_to_load = {} for k, v in state_dict.items(): ignore = any( [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] ) if ignore: continue state_dict_to_load[k] = v if module_name is not None: state_dict_to_load = {} for k, v in state_dict.items(): m = re.match(rf"^{module_name}\.(.*)$", k) if m is None: continue state_dict_to_load[m.group(1)] = v return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] def C(value, epoch: int, global_step: int) -> float: if isinstance(value, int) or isinstance(value, float): pass else: value = config_to_primitive(value) if not isinstance(value, list): raise TypeError("Scalar specification only supports list, got", type(value)) if len(value) == 3: value = [0] + value assert len(value) == 4 start_step, start_value, end_value, end_step = value if isinstance(end_step, int): current_step = global_step value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) elif isinstance(end_step, float): current_step = epoch value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) return value def cleanup(): gc.collect() torch.cuda.empty_cache() tcnn.free_temporary_memory() def finish_with_cleanup(func: Callable): def wrapper(*args, **kwargs): out = func(*args, **kwargs) cleanup() return out return wrapper def _distributed_available(): return torch.distributed.is_available() and torch.distributed.is_initialized() def barrier(): if not _distributed_available(): return else: torch.distributed.barrier() def broadcast(tensor, src=0): if not _distributed_available(): return tensor else: torch.distributed.broadcast(tensor, src=src) return tensor def enable_gradient(model, enabled: bool = True) -> None: for param in model.parameters(): param.requires_grad_(enabled) def all_gather_batch(tensors): """ Performs all_gather operation on the provided tensors. """ # Queue the gathered tensors world_size = get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: if isinstance(tensors, list): return tensors return tensors if not isinstance(tensors, list): is_list = False tensors = [tensors] else: is_list = True output_tensor = [] tensor_list = [] for tensor in tensors: tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] dist.all_gather( tensor_all, tensor, async_op=False # performance opt ) tensor_list.append(tensor_all) for tensor_all in tensor_list: output_tensor.append(torch.cat(tensor_all, dim=0)) if not is_list: return output_tensor[0] return output_tensor