wyysf's picture
i
e814557
raw
history blame
No virus
4.77 kB
import gc
import os
import re
import torch
import torch.distributed as dist
from packaging import version
from craftsman.utils.config import config_to_primitive
from craftsman.utils.typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_world_size():
world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE")
for key in world_size_keys:
world_size = os.environ.get(key)
if world_size is not None:
return int(world_size)
return 1
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
def C(value, epoch: int, global_step: int) -> float:
if isinstance(value, int) or isinstance(value, float):
pass
else:
value = config_to_primitive(value)
if not isinstance(value, list):
raise TypeError("Scalar specification only supports list, got", type(value))
if len(value) == 3:
value = [0] + value
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
elif isinstance(end_step, float):
current_step = epoch
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def cleanup():
gc.collect()
torch.cuda.empty_cache()
tcnn.free_temporary_memory()
def finish_with_cleanup(func: Callable):
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
cleanup()
return out
return wrapper
def _distributed_available():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def barrier():
if not _distributed_available():
return
else:
torch.distributed.barrier()
def broadcast(tensor, src=0):
if not _distributed_available():
return tensor
else:
torch.distributed.broadcast(tensor, src=src)
return tensor
def enable_gradient(model, enabled: bool = True) -> None:
for param in model.parameters():
param.requires_grad_(enabled)
def all_gather_batch(tensors):
"""
Performs all_gather operation on the provided tensors.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
if isinstance(tensors, list):
return tensors
return tensors
if not isinstance(tensors, list):
is_list = False
tensors = [tensors]
else:
is_list = True
output_tensor = []
tensor_list = []
for tensor in tensors:
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(
tensor_all,
tensor,
async_op=False # performance opt
)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
if not is_list:
return output_tensor[0]
return output_tensor