| import os |
| import argparse |
| from glob import glob |
| import prettytable as pt |
|
|
| from evaluation.metrics import evaluator |
| from config import Config |
|
|
|
|
| config = Config() |
|
|
|
|
| def do_eval(args): |
| |
| |
| for _data_name in args.data_lst.split("+"): |
| pred_data_dir = sorted( |
| glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)) |
| ) |
| if not pred_data_dir: |
| print("Skip dataset {}.".format(_data_name)) |
| continue |
| gt_src = os.path.join(args.gt_root, _data_name) |
| gt_paths = sorted(glob(os.path.join(gt_src, "gt", "*"))) |
| print("#" * 20, _data_name, "#" * 20) |
| filename = os.path.join(args.save_dir, "{}_eval.txt".format(_data_name)) |
| tb = pt.PrettyTable() |
| tb.vertical_char = "&" |
| if config.task == "DIS5K": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "maxFm", |
| "wFmeasure", |
| "MAE", |
| "Smeasure", |
| "meanEm", |
| "HCE", |
| "maxEm", |
| "meanFm", |
| "adpEm", |
| "adpFm", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| elif config.task == "COD": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "Smeasure", |
| "wFmeasure", |
| "meanFm", |
| "meanEm", |
| "maxEm", |
| "MAE", |
| "maxFm", |
| "adpEm", |
| "adpFm", |
| "HCE", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| elif config.task == "HRSOD": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "Smeasure", |
| "maxFm", |
| "meanEm", |
| "MAE", |
| "maxEm", |
| "meanFm", |
| "wFmeasure", |
| "adpEm", |
| "adpFm", |
| "HCE", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| elif config.task == "General": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "maxFm", |
| "wFmeasure", |
| "MAE", |
| "Smeasure", |
| "meanEm", |
| "HCE", |
| "maxEm", |
| "meanFm", |
| "adpEm", |
| "adpFm", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| elif config.task == "General-2K": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "maxFm", |
| "wFmeasure", |
| "MAE", |
| "Smeasure", |
| "meanEm", |
| "HCE", |
| "maxEm", |
| "meanFm", |
| "adpEm", |
| "adpFm", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| elif config.task == "Matting": |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "Smeasure", |
| "maxFm", |
| "meanEm", |
| "MSE", |
| "maxEm", |
| "meanFm", |
| "wFmeasure", |
| "adpEm", |
| "adpFm", |
| "HCE", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| else: |
| tb.field_names = [ |
| "Dataset", |
| "Method", |
| "Smeasure", |
| "MAE", |
| "maxEm", |
| "meanEm", |
| "maxFm", |
| "meanFm", |
| "wFmeasure", |
| "adpEm", |
| "adpFm", |
| "HCE", |
| "mBA", |
| "maxBIoU", |
| "meanBIoU", |
| ] |
| for _model_name in args.model_lst[:]: |
| print("\t", "Evaluating model: {}...".format(_model_name)) |
| pred_paths = [ |
| p.replace( |
| args.gt_root, os.path.join(args.pred_root, _model_name) |
| ).replace("/gt/", "/") |
| for p in gt_paths |
| ] |
| |
| em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator( |
| gt_paths=gt_paths, |
| pred_paths=pred_paths, |
| metrics=args.metrics.split("+"), |
| verbose=config.verbose_eval, |
| ) |
| if config.task == "DIS5K": |
| scores = [ |
| fm["curve"].max().round(3), |
| wfm.round(3), |
| mae.round(3), |
| sm.round(3), |
| em["curve"].mean().round(3), |
| int(hce.round()), |
| em["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| elif config.task == "COD": |
| scores = [ |
| sm.round(3), |
| wfm.round(3), |
| fm["curve"].mean().round(3), |
| em["curve"].mean().round(3), |
| em["curve"].max().round(3), |
| mae.round(3), |
| fm["curve"].max().round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| int(hce.round()), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| elif config.task == "HRSOD": |
| scores = [ |
| sm.round(3), |
| fm["curve"].max().round(3), |
| em["curve"].mean().round(3), |
| mae.round(3), |
| em["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| wfm.round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| int(hce.round()), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| elif config.task == "General": |
| scores = [ |
| fm["curve"].max().round(3), |
| wfm.round(3), |
| mae.round(3), |
| sm.round(3), |
| em["curve"].mean().round(3), |
| int(hce.round()), |
| em["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| elif config.task == "General-2K": |
| scores = [ |
| fm["curve"].max().round(3), |
| wfm.round(3), |
| mae.round(3), |
| sm.round(3), |
| em["curve"].mean().round(3), |
| int(hce.round()), |
| em["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| elif config.task == "Matting": |
| scores = [ |
| sm.round(3), |
| fm["curve"].max().round(3), |
| em["curve"].mean().round(3), |
| mse.round(5), |
| em["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| wfm.round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| int(hce.round()), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
| else: |
| scores = [ |
| sm.round(3), |
| mae.round(3), |
| em["curve"].max().round(3), |
| em["curve"].mean().round(3), |
| fm["curve"].max().round(3), |
| fm["curve"].mean().round(3), |
| wfm.round(3), |
| em["adp"].round(3), |
| fm["adp"].round(3), |
| int(hce.round()), |
| mba.round(3), |
| biou["curve"].max().round(3), |
| biou["curve"].mean().round(3), |
| ] |
|
|
| for idx_score, score in enumerate(scores): |
| scores[idx_score] = ( |
| "." + format(score, ".3f").split(".")[-1] |
| if score <= 1 |
| else format(score, "<4") |
| ) |
| records = [_data_name, _model_name] + scores |
| tb.add_row(records) |
| |
| with open(filename, "w+") as file_to_write: |
| file_to_write.write(str(tb) + "\n") |
| print(tb) |
|
|
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--gt_root", |
| type=str, |
| help="ground-truth root", |
| default=os.path.join(config.data_root_dir, config.task), |
| ) |
| parser.add_argument( |
| "--pred_root", type=str, help="prediction root", default="./e_preds" |
| ) |
| parser.add_argument( |
| "--data_lst", |
| type=str, |
| help="test dataset", |
| default=config.testsets.replace(",", "+"), |
| ) |
| parser.add_argument( |
| "--save_dir", type=str, help="candidate competitors", default="e_results" |
| ) |
| parser.add_argument( |
| "--check_integrity", |
| type=bool, |
| help="whether to check the file integrity", |
| default=False, |
| ) |
| parser.add_argument( |
| "--metrics", |
| type=str, |
| help="candidate competitors", |
| default="+".join( |
| ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ |
| : 100 if "DIS5K" in config.task else -1 |
| ] |
| ), |
| ) |
| args = parser.parse_args() |
| args.metrics = "+".join( |
| ["S", "MAE", "E", "F", "WF", "MBA", "BIoU", "MSE", "HCE"][ |
| : ( |
| 100 |
| if sum(["DIS-" in _data for _data in args.data_lst.split("+")]) |
| else -1 |
| ) |
| ] |
| ) |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
| try: |
| args.model_lst = [ |
| m |
| for m in sorted( |
| os.listdir(args.pred_root), |
| key=lambda x: int(x.split("epoch_")[-1]), |
| reverse=True, |
| ) |
| if int(m.split("epoch_")[-1]) % 1 == 0 |
| ] |
| except: |
| args.model_lst = [m for m in sorted(os.listdir(args.pred_root))] |
|
|
| |
| if args.check_integrity: |
| for _data_name in args.data_lst.split("+"): |
| for _model_name in args.model_lst: |
| gt_pth = os.path.join(args.gt_root, _data_name) |
| pred_pth = os.path.join(args.pred_root, _model_name, _data_name) |
| if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)): |
| print( |
| len(sorted(os.listdir(gt_pth))), |
| len(sorted(os.listdir(pred_pth))), |
| ) |
| print( |
| "The {} Dataset of {} Model is not matching to the ground-truth".format( |
| _data_name, _model_name |
| ) |
| ) |
| else: |
| print(">>> skip check the integrity of each candidates") |
|
|
| |
| do_eval(args) |
|
|