|
import contextlib |
|
import numpy as np |
|
import random |
|
import shutil |
|
import os |
|
|
|
import torch |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"): |
|
filename = os.path.join(checkpoint_path, filename) |
|
torch.save(state, filename) |
|
if is_best: |
|
shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt")) |
|
|
|
|
|
def load_checkpoint(model, path): |
|
best_checkpoint = torch.load(path) |
|
model.load_state_dict(best_checkpoint["state_dict"]) |
|
|
|
def log_metrics(set_name, metrics, logger): |
|
logger.info( |
|
"{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format( |
|
set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"] |
|
) |
|
) |
|
|
|
|
|
@contextlib.contextmanager |
|
def numpy_seed(seed, *addl_seeds): |
|
"""Context manager which seeds the NumPy PRNG with the specified seed and |
|
restores the state afterward""" |
|
if seed is None: |
|
yield |
|
return |
|
if len(addl_seeds) > 0: |
|
seed = int(hash((seed, *addl_seeds)) % 1e6) |
|
state = np.random.get_state() |
|
np.random.seed(seed) |
|
try: |
|
yield |
|
finally: |
|
np.random.set_state(state) |
|
|