| import sys |
|
|
| from model.trainer import Trainer |
|
|
| sys.path.insert(0, '.') |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.backends.cudnn as cudnn |
| from torch.nn.parallel import gather |
| import torch.optim.lr_scheduler |
|
|
| import dataset.dataset as myDataLoader |
| import dataset.Transforms as myTransforms |
| from model.metric_tool import ConfuseMatrixMeter |
| from model.utils import BCEDiceLoss, init_seed, adjust_learning_rate |
|
|
| import os, time |
| import numpy as np |
| from argparse import ArgumentParser |
|
|
|
|
|
|
| @torch.no_grad() |
| def val(args, val_loader, model): |
| model.eval() |
|
|
| salEvalVal = ConfuseMatrixMeter(n_class=2) |
|
|
| epoch_loss = [] |
|
|
| total_batches = len(val_loader) |
| print(len(val_loader)) |
| for iter, batched_inputs in enumerate(val_loader): |
|
|
| img, target = batched_inputs |
| pre_img = img[:, 0:3] |
| post_img = img[:, 3:6] |
|
|
| start_time = time.time() |
|
|
| if args.onGPU == True: |
| pre_img = pre_img.cuda() |
| target = target.cuda() |
| post_img = post_img.cuda() |
|
|
| pre_img_var = torch.autograd.Variable(pre_img).float() |
| post_img_var = torch.autograd.Variable(post_img).float() |
| target_var = torch.autograd.Variable(target).float() |
|
|
| |
| output = model(pre_img_var, post_img_var) |
| loss = BCEDiceLoss(output, target_var) |
|
|
| pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long() |
|
|
| |
| time_taken = time.time() - start_time |
|
|
| epoch_loss.append(loss.data.item()) |
|
|
| |
| if args.onGPU and torch.cuda.device_count() > 1: |
| output = gather(pred, 0, dim=0) |
| |
| f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target_var.cpu().numpy()) |
| if iter % 5 == 0: |
| print('\r[%d/%d] F1: %3f loss: %.3f time: %.3f' % (iter, total_batches, f1, loss.data.item(), time_taken), |
| end='') |
|
|
| average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss) |
| scores = salEvalVal.get_scores() |
|
|
| return average_epoch_loss_val, scores |
|
|
|
|
| def train(args, train_loader, model, optimizer, epoch, max_batches, cur_iter=0, lr_factor=1.): |
| |
| model.train() |
|
|
| salEvalVal = ConfuseMatrixMeter(n_class=2) |
| epoch_loss = [] |
|
|
| for iter, batched_inputs in enumerate(train_loader): |
|
|
| img, target = batched_inputs |
| pre_img = img[:, 0:3] |
| post_img = img[:, 3:6] |
|
|
| start_time = time.time() |
|
|
| |
| lr = adjust_learning_rate(args, optimizer, epoch, iter + cur_iter, max_batches, lr_factor=lr_factor) |
|
|
| if args.onGPU == True: |
| pre_img = pre_img.cuda() |
| target = target.cuda() |
| post_img = post_img.cuda() |
|
|
| pre_img_var = torch.autograd.Variable(pre_img).float() |
| post_img_var = torch.autograd.Variable(post_img).float() |
| target_var = torch.autograd.Variable(target).float() |
|
|
| |
| output = model(pre_img_var, post_img_var) |
| loss = BCEDiceLoss(output, target_var) |
|
|
| pred = torch.where(output > 0.5, torch.ones_like(output), torch.zeros_like(output)).long() |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| epoch_loss.append(loss.data.item()) |
| time_taken = time.time() - start_time |
| res_time = (max_batches * args.max_epochs - iter - cur_iter) * time_taken / 3600 |
|
|
| if args.onGPU and torch.cuda.device_count() > 1: |
| output = gather(pred, 0, dim=0) |
|
|
| |
| with torch.no_grad(): |
| f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target_var.cpu().numpy()) |
|
|
| if iter % 5 == 0: |
| print('\riteration: [%d/%d] f1: %.3f lr: %.7f loss: %.3f time:%.3f h' % ( |
| iter + cur_iter, max_batches * args.max_epochs, f1, lr, loss.data.item(), |
| res_time), |
| end='') |
|
|
| average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss) |
| scores = salEvalVal.get_scores() |
|
|
| return average_epoch_loss_train, scores, lr |
|
|
|
|
| def trainValidateSegmentation(args): |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| |
| torch.backends.cudnn.benchmark = True |
|
|
| init_seed(args.seed) |
|
|
| args.savedir = args.savedir + '_' + args.file_root + '_iter_' + str(args.max_steps) + '_lr_' + str(args.lr) + '/' |
|
|
| if args.file_root == 'LEVIR': |
| args.file_root = './levir_cd_256' |
| elif args.file_root == 'WHU': |
| args.file_root = './whu_cd_256' |
| elif args.file_root == 'CLCD': |
| args.file_root = './clcd_256' |
| elif args.file_root == 'SYSU': |
| args.file_root = './sysu_256' |
| elif args.file_root == 'OSCD': |
| args.file_root = 'oscd_256' |
| else: |
| raise TypeError('%s has not defined' % args.file_root) |
|
|
| if not os.path.exists(args.savedir): |
| os.makedirs(args.savedir) |
|
|
|
|
| model = Trainer(args.model_type).float() |
| if args.onGPU: |
| model = model.cuda() |
|
|
| |
| |
|
|
| mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485] |
| std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229] |
|
|
| |
| trainDataset_main = myTransforms.Compose([ |
| myTransforms.Normalize(mean=mean, std=std), |
| myTransforms.Scale(args.inWidth, args.inHeight), |
| myTransforms.RandomCropResize(int(7. / 224. * args.inWidth)), |
| myTransforms.RandomFlip(), |
| myTransforms.RandomExchange(), |
| myTransforms.ToTensor() |
| ]) |
|
|
| valDataset = myTransforms.Compose([ |
| myTransforms.Normalize(mean=mean, std=std), |
| myTransforms.Scale(args.inWidth, args.inHeight), |
| myTransforms.ToTensor() |
| ]) |
|
|
| train_data = myDataLoader.Dataset(file_root=args.file_root, mode="train", transform=trainDataset_main) |
|
|
| trainLoader = torch.utils.data.DataLoader( |
| train_data, |
| batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, pin_memory=True, drop_last=False |
| ) |
|
|
| test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset) |
| testLoader = torch.utils.data.DataLoader( |
| test_data, shuffle=False, |
| batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) |
|
|
|
|
| max_batches = len(trainLoader) |
| print('For each epoch, we have {} batches'.format(max_batches)) |
|
|
| if args.onGPU: |
| cudnn.benchmark = True |
|
|
| args.max_epochs = int(np.ceil(args.max_steps / max_batches)) |
| start_epoch = 0 |
| cur_iter = 0 |
| max_F1_val = 0 |
|
|
| if args.resume is not None: |
| args.resume = args.savedir + 'checkpoint.pth.tar' |
| if os.path.isfile(args.resume): |
| print("=> loading checkpoint '{}'".format(args.resume)) |
| checkpoint = torch.load(args.resume) |
| start_epoch = checkpoint['epoch'] |
| cur_iter = start_epoch * len(trainLoader) |
| |
| model.load_state_dict(checkpoint['state_dict']) |
| print("=> loaded checkpoint '{}' (epoch {})" |
| .format(args.resume, checkpoint['epoch'])) |
| else: |
| print("=> no checkpoint found at '{}'".format(args.resume)) |
|
|
| logFileLoc = args.savedir + args.logFile |
| if os.path.isfile(logFileLoc): |
| logger = open(logFileLoc, 'a') |
| else: |
| logger = open(logFileLoc, 'w') |
| logger.write( |
| "\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % ('Epoch', 'Kappa (val)', 'IoU (val)', 'F1 (val)', 'R (val)', 'P (val)', 'OA (val)')) |
| logger.flush() |
|
|
| optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.99), eps=1e-08, weight_decay=1e-4) |
|
|
| for epoch in range(start_epoch, args.max_epochs): |
| lossTr, score_tr, lr = \ |
| train(args, trainLoader, model, optimizer, epoch, max_batches, cur_iter) |
| cur_iter += len(trainLoader) |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| if epoch == 0: |
| continue |
|
|
| lossVal, score_val = val(args, testLoader, model) |
| torch.cuda.empty_cache() |
| logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % (epoch, score_val['Kappa'], score_val['IoU'], |
| score_val['F1'], score_val['recall'], |
| score_val['precision'], score_val['OA'])) |
| logger.flush() |
|
|
| torch.save({ |
| 'epoch': epoch + 1, |
| 'arch': str(model), |
| 'state_dict': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'lossTr': lossTr, |
| 'lossVal': lossVal, |
| 'F_Tr': score_tr['F1'], |
| 'F_val': score_val['F1'], |
| 'lr': lr |
| }, args.savedir + 'checkpoint.pth.tar') |
|
|
| |
| model_file_name = args.savedir + 'best_model.pth' |
| if epoch % 1 == 0 and max_F1_val <= score_val['F1']: |
| max_F1_val = score_val['F1'] |
| torch.save(model.state_dict(), model_file_name) |
|
|
| print("Epoch " + str(epoch) + ': Details') |
| print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t F1(tr) = %.4f\t F1(val) = %.4f" \ |
| % (epoch, lossTr, lossVal, score_tr['F1'], score_val['F1'])) |
| torch.cuda.empty_cache() |
|
|
| state_dict = torch.load(model_file_name) |
| model.load_state_dict(state_dict) |
|
|
| loss_test, score_test = val(args, testLoader, model) |
| print("\nTest :\t Kappa (te) = %.4f\t IoU (te) = %.4f\t F1 (te) = %.4f\t R (te) = %.4f\t P (te) = %.4f" \ |
| % (score_test['Kappa'], score_test['IoU'], score_test['F1'], score_test['recall'], score_test['precision'])) |
| logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % ('Test', score_test['Kappa'], score_test['IoU'], |
| score_test['F1'], score_test['recall'], |
| score_test['precision'], score_test['OA'])) |
| logger.flush() |
| logger.close() |
|
|
|
|
| if __name__ == '__main__': |
| parser = ArgumentParser() |
| parser.add_argument('--file_root', default="LEVIR", help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD ') |
| parser.add_argument('--inWidth', type=int, default=256, help='Width of RGB image') |
| parser.add_argument('--inHeight', type=int, default=256, help='Height of RGB image') |
| parser.add_argument('--max_steps', type=int, default=80000, help='Max. number of iterations') |
| parser.add_argument('--num_workers', type=int, default=4, help='No. of parallel threads') |
| parser.add_argument('--model_type', type=str, default='small', help='select vit model type | tiny | small') |
| parser.add_argument('--batch_size', type=int, default=16, help='Batch size') |
| parser.add_argument('--step_loss', type=int, default=100, help='Decrease learning rate after how many epochs') |
| parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate') |
| parser.add_argument('--lr_mode', default='poly', help='Learning rate policy, step or poly') |
| parser.add_argument('--seed', default=16, help='initialization seed number') |
| parser.add_argument('--savedir', default='./results', help='Directory to save the results') |
| parser.add_argument('--resume', default=None, help='Use this checkpoint to continue training | ' |
| './results_ep100/checkpoint.pth.tar') |
| parser.add_argument('--logFile', default='trainValLog.txt', |
| help='File that stores the training and validation logs') |
| parser.add_argument('--onGPU', default=True, type=lambda x: (str(x).lower() == 'true'), |
| help='Run on CPU or GPU. If TRUE, then GPU.') |
| parser.add_argument('--gpu_id', default=0, type=int, help='GPU id number') |
|
|
| args = parser.parse_args() |
| print('Called with args:') |
| print(args) |
|
|
| trainValidateSegmentation(args) |
|
|