Segment-Any-Anomaly / utils /training_utils.py
cyk
SAA+
32faf2b
import random
import shutil
import time
import torch
# from torch.utils.tensorboard import SummaryWriter
from utils.visualization import *
from loguru import logger
# def get_tensorboard_logger_from_args(tensorboard_dir, reset_version=False):
# if reset_version:
# shutil.rmtree(os.path.join(tensorboard_dir))
# return SummaryWriter(log_dir=tensorboard_dir)
def get_optimizer_from_args(model, lr, weight_decay, **kwargs) -> torch.optim.Optimizer:
return torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
weight_decay=weight_decay)
def get_lr_schedule(optimizer):
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def get_dir_from_args(root_dir, class_name, **kwargs):
exp_name = f"{kwargs['dataset']}-k-{kwargs['k_shot']}"
csv_dir = os.path.join(root_dir, 'csv')
csv_path = os.path.join(csv_dir, f"{exp_name}-indx-{kwargs['experiment_indx']}.csv")
model_dir = os.path.join(root_dir, exp_name, 'models')
img_dir = os.path.join(root_dir, exp_name, 'imgs')
logger_dir = os.path.join(root_dir, exp_name, 'logger', class_name)
log_file_name = os.path.join(logger_dir,
f'log_{time.strftime("%Y-%m-%d-%H-%I-%S", time.localtime(time.time()))}.log')
model_name = f'{class_name}'
os.makedirs(model_dir, exist_ok=True)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(logger_dir, exist_ok=True)
os.makedirs(csv_dir, exist_ok=True)
logger.start(log_file_name)
logger.info(f"===> Root dir for this experiment: {logger_dir}")
return model_dir, img_dir, logger_dir, model_name, csv_path