| |
|
| | import json |
| | import argparse |
| | from pycocotools import mask as mask_utils |
| | import numpy as np |
| | import tqdm |
| | from sklearn.metrics import balanced_accuracy_score |
| |
|
| | import utils |
| | import cv2 |
| | import os |
| | from PIL import Image |
| | from pycocotools.mask import encode, decode, frPyObjects |
| | from natsort import natsorted |
| |
|
| | pred_root = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap/predictions/ego_query_finalnew" |
| | split_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/data/split.json" |
| | data_path = "/data/work2-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap" |
| | val_set = os.listdir(pred_root) |
| | |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def fuse_davis_mask(mask_list): |
| | fused_mask = np.zeros_like(mask_list[0]) |
| | for mask in mask_list: |
| | fused_mask[mask == 1] = 1 |
| | return fused_mask |
| |
|
| | |
| | def evaluate_take(take_id): |
| |
|
| | pred_path = os.path.join(pred_root, take_id) |
| | cams = os.listdir(pred_path) |
| | exo = cams[0] |
| | pred_path = os.path.join(pred_path, exo) |
| |
|
| |
|
| | gt_path = f"{data_path}/{take_id}/annotation.json" |
| | with open(gt_path, 'r') as fp: |
| | gt = json.load(fp) |
| |
|
| | objs = list(gt['masks'].keys()) |
| | total_cam = [] |
| | for obj in objs: |
| | total_cam += list(gt['masks'][obj].keys()) |
| | total_cam = set(total_cam) |
| | ego_cams = [x for x in total_cam if 'aria' in x] |
| | if len(ego_cams)==0: |
| | print(take_id) |
| | ego = ego_cams[0] |
| | |
| |
|
| | objs_both_have = [] |
| | for obj in objs: |
| | if ego in gt["masks"][obj].keys() and exo in gt["masks"][obj].keys(): |
| | objs_both_have.append(obj) |
| |
|
| | obj_ref = objs_both_have[0] |
| | for obj in objs_both_have: |
| | if len(list(gt["masks"][obj_ref][ego].keys())) < len(list(gt["masks"][obj][ego].keys())): |
| | obj_ref = obj |
| |
|
| |
|
| | IoUs = [] |
| | ShapeAcc = [] |
| | ExistenceAcc = [] |
| | LocationScores = [] |
| | |
| | frames = os.listdir(pred_path) |
| | idx = [f.split(".")[0] for f in frames] |
| |
|
| |
|
| | |
| | |
| | all_ref_keys = np.asarray( |
| | natsorted(gt["masks"][obj_ref][ego]) |
| | ).astype(np.int64) |
| | first_anno_key = str(all_ref_keys[0]) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | obj_list_ego = [] |
| | for obj in objs_both_have: |
| | if first_anno_key in gt["masks"][obj][ego].keys(): |
| | obj_list_ego.append(obj) |
| |
|
| | for id in idx: |
| |
|
| | obj_list_exo = [] |
| | for obj in obj_list_ego: |
| | if id in gt["masks"][obj][exo].keys(): |
| | obj_list_exo.append(obj) |
| |
|
| | gt_mask_list = [] |
| | |
| | for obj in obj_list_exo: |
| | gt_mask = gt["masks"][obj][exo][id] |
| | gt_mask = decode(gt_mask) |
| | gt_mask_list.append(gt_mask) |
| |
|
| | |
| | if len(gt_mask_list) == 0: |
| | continue |
| |
|
| | pred_mask = Image.open(f"{pred_path}/{id}.png") |
| | pred_mask = np.array(pred_mask) |
| | pred_mask[pred_mask != 0] = 1 |
| | h, w = pred_mask.shape |
| |
|
| | fused_gt_mask = fuse_davis_mask(gt_mask_list) |
| |
|
| | |
| | gt_mask = cv2.resize(fused_gt_mask, (w, h), interpolation=cv2.INTER_NEAREST) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | iou, shape_acc = utils.eval_mask(gt_mask, pred_mask) |
| | ex_acc = utils.existence_accuracy(gt_mask, pred_mask) |
| | location_score = utils.location_score(gt_mask, pred_mask, size=(h, w)) |
| | IoUs.append(iou) |
| | ShapeAcc.append(shape_acc) |
| | ExistenceAcc.append(ex_acc) |
| | LocationScores.append(location_score) |
| |
|
| | IoUs = np.array(IoUs) |
| | ShapeAcc = np.array(ShapeAcc) |
| | ExistenceAcc = np.array(ExistenceAcc) |
| | LocationScores = np.array(LocationScores) |
| |
|
| | print(np.mean(IoUs)) |
| | return IoUs.tolist(), ShapeAcc.tolist(), ExistenceAcc.tolist(), LocationScores.tolist() |
| |
|
| | def main(): |
| | total_iou = [] |
| | total_shape_acc = [] |
| | total_existence_acc = [] |
| | total_location_scores = [] |
| | for take_id in val_set: |
| | ious, shape_accs, existence_accs, location_scores = evaluate_take(take_id) |
| | total_iou += ious |
| | total_shape_acc += shape_accs |
| | total_existence_acc += existence_accs |
| | total_location_scores += location_scores |
| |
|
| | print('TOTAL IOU: ', np.mean(total_iou)) |
| | print('TOTAL LOCATION SCORE: ', np.mean(total_location_scores)) |
| | print('TOTAL SHAPE ACC: ', np.mean(total_shape_acc)) |
| | |
| |
|
| | if __name__ == '__main__': |
| | main() |