Spaces:
Runtime error
Runtime error
import os | |
import time | |
import random | |
import logging | |
from typing import OrderedDict | |
import torch | |
import torch.linalg | |
import numpy as np | |
import yaml | |
from easydict import EasyDict | |
from glob import glob | |
class BlackHole(object): | |
def __setattr__(self, name, value): | |
pass | |
def __call__(self, *args, **kwargs): | |
return self | |
def __getattr__(self, name): | |
return self | |
class Counter(object): | |
def __init__(self, start=0): | |
super().__init__() | |
self.now = start | |
def step(self, delta=1): | |
prev = self.now | |
self.now += delta | |
return prev | |
def get_logger(name, log_dir=None): | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.DEBUG) | |
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') | |
stream_handler = logging.StreamHandler() | |
stream_handler.setLevel(logging.DEBUG) | |
stream_handler.setFormatter(formatter) | |
logger.addHandler(stream_handler) | |
if log_dir is not None: | |
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) | |
file_handler.setLevel(logging.DEBUG) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
return logger | |
def get_new_log_dir(root='./logs', prefix='', tag=''): | |
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) | |
if prefix != '': | |
fn = prefix + '_' + fn | |
if tag != '': | |
fn = fn + '_' + tag | |
log_dir = os.path.join(root, fn) | |
os.makedirs(log_dir) | |
return log_dir | |
def seed_all(seed): | |
torch.backends.cudnn.deterministic = True | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def inf_iterator(iterable): | |
iterator = iterable.__iter__() | |
while True: | |
try: | |
yield iterator.__next__() | |
except StopIteration: | |
iterator = iterable.__iter__() | |
def log_hyperparams(writer, args): | |
from torch.utils.tensorboard.summary import hparams | |
vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} | |
exp, ssi, sei = hparams(vars_args, {}) | |
writer.file_writer.add_summary(exp) | |
writer.file_writer.add_summary(ssi) | |
writer.file_writer.add_summary(sei) | |
def int_tuple(argstr): | |
return tuple(map(int, argstr.split(','))) | |
def str_tuple(argstr): | |
return tuple(argstr.split(',')) | |
def get_checkpoint_path(folder, it=None): | |
if it is not None: | |
return os.path.join(folder, '%d.pt' % it), it | |
all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) | |
all_iters.sort() | |
return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] | |
def load_config(config_path): | |
with open(config_path, 'r') as f: | |
config = EasyDict(yaml.safe_load(f)) | |
config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] | |
return config, config_name | |
def extract_weights(weights: OrderedDict, prefix): | |
extracted = OrderedDict() | |
for k, v in weights.items(): | |
if k.startswith(prefix): | |
extracted.update({ | |
k[len(prefix):]: v | |
}) | |
return extracted | |
def current_milli_time(): | |
return round(time.time() * 1000) | |