SivaResearch's picture
demo
b6d5990
raw
history blame contribute delete
No virus
1.46 kB
import contextlib
import numpy as np
import random
import shutil
import os
import torch
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"):
filename = os.path.join(checkpoint_path, filename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt"))
def load_checkpoint(model, path):
best_checkpoint = torch.load(path)
model.load_state_dict(best_checkpoint["state_dict"])
def log_metrics(set_name, metrics, logger):
logger.info(
"{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format(
set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"]
)
)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
if len(addl_seeds) > 0:
seed = int(hash((seed, *addl_seeds)) % 1e6)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)