import bisect import functools import logging import numbers import os import signal import sys import traceback import warnings import torch from pytorch_lightning import seed_everything LOGGER = logging.getLogger(__name__) def check_and_warn_input_range(tensor, min_value, max_value, name): actual_min = tensor.min() actual_max = tensor.max() if actual_min < min_value or actual_max > max_value: warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") def sum_dict_with_prefix(target, cur_dict, prefix, default=0): for k, v in cur_dict.items(): target_key = prefix + k target[target_key] = target.get(target_key, default) + v def average_dicts(dict_list): result = {} norm = 1e-3 for dct in dict_list: sum_dict_with_prefix(result, dct, '') norm += 1 for k in list(result): result[k] /= norm return result def add_prefix_to_keys(dct, prefix): return {prefix + k: v for k, v in dct.items()} def set_requires_grad(module, value): for param in module.parameters(): param.requires_grad = value def flatten_dict(dct): result = {} for k, v in dct.items(): if isinstance(k, tuple): k = '_'.join(k) if isinstance(v, dict): for sub_k, sub_v in flatten_dict(v).items(): result[f'{k}_{sub_k}'] = sub_v else: result[k] = v return result class LinearRamp: def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): self.start_value = start_value self.end_value = end_value self.start_iter = start_iter self.end_iter = end_iter def __call__(self, i): if i < self.start_iter: return self.start_value if i >= self.end_iter: return self.end_value part = (i - self.start_iter) / (self.end_iter - self.start_iter) return self.start_value * (1 - part) + self.end_value * part class LadderRamp: def __init__(self, start_iters, values): self.start_iters = start_iters self.values = values assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) def __call__(self, i): segment_i = bisect.bisect_right(self.start_iters, i) return self.values[segment_i] def get_ramp(kind='ladder', **kwargs): if kind == 'linear': return LinearRamp(**kwargs) if kind == 'ladder': return LadderRamp(**kwargs) raise ValueError(f'Unexpected ramp kind: {kind}') def print_traceback_handler(sig, frame): LOGGER.warning(f'Received signal {sig}') bt = ''.join(traceback.format_stack()) LOGGER.warning(f'Requested stack trace:\n{bt}') def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): LOGGER.warning(f'Setting signal {sig} handler {handler}') signal.signal(sig, handler) def handle_deterministic_config(config): seed = dict(config).get('seed', None) if seed is None: return False seed_everything(seed) return True def get_shape(t): if torch.is_tensor(t): return tuple(t.shape) elif isinstance(t, dict): return {n: get_shape(q) for n, q in t.items()} elif isinstance(t, (list, tuple)): return [get_shape(q) for q in t] elif isinstance(t, numbers.Number): return type(t) else: raise ValueError('unexpected type {}'.format(type(t))) def get_has_ddp_rank(): master_port = os.environ.get('MASTER_PORT', None) node_rank = os.environ.get('NODE_RANK', None) local_rank = os.environ.get('LOCAL_RANK', None) world_size = os.environ.get('WORLD_SIZE', None) has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None return has_rank def handle_ddp_subprocess(): def main_decorator(main_func): @functools.wraps(main_func) def new_main(*args, **kwargs): # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) has_parent = parent_cwd is not None has_rank = get_has_ddp_rank() assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' if has_parent: # we are in the worker sys.argv.extend([ f'hydra.run.dir={parent_cwd}', # 'hydra/hydra_logging=disabled', # 'hydra/job_logging=disabled' ]) # do nothing if this is a top-level process # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization main_func(*args, **kwargs) return new_main return main_decorator def handle_ddp_parent_process(): parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) has_parent = parent_cwd is not None has_rank = get_has_ddp_rank() assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' if parent_cwd is None: os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() return has_parent