Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
# cudnn.enabled = True | |
# cudnn.benchmark = True | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import os | |
import os.path as osp | |
import sys | |
import numpy as np | |
import pprint | |
import timeit | |
import time | |
import copy | |
import matplotlib.pyplot as plt | |
from .cfg_holder import cfg_unique_holder as cfguh | |
from .data_factory import \ | |
get_dataset, collate, \ | |
get_loader, \ | |
get_transform, \ | |
get_estimator, \ | |
get_formatter, \ | |
get_sampler | |
from .model_zoo import \ | |
get_model, get_optimizer, get_scheduler | |
from .log_service import print_log, distributed_log_manager | |
from .evaluator import get_evaluator | |
from . import sync | |
class train_stage(object): | |
""" | |
This is a template for a train stage, | |
(can be either train or test or anything) | |
Usually, it takes RANK | |
one dataloader, one model, one optimizer, one scheduler. | |
But it is not limited to these parameters. | |
""" | |
def __init__(self): | |
self.nested_eval_stage = None | |
self.rv_keep = None | |
def is_better(self, x): | |
return (self.rv_keep is None) or (x>self.rv_keep) | |
def set_model(self, net, mode): | |
if mode == 'train': | |
return net.train() | |
elif mode == 'eval': | |
return net.eval() | |
else: | |
raise ValueError | |
def __call__(self, | |
**paras): | |
cfg = cfguh().cfg | |
cfgt = cfg.train | |
logm = distributed_log_manager() | |
epochn, itern, samplen = 0, 0, 0 | |
step_type = cfgt.get('step_type', 'iter') | |
assert step_type in ['epoch', 'iter', 'sample'], \ | |
'Step type must be in [epoch, iter, sample]' | |
step_num = cfgt.get('step_num' , None) | |
gradacc_every = cfgt.get('gradacc_every', 1 ) | |
log_every = cfgt.get('log_every' , None) | |
ckpt_every = cfgt.get('ckpt_every' , None) | |
eval_start = cfgt.get('eval_start' , 0 ) | |
eval_every = cfgt.get('eval_every' , None) | |
if paras.get('resume_step', None) is not None: | |
resume_step = paras['resume_step'] | |
assert step_type == resume_step['type'] | |
epochn = resume_step['epochn'] | |
itern = resume_step['itern'] | |
samplen = resume_step['samplen'] | |
del paras['resume_step'] | |
trainloader = paras['trainloader'] | |
optimizer = paras['optimizer'] | |
scheduler = paras['scheduler'] | |
net = paras['net'] | |
GRANK, LRANK, NRANK = sync.get_rank('all') | |
GWSIZE, LWSIZE, NODES = sync.get_world_size('all') | |
weight_path = osp.join(cfgt.log_dir, 'weight') | |
if (GRANK==0) and (not osp.isdir(weight_path)): | |
os.makedirs(weight_path) | |
if (GRANK==0) and (cfgt.save_init_model): | |
self.save(net, is_init=True, step=0, optimizer=optimizer) | |
epoch_time = timeit.default_timer() | |
end_flag = False | |
net.train() | |
while True: | |
if step_type == 'epoch': | |
lr = scheduler[epochn] if scheduler is not None else None | |
for batch in trainloader: | |
# so first element of batch (usually image) can be [tensor] | |
if not isinstance(batch[0], list): | |
bs = batch[0].shape[0] | |
else: | |
bs = len(batch[0]) | |
if cfgt.skip_partial_batch and (bs != cfgt.batch_size_per_gpu): | |
continue | |
itern_next = itern + 1 | |
samplen_next = samplen + bs*GWSIZE | |
if step_type == 'iter': | |
lr = scheduler[itern//gradacc_every] if scheduler is not None else None | |
grad_update = itern%gradacc_every==(gradacc_every-1) | |
elif step_type == 'sample': | |
lr = scheduler[samplen] if scheduler is not None else None | |
# TODO: | |
# grad_update = samplen%gradacc_every==(gradacc_every-1) | |
# timeDebug = timeit.default_timer() | |
paras_new = self.main( | |
batch=batch, | |
lr=lr, | |
itern=itern, | |
epochn=epochn, | |
samplen=samplen, | |
isinit=False, | |
grad_update=grad_update, | |
**paras) | |
# print_log(timeit.default_timer() - timeDebug) | |
paras.update(paras_new) | |
logm.accumulate(bs, **paras['log_info']) | |
####### | |
# log # | |
####### | |
display_flag = False | |
if log_every is not None: | |
display_i = (itern//log_every) != (itern_next//log_every) | |
display_s = (samplen//log_every) != (samplen_next//log_every) | |
display_flag = (display_i and (step_type=='iter')) \ | |
or (display_s and (step_type=='sample')) | |
if display_flag: | |
tbstep = itern_next if step_type=='iter' else samplen_next | |
console_info = logm.train_summary( | |
itern_next, epochn, samplen_next, lr, tbstep=tbstep) | |
logm.clear() | |
print_log(console_info) | |
######## | |
# eval # | |
######## | |
eval_flag = False | |
if (self.nested_eval_stage is not None) and (eval_every is not None) and (NRANK == 0): | |
if step_type=='iter': | |
eval_flag = (itern//eval_every) != (itern_next//eval_every) | |
eval_flag = eval_flag and (itern_next>=eval_start) | |
eval_flag = eval_flag or itern==0 | |
if step_type=='sample': | |
eval_flag = (samplen//eval_every) != (samplen_next//eval_every) | |
eval_flag = eval_flag and (samplen_next>=eval_start) | |
eval_flag = eval_flag or samplen==0 | |
if eval_flag: | |
eval_cnt = itern_next if step_type=='iter' else samplen_next | |
net = self.set_model(net, 'eval') | |
rv = self.nested_eval_stage( | |
eval_cnt=eval_cnt, **paras) | |
rv = rv.get('eval_rv', None) | |
if rv is not None: | |
logm.tensorboard_log(eval_cnt, rv, mode='eval') | |
if self.is_better(rv): | |
self.rv_keep = rv | |
if GRANK==0: | |
step = {'epochn':epochn, 'itern':itern_next, | |
'samplen':samplen_next, 'type':step_type, } | |
self.save(net, is_best=True, step=step, optimizer=optimizer) | |
net = self.set_model(net, 'train') | |
######## | |
# ckpt # | |
######## | |
ckpt_flag = False | |
if (GRANK==0) and (ckpt_every is not None): | |
# not distributed | |
ckpt_i = (itern//ckpt_every) != (itern_next//ckpt_every) | |
ckpt_s = (samplen//ckpt_every) != (samplen_next//ckpt_every) | |
ckpt_flag = (ckpt_i and (step_type=='iter')) \ | |
or (ckpt_s and (step_type=='sample')) | |
if ckpt_flag: | |
if step_type == 'iter': | |
print_log('Checkpoint... {}'.format(itern_next)) | |
step = {'epochn':epochn, 'itern':itern_next, | |
'samplen':samplen_next, 'type':step_type, } | |
self.save(net, itern=itern_next, step=step, optimizer=optimizer) | |
else: | |
print_log('Checkpoint... {}'.format(samplen_next)) | |
step = {'epochn':epochn, 'itern':itern_next, | |
'samplen':samplen_next, 'type':step_type, } | |
self.save(net, samplen=samplen_next, step=step, optimizer=optimizer) | |
####### | |
# end # | |
####### | |
itern = itern_next | |
samplen = samplen_next | |
if step_type is not None: | |
end_flag = (itern>=step_num and (step_type=='iter')) \ | |
or (samplen>=step_num and (step_type=='sample')) | |
if end_flag: | |
break | |
# loop end | |
epochn += 1 | |
print_log('Epoch {} time:{:.2f}s.'.format( | |
epochn, timeit.default_timer()-epoch_time)) | |
epoch_time = timeit.default_timer() | |
if end_flag: | |
break | |
elif step_type != 'epoch': | |
# This is temporarily added to resolve the data issue | |
trainloader = self.trick_update_trainloader(trainloader) | |
continue | |
####### | |
# log # | |
####### | |
display_flag = False | |
if (log_every is not None) and (step_type=='epoch'): | |
display_flag = (epochn==1) or (epochn%log_every==0) | |
if display_flag: | |
console_info = logm.train_summary( | |
itern, epochn, samplen, lr, tbstep=epochn) | |
logm.clear() | |
print_log(console_info) | |
######## | |
# eval # | |
######## | |
eval_flag = False | |
if (self.nested_eval_stage is not None) and (eval_every is not None) \ | |
and (step_type=='epoch') and (NRANK==0): | |
eval_flag = (epochn%eval_every==0) and (itern_next>=eval_start) | |
eval_flag = (epochn==1) or eval_flag | |
if eval_flag: | |
net = self.set_model(net, 'eval') | |
rv = self.nested_eval_stage( | |
eval_cnt=epochn, | |
**paras)['eval_rv'] | |
if rv is not None: | |
logm.tensorboard_log(epochn, rv, mode='eval') | |
if self.is_better(rv): | |
self.rv_keep = rv | |
if (GRANK==0): | |
step = {'epochn':epochn, 'itern':itern, | |
'samplen':samplen, 'type':step_type, } | |
self.save(net, is_best=True, step=step, optimizer=optimizer) | |
net = self.set_model(net, 'train') | |
######## | |
# ckpt # | |
######## | |
ckpt_flag = False | |
if (ckpt_every is not None) and (GRANK==0) and (step_type=='epoch'): | |
# not distributed | |
ckpt_flag = epochn%ckpt_every==0 | |
if ckpt_flag: | |
print_log('Checkpoint... {}'.format(itern_next)) | |
step = {'epochn':epochn, 'itern':itern, | |
'samplen':samplen, 'type':step_type, } | |
self.save(net, epochn=epochn, step=step, optimizer=optimizer) | |
####### | |
# end # | |
####### | |
if (step_type=='epoch') and (epochn>=step_num): | |
break | |
# loop end | |
# This is temporarily added to resolve the data issue | |
trainloader = self.trick_update_trainloader(trainloader) | |
logm.tensorboard_close() | |
return {} | |
def main(self, **paras): | |
raise NotImplementedError | |
def trick_update_trainloader(self, trainloader): | |
return trainloader | |
def save_model(self, net, path_noext, **paras): | |
cfgt = cfguh().cfg.train | |
path = path_noext+'.pth' | |
if isinstance(net, (torch.nn.DataParallel, | |
torch.nn.parallel.DistributedDataParallel)): | |
netm = net.module | |
else: | |
netm = net | |
torch.save(netm.state_dict(), path) | |
print_log('Saving model file {0}'.format(path)) | |
def save(self, net, itern=None, epochn=None, samplen=None, | |
is_init=False, is_best=False, is_last=False, **paras): | |
exid = cfguh().cfg.env.experiment_id | |
cfgt = cfguh().cfg.train | |
cfgm = cfguh().cfg.model | |
if isinstance(net, (torch.nn.DataParallel, | |
torch.nn.parallel.DistributedDataParallel)): | |
netm = net.module | |
else: | |
netm = net | |
net_symbol = cfgm.symbol | |
check = sum([ | |
itern is not None, samplen is not None, epochn is not None, | |
is_init, is_best, is_last]) | |
assert check<2 | |
if itern is not None: | |
path_noexp = '{}_{}_iter_{}'.format(exid, net_symbol, itern) | |
elif samplen is not None: | |
path_noexp = '{}_{}_samplen_{}'.format(exid, net_symbol, samplen) | |
elif epochn is not None: | |
path_noexp = '{}_{}_epoch_{}'.format(exid, net_symbol, epochn) | |
elif is_init: | |
path_noexp = '{}_{}_init'.format(exid, net_symbol) | |
elif is_best: | |
path_noexp = '{}_{}_best'.format(exid, net_symbol) | |
elif is_last: | |
path_noexp = '{}_{}_last'.format(exid, net_symbol) | |
else: | |
path_noexp = '{}_{}_default'.format(exid, net_symbol) | |
path_noexp = osp.join(cfgt.log_dir, 'weight', path_noexp) | |
self.save_model(net, path_noexp, **paras) | |
class eval_stage(object): | |
def __init__(self): | |
self.evaluator = None | |
def create_dir(self, path): | |
local_rank = sync.get_rank('local') | |
if (not osp.isdir(path)) and (local_rank == 0): | |
os.makedirs(path) | |
sync.nodewise_sync().barrier() | |
def __call__(self, | |
evalloader, | |
net, | |
**paras): | |
cfgt = cfguh().cfg.eval | |
local_rank = sync.get_rank('local') | |
if self.evaluator is None: | |
evaluator = get_evaluator()(cfgt.evaluator) | |
self.evaluator = evaluator | |
else: | |
evaluator = self.evaluator | |
time_check = timeit.default_timer() | |
for idx, batch in enumerate(evalloader): | |
rv = self.main(batch, net) | |
evaluator.add_batch(**rv) | |
if cfgt.output_result: | |
try: | |
self.output_f(**rv, cnt=paras['eval_cnt']) | |
except: | |
self.output_f(**rv) | |
if idx%cfgt.log_display == cfgt.log_display-1: | |
print_log('processed.. {}, Time:{:.2f}s'.format( | |
idx+1, timeit.default_timer() - time_check)) | |
time_check = timeit.default_timer() | |
# break | |
evaluator.set_sample_n(len(evalloader.dataset)) | |
eval_rv = evaluator.compute() | |
if local_rank == 0: | |
evaluator.one_line_summary() | |
evaluator.save(cfgt.log_dir) | |
evaluator.clear_data() | |
return { | |
'eval_rv' : eval_rv | |
} | |
class exec_container(object): | |
""" | |
This is the base functor for all types of executions. | |
One execution can have multiple stages, | |
but are only allowed to use the same | |
config, network, dataloader. | |
Thus, in most of the cases, one exec_container is one | |
training/evaluation/demo... | |
If DPP is in use, this functor should be spawn. | |
""" | |
def __init__(self, | |
cfg, | |
**kwargs): | |
self.cfg = cfg | |
self.registered_stages = [] | |
self.node_rank = None | |
self.local_rank = None | |
self.global_rank = None | |
self.local_world_size = None | |
self.global_world_size = None | |
self.nodewise_sync_global_obj = sync.nodewise_sync_global() | |
def register_stage(self, stage): | |
self.registered_stages.append(stage) | |
def __call__(self, | |
local_rank, | |
**kwargs): | |
cfg = self.cfg | |
cfguh().save_cfg(cfg) | |
self.node_rank = cfg.env.node_rank | |
self.local_rank = local_rank | |
self.nodes = cfg.env.nodes | |
self.local_world_size = cfg.env.gpu_count | |
self.global_rank = self.local_rank + self.node_rank * self.nodes | |
self.global_world_size = self.nodes * self.local_world_size | |
dist.init_process_group( | |
backend = cfg.env.dist_backend, | |
init_method = cfg.env.dist_url, | |
rank = self.global_rank, | |
world_size = self.global_world_size,) | |
torch.cuda.set_device(local_rank) | |
sync.nodewise_sync().copy_global(self.nodewise_sync_global_obj).local_init() | |
if isinstance(cfg.env.rnd_seed, int): | |
np.random.seed(cfg.env.rnd_seed + self.global_rank) | |
torch.manual_seed(cfg.env.rnd_seed + self.global_rank) | |
time_start = timeit.default_timer() | |
para = {'itern_total' : 0,} | |
dl_para = self.prepare_dataloader() | |
assert isinstance(dl_para, dict) | |
para.update(dl_para) | |
md_para = self.prepare_model() | |
assert isinstance(md_para, dict) | |
para.update(md_para) | |
for stage in self.registered_stages: | |
stage_para = stage(**para) | |
if stage_para is not None: | |
para.update(stage_para) | |
if self.global_rank==0: | |
self.save_last_model(**para) | |
print_log( | |
'Total {:.2f} seconds'.format(timeit.default_timer() - time_start)) | |
dist.destroy_process_group() | |
def prepare_dataloader(self): | |
""" | |
Prepare the dataloader from config. | |
""" | |
return { | |
'trainloader' : None, | |
'evalloader' : None} | |
def prepare_model(self): | |
""" | |
Prepare the model from config. | |
""" | |
return {'net' : None} | |
def save_last_model(self, **para): | |
return | |
def destroy(self): | |
self.nodewise_sync_global_obj.destroy() | |
class train(exec_container): | |
def prepare_dataloader(self): | |
cfg = cfguh().cfg | |
trainset = get_dataset()(cfg.train.dataset) | |
sampler = get_sampler()( | |
dataset=trainset, cfg=cfg.train.dataset.get('sampler', 'default_train')) | |
trainloader = torch.utils.data.DataLoader( | |
trainset, | |
batch_size = cfg.train.batch_size_per_gpu, | |
sampler = sampler, | |
num_workers = cfg.train.dataset_num_workers_per_gpu, | |
drop_last = False, | |
pin_memory = cfg.train.dataset.get('pin_memory', False), | |
collate_fn = collate(),) | |
evalloader = None | |
if 'eval' in cfg: | |
evalset = get_dataset()(cfg.eval.dataset) | |
if evalset is not None: | |
sampler = get_sampler()( | |
dataset=evalset, cfg=cfg.eval.dataset.get('sampler', 'default_eval')) | |
evalloader = torch.utils.data.DataLoader( | |
evalset, | |
batch_size = cfg.eval.batch_size_per_gpu, | |
sampler = sampler, | |
num_workers = cfg.eval.dataset_num_workers_per_gpu, | |
drop_last = False, | |
pin_memory = cfg.eval.dataset.get('pin_memory', False), | |
collate_fn = collate(),) | |
return { | |
'trainloader' : trainloader, | |
'evalloader' : evalloader,} | |
def prepare_model(self): | |
cfg = cfguh().cfg | |
net = get_model()(cfg.model) | |
if cfg.env.cuda: | |
net.to(self.local_rank) | |
net = torch.nn.parallel.DistributedDataParallel( | |
net, device_ids=[self.local_rank], | |
find_unused_parameters=True) | |
net.train() | |
scheduler = get_scheduler()(cfg.train.scheduler) | |
optimizer = get_optimizer()(net, cfg.train.optimizer) | |
return { | |
'net' : net, | |
'optimizer' : optimizer, | |
'scheduler' : scheduler,} | |
def save_last_model(self, **para): | |
cfgt = cfguh().cfg.train | |
net = para['net'] | |
net_symbol = cfguh().cfg.model.symbol | |
if isinstance(net, (torch.nn.DataParallel, | |
torch.nn.parallel.DistributedDataParallel)): | |
netm = net.module | |
else: | |
netm = net | |
path = osp.join(cfgt.log_dir, '{}_{}_last.pth'.format( | |
cfgt.experiment_id, net_symbol)) | |
torch.save(netm.state_dict(), path) | |
print_log('Saving model file {0}'.format(path)) | |
class eval(exec_container): | |
def prepare_dataloader(self): | |
cfg = cfguh().cfg | |
evalloader = None | |
if cfg.eval.get('dataset', None) is not None: | |
evalset = get_dataset()(cfg.eval.dataset) | |
if evalset is None: | |
return | |
sampler = get_sampler()( | |
dataset=evalset, cfg=getattr(cfg.eval.dataset, 'sampler', 'default_eval')) | |
evalloader = torch.utils.data.DataLoader( | |
evalset, | |
batch_size = cfg.eval.batch_size_per_gpu, | |
sampler = sampler, | |
num_workers = cfg.eval.dataset_num_workers_per_gpu, | |
drop_last = False, | |
pin_memory = False, | |
collate_fn = collate(), ) | |
return { | |
'trainloader' : None, | |
'evalloader' : evalloader,} | |
def prepare_model(self): | |
cfg = cfguh().cfg | |
net = get_model()(cfg.model) | |
if cfg.env.cuda: | |
net.to(self.local_rank) | |
net = torch.nn.parallel.DistributedDataParallel( | |
net, device_ids=[self.local_rank], | |
find_unused_parameters=True) | |
net.eval() | |
return {'net' : net,} | |
def save_last_model(self, **para): | |
return | |
############### | |
# some helper # | |
############### | |
def torch_to_numpy(*argv): | |
if len(argv) > 1: | |
data = list(argv) | |
else: | |
data = argv[0] | |
if isinstance(data, torch.Tensor): | |
return data.to('cpu').detach().numpy() | |
elif isinstance(data, (list, tuple)): | |
out = [] | |
for di in data: | |
out.append(torch_to_numpy(di)) | |
return out | |
elif isinstance(data, dict): | |
out = {} | |
for ni, di in data.items(): | |
out[ni] = torch_to_numpy(di) | |
return out | |
else: | |
return data | |
import importlib | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |