import os import time import random import logging from typing import OrderedDict import torch import torch.linalg import numpy as np import yaml from easydict import EasyDict from glob import glob class BlackHole(object): def __setattr__(self, name, value): pass def __call__(self, *args, **kwargs): return self def __getattr__(self, name): return self class Counter(object): def __init__(self, start=0): super().__init__() self.now = start def step(self, delta=1): prev = self.now self.now += delta return prev def get_logger(name, log_dir=None): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.DEBUG) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) if log_dir is not None: file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def get_new_log_dir(root='./logs', prefix='', tag=''): fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) if prefix != '': fn = prefix + '_' + fn if tag != '': fn = fn + '_' + tag log_dir = os.path.join(root, fn) os.makedirs(log_dir) return log_dir def seed_all(seed): torch.backends.cudnn.deterministic = True torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def inf_iterator(iterable): iterator = iterable.__iter__() while True: try: yield iterator.__next__() except StopIteration: iterator = iterable.__iter__() def log_hyperparams(writer, args): from torch.utils.tensorboard.summary import hparams vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} exp, ssi, sei = hparams(vars_args, {}) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei) def int_tuple(argstr): return tuple(map(int, argstr.split(','))) def str_tuple(argstr): return tuple(argstr.split(',')) def get_checkpoint_path(folder, it=None): if it is not None: return os.path.join(folder, '%d.pt' % it), it all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) all_iters.sort() return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] def load_config(config_path): with open(config_path, 'r') as f: config = EasyDict(yaml.safe_load(f)) config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] return config, config_name def extract_weights(weights: OrderedDict, prefix): extracted = OrderedDict() for k, v in weights.items(): if k.startswith(prefix): extracted.update({ k[len(prefix):]: v }) return extracted def current_milli_time(): return round(time.time() * 1000)