Versatile-Diffusion / lib /log_service.py
Wmcs91's picture
Duplicate from shi-labs/Versatile-Diffusion
fb53ec8
import timeit
import numpy as np
import os
import os.path as osp
import shutil
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
from .cfg_holder import cfg_unique_holder as cfguh
from . import sync
print_console_local_rank0_only = True
def print_log(*console_info):
local_rank = sync.get_rank('local')
if print_console_local_rank0_only and (local_rank!=0):
return
console_info = [str(i) for i in console_info]
console_info = ' '.join(console_info)
print(console_info)
if local_rank!=0:
return
log_file = None
try:
log_file = cfguh().cfg.train.log_file
except:
try:
log_file = cfguh().cfg.eval.log_file
except:
return
if log_file is not None:
with open(log_file, 'a') as f:
f.write(console_info + '\n')
class distributed_log_manager(object):
def __init__(self):
self.sum = {}
self.cnt = {}
self.time_check = timeit.default_timer()
cfgt = cfguh().cfg.train
use_tensorboard = getattr(cfgt, 'log_tensorboard', False)
self.ddp = sync.is_ddp()
self.rank = sync.get_rank('local')
self.world_size = sync.get_world_size('local')
self.tb = None
if use_tensorboard and (self.rank==0):
import tensorboardX
monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
def accumulate(self, n, **data):
if n < 0:
raise ValueError
for itemn, di in data.items():
if itemn in self.sum:
self.sum[itemn] += di * n
self.cnt[itemn] += n
else:
self.sum[itemn] = di * n
self.cnt[itemn] = n
def get_mean_value_dict(self):
value_gather = [
self.sum[itemn]/self.cnt[itemn] \
for itemn in sorted(self.sum.keys()) ]
value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank)
if self.ddp:
dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
value_gather_tensor /= self.world_size
mean = {}
for idx, itemn in enumerate(sorted(self.sum.keys())):
mean[itemn] = value_gather_tensor[idx].item()
return mean
def tensorboard_log(self, step, data, mode='train', **extra):
if self.tb is None:
return
if mode == 'train':
self.tb.add_scalar('other/epochn', extra['epochn'], step)
if 'lr' in extra:
self.tb.add_scalar('other/lr', extra['lr'], step)
for itemn, di in data.items():
if itemn.find('loss') == 0:
self.tb.add_scalar('loss/'+itemn, di, step)
elif itemn == 'Loss':
self.tb.add_scalar('Loss', di, step)
else:
self.tb.add_scalar('other/'+itemn, di, step)
elif mode == 'eval':
if isinstance(data, dict):
for itemn, di in data.items():
self.tb.add_scalar('eval/'+itemn, di, step)
else:
self.tb.add_scalar('eval', data, step)
return
def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
console_info = [
'Iter:{}'.format(itern),
'Epoch:{}'.format(epochn),
'Sample:{}'.format(samplen),]
if lr is not None:
console_info += ['LR:{:.4E}'.format(lr)]
mean = self.get_mean_value_dict()
tbstep = itern if tbstep is None else tbstep
self.tensorboard_log(
tbstep, mean, mode='train',
itern=itern, epochn=epochn, lr=lr)
loss = mean.pop('Loss')
mean_info = ['Loss:{:.4f}'.format(loss)] + [
'{}:{:.4f}'.format(itemn, mean[itemn]) \
for itemn in sorted(mean.keys()) \
if itemn.find('loss') == 0
]
console_info += mean_info
console_info.append('Time:{:.2f}s'.format(
timeit.default_timer() - self.time_check))
return ' , '.join(console_info)
def clear(self):
self.sum = {}
self.cnt = {}
self.time_check = timeit.default_timer()
def tensorboard_close(self):
if self.tb is not None:
self.tb.close()
# ----- also include some small utils -----
def torch_to_numpy(*argv):
if len(argv) > 1:
data = list(argv)
else:
data = argv[0]
if isinstance(data, torch.Tensor):
return data.to('cpu').detach().numpy()
elif isinstance(data, (list, tuple)):
out = []
for di in data:
out.append(torch_to_numpy(di))
return out
elif isinstance(data, dict):
out = {}
for ni, di in data.items():
out[ni] = torch_to_numpy(di)
return out
else:
return data