|
import logging |
|
import os |
|
import random |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import yaml |
|
from easydict import EasyDict |
|
|
|
|
|
class BlackHole(object): |
|
def __setattr__(self, name, value): |
|
pass |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self |
|
|
|
def __getattr__(self, name): |
|
return self |
|
|
|
|
|
def load_config(path): |
|
with open(path, 'r') as f: |
|
return EasyDict(yaml.safe_load(f)) |
|
|
|
|
|
def get_logger(name, log_dir=None): |
|
logger = logging.getLogger(name) |
|
logger.setLevel(logging.DEBUG) |
|
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') |
|
|
|
stream_handler = logging.StreamHandler() |
|
stream_handler.setLevel(logging.DEBUG) |
|
stream_handler.setFormatter(formatter) |
|
logger.addHandler(stream_handler) |
|
|
|
if log_dir is not None: |
|
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) |
|
file_handler.setLevel(logging.DEBUG) |
|
file_handler.setFormatter(formatter) |
|
logger.addHandler(file_handler) |
|
|
|
return logger |
|
|
|
|
|
def get_new_log_dir(root='./logs', prefix='', tag=''): |
|
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) |
|
if prefix != '': |
|
fn = prefix + '_' + fn |
|
if tag != '': |
|
fn = fn + '_' + tag |
|
log_dir = os.path.join(root, fn) |
|
os.makedirs(log_dir) |
|
return log_dir |
|
|
|
|
|
def seed_all(seed): |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
def log_hyperparams(writer, args): |
|
from torch.utils.tensorboard.summary import hparams |
|
vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} |
|
exp, ssi, sei = hparams(vars_args, {}) |
|
writer.file_writer.add_summary(exp) |
|
writer.file_writer.add_summary(ssi) |
|
writer.file_writer.add_summary(sei) |
|
|
|
|
|
def int_tuple(argstr): |
|
return tuple(map(int, argstr.split(','))) |
|
|
|
|
|
def str_tuple(argstr): |
|
return tuple(argstr.split(',')) |
|
|
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|