|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
from evaluator import Eval_thread |
|
|
from dataloader import EvalDataset |
|
|
|
|
|
import sys |
|
|
sys.path.append('..') |
|
|
|
|
|
|
|
|
def main(cfg): |
|
|
dataset_names = cfg.datasets.split('+') |
|
|
root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr] |
|
|
root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet') |
|
|
print('root_dir_predictions:', root_dir_predictions) |
|
|
root_dir_prediction = root_dir_predictions[0] |
|
|
root_dir_good_ones = 'good_ones' |
|
|
for dataset in dataset_names: |
|
|
dir_prediction = os.path.join(root_dir_prediction, dataset) |
|
|
dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset) |
|
|
dir_gt = os.path.join(cfg.gt_dir, dataset) |
|
|
loader = EvalDataset( |
|
|
dir_prediction, |
|
|
dir_gt, |
|
|
return_predpath=True, |
|
|
return_gtpath=True |
|
|
) |
|
|
loader_comp = EvalDataset( |
|
|
dir_prediction_comp, |
|
|
dir_gt, |
|
|
return_predpath=True |
|
|
) |
|
|
print('Selecting predictions from {}'.format(dir_prediction)) |
|
|
thread = Eval_thread(loader, cuda=cfg.cuda) |
|
|
s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2) |
|
|
dir_good_ones = os.path.join(root_dir_good_ones, dataset) |
|
|
os.makedirs(dir_good_ones, exist_ok=True) |
|
|
print('have good_ones {}'.format(len(good_ones))) |
|
|
for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt): |
|
|
dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2]) |
|
|
os.makedirs(dir_category, exist_ok=True) |
|
|
save_path = os.path.join(dir_category, good_one.split('/')[-1]) |
|
|
sal_map = cv2.imread(good_one) |
|
|
sal_map_gt = cv2.imread(good_one_gt) |
|
|
sal_map_comp = cv2.imread(good_one_comp) |
|
|
image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg') |
|
|
image = cv2.imread(image_path) |
|
|
cv2.imwrite(save_path, sal_map) |
|
|
split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127 |
|
|
comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp]) |
|
|
save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:])) |
|
|
cv2.imwrite(save_path_comp, comp) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015') |
|
|
parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT') |
|
|
parser.add_argument('--cuda', type=bool, default=True) |
|
|
config = parser.parse_args() |
|
|
main(config) |