elia / mask2former_utils /summary.py
yxchng
add files
a166479
raw
history blame
No virus
3.65 kB
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : summary.py
@Time : 2022/10/15 23:38:13
@Author : BQH
@Version : 1.0
@Contact : raogx.vip@hotmail.com
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc : 运行时日志文件
'''
# here put the import lib
import os
import sys
import torch
import logging
from datetime import datetime
# return a fake summarywriter if tensorbaordX is not installed
try:
from tensorboardX import SummaryWriter
except ImportError:
class SummaryWriter:
def __init__(self, log_dir=None, comment='', **kwargs):
print('\nunable to import tensorboardX, log will be recorded by pytorch!\n')
self.log_dir = log_dir if log_dir is not None else './logs'
os.makedirs('./logs', exist_ok=True)
self.logs = {'comment': comment}
return
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
if tag in self.logs:
self.logs[tag].append((scalar_value, global_step, walltime))
else:
self.logs[tag] = [(scalar_value, global_step, walltime)]
return
def close(self):
timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_')
torch.save(self.logs, os.path.join(self.log_dir, 'log_%s.pickle' % timestamp))
return
class EmptySummaryWriter:
def __init__(self, **kwargs):
pass
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
pass
def close(self):
pass
def create_summary(distributed_rank=0, **kwargs):
if distributed_rank > 0:
return EmptySummaryWriter(**kwargs)
else:
return SummaryWriter(**kwargs)
def create_logger(distributed_rank=0, save_dir=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
# don't log results for the non-master process
if distributed_rank > 0:
return logger
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
# formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
formatter = logging.Formatter("%(message)s [%(asctime)s]")
ch.setFormatter(formatter)
logger.addHandler(ch)
if save_dir is not None:
fh = logging.FileHandler(os.path.join(save_dir, filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
class Saver:
def __init__(self, distributed_rank, save_dir):
self.distributed_rank = distributed_rank
self.save_dir = save_dir
os.makedirs(self.save_dir, exist_ok=True)
return
def save(self, obj, save_name):
if self.distributed_rank == 0:
torch.save(obj, os.path.join(self.save_dir, save_name + '.t7'))
return 'checkpoint saved in %s !' % os.path.join(self.save_dir, save_name)
else:
return ''
def create_saver(distributed_rank, save_dir):
return Saver(distributed_rank, save_dir)
class DisablePrint:
def __init__(self, local_rank=0):
self.local_rank = local_rank
def __enter__(self):
if self.local_rank != 0:
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
else:
pass
def __exit__(self, exc_type, exc_val, exc_tb):
if self.local_rank != 0:
sys.stdout.close()
sys.stdout = self._original_stdout
else:
pass