import os import numpy as np import os.path as osp import cv2 import argparse import torch from torch.utils.data import DataLoader import torchvision from dataset import BSDS_Dataset from models import RCF def single_scale_test(model, test_loader, test_list, save_dir): model.eval() if not osp.isdir(save_dir): os.makedirs(save_dir) for idx, image in enumerate(test_loader): image = image.cuda() _, _, H, W = image.shape results = model(image) all_res = torch.zeros((len(results), 1, H, W)) for i in range(len(results)): all_res[i, 0, :, :] = results[i] filename = osp.splitext(test_list[idx])[0] torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename)) fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() fuse_res = ((1 - fuse_res) * 255).astype(np.uint8) cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res) #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') print('Running single-scale test done') def multi_scale_test(model, test_loader, test_list, save_dir): model.eval() if not osp.isdir(save_dir): os.makedirs(save_dir) scale = [0.5, 1, 1.5] for idx, image in enumerate(test_loader): in_ = image[0].numpy().transpose((1, 2, 0)) _, _, H, W = image.shape ms_fuse = np.zeros((H, W), np.float32) for k in range(len(scale)): im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) im_ = im_.transpose((2, 0, 1)) results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0)) fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR) ms_fuse += fuse_res ms_fuse = ms_fuse / len(scale) ### rescale trick # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min()) filename = osp.splitext(test_list[idx])[0] ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8) cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse) #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') print('Running multi-scale test done') if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch Testing') parser.add_argument('--gpu', default='0', type=str, help='GPU ID') parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint') parser.add_argument('--save-dir', help='output folder', default='results/RCF') parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS') args = parser.parse_args() os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if not osp.isdir(args.save_dir): os.makedirs(args.save_dir) test_dataset = BSDS_Dataset(root=args.dataset, split='test') test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False) test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list] assert len(test_list) == len(test_loader) model = RCF().cuda() if osp.isfile(args.checkpoint): print("=> loading checkpoint from '{}'".format(args.checkpoint)) checkpoint = torch.load(args.checkpoint) model.load_state_dict(checkpoint) print("=> checkpoint loaded") else: print("=> no checkpoint found at '{}'".format(args.checkpoint)) print('Performing the testing...') single_scale_test(model, test_loader, test_list, args.save_dir) multi_scale_test(model, test_loader, test_list, args.save_dir)