import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn # cudnn.enabled = True # cudnn.benchmark = True import torch.distributed as dist import torch.multiprocessing as mp import os import os.path as osp import sys import numpy as np import random import pprint import timeit import time import copy import matplotlib.pyplot as plt from .cfg_holder import cfg_unique_holder as cfguh from .data_factory import \ get_dataset, collate, \ get_loader, \ get_transform, \ get_estimator, \ get_formatter, \ get_sampler from .model_zoo import \ get_model, get_optimizer, get_scheduler from .log_service import print_log, distributed_log_manager from .evaluator import get_evaluator from . import sync class train_stage(object): """ This is a template for a train stage, (can be either train or test or anything) Usually, it takes RANK one dataloader, one model, one optimizer, one scheduler. But it is not limited to these parameters. """ def __init__(self): self.nested_eval_stage = None self.rv_keep = None def is_better(self, x): return (self.rv_keep is None) or (x>self.rv_keep) def set_model(self, net, mode): if mode == 'train': return net.train() elif mode == 'eval': return net.eval() else: raise ValueError def __call__(self, **paras): cfg = cfguh().cfg cfgt = cfg.train logm = distributed_log_manager() epochn, itern_local, itern, samplen = 0, 0, 0, 0 step_type = cfgt.get('step_type', 'iter') assert step_type in ['epoch', 'iter', 'sample'], \ 'Step type must be in [epoch, iter, sample]' step_num = cfgt.get('step_num' , None) gradacc_every = cfgt.get('gradacc_every', 1 ) log_every = cfgt.get('log_every' , None) ckpt_every = cfgt.get('ckpt_every' , None) eval_start = cfgt.get('eval_start' , 0 ) eval_every = cfgt.get('eval_every' , None) if paras.get('resume_step', None) is not None: resume_step = paras['resume_step'] assert step_type == resume_step['type'] epochn = resume_step['epochn'] itern = resume_step['itern'] itern_local = itern * gradacc_every samplen = resume_step['samplen'] del paras['resume_step'] trainloader = paras['trainloader'] if trainloader is None: import itertools trainloader = itertools.cycle([None]) optimizer = paras['optimizer'] scheduler = paras['scheduler'] net = paras['net'] GRANK, LRANK, NRANK = sync.get_rank('all') GWSIZE, LWSIZE, NODES = sync.get_world_size('all') weight_path = osp.join(cfgt.log_dir, 'weight') if (GRANK==0) and (not osp.isdir(weight_path)): os.makedirs(weight_path) if (GRANK==0) and (cfgt.save_init_model): self.save(net, is_init=True, step=0, optimizer=optimizer) epoch_time = timeit.default_timer() end_flag = False net.train() while True: if step_type == 'epoch': lr = scheduler[epochn] if scheduler is not None else None for batch in trainloader: # so first element of batch (usually image) can be [tensor] if batch is None: bs = cfgt.batch_size_per_gpu elif not isinstance(batch[0], list): bs = batch[0].shape[0] else: bs = len(batch[0]) if cfgt.skip_partial_batch and (bs != cfgt.batch_size_per_gpu): continue itern_local_next = itern_local + 1 samplen_next = samplen + bs*GWSIZE if step_type == 'iter': lr = scheduler[itern] if scheduler is not None else None grad_update = itern_local%gradacc_every==(gradacc_every-1) elif step_type == 'sample': lr = scheduler[samplen] if scheduler is not None else None # TODO: # grad_update = samplen%gradacc_every==(gradacc_every-1) itern_next = itern + 1 if grad_update else itern # timeDebug = timeit.default_timer() paras_new = self.main( batch=batch, lr=lr, itern_local=itern_local, itern=itern, epochn=epochn, samplen=samplen, isinit=False, grad_update=grad_update, **paras) # print_log(timeit.default_timer() - timeDebug) paras.update(paras_new) logm.accumulate(bs, **paras['log_info']) ####### # log # ####### display_flag = False if log_every is not None: display_i = (itern//log_every) != (itern_next//log_every) display_s = (samplen//log_every) != (samplen_next//log_every) display_flag = (display_i and (step_type=='iter')) \ or (display_s and (step_type=='sample')) if display_flag: tbstep = itern_next if step_type=='iter' else samplen_next console_info = logm.train_summary( itern_next, epochn, samplen_next, lr, tbstep=tbstep) logm.clear() print_log(console_info) ######## # eval # ######## eval_flag = False if (self.nested_eval_stage is not None) and (eval_every is not None) and (NRANK == 0): if step_type=='iter': eval_flag = (itern//eval_every) != (itern_next//eval_every) eval_flag = eval_flag and (itern_next>=eval_start) eval_flag = eval_flag or itern_local==0 if step_type=='sample': eval_flag = (samplen//eval_every) != (samplen_next//eval_every) eval_flag = eval_flag and (samplen_next>=eval_start) eval_flag = eval_flag or samplen==0 if eval_flag: eval_cnt = itern_next if step_type=='iter' else samplen_next net = self.set_model(net, 'eval') rv = self.nested_eval_stage( eval_cnt=eval_cnt, **paras) rv = rv.get('eval_rv', None) if rv is not None: logm.tensorboard_log(eval_cnt, rv, mode='eval') if self.is_better(rv): self.rv_keep = rv if GRANK==0: step = {'epochn':epochn, 'itern':itern_next, 'samplen':samplen_next, 'type':step_type, } self.save(net, is_best=True, step=step, optimizer=optimizer) net = self.set_model(net, 'train') ######## # ckpt # ######## ckpt_flag = False if (GRANK==0) and (ckpt_every is not None): # not distributed ckpt_i = (itern//ckpt_every) != (itern_next//ckpt_every) ckpt_s = (samplen//ckpt_every) != (samplen_next//ckpt_every) ckpt_flag = (ckpt_i and (step_type=='iter')) \ or (ckpt_s and (step_type=='sample')) if ckpt_flag: if step_type == 'iter': print_log('Checkpoint... {}'.format(itern_next)) step = {'epochn':epochn, 'itern':itern_next, 'samplen':samplen_next, 'type':step_type, } self.save(net, itern=itern_next, step=step, optimizer=optimizer) else: print_log('Checkpoint... {}'.format(samplen_next)) step = {'epochn':epochn, 'itern':itern_next, 'samplen':samplen_next, 'type':step_type, } self.save(net, samplen=samplen_next, step=step, optimizer=optimizer) ####### # end # ####### itern_local = itern_local_next itern = itern_next samplen = samplen_next if step_type is not None: end_flag = (itern>=step_num and (step_type=='iter')) \ or (samplen>=step_num and (step_type=='sample')) if end_flag: break # loop end epochn += 1 print_log('Epoch {} time:{:.2f}s.'.format( epochn, timeit.default_timer()-epoch_time)) epoch_time = timeit.default_timer() if end_flag: break elif step_type != 'epoch': # This is temporarily added to resolve the data issue trainloader = self.trick_update_trainloader(trainloader) continue ####### # log # ####### display_flag = False if (log_every is not None) and (step_type=='epoch'): display_flag = (epochn==1) or (epochn%log_every==0) if display_flag: console_info = logm.train_summary( itern, epochn, samplen, lr, tbstep=epochn) logm.clear() print_log(console_info) ######## # eval # ######## eval_flag = False if (self.nested_eval_stage is not None) and (eval_every is not None) \ and (step_type=='epoch') and (NRANK==0): eval_flag = (epochn%eval_every==0) and (itern_next>=eval_start) eval_flag = (epochn==1) or eval_flag if eval_flag: net = self.set_model(net, 'eval') rv = self.nested_eval_stage( eval_cnt=epochn, **paras)['eval_rv'] if rv is not None: logm.tensorboard_log(epochn, rv, mode='eval') if self.is_better(rv): self.rv_keep = rv if (GRANK==0): step = {'epochn':epochn, 'itern':itern, 'samplen':samplen, 'type':step_type, } self.save(net, is_best=True, step=step, optimizer=optimizer) net = self.set_model(net, 'train') ######## # ckpt # ######## ckpt_flag = False if (ckpt_every is not None) and (GRANK==0) and (step_type=='epoch'): # not distributed ckpt_flag = epochn%ckpt_every==0 if ckpt_flag: print_log('Checkpoint... {}'.format(itern_next)) step = {'epochn':epochn, 'itern':itern, 'samplen':samplen, 'type':step_type, } self.save(net, epochn=epochn, step=step, optimizer=optimizer) ####### # end # ####### if (step_type=='epoch') and (epochn>=step_num): break # loop end # This is temporarily added to resolve the data issue trainloader = self.trick_update_trainloader(trainloader) logm.tensorboard_close() return {} def main(self, **paras): raise NotImplementedError def trick_update_trainloader(self, trainloader): return trainloader def save_model(self, net, path_noext, **paras): cfgt = cfguh().cfg.train path = path_noext+'.pth' if isinstance(net, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): netm = net.module else: netm = net torch.save(netm.state_dict(), path) print_log('Saving model file {0}'.format(path)) def save(self, net, itern=None, epochn=None, samplen=None, is_init=False, is_best=False, is_last=False, **paras): exid = cfguh().cfg.env.experiment_id cfgt = cfguh().cfg.train cfgm = cfguh().cfg.model if isinstance(net, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): netm = net.module else: netm = net net_symbol = cfgm.symbol check = sum([ itern is not None, samplen is not None, epochn is not None, is_init, is_best, is_last]) assert check<2 if itern is not None: path_noexp = '{}_{}_iter_{}'.format(exid, net_symbol, itern) elif samplen is not None: path_noexp = '{}_{}_samplen_{}'.format(exid, net_symbol, samplen) elif epochn is not None: path_noexp = '{}_{}_epoch_{}'.format(exid, net_symbol, epochn) elif is_init: path_noexp = '{}_{}_init'.format(exid, net_symbol) elif is_best: path_noexp = '{}_{}_best'.format(exid, net_symbol) elif is_last: path_noexp = '{}_{}_last'.format(exid, net_symbol) else: path_noexp = '{}_{}_default'.format(exid, net_symbol) path_noexp = osp.join(cfgt.log_dir, 'weight', path_noexp) self.save_model(net, path_noexp, **paras) class eval_stage(object): def __init__(self): self.evaluator = None def create_dir(self, path): grank = sync.get_rank('global') if (not osp.isdir(path)) and (grank == 0): os.makedirs(path) sync.nodewise_sync().barrier() def __call__(self, evalloader, net, **paras): cfgt = cfguh().cfg.eval local_rank = sync.get_rank('local') if self.evaluator is None: evaluator = get_evaluator()(cfgt.evaluator) self.evaluator = evaluator else: evaluator = self.evaluator time_check = timeit.default_timer() for idx, batch in enumerate(evalloader): rv = self.main(batch, net) evaluator.add_batch(**rv) if cfgt.output_result: try: self.output_f(**rv, cnt=paras['eval_cnt']) except: self.output_f(**rv) if idx%cfgt.log_display == cfgt.log_display-1: print_log('processed.. {}, Time:{:.2f}s'.format( idx+1, timeit.default_timer() - time_check)) time_check = timeit.default_timer() # break evaluator.set_sample_n(len(evalloader.dataset)) eval_rv = evaluator.compute() if local_rank == 0: evaluator.one_line_summary() evaluator.save(cfgt.log_dir) evaluator.clear_data() return { 'eval_rv' : eval_rv } class exec_container(object): """ This is the base functor for all types of executions. One execution can have multiple stages, but are only allowed to use the same config, network, dataloader. Thus, in most of the cases, one exec_container is one training/evaluation/demo... If DPP is in use, this functor should be spawn. """ def __init__(self, cfg, **kwargs): self.cfg = cfg self.registered_stages = [] self.node_rank = None self.local_rank = None self.global_rank = None self.local_world_size = None self.global_world_size = None self.nodewise_sync_global_obj = sync.nodewise_sync_global() def register_stage(self, stage): self.registered_stages.append(stage) def __call__(self, local_rank, **kwargs): cfg = self.cfg cfguh().save_cfg(cfg) self.node_rank = cfg.env.node_rank self.local_rank = local_rank self.nodes = cfg.env.nodes self.local_world_size = cfg.env.gpu_count self.global_rank = self.local_rank + self.node_rank * self.local_world_size self.global_world_size = self.nodes * self.local_world_size print('init {}/{}'.format(self.global_rank, self.global_world_size)) dist.init_process_group( backend = cfg.env.dist_backend, init_method = cfg.env.dist_url, rank = self.global_rank, world_size = self.global_world_size,) torch.cuda.set_device(local_rank) sync.nodewise_sync().copy_global(self.nodewise_sync_global_obj).local_init() if isinstance(cfg.env.rnd_seed, int): random.seed(cfg.env.rnd_seed + self.global_rank + 200) np.random.seed(cfg.env.rnd_seed + self.global_rank + 100) torch.manual_seed(cfg.env.rnd_seed + self.global_rank) time_start = timeit.default_timer() para = {'itern_total' : 0,} dl_para = self.prepare_dataloader() assert isinstance(dl_para, dict) para.update(dl_para) md_para = self.prepare_model() assert isinstance(md_para, dict) para.update(md_para) for stage in self.registered_stages: stage_para = stage(**para) if stage_para is not None: para.update(stage_para) if self.global_rank==0: self.save_last_model(**para) print_log( 'Total {:.2f} seconds'.format(timeit.default_timer() - time_start)) dist.destroy_process_group() def prepare_dataloader(self): """ Prepare the dataloader from config. """ return { 'trainloader' : None, 'evalloader' : None} def prepare_model(self): """ Prepare the model from config. """ return {'net' : None} def save_last_model(self, **para): return def destroy(self): self.nodewise_sync_global_obj.destroy() class train(exec_container): def prepare_dataloader(self): cfg = cfguh().cfg trainset = get_dataset()(cfg.train.dataset) trainloader = None if trainset is not None: sampler = get_sampler()( dataset=trainset, cfg=cfg.train.dataset.get('sampler', 'default_train')) trainloader = torch.utils.data.DataLoader( trainset, batch_size = cfg.train.batch_size_per_gpu, sampler = sampler, num_workers = cfg.train.dataset_num_workers_per_gpu, drop_last = False, pin_memory = cfg.train.dataset.get('pin_memory', False), collate_fn = collate(),) evalloader = None if 'eval' in cfg: evalset = get_dataset()(cfg.eval.dataset) if evalset is not None: sampler = get_sampler()( dataset=evalset, cfg=cfg.eval.dataset.get('sampler', 'default_eval')) evalloader = torch.utils.data.DataLoader( evalset, batch_size = cfg.eval.batch_size_per_gpu, sampler = sampler, num_workers = cfg.eval.dataset_num_workers_per_gpu, drop_last = False, pin_memory = cfg.eval.dataset.get('pin_memory', False), collate_fn = collate(),) return { 'trainloader' : trainloader, 'evalloader' : evalloader,} def prepare_model(self): cfg = cfguh().cfg net = get_model()(cfg.model) find_unused_parameters=cfg.model.get('find_unused_parameters', False) if cfg.env.cuda: net.to(self.local_rank) net = torch.nn.parallel.DistributedDataParallel( net, device_ids=[self.local_rank], find_unused_parameters=find_unused_parameters) net.train() scheduler = get_scheduler()(cfg.train.scheduler) optimizer = get_optimizer()(net, cfg.train.optimizer) return { 'net' : net, 'optimizer' : optimizer, 'scheduler' : scheduler,} def save_last_model(self, **para): cfgt = cfguh().cfg.train net = para['net'] net_symbol = cfguh().cfg.model.symbol if isinstance(net, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): netm = net.module else: netm = net path = osp.join(cfgt.log_dir, '{}_{}_last.pth'.format( cfgt.experiment_id, net_symbol)) torch.save(netm.state_dict(), path) print_log('Saving model file {0}'.format(path)) class eval(exec_container): def prepare_dataloader(self): cfg = cfguh().cfg evalloader = None if cfg.eval.get('dataset', None) is not None: evalset = get_dataset()(cfg.eval.dataset) if evalset is None: return sampler = get_sampler()( dataset=evalset, cfg=getattr(cfg.eval.dataset, 'sampler', 'default_eval')) evalloader = torch.utils.data.DataLoader( evalset, batch_size = cfg.eval.batch_size_per_gpu, sampler = sampler, num_workers = cfg.eval.dataset_num_workers_per_gpu, drop_last = False, pin_memory = False, collate_fn = collate(), ) return { 'trainloader' : None, 'evalloader' : evalloader,} def prepare_model(self): cfg = cfguh().cfg net = get_model()(cfg.model) if cfg.env.cuda: net.to(self.local_rank) net = torch.nn.parallel.DistributedDataParallel( net, device_ids=[self.local_rank], find_unused_parameters=True) net.eval() return {'net' : net,} def save_last_model(self, **para): return ############### # some helper # ############### def torch_to_numpy(*argv): if len(argv) > 1: data = list(argv) else: data = argv[0] if isinstance(data, torch.Tensor): return data.to('cpu').detach().numpy() elif isinstance(data, (list, tuple)): out = [] for di in data: out.append(torch_to_numpy(di)) return out elif isinstance(data, dict): out = {} for ni, di in data.items(): out[ni] = torch_to_numpy(di) return out else: return data import importlib def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls)