Spaces:
Runtime error
Runtime error
File size: 3,258 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import time
import random
import logging
from typing import OrderedDict
import torch
import torch.linalg
import numpy as np
import yaml
from easydict import EasyDict
from glob import glob
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
class Counter(object):
def __init__(self, start=0):
super().__init__()
self.now = start
def step(self, delta=1):
prev = self.now
self.now += delta
return prev
def get_logger(name, log_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_dir is not None:
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_new_log_dir(root='./logs', prefix='', tag=''):
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
if prefix != '':
fn = prefix + '_' + fn
if tag != '':
fn = fn + '_' + tag
log_dir = os.path.join(root, fn)
os.makedirs(log_dir)
return log_dir
def seed_all(seed):
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def inf_iterator(iterable):
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()
def log_hyperparams(writer, args):
from torch.utils.tensorboard.summary import hparams
vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
exp, ssi, sei = hparams(vars_args, {})
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
def int_tuple(argstr):
return tuple(map(int, argstr.split(',')))
def str_tuple(argstr):
return tuple(argstr.split(','))
def get_checkpoint_path(folder, it=None):
if it is not None:
return os.path.join(folder, '%d.pt' % it), it
all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt'))))
all_iters.sort()
return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1]
def load_config(config_path):
with open(config_path, 'r') as f:
config = EasyDict(yaml.safe_load(f))
config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')]
return config, config_name
def extract_weights(weights: OrderedDict, prefix):
extracted = OrderedDict()
for k, v in weights.items():
if k.startswith(prefix):
extracted.update({
k[len(prefix):]: v
})
return extracted
def current_milli_time():
return round(time.time() * 1000)
|