Spaces:
Runtime error
Runtime error
File size: 3,824 Bytes
1ff2d47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import numpy as np
import os.path as osp
import cv2
import argparse
import torch
from torch.utils.data import DataLoader
import torchvision
from dataset import BSDS_Dataset
from models import RCF
def single_scale_test(model, test_loader, test_list, save_dir):
model.eval()
if not osp.isdir(save_dir):
os.makedirs(save_dir)
for idx, image in enumerate(test_loader):
image = image.cuda()
_, _, H, W = image.shape
results = model(image)
all_res = torch.zeros((len(results), 1, H, W))
for i in range(len(results)):
all_res[i, 0, :, :] = results[i]
filename = osp.splitext(test_list[idx])[0]
torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
print('Running single-scale test done')
def multi_scale_test(model, test_loader, test_list, save_dir):
model.eval()
if not osp.isdir(save_dir):
os.makedirs(save_dir)
scale = [0.5, 1, 1.5]
for idx, image in enumerate(test_loader):
in_ = image[0].numpy().transpose((1, 2, 0))
_, _, H, W = image.shape
ms_fuse = np.zeros((H, W), np.float32)
for k in range(len(scale)):
im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
im_ = im_.transpose((2, 0, 1))
results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
ms_fuse += fuse_res
ms_fuse = ms_fuse / len(scale)
### rescale trick
# ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
filename = osp.splitext(test_list[idx])[0]
ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
#print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
print('Running multi-scale test done')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Testing')
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
parser.add_argument('--save-dir', help='output folder', default='results/RCF')
parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
args = parser.parse_args()
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if not osp.isdir(args.save_dir):
os.makedirs(args.save_dir)
test_dataset = BSDS_Dataset(root=args.dataset, split='test')
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
assert len(test_list) == len(test_loader)
model = RCF().cuda()
if osp.isfile(args.checkpoint):
print("=> loading checkpoint from '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint)
print("=> checkpoint loaded")
else:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
print('Performing the testing...')
single_scale_test(model, test_loader, test_list, args.save_dir)
multi_scale_test(model, test_loader, test_list, args.save_dir)
|