MolCRAFT / core /utils /misc.py
Atomu2014's picture
demo init commit
1f0c7b9
import logging
import os
import random
import time
import numpy as np
import torch
import yaml
from easydict import EasyDict
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
def load_config(path):
with open(path, 'r') as f:
return EasyDict(yaml.safe_load(f))
def get_logger(name, log_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_dir is not None:
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_new_log_dir(root='./logs', prefix='', tag=''):
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
if prefix != '':
fn = prefix + '_' + fn
if tag != '':
fn = fn + '_' + tag
log_dir = os.path.join(root, fn)
os.makedirs(log_dir)
return log_dir
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def log_hyperparams(writer, args):
from torch.utils.tensorboard.summary import hparams
vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
exp, ssi, sei = hparams(vars_args, {})
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
def int_tuple(argstr):
return tuple(map(int, argstr.split(',')))
def str_tuple(argstr):
return tuple(argstr.split(','))
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)