import argparse import os, sys import math BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(BASE_DIR) import pprint import time import torch import torch.nn.parallel from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda import amp import torch.distributed as dist import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import numpy as np from lib.utils import DataLoaderX, torch_distributed_zero_first from tensorboardX import SummaryWriter import lib.dataset as dataset from lib.config import cfg from lib.config import update_config from lib.core.loss import get_loss from lib.core.function import train from lib.core.function import validate from lib.core.general import fitness from lib.models import get_net from lib.utils import is_parallel from lib.utils.utils import get_optimizer from lib.utils.utils import save_checkpoint from lib.utils.utils import create_logger, select_device from lib.utils import run_anchor def parse_args(): parser = argparse.ArgumentParser(description='Train Multitask network') # general # parser.add_argument('--cfg', # help='experiment configure file name', # required=True, # type=str) # philly parser.add_argument('--modelDir', help='model directory', type=str, default='') parser.add_argument('--logDir', help='log directory', type=str, default='runs/') parser.add_argument('--dataDir', help='data directory', type=str, default='') parser.add_argument('--prevModelDir', help='prev Model directory', type=str, default='') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS') args = parser.parse_args() return args def main(): # set all the configurations args = parse_args() update_config(cfg, args) # Set DDP variables world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 rank = global_rank #print(rank) # TODO: handle distributed training logger # set the logger, tb_log_dir means tensorboard logdir logger, final_output_dir, tb_log_dir = create_logger( cfg, cfg.LOG_DIR, 'train', rank=rank) if rank in [-1, 0]: logger.info(pprint.pformat(args)) logger.info(cfg) writer_dict = { 'writer': SummaryWriter(log_dir=tb_log_dir), 'train_global_steps': 0, 'valid_global_steps': 0, } else: writer_dict = None # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED # bulid up model # start_time = time.time() print("begin to bulid up model...") # DP mode device = select_device(logger, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU* len(cfg.GPUS)) if not cfg.DEBUG \ else select_device(logger, 'cpu') if args.local_rank != -1: assert torch.cuda.device_count() > args.local_rank torch.cuda.set_device(args.local_rank) device = torch.device('cuda', args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend print("load model to device") model = get_net(cfg).to(device) # print("load finished") #model = model.to(device) # print("finish build model") # define loss function (criterion) and optimizer criterion = get_loss(cfg, device=device) optimizer = get_optimizer(cfg, model) # load checkpoint model best_perf = 0.0 best_model = False last_epoch = -1 Encoder_para_idx = [str(i) for i in range(0, 17)] Det_Head_para_idx = [str(i) for i in range(17, 25)] Da_Seg_Head_para_idx = [str(i) for i in range(25, 34)] Ll_Seg_Head_para_idx = [str(i) for i in range(34,43)] lf = lambda x: ((1 + math.cos(x * math.pi / cfg.TRAIN.END_EPOCH)) / 2) * \ (1 - cfg.TRAIN.LRF) + cfg.TRAIN.LRF # cosine lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) begin_epoch = cfg.TRAIN.BEGIN_EPOCH if rank in [-1, 0]: checkpoint_file = os.path.join( os.path.join(cfg.LOG_DIR, cfg.DATASET.DATASET), 'checkpoint.pth' ) if os.path.exists(cfg.MODEL.PRETRAINED): logger.info("=> loading model '{}'".format(cfg.MODEL.PRETRAINED)) checkpoint = torch.load(cfg.MODEL.PRETRAINED) begin_epoch = checkpoint['epoch'] # best_perf = checkpoint['perf'] last_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( cfg.MODEL.PRETRAINED, checkpoint['epoch'])) #cfg.NEED_AUTOANCHOR = False #disable autoanchor if os.path.exists(cfg.MODEL.PRETRAINED_DET): logger.info("=> loading model weight in det branch from '{}'".format(cfg.MODEL.PRETRAINED)) det_idx_range = [str(i) for i in range(0,25)] model_dict = model.state_dict() checkpoint_file = cfg.MODEL.PRETRAINED_DET checkpoint = torch.load(checkpoint_file) begin_epoch = checkpoint['epoch'] last_epoch = checkpoint['epoch'] checkpoint_dict = {k: v for k, v in checkpoint['state_dict'].items() if k.split(".")[1] in det_idx_range} model_dict.update(checkpoint_dict) model.load_state_dict(model_dict) logger.info("=> loaded det branch checkpoint '{}' ".format(checkpoint_file)) if cfg.AUTO_RESUME and os.path.exists(checkpoint_file): logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) checkpoint = torch.load(checkpoint_file) begin_epoch = checkpoint['epoch'] # best_perf = checkpoint['perf'] last_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) # optimizer = get_optimizer(cfg, model) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_file, checkpoint['epoch'])) #cfg.NEED_AUTOANCHOR = False #disable autoanchor # model = model.to(device) if cfg.TRAIN.SEG_ONLY: #Only train two segmentation branchs logger.info('freeze encoder and Det head...') for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Encoder_para_idx + Det_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if cfg.TRAIN.DET_ONLY: #Only train detection branch logger.info('freeze encoder and two Seg heads...') # print(model.named_parameters) for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Encoder_para_idx + Da_Seg_Head_para_idx + Ll_Seg_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if cfg.TRAIN.ENC_SEG_ONLY: # Only train encoder and two segmentation branchs logger.info('freeze Det head...') for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Det_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if cfg.TRAIN.ENC_DET_ONLY or cfg.TRAIN.DET_ONLY: # Only train encoder and detection branchs logger.info('freeze two Seg heads...') for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Da_Seg_Head_para_idx + Ll_Seg_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if cfg.TRAIN.LANE_ONLY: logger.info('freeze encoder and Det head and Da_Seg heads...') # print(model.named_parameters) for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Encoder_para_idx + Da_Seg_Head_para_idx + Det_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if cfg.TRAIN.DRIVABLE_ONLY: logger.info('freeze encoder and Det head and Ll_Seg heads...') # print(model.named_parameters) for k, v in model.named_parameters(): v.requires_grad = True # train all layers if k.split(".")[1] in Encoder_para_idx + Ll_Seg_Head_para_idx + Det_Head_para_idx: print('freezing %s' % k) v.requires_grad = False if rank == -1 and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, device_ids=cfg.GPUS) # model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() # # DDP mode if rank != -1: model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank,find_unused_parameters=True) # assign model params model.gr = 1.0 model.nc = 1 # print('bulid model finished') print("begin to load data") # Data loading normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg=cfg, is_train=True, inputsize=cfg.MODEL.IMAGE_SIZE, transform=transforms.Compose([ transforms.ToTensor(), normalize, ]) ) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if rank != -1 else None train_loader = DataLoaderX( train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=(cfg.TRAIN.SHUFFLE & rank == -1), num_workers=cfg.WORKERS, sampler=train_sampler, pin_memory=cfg.PIN_MEMORY, collate_fn=dataset.AutoDriveDataset.collate_fn ) num_batch = len(train_loader) if rank in [-1, 0]: valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( cfg=cfg, is_train=False, inputsize=cfg.MODEL.IMAGE_SIZE, transform=transforms.Compose([ transforms.ToTensor(), normalize, ]) ) valid_loader = DataLoaderX( valid_dataset, batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS), shuffle=False, num_workers=cfg.WORKERS, pin_memory=cfg.PIN_MEMORY, collate_fn=dataset.AutoDriveDataset.collate_fn ) print('load data finished') if rank in [-1, 0]: if cfg.NEED_AUTOANCHOR: logger.info("begin check anchors") run_anchor(logger,train_dataset, model=model, thr=cfg.TRAIN.ANCHOR_THRESHOLD, imgsz=min(cfg.MODEL.IMAGE_SIZE)) else: logger.info("anchors loaded successfully") det = model.module.model[model.module.detector_index] if is_parallel(model) \ else model.model[model.detector_index] logger.info(str(det.anchors)) # training num_warmup = max(round(cfg.TRAIN.WARMUP_EPOCHS * num_batch), 1000) scaler = amp.GradScaler(enabled=device.type != 'cpu') print('=> start training...') for epoch in range(begin_epoch+1, cfg.TRAIN.END_EPOCH+1): if rank != -1: train_loader.sampler.set_epoch(epoch) # train for one epoch train(cfg, train_loader, model, criterion, optimizer, scaler, epoch, num_batch, num_warmup, writer_dict, logger, device, rank) lr_scheduler.step() # evaluate on validation set if (epoch % cfg.TRAIN.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH) and rank in [-1, 0]: # print('validate') da_segment_results,ll_segment_results,detect_results, total_loss,maps, times = validate( epoch,cfg, valid_loader, valid_dataset, model, criterion, final_output_dir, tb_log_dir, writer_dict, logger, device, rank ) fi = fitness(np.array(detect_results).reshape(1, -1)) #目标检测评价指标 msg = 'Epoch: [{0}] Loss({loss:.3f})\n' \ 'Driving area Segment: Acc({da_seg_acc:.3f}) IOU ({da_seg_iou:.3f}) mIOU({da_seg_miou:.3f})\n' \ 'Lane line Segment: Acc({ll_seg_acc:.3f}) IOU ({ll_seg_iou:.3f}) mIOU({ll_seg_miou:.3f})\n' \ 'Detect: P({p:.3f}) R({r:.3f}) mAP@0.5({map50:.3f}) mAP@0.5:0.95({map:.3f})\n'\ 'Time: inference({t_inf:.4f}s/frame) nms({t_nms:.4f}s/frame)'.format( epoch, loss=total_loss, da_seg_acc=da_segment_results[0],da_seg_iou=da_segment_results[1],da_seg_miou=da_segment_results[2], ll_seg_acc=ll_segment_results[0],ll_seg_iou=ll_segment_results[1],ll_seg_miou=ll_segment_results[2], p=detect_results[0],r=detect_results[1],map50=detect_results[2],map=detect_results[3], t_inf=times[0], t_nms=times[1]) logger.info(msg) # if perf_indicator >= best_perf: # best_perf = perf_indicator # best_model = True # else: # best_model = False # save checkpoint model and best model if rank in [-1, 0]: savepath = os.path.join(final_output_dir, f'epoch-{epoch}.pth') logger.info('=> saving checkpoint to {}'.format(savepath)) save_checkpoint( epoch=epoch, name=cfg.MODEL.NAME, model=model, # 'best_state_dict': model.module.state_dict(), # 'perf': perf_indicator, optimizer=optimizer, output_dir=final_output_dir, filename=f'epoch-{epoch}.pth' ) save_checkpoint( epoch=epoch, name=cfg.MODEL.NAME, model=model, # 'best_state_dict': model.module.state_dict(), # 'perf': perf_indicator, optimizer=optimizer, output_dir=os.path.join(cfg.LOG_DIR, cfg.DATASET.DATASET), filename='checkpoint.pth' ) # save final model if rank in [-1, 0]: final_model_state_file = os.path.join( final_output_dir, 'final_state.pth' ) logger.info('=> saving final model state to {}'.format( final_model_state_file) ) model_state = model.module.state_dict() if is_parallel(model) else model.state_dict() torch.save(model_state, final_model_state_file) writer_dict['writer'].close() else: dist.destroy_process_group() if __name__ == '__main__': main()