Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import os.path as osp | |
import cv2 | |
import argparse | |
import torch | |
from torch.utils.data import DataLoader | |
import torchvision | |
from dataset import BSDS_Dataset | |
from models import RCF | |
def single_scale_test(model, test_loader, test_list, save_dir): | |
model.eval() | |
if not osp.isdir(save_dir): | |
os.makedirs(save_dir) | |
for idx, image in enumerate(test_loader): | |
image = image.cuda() | |
_, _, H, W = image.shape | |
results = model(image) | |
all_res = torch.zeros((len(results), 1, H, W)) | |
for i in range(len(results)): | |
all_res[i, 0, :, :] = results[i] | |
filename = osp.splitext(test_list[idx])[0] | |
torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename)) | |
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() | |
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8) | |
cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res) | |
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') | |
print('Running single-scale test done') | |
def multi_scale_test(model, test_loader, test_list, save_dir): | |
model.eval() | |
if not osp.isdir(save_dir): | |
os.makedirs(save_dir) | |
scale = [0.5, 1, 1.5] | |
for idx, image in enumerate(test_loader): | |
in_ = image[0].numpy().transpose((1, 2, 0)) | |
_, _, H, W = image.shape | |
ms_fuse = np.zeros((H, W), np.float32) | |
for k in range(len(scale)): | |
im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) | |
im_ = im_.transpose((2, 0, 1)) | |
results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0)) | |
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() | |
fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR) | |
ms_fuse += fuse_res | |
ms_fuse = ms_fuse / len(scale) | |
### rescale trick | |
# ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min()) | |
filename = osp.splitext(test_list[idx])[0] | |
ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8) | |
cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse) | |
#print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') | |
print('Running multi-scale test done') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='PyTorch Testing') | |
parser.add_argument('--gpu', default='0', type=str, help='GPU ID') | |
parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint') | |
parser.add_argument('--save-dir', help='output folder', default='results/RCF') | |
parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS') | |
args = parser.parse_args() | |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |
if not osp.isdir(args.save_dir): | |
os.makedirs(args.save_dir) | |
test_dataset = BSDS_Dataset(root=args.dataset, split='test') | |
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False) | |
test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list] | |
assert len(test_list) == len(test_loader) | |
model = RCF().cuda() | |
if osp.isfile(args.checkpoint): | |
print("=> loading checkpoint from '{}'".format(args.checkpoint)) | |
checkpoint = torch.load(args.checkpoint) | |
model.load_state_dict(checkpoint) | |
print("=> checkpoint loaded") | |
else: | |
print("=> no checkpoint found at '{}'".format(args.checkpoint)) | |
print('Performing the testing...') | |
single_scale_test(model, test_loader, test_list, args.save_dir) | |
multi_scale_test(model, test_loader, test_list, args.save_dir) | |