Spaces:
Runtime error
Runtime error
import random | |
import shutil | |
import time | |
import torch | |
# from torch.utils.tensorboard import SummaryWriter | |
from utils.visualization import * | |
from loguru import logger | |
# def get_tensorboard_logger_from_args(tensorboard_dir, reset_version=False): | |
# if reset_version: | |
# shutil.rmtree(os.path.join(tensorboard_dir)) | |
# return SummaryWriter(log_dir=tensorboard_dir) | |
def get_optimizer_from_args(model, lr, weight_decay, **kwargs) -> torch.optim.Optimizer: | |
return torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, | |
weight_decay=weight_decay) | |
def get_lr_schedule(optimizer): | |
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) | |
def setup_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def get_dir_from_args(root_dir, class_name, **kwargs): | |
exp_name = f"{kwargs['dataset']}-k-{kwargs['k_shot']}" | |
csv_dir = os.path.join(root_dir, 'csv') | |
csv_path = os.path.join(csv_dir, f"{exp_name}-indx-{kwargs['experiment_indx']}.csv") | |
model_dir = os.path.join(root_dir, exp_name, 'models') | |
img_dir = os.path.join(root_dir, exp_name, 'imgs') | |
logger_dir = os.path.join(root_dir, exp_name, 'logger', class_name) | |
log_file_name = os.path.join(logger_dir, | |
f'log_{time.strftime("%Y-%m-%d-%H-%I-%S", time.localtime(time.time()))}.log') | |
model_name = f'{class_name}' | |
os.makedirs(model_dir, exist_ok=True) | |
os.makedirs(img_dir, exist_ok=True) | |
os.makedirs(logger_dir, exist_ok=True) | |
os.makedirs(csv_dir, exist_ok=True) | |
logger.start(log_file_name) | |
logger.info(f"===> Root dir for this experiment: {logger_dir}") | |
return model_dir, img_dir, logger_dir, model_name, csv_path | |