fmsfm's picture
Upload 13 files
1ff2d47
raw
history blame
No virus
3.82 kB
import os
import numpy as np
import os.path as osp
import cv2
import argparse
import torch
from torch.utils.data import DataLoader
import torchvision
from dataset import BSDS_Dataset
from models import RCF
def single_scale_test(model, test_loader, test_list, save_dir):
model.eval()
if not osp.isdir(save_dir):
os.makedirs(save_dir)
for idx, image in enumerate(test_loader):
image = image.cuda()
_, _, H, W = image.shape
results = model(image)
all_res = torch.zeros((len(results), 1, H, W))
for i in range(len(results)):
all_res[i, 0, :, :] = results[i]
filename = osp.splitext(test_list[idx])[0]
torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
print('Running single-scale test done')
def multi_scale_test(model, test_loader, test_list, save_dir):
model.eval()
if not osp.isdir(save_dir):
os.makedirs(save_dir)
scale = [0.5, 1, 1.5]
for idx, image in enumerate(test_loader):
in_ = image[0].numpy().transpose((1, 2, 0))
_, _, H, W = image.shape
ms_fuse = np.zeros((H, W), np.float32)
for k in range(len(scale)):
im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
im_ = im_.transpose((2, 0, 1))
results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
ms_fuse += fuse_res
ms_fuse = ms_fuse / len(scale)
### rescale trick
# ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
filename = osp.splitext(test_list[idx])[0]
ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
#print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
print('Running multi-scale test done')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Testing')
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
parser.add_argument('--save-dir', help='output folder', default='results/RCF')
parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
args = parser.parse_args()
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if not osp.isdir(args.save_dir):
os.makedirs(args.save_dir)
test_dataset = BSDS_Dataset(root=args.dataset, split='test')
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
assert len(test_list) == len(test_loader)
model = RCF().cuda()
if osp.isfile(args.checkpoint):
print("=> loading checkpoint from '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint)
print("=> checkpoint loaded")
else:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
print('Performing the testing...')
single_scale_test(model, test_loader, test_list, args.save_dir)
multi_scale_test(model, test_loader, test_list, args.save_dir)