import os import numpy as np import os.path as osp import cv2 import argparse import time import torch from torch.utils.data import DataLoader import torchvision from dataset import BSDS_Dataset from models import RCF from utils import Logger, Averagvalue, Cross_entropy_loss def train(args, model, train_loader, optimizer, epoch, logger): batch_time = Averagvalue() losses = Averagvalue() model.train() end = time.time() counter = 0 for i, (image, label) in enumerate(train_loader): image, label = image.cuda(), label.cuda() outputs = model(image) loss = torch.zeros(1).cuda() for o in outputs: loss = loss + Cross_entropy_loss(o, label) counter += 1 loss = loss / args.iter_size loss.backward() if counter == args.iter_size: optimizer.step() optimizer.zero_grad() counter = 0 # measure accuracy and record loss losses.update(loss.item(), image.size(0)) batch_time.update(time.time() - end) if i % args.print_freq == 0: logger.info('Epoch: [{0}/{1}][{2}/{3}] '.format(epoch + 1, args.max_epoch, i, len(train_loader)) + \ 'Time {batch_time.val:.3f} (avg: {batch_time.avg:.3f}) '.format(batch_time=batch_time) + \ 'Loss {loss.val:f} (avg: {loss.avg:f}) '.format(loss=losses)) end = time.time() def single_scale_test(model, test_loader, test_list, save_dir): model.eval() if not osp.isdir(save_dir): os.makedirs(save_dir) for idx, image in enumerate(test_loader): image = image.cuda() _, _, H, W = image.shape results = model(image) all_res = torch.zeros((len(results), 1, H, W)) for i in range(len(results)): all_res[i, 0, :, :] = results[i] filename = osp.splitext(test_list[idx])[0] torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename)) fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() fuse_res = ((1 - fuse_res) * 255).astype(np.uint8) cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res) #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') logger.info('Running single-scale test done') def multi_scale_test(model, test_loader, test_list, save_dir): model.eval() if not osp.isdir(save_dir): os.makedirs(save_dir) scale = [0.5, 1, 1.5] for idx, image in enumerate(test_loader): in_ = image[0].numpy().transpose((1, 2, 0)) _, _, H, W = image.shape ms_fuse = np.zeros((H, W), np.float32) for k in range(len(scale)): im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) im_ = im_.transpose((2, 0, 1)) results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0)) fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR) ms_fuse += fuse_res ms_fuse = ms_fuse / len(scale) ### rescale trick # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min()) filename = osp.splitext(test_list[idx])[0] ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8) cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse) #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') logger.info('Running multi-scale test done') if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch Training') parser.add_argument('--batch-size', default=1, type=int, help='batch size') parser.add_argument('--lr', default=1e-6, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--weight-decay', default=2e-4, type=float, help='weight decay') parser.add_argument('--stepsize', default=3, type=int, help='learning rate step size') parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay rate') parser.add_argument('--max-epoch', default=10, type=int, help='the number of training epochs') parser.add_argument('--iter-size', default=10, type=int, help='iter size') parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number') parser.add_argument('--print-freq', default=200, type=int, help='print frequency') parser.add_argument('--gpu', default='0', type=str, help='GPU ID') parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') parser.add_argument('--save-dir', help='output folder', default='results/RCF') parser.add_argument('--dataset', help='root folder of dataset', default='data') args = parser.parse_args() os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if not osp.isdir(args.save_dir): os.makedirs(args.save_dir) logger = Logger(osp.join(args.save_dir, 'log.txt')) logger.info('Called with args:') for (key, value) in vars(args).items(): logger.info('{0:15} | {1}'.format(key, value)) train_dataset = BSDS_Dataset(root=args.dataset, split='train') test_dataset = BSDS_Dataset(root=osp.join(args.dataset, 'HED-BSDS'), split='test') train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, drop_last=True, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, drop_last=False, shuffle=False) test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list] assert len(test_list) == len(test_loader) model = RCF(pretrained='vgg16convs.mat').cuda() parameters = {'conv1-4.weight': [], 'conv1-4.bias': [], 'conv5.weight': [], 'conv5.bias': [], 'conv_down_1-5.weight': [], 'conv_down_1-5.bias': [], 'score_dsn_1-5.weight': [], 'score_dsn_1-5.bias': [], 'score_fuse.weight': [], 'score_fuse.bias': []} for pname, p in model.named_parameters(): if pname in ['conv1_1.weight','conv1_2.weight', 'conv2_1.weight','conv2_2.weight', 'conv3_1.weight','conv3_2.weight','conv3_3.weight', 'conv4_1.weight','conv4_2.weight','conv4_3.weight']: parameters['conv1-4.weight'].append(p) elif pname in ['conv1_1.bias','conv1_2.bias', 'conv2_1.bias','conv2_2.bias', 'conv3_1.bias','conv3_2.bias','conv3_3.bias', 'conv4_1.bias','conv4_2.bias','conv4_3.bias']: parameters['conv1-4.bias'].append(p) elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']: parameters['conv5.weight'].append(p) elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias']: parameters['conv5.bias'].append(p) elif pname in ['conv1_1_down.weight','conv1_2_down.weight', 'conv2_1_down.weight','conv2_2_down.weight', 'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight', 'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight', 'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']: parameters['conv_down_1-5.weight'].append(p) elif pname in ['conv1_1_down.bias','conv1_2_down.bias', 'conv2_1_down.bias','conv2_2_down.bias', 'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias', 'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias', 'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']: parameters['conv_down_1-5.bias'].append(p) elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight', 'score_dsn4.weight','score_dsn5.weight']: parameters['score_dsn_1-5.weight'].append(p) elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias', 'score_dsn4.bias','score_dsn5.bias']: parameters['score_dsn_1-5.bias'].append(p) elif pname in ['score_fuse.weight']: parameters['score_fuse.weight'].append(p) elif pname in ['score_fuse.bias']: parameters['score_fuse.bias'].append(p) optimizer = torch.optim.SGD([ {'params': parameters['conv1-4.weight'], 'lr': args.lr*1, 'weight_decay': args.weight_decay}, {'params': parameters['conv1-4.bias'], 'lr': args.lr*2, 'weight_decay': 0.}, {'params': parameters['conv5.weight'], 'lr': args.lr*100, 'weight_decay': args.weight_decay}, {'params': parameters['conv5.bias'], 'lr': args.lr*200, 'weight_decay': 0.}, {'params': parameters['conv_down_1-5.weight'], 'lr': args.lr*0.1, 'weight_decay': args.weight_decay}, {'params': parameters['conv_down_1-5.bias'], 'lr': args.lr*0.2, 'weight_decay': 0.}, {'params': parameters['score_dsn_1-5.weight'], 'lr': args.lr*0.01, 'weight_decay': args.weight_decay}, {'params': parameters['score_dsn_1-5.bias'], 'lr': args.lr*0.02, 'weight_decay': 0.}, {'params': parameters['score_fuse.weight'], 'lr': args.lr*0.001, 'weight_decay': args.weight_decay}, {'params': parameters['score_fuse.bias'], 'lr': args.lr*0.002, 'weight_decay': 0.}, ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) if args.resume is not None: if osp.isfile(args.resume): logger.info("=> loading checkpoint from '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 logger.info("=> checkpoint loaded") else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) for epoch in range(args.start_epoch, args.max_epoch): logger.info('Performing initial testing...') train(args, model, train_loader, optimizer, epoch, logger) save_dir = osp.join(args.save_dir, 'epoch%d-test' % (epoch + 1)) single_scale_test(model, test_loader, test_list, save_dir) multi_scale_test(model, test_loader, test_list, save_dir) # Save checkpoint save_file = osp.join(args.save_dir, 'checkpoint_epoch{}.pth'.format(epoch + 1)) torch.save({ 'epoch': epoch, 'args': args, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), }, save_file) lr_scheduler.step() # will adjust learning rate logger.close()