import os import random import numpy as np import torch from .log import getLogger # TODO: finish implementing this LOGGER = getLogger(__name__) def worker_init_function(worker_id): seed = torch.utils.data.get_worker_info().seed np_seed = seed if np_seed > 2**32 - 1: np_seed = seed % (2**32 - 1) - 526 + int(worker_id) np.random.seed(np_seed) torch.manual_seed(seed) random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) def get_randstate_magic_numbers(device=None): """Use these to check that randstate advances the same accross runs""" np_int = np.random.randint(0, int(1e6)) random_int = random.randint(0, int(1e6)) torch_cpu_int = torch.randint(int(1e6), (1,), device='cpu').item() if device is not None: torch_device_int = torch.randint(int(1e6), (1,), device=device).item() else: torch_device_int = None return (random_int, np_int, torch_cpu_int, torch_device_int) class PytorchRNGState(torch.nn.Module): """Class to save/restore PRNG states that masquarades as nn.Module for checkpointing""" __RANDOM_PRNG_STATE__ = '__random_prng_state__' __NUMPY_PRNG_STATE__ = '__numpy_prng_state__' __TORCH_PRNG_STATE__ = '__torch_prng_state__' __CUDA_PRNG_STATE__ = '__cuda_prng_state__' def __init__(self, seed=42): super(PytorchRNGState, self).__init__() self.register_buffer('initial_seed', torch.tensor(seed, dtype=torch.long), persistent=True) self.register_buffer('already_seeded', torch.tensor(False, dtype=torch.bool), persistent=True) @property def device(self): return self.initial_seed.device def seed_everything(self): if torch.all(self.already_seeded): # sticky for checkpointing; do only once return else: seed = int(self.initial_seed.item()) np.random.seed(seed) torch.manual_seed(seed) random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) self.already_seeded = torch.logical_not(self.already_seeded) # keep it as buffer i.e. tensor LOGGER.info(f'Seed set to {seed}') def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super(PytorchRNGState, self).state_dict(destination, prefix, keep_vars) state_dict[self.__RANDOM_PRNG_STATE__] = random.getstate() state_dict[self.__NUMPY_PRNG_STATE__] = np.random.get_state() state_dict[self.__TORCH_PRNG_STATE__] = torch.random.get_rng_state() if torch.cuda.is_available() and 'cuda' in str(self.device): cuda_state = torch.cuda.get_rng_state(self.device) state_dict[self.__CUDA_PRNG_STATE__] = cuda_state return state_dict def load_state_dict(self, state_dict, strict=True): random.setstate(state_dict.pop(self.__RANDOM_PRNG_STATE__)) np.random.set_state(state_dict.pop(self.__NUMPY_PRNG_STATE__)) torch.set_rng_state(state_dict.pop(self.__TORCH_PRNG_STATE__)) LOGGER.debug(f'Restored state to python process and ') if strict: if torch.cuda.is_available() and 'cuda' in str(self.device) and self.__CUDA_PRNG_STATE__ not in state_dict: raise RuntimeError(f'Error in restoring CUDA PRNG state: state missing') if self.__CUDA_PRNG_STATE__ in state_dict and (torch.cuda.is_available() or 'cuda' not in str(self.device)): raise RuntimeError(f'Error in restoring CUDA PRNG state: CUDA not available') if self.__CUDA_PRNG_STATE__ in state_dict and torch.cuda.is_available() and 'cuda' in str(self.device): torch.cuda.set_rng_state(state_dict.pop(self.__CUDA_PRNG_STATE__), self.device) return super(PytorchRNGState, self).load_state_dict(state_dict, strict)