import contextlib import numpy as np import random import shutil import os import torch def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"): filename = os.path.join(checkpoint_path, filename) torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt")) def load_checkpoint(model, path): best_checkpoint = torch.load(path) model.load_state_dict(best_checkpoint["state_dict"]) def log_metrics(set_name, metrics, logger): logger.info( "{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format( set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"] ) ) @contextlib.contextmanager def numpy_seed(seed, *addl_seeds): """Context manager which seeds the NumPy PRNG with the specified seed and restores the state afterward""" if seed is None: yield return if len(addl_seeds) > 0: seed = int(hash((seed, *addl_seeds)) % 1e6) state = np.random.get_state() np.random.seed(seed) try: yield finally: np.random.set_state(state)