import haienv haienv.set_env('lavt2') import torch.multiprocessing as mp import torch.distributed as dist import datetime import os import time import torch import torch.utils.data from torch import nn from functools import reduce import operator from bert.modeling_bert import BertModel import torchvision from lib import segmentation import transforms as T import utils import numpy as np import torch.nn.functional as F import gc from collections import OrderedDict import torch.backends.cudnn as cudnn from ffrecord.torch import DataLoader,Dataset def get_dataset(image_set, transform, args): from data.dataset_refer_bert import ReferDataset ds = ReferDataset(args, split=image_set, image_transforms=transform, target_transforms=None ) num_classes = 2 return ds, num_classes # IoU calculation for validation def IoU(pred, gt): pred = pred.argmax(1) intersection = torch.sum(torch.mul(pred, gt)) union = torch.sum(torch.add(pred, gt)) - intersection if intersection == 0 or union == 0: iou = 0 else: iou = float(intersection) / float(union) return iou, intersection, union def get_transform(args): transforms = [T.Resize(args.img_size, args.img_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return T.Compose(transforms) def criterion(input, target): weight = torch.FloatTensor([0.9, 1.1]).cuda() return nn.functional.cross_entropy(input, target, weight=weight) def evaluate(model, data_loader, bert_model): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' total_its = 0 acc_ious = 0 # evaluation variables cum_I, cum_U = 0, 0 eval_seg_iou_list = [.5, .6, .7, .8, .9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_IoU = [] with torch.no_grad(): for data in metric_logger.log_every(data_loader, 100, header): total_its += 1 image, target, sentences, attentions = data image, target, sentences, attentions = image.cuda(non_blocking=True),\ target.cuda(non_blocking=True),\ sentences.cuda(non_blocking=True),\ attentions.cuda(non_blocking=True) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) #print("sentences", sentences.shape) #print("attentions", attentions.shape) if bert_model is not None: last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] #print("last hidden states", last_hidden_states.shape) embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1) output = model(image, embedding, l_mask=attentions) else: output = model(image, sentences, l_mask=attentions) iou, I, U = IoU(output, target) acc_ious += iou mean_IoU.append(iou) cum_I += I cum_U += U for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += (iou >= eval_seg_iou) seg_total += 1 iou = acc_ious / total_its mean_IoU = np.array(mean_IoU) mIoU = np.mean(mean_IoU) print('Final results:') print('Mean IoU is %.2f\n' % (mIoU * 100.)) results_str = '' for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % \ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) print(results_str) return 100 * iou, 100 * cum_I / cum_U def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, iterations, bert_model): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) header = 'Epoch: [{}]'.format(epoch) train_loss = 0 total_its = 0 for data in metric_logger.log_every(data_loader, print_freq, header): total_its += 1 image, target, sentences, attentions = data image, target, sentences, attentions = image.cuda(non_blocking=True),\ target.cuda(non_blocking=True),\ sentences.cuda(non_blocking=True),\ attentions.cuda(non_blocking=True) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) #print(sentences.shape, attentions.shape, target.shape) #print(sentences) #print('a', sentences.shape) #print('b', attentions.shape) if bert_model is not None: last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) #print('c', last_hidden_states.shape) embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy #print('e', embedding.shape) attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1) #print('f', attentions.shape) output = model(image, embedding, l_mask=attentions) else: output = model(image, sentences, l_mask=attentions) loss = criterion(output, target) optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+ loss.backward() optimizer.step() lr_scheduler.step() torch.cuda.synchronize() train_loss += loss.item() iterations += 1 metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) del image, target, sentences, attentions, loss, output, data if bert_model is not None: del last_hidden_states, embedding #gc.collect() #torch.cuda.empty_cache() torch.cuda.synchronize() #def main(args): def main(local_rank, args): ip = os.environ['MASTER_IP'] port = os.environ['MASTER_PORT'] hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1 rank = int(os.environ['RANK']) # 当前机器编号 gpus = torch.cuda.device_count() # 每台机器的GPU个数 print(local_rank, rank, gpus) #3 0 8 dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank) torch.cuda.set_device(local_rank) dist.barrier() #utils.init_distributed_mode(args) args.distributed=True args.gpu = local_rank print(args) #misc.init_distributed_mode(args) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) #cudnn.benchmark = True dataset, num_classes = get_dataset("train", get_transform(args=args), args=args) dataset_test, _ = get_dataset("val", get_transform(args=args), args=args) # batch sampler print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") #num_tasks = utils.get_world_size() #global_rank = utils.get_rank() num_tasks = hosts*gpus global_rank = rank*gpus+local_rank train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True) test_sampler = torch.utils.data.SequentialSampler(dataset_test) # data loader data_loader = DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) data_loader_test = DataLoader( dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) # model initialization print(args.model) model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, args=args) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False) single_model = model.module if args.model != 'lavt_one': model_class = BertModel bert_model = model_class.from_pretrained(args.ck_bert) bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel bert_model.cuda() bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) single_bert_model = bert_model.module else: bert_model = None single_bert_model = None input_shape = dict() input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) cfg = Dict() cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 cfg.MODEL.MASK_FORMER.NHEADS = 8 cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 cfg.MODEL.MASK_FORMER.PRE_NORM = False maskformer_head = MaskFormerHead(cfg, input_shape) maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) maskformer_head.cuda() maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) single_head = maskformer_head.module print(single_head) if args.resume == "auto": last_ckpt = "" for e in range(args.epochs): ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') if os.path.exists(ckpt_path): last_ckpt = ckpt_path args.resume = last_ckpt # resume training if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') single_model.load_state_dict(checkpoint['model']) single_head.load_state_dict(checkpoint['head_model']) if args.model != 'lavt_one': single_bert_model.load_state_dict(checkpoint['bert_model']) # parameters to optimize backbone_no_decay = list() backbone_decay = list() for name, m in single_model.backbone.named_parameters(): if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: backbone_no_decay.append(m) else: backbone_decay.append(m) if args.model != 'lavt_one': params_to_optimize = [ {'params': backbone_no_decay, 'weight_decay': 0.0}, {'params': backbone_decay}, {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, # the following are the parameters of bert {"params": reduce(operator.concat, [[p for p in single_bert_model.encoder.layer[i].parameters() if p.requires_grad] for i in range(10)])}, {"params": single_head.parameters()} ] else: params_to_optimize = [ {'params': backbone_no_decay, 'weight_decay': 0.0}, {'params': backbone_decay}, {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, # the following are the parameters of bert {"params": reduce(operator.concat, [[p for p in single_model.text_encoder.encoder.layer[i].parameters() if p.requires_grad] for i in range(10)])}, ] # optimizer optimizer = torch.optim.AdamW(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad ) # learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) # housekeeping start_time = time.time() iterations = 0 best_oIoU = -0.1 # resume training (optimizer, lr scheduler, and the epoch) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) resume_epoch = checkpoint['epoch'] else: resume_epoch = -999 # training loops for epoch in range(max(0, resume_epoch+1), args.epochs): data_loader.sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, iterations, bert_model, single_head) iou, overallIoU = evaluate(model, data_loader_test, bert_model, single_head) print('Average object IoU {}'.format(iou)) print('Overall IoU {}'.format(overallIoU)) if single_bert_model is not None: dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict(), 'head_model': single_head.state_dict()} else: dict_to_save = {'model': single_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict()} checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') if utils.is_main_process(): os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) if utils.is_main_process(): ckpt_paths = [] for e in range(args.epochs): ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') print(ckpt_path) if os.path.exists(ckpt_path): ckpt_paths.append(ckpt_path) print(ckpt_paths) for ckpt_path in ckpt_paths[:-args.max_ckpt]: os.remove(ckpt_path) print("remove {:s}".format(ckpt_path)) save_checkpoint = (best_oIoU < overallIoU) if save_checkpoint: print('Better epoch: {}\n'.format(epoch)) if single_bert_model is not None: dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict()} else: dict_to_save = {'model': single_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict()} checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') if utils.is_main_process(): os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) best_oIoU = overallIoU # summarize total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == "__main__": from args import get_parser parser = get_parser() args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # set up distributed learning #utils.init_distributed_mode(args) print('Image size: {}'.format(str(args.img_size))) #main(args) mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())