import os import time import json import datetime as datetime import numpy as np import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist from torch.utils.data import DataLoader from torchvision import transforms from dataloaders.train_datasets import DAVIS2017_Train, YOUTUBEVOS_Train, StaticTrain, TEST import dataloaders.video_transforms as tr from utils.meters import AverageMeter from utils.image import label2colormap, masked_image, save_image from utils.checkpoint import load_network_and_optimizer, load_network, save_network from utils.learning import adjust_learning_rate, get_trainable_params from utils.metric import pytorch_iou from utils.ema import ExponentialMovingAverage, get_param_buffer_for_ema from networks.models import build_vos_model from networks.engines import build_engine class Trainer(object): def __init__(self, rank, cfg, enable_amp=True): self.gpu = rank + cfg.DIST_START_GPU self.gpu_num = cfg.TRAIN_GPUS self.rank = rank self.cfg = cfg self.print_log("Exp {}:".format(cfg.EXP_NAME)) self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True)) print("Use GPU {} for training VOS.".format(self.gpu)) torch.cuda.set_device(self.gpu) torch.backends.cudnn.benchmark = True if cfg.DATA_RANDOMCROP[ 0] == cfg.DATA_RANDOMCROP[ 1] and 'swin' not in cfg.MODEL_ENCODER else False self.print_log('Build VOS model.') self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu) self.model_encoder = self.model.encoder self.engine = build_engine( cfg.MODEL_ENGINE, 'train', aot_model=self.model, gpu_id=self.gpu, long_term_mem_gap=cfg.TRAIN_LONG_TERM_MEM_GAP) if cfg.MODEL_FREEZE_BACKBONE: for param in self.model_encoder.parameters(): param.requires_grad = False if cfg.DIST_ENABLE: dist.init_process_group(backend=cfg.DIST_BACKEND, init_method=cfg.DIST_URL, world_size=cfg.TRAIN_GPUS, rank=rank, timeout=datetime.timedelta(seconds=300)) self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm( self.model.encoder).cuda(self.gpu) self.dist_engine = torch.nn.parallel.DistributedDataParallel( self.engine, device_ids=[self.gpu], output_device=self.gpu, find_unused_parameters=True, broadcast_buffers=False) else: self.dist_engine = self.engine self.use_frozen_bn = False if 'swin' in cfg.MODEL_ENCODER: self.print_log('Use LN in Encoder!') elif not cfg.MODEL_FREEZE_BN: if cfg.DIST_ENABLE: self.print_log('Use Sync BN in Encoder!') else: self.print_log('Use BN in Encoder!') else: self.use_frozen_bn = True self.print_log('Use Frozen BN in Encoder!') if self.rank == 0: try: total_steps = float(cfg.TRAIN_TOTAL_STEPS) ema_decay = 1. - 1. / (total_steps * cfg.TRAIN_EMA_RATIO) self.ema_params = get_param_buffer_for_ema( self.model, update_buffer=(not cfg.MODEL_FREEZE_BN)) self.ema = ExponentialMovingAverage(self.ema_params, decay=ema_decay) self.ema_dir = cfg.DIR_EMA_CKPT except Exception as inst: self.print_log(inst) self.print_log('Error: failed to create EMA model!') self.print_log('Build optimizer.') trainable_params = get_trainable_params( model=self.dist_engine, base_lr=cfg.TRAIN_LR, use_frozen_bn=self.use_frozen_bn, weight_decay=cfg.TRAIN_WEIGHT_DECAY, exclusive_wd_dict=cfg.TRAIN_WEIGHT_DECAY_EXCLUSIVE, no_wd_keys=cfg.TRAIN_WEIGHT_DECAY_EXEMPTION) if cfg.TRAIN_OPT == 'sgd': self.optimizer = optim.SGD(trainable_params, lr=cfg.TRAIN_LR, momentum=cfg.TRAIN_SGD_MOMENTUM, nesterov=True) else: self.optimizer = optim.AdamW(trainable_params, lr=cfg.TRAIN_LR, weight_decay=cfg.TRAIN_WEIGHT_DECAY) self.enable_amp = enable_amp if enable_amp: self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None self.prepare_dataset() self.process_pretrained_model() if cfg.TRAIN_TBLOG and self.rank == 0: from tensorboardX import SummaryWriter self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) def process_pretrained_model(self): cfg = self.cfg self.step = cfg.TRAIN_START_STEP self.epoch = 0 if cfg.TRAIN_AUTO_RESUME: ckpts = os.listdir(cfg.DIR_CKPT) if len(ckpts) > 0: ckpts = list( map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) ckpt = np.sort(ckpts)[-1] cfg.TRAIN_RESUME = True cfg.TRAIN_RESUME_CKPT = ckpt cfg.TRAIN_RESUME_STEP = ckpt else: cfg.TRAIN_RESUME = False if cfg.TRAIN_RESUME: if self.rank == 0: try: try: ema_ckpt_dir = os.path.join( self.ema_dir, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) ema_model, removed_dict = load_network( self.model, ema_ckpt_dir, self.gpu) except Exception as inst: self.print_log(inst) self.print_log('Try to use backup EMA checkpoint.') DIR_RESULT = './backup/{}/{}'.format( cfg.EXP_NAME, cfg.STAGE_NAME) DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') ema_ckpt_dir = os.path.join( DIR_EMA_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) ema_model, removed_dict = load_network( self.model, ema_ckpt_dir, self.gpu) if len(removed_dict) > 0: self.print_log( 'Remove {} from EMA model.'.format(removed_dict)) ema_decay = self.ema.decay del (self.ema) ema_params = get_param_buffer_for_ema( ema_model, update_buffer=(not cfg.MODEL_FREEZE_BN)) self.ema = ExponentialMovingAverage(ema_params, decay=ema_decay) self.ema.num_updates = cfg.TRAIN_RESUME_CKPT except Exception as inst: self.print_log(inst) self.print_log('Error: EMA model not found!') try: resume_ckpt = os.path.join( cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) self.model, self.optimizer, removed_dict = load_network_and_optimizer( self.model, self.optimizer, resume_ckpt, self.gpu, scaler=self.scaler) except Exception as inst: self.print_log(inst) self.print_log('Try to use backup checkpoint.') DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME, cfg.STAGE_NAME) DIR_CKPT = os.path.join(DIR_RESULT, 'ckpt') resume_ckpt = os.path.join( DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) self.model, self.optimizer, removed_dict = load_network_and_optimizer( self.model, self.optimizer, resume_ckpt, self.gpu, scaler=self.scaler) if len(removed_dict) > 0: self.print_log( 'Remove {} from checkpoint.'.format(removed_dict)) self.step = cfg.TRAIN_RESUME_STEP if cfg.TRAIN_TOTAL_STEPS <= self.step: self.print_log("Your training has finished!") exit() self.epoch = int(np.ceil(self.step / len(self.train_loader))) self.print_log('Resume from step {}'.format(self.step)) elif cfg.PRETRAIN: if cfg.PRETRAIN_FULL: try: self.model, removed_dict = load_network( self.model, cfg.PRETRAIN_MODEL, self.gpu) except Exception as inst: self.print_log(inst) self.print_log('Try to use backup EMA checkpoint.') DIR_RESULT = './backup/{}/{}'.format( cfg.EXP_NAME, cfg.STAGE_NAME) DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') PRETRAIN_MODEL = os.path.join( DIR_EMA_CKPT, cfg.PRETRAIN_MODEL.split('/')[-1]) self.model, removed_dict = load_network( self.model, PRETRAIN_MODEL, self.gpu) if len(removed_dict) > 0: self.print_log('Remove {} from pretrained model.'.format( removed_dict)) self.print_log('Load pretrained VOS model from {}.'.format( cfg.PRETRAIN_MODEL)) else: model_encoder, removed_dict = load_network( self.model_encoder, cfg.PRETRAIN_MODEL, self.gpu) if len(removed_dict) > 0: self.print_log('Remove {} from pretrained model.'.format( removed_dict)) self.print_log( 'Load pretrained backbone model from {}.'.format( cfg.PRETRAIN_MODEL)) def prepare_dataset(self): cfg = self.cfg self.enable_prev_frame = cfg.TRAIN_ENABLE_PREV_FRAME self.print_log('Process dataset...') if cfg.TRAIN_AUG_TYPE == 'v1': composed_transforms = transforms.Compose([ tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, cfg.DATA_MAX_SCALE_FACTOR, cfg.DATA_SHORT_EDGE_LEN), tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, max_obj_num=cfg.MODEL_MAX_OBJ_NUM), tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), tr.ToTensor() ]) elif cfg.TRAIN_AUG_TYPE == 'v2': composed_transforms = transforms.Compose([ tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, cfg.DATA_MAX_SCALE_FACTOR, cfg.DATA_SHORT_EDGE_LEN), tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, max_obj_num=cfg.MODEL_MAX_OBJ_NUM), tr.RandomColorJitter(), tr.RandomGrayScale(), tr.RandomGaussianBlur(), tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), tr.ToTensor() ]) else: assert NotImplementedError train_datasets = [] if 'static' in cfg.DATASETS: pretrain_vos_dataset = StaticTrain( cfg.DIR_STATIC, cfg.DATA_RANDOMCROP, seq_len=cfg.DATA_SEQ_LEN, merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, max_obj_n=cfg.MODEL_MAX_OBJ_NUM, aug_type=cfg.TRAIN_AUG_TYPE) train_datasets.append(pretrain_vos_dataset) self.enable_prev_frame = False if 'davis2017' in cfg.DATASETS: train_davis_dataset = DAVIS2017_Train( root=cfg.DIR_DAVIS, full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, transform=composed_transforms, repeat_time=cfg.DATA_DAVIS_REPEAT, seq_len=cfg.DATA_SEQ_LEN, rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, enable_prev_frame=self.enable_prev_frame, max_obj_n=cfg.MODEL_MAX_OBJ_NUM) train_datasets.append(train_davis_dataset) if 'youtubevos' in cfg.DATASETS: train_ytb_dataset = YOUTUBEVOS_Train( root=cfg.DIR_YTB, transform=composed_transforms, seq_len=cfg.DATA_SEQ_LEN, rand_gap=cfg.DATA_RANDOM_GAP_YTB, rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, enable_prev_frame=self.enable_prev_frame, max_obj_n=cfg.MODEL_MAX_OBJ_NUM) train_datasets.append(train_ytb_dataset) if 'test' in cfg.DATASETS: test_dataset = TEST(transform=composed_transforms, seq_len=cfg.DATA_SEQ_LEN) train_datasets.append(test_dataset) if len(train_datasets) > 1: train_dataset = torch.utils.data.ConcatDataset(train_datasets) elif len(train_datasets) == 1: train_dataset = train_datasets[0] else: self.print_log('No dataset!') exit(0) self.train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) if self.cfg.DIST_ENABLE else None self.train_loader = DataLoader(train_dataset, batch_size=int(cfg.TRAIN_BATCH_SIZE / cfg.TRAIN_GPUS), shuffle=False if self.cfg.DIST_ENABLE else True, num_workers=cfg.DATA_WORKERS, pin_memory=True, sampler=self.train_sampler, drop_last=True, prefetch_factor=4) self.print_log('Done!') def sequential_training(self): cfg = self.cfg if self.enable_prev_frame: frame_names = ['Ref', 'Prev'] else: frame_names = ['Ref(Prev)'] for i in range(cfg.DATA_SEQ_LEN - 1): frame_names.append('Curr{}'.format(i + 1)) seq_len = len(frame_names) running_losses = [] running_ious = [] for _ in range(seq_len): running_losses.append(AverageMeter()) running_ious.append(AverageMeter()) batch_time = AverageMeter() avg_obj = AverageMeter() optimizer = self.optimizer model = self.dist_engine train_sampler = self.train_sampler train_loader = self.train_loader step = self.step epoch = self.epoch max_itr = cfg.TRAIN_TOTAL_STEPS start_seq_training_step = int(cfg.TRAIN_SEQ_TRAINING_START_RATIO * max_itr) use_prev_prob = cfg.MODEL_USE_PREV_PROB self.print_log('Start training:') model.train() while step < cfg.TRAIN_TOTAL_STEPS: if self.cfg.DIST_ENABLE: train_sampler.set_epoch(epoch) epoch += 1 last_time = time.time() for frame_idx, sample in enumerate(train_loader): if step > cfg.TRAIN_TOTAL_STEPS: break if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: tf_board = True else: tf_board = False if step >= start_seq_training_step: use_prev_pred = True freeze_params = cfg.TRAIN_SEQ_TRAINING_FREEZE_PARAMS else: use_prev_pred = False freeze_params = [] if step % cfg.TRAIN_LR_UPDATE_STEP == 0: now_lr = adjust_learning_rate( optimizer=optimizer, base_lr=cfg.TRAIN_LR, p=cfg.TRAIN_LR_POWER, itr=step, max_itr=max_itr, restart=cfg.TRAIN_LR_RESTART, warm_up_steps=cfg.TRAIN_LR_WARM_UP_RATIO * max_itr, is_cosine_decay=cfg.TRAIN_LR_COSINE_DECAY, min_lr=cfg.TRAIN_LR_MIN, encoder_lr_ratio=cfg.TRAIN_LR_ENCODER_RATIO, freeze_params=freeze_params) ref_imgs = sample['ref_img'] # batch_size * 3 * h * w prev_imgs = sample['prev_img'] curr_imgs = sample['curr_img'] ref_labels = sample['ref_label'] # batch_size * 1 * h * w prev_labels = sample['prev_label'] curr_labels = sample['curr_label'] obj_nums = sample['meta']['obj_num'] bs, _, h, w = curr_imgs[0].size() ref_imgs = ref_imgs.cuda(self.gpu, non_blocking=True) prev_imgs = prev_imgs.cuda(self.gpu, non_blocking=True) curr_imgs = [ curr_img.cuda(self.gpu, non_blocking=True) for curr_img in curr_imgs ] ref_labels = ref_labels.cuda(self.gpu, non_blocking=True) prev_labels = prev_labels.cuda(self.gpu, non_blocking=True) curr_labels = [ curr_label.cuda(self.gpu, non_blocking=True) for curr_label in curr_labels ] obj_nums = list(obj_nums) obj_nums = [int(obj_num) for obj_num in obj_nums] batch_size = ref_imgs.size(0) all_frames = torch.cat([ref_imgs, prev_imgs] + curr_imgs, dim=0) all_labels = torch.cat([ref_labels, prev_labels] + curr_labels, dim=0) self.engine.restart_engine(batch_size, True) optimizer.zero_grad(set_to_none=True) if self.enable_amp: with torch.cuda.amp.autocast(enabled=True): loss, all_pred, all_loss, boards = model( all_frames, all_labels, batch_size, use_prev_pred=use_prev_pred, obj_nums=obj_nums, step=step, tf_board=tf_board, enable_prev_frame=self.enable_prev_frame, use_prev_prob=use_prev_prob) loss = torch.mean(loss) start = time.time() self.scaler.scale(loss).backward() end = time.time() print(end-start) self.scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.TRAIN_CLIP_GRAD_NORM) self.scaler.step(optimizer) self.scaler.update() else: loss, all_pred, all_loss, boards = model( all_frames, all_labels, ref_imgs.size(0), use_prev_pred=use_prev_pred, obj_nums=obj_nums, step=step, tf_board=tf_board, enable_prev_frame=self.enable_prev_frame, use_prev_prob=use_prev_prob) loss = torch.mean(loss) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.TRAIN_CLIP_GRAD_NORM) loss.backward() optimizer.step() for idx in range(seq_len): now_pred = all_pred[idx].detach() now_label = all_labels[idx * bs:(idx + 1) * bs].detach() now_loss = torch.mean(all_loss[idx].detach()) now_iou = pytorch_iou(now_pred.unsqueeze(1), now_label, obj_nums) * 100 if self.cfg.DIST_ENABLE: dist.all_reduce(now_loss) dist.all_reduce(now_iou) now_loss /= self.gpu_num now_iou /= self.gpu_num if self.rank == 0: running_losses[idx].update(now_loss.item()) running_ious[idx].update(now_iou.item()) if self.rank == 0: self.ema.update(self.ema_params) avg_obj.update(sum(obj_nums) / float(len(obj_nums))) curr_time = time.time() batch_time.update(curr_time - last_time) last_time = curr_time if step % cfg.TRAIN_TBLOG_STEP == 0: all_f = [ref_imgs, prev_imgs] + curr_imgs self.process_log(ref_imgs, all_f[-2], all_f[-1], ref_labels, all_pred[-2], now_label, now_pred, boards, running_losses, running_ious, now_lr, step) if step % cfg.TRAIN_LOG_STEP == 0: strs = 'I:{}, LR:{:.5f}, T:{:.1f}({:.1f})s, Obj:{:.1f}({:.1f})'.format( step, now_lr, batch_time.val, batch_time.moving_avg, avg_obj.val, avg_obj.moving_avg) batch_time.reset() avg_obj.reset() for idx in range(seq_len): strs += ', {}: L {:.3f}({:.3f}) IoU {:.1f}({:.1f})%'.format( frame_names[idx], running_losses[idx].val, running_losses[idx].moving_avg, running_ious[idx].val, running_ious[idx].moving_avg) running_losses[idx].reset() running_ious[idx].reset() self.print_log(strs) step += 1 if step % cfg.TRAIN_SAVE_STEP == 0 and self.rank == 0: max_mem = torch.cuda.max_memory_allocated( device=self.gpu) / (1024.**3) ETA = str( datetime.timedelta( seconds=int(batch_time.moving_avg * (cfg.TRAIN_TOTAL_STEPS - step)))) self.print_log('ETA: {}, Max Mem: {:.2f}G.'.format( ETA, max_mem)) self.print_log('Save CKPT (Step {}).'.format(step)) save_network(self.model, optimizer, step, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT, backup_dir='./backup/{}/{}/ckpt'.format( cfg.EXP_NAME, cfg.STAGE_NAME), scaler=self.scaler) try: torch.cuda.empty_cache() # First save original parameters before replacing with EMA version self.ema.store(self.ema_params) # Copy EMA parameters to model self.ema.copy_to(self.ema_params) # Save EMA model save_network( self.model, optimizer, step, self.ema_dir, cfg.TRAIN_MAX_KEEP_CKPT, backup_dir='./backup/{}/{}/ema_ckpt'.format( cfg.EXP_NAME, cfg.STAGE_NAME), scaler=self.scaler) # Restore original parameters to resume training later self.ema.restore(self.ema_params) except Exception as inst: self.print_log(inst) self.print_log('Error: failed to save EMA model!') self.print_log('Stop training!') def print_log(self, string): if self.rank == 0: print(string) def process_log(self, ref_imgs, prev_imgs, curr_imgs, ref_labels, prev_labels, curr_labels, curr_pred, boards, running_losses, running_ious, now_lr, step): cfg = self.cfg mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img, show_prev_img, show_curr_img = [ img.cpu().numpy()[0] * sigma + mean for img in [ref_imgs, prev_imgs, curr_imgs] ] show_gt, show_prev_gt, show_ref_gt, show_preds_s = [ label.cpu()[0].squeeze(0).numpy() for label in [curr_labels, prev_labels, ref_labels, curr_pred] ] show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [ label2colormap(label).transpose((2, 0, 1)) for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s] ] if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: show_ref_img = masked_image(show_ref_img, show_ref_gtf, show_ref_gt) if cfg.TRAIN_IMG_LOG: save_image( show_ref_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_ref_img.jpeg' % (step))) show_prev_img = masked_image(show_prev_img, show_prev_gtf, show_prev_gt) if cfg.TRAIN_IMG_LOG: save_image( show_prev_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_prev_img.jpeg' % (step))) show_img_pred = masked_image(show_curr_img, show_preds_sf, show_preds_s) if cfg.TRAIN_IMG_LOG: save_image( show_img_pred, os.path.join(cfg.DIR_IMG_LOG, '%06d_prediction.jpeg' % (step))) show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) if cfg.TRAIN_IMG_LOG: save_image( show_curr_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_groundtruth.jpeg' % (step))) if cfg.TRAIN_TBLOG: for seq_step, running_loss, running_iou in zip( range(len(running_losses)), running_losses, running_ious): self.tblogger.add_scalar('S{}/Loss'.format(seq_step), running_loss.avg, step) self.tblogger.add_scalar('S{}/IoU'.format(seq_step), running_iou.avg, step) self.tblogger.add_scalar('LR', now_lr, step) self.tblogger.add_image('Ref/Image', show_ref_img, step) self.tblogger.add_image('Ref/GT', show_ref_gtf, step) self.tblogger.add_image('Prev/Image', show_prev_img, step) self.tblogger.add_image('Prev/GT', show_prev_gtf, step) self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) for key in boards['image'].keys(): tmp = boards['image'][key].cpu().numpy() self.tblogger.add_image('S{}/' + key, tmp, step) for key in boards['scalar'].keys(): tmp = boards['scalar'][key].cpu().numpy() self.tblogger.add_scalar('S{}/' + key, tmp, step) self.tblogger.flush() del (boards)