Spaces:
Running
Running
| import json | |
| import os | |
| import sys | |
| import cv2 | |
| import numpy as np | |
| from shapely.geometry import Polygon | |
| from tabulate import tabulate | |
| def get_image_path(image_dir, file_name_wo_ext): | |
| ext_list = ["", ".jpg", ".JPG", ".png", ".PNG", ".jpeg"] | |
| image_path = None | |
| for ext in ext_list: | |
| image_path_tmp = os.path.join(image_dir, file_name_wo_ext + ext) | |
| if os.path.exists(image_path_tmp): | |
| image_path = image_path_tmp | |
| break | |
| return image_path | |
| def visual_badcase(image_path, pred_list, label_list, output_dir="visual_badcase", info=None, prefix=""): | |
| """ """ | |
| img = cv2.imread(image_path) if os.path.exists(image_path) is not None else None | |
| if img is None: | |
| print("--> Warning: skip, given iamge NOT exists: {}".format(image_path)) | |
| return None | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| for label in label_list: | |
| points, class_id = label["poly"], label["category_id"] | |
| pts = np.array(points).reshape((1, -1, 2)).astype(np.int32) | |
| cv2.polylines(img, pts, isClosed=True, color=(0, 255, 0), thickness=3) | |
| cv2.putText(img, "gt:" + str(class_id), tuple(pts[0][0].tolist()), font, 1, (0, 255, 0), 2) | |
| for label in pred_list: | |
| points, class_id = label["poly"], label["category_id"] | |
| pts = np.array(points).reshape((1, -1, 2)).astype(np.int32) | |
| cv2.polylines(img, pts, isClosed=True, color=(255, 0, 0), thickness=3) | |
| cv2.putText(img, "pred:" + str(class_id), tuple(pts[0][-1].tolist()), font, 1, (255, 0, 0), 2) | |
| if info is not None: | |
| cv2.putText(img, str(info), (40, 40), font, 1, (0, 0, 255), 2) | |
| output_path = os.path.join(output_dir, prefix + os.path.basename(image_path) + "_vis.jpg") | |
| cv2.imwrite(output_path, img) | |
| return output_path | |
| def pub_load_gt_from_json(json_path): | |
| """ """ | |
| with open(json_path) as f: | |
| gt_info = json.load(f) | |
| gt_image_list = gt_info["images"] | |
| gt_anno_list = gt_info["annotations"] | |
| id_to_image_info = {} | |
| for image_item in gt_image_list: | |
| id_to_image_info[image_item["id"]] = { | |
| "file_name": image_item["file_name"], | |
| "group_name": image_item.get("group_name", "huntie"), | |
| } | |
| group_info = {} | |
| for annotation_item in gt_anno_list: | |
| image_info = id_to_image_info[annotation_item["image_id"]] | |
| image_name, group_name = image_info["file_name"], image_info["group_name"] | |
| # import ipdb;ipdb.set_trace() | |
| if image_name == "15_103.tar_1705.05489.gz_main_12_ori.jpg": | |
| print(image_info["file_name"], annotation_item["image_id"]) | |
| # import ipdb;ipdb.set_trace() | |
| if group_name not in group_info: | |
| group_info[group_name] = {} | |
| if image_name not in group_info[group_name]: | |
| group_info[group_name][image_name] = [] | |
| box_xywh = annotation_item["bbox"] | |
| box_xyxy = [box_xywh[0], box_xywh[1], box_xywh[0] + box_xywh[2], box_xywh[1] + box_xywh[3]] | |
| pts = np.round( | |
| [box_xyxy[0], box_xyxy[1], box_xyxy[2], box_xyxy[1], box_xyxy[2], box_xyxy[3], box_xyxy[0], box_xyxy[3]] | |
| ) | |
| anno_info = { | |
| "category_id": annotation_item["category_id"], | |
| "poly": pts, | |
| "secondary_id": annotation_item.get("secondary_id", -1), | |
| "direction_id": annotation_item.get("direction_id", -1), | |
| } | |
| group_info[group_name][image_name].append(anno_info) | |
| group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()]) | |
| print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str)) | |
| return group_info | |
| def load_gt_from_json(json_path): | |
| """ """ | |
| with open(json_path) as f: | |
| gt_info = json.load(f) | |
| gt_image_list = gt_info["images"] | |
| gt_anno_list = gt_info["annotations"] | |
| id_to_image_info = {} | |
| for image_item in gt_image_list: | |
| id_to_image_info[image_item["id"]] = { | |
| "file_name": image_item["file_name"], | |
| "group_name": image_item.get("group_name", "huntie"), | |
| } | |
| group_info = {} | |
| for annotation_item in gt_anno_list: | |
| image_info = id_to_image_info[annotation_item["image_id"]] | |
| image_name, group_name = image_info["file_name"], image_info["group_name"] | |
| if group_name not in group_info: | |
| group_info[group_name] = {} | |
| if image_name not in group_info[group_name]: | |
| group_info[group_name][image_name] = [] | |
| anno_info = { | |
| "category_id": annotation_item["category_id"], | |
| "poly": annotation_item["poly"], | |
| "secondary_id": annotation_item.get("secondary_id", -1), | |
| "direction_id": annotation_item.get("direction_id", -1), | |
| } | |
| group_info[group_name][image_name].append(anno_info) | |
| group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()]) | |
| print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str)) | |
| return group_info | |
| def calc_iou(label, detect): | |
| label_box = [] | |
| detect_box = [] | |
| d_area = [] | |
| for i in range(0, len(detect)): | |
| pred_poly = detect[i]["poly"] | |
| box_det = [] | |
| for k in range(0, 4): | |
| box_det.append([pred_poly[2 * k], pred_poly[2 * k + 1]]) | |
| detect_box.append(box_det) | |
| try: | |
| poly = Polygon(box_det) | |
| d_area.append(poly.area) | |
| except: | |
| print("invalid detects", pred_poly) | |
| exit(-1) | |
| l_area = [] | |
| for i in range(0, len(label)): | |
| gt_poly = label[i]["poly"] | |
| box_gt = [] | |
| for k in range(4): | |
| box_gt.append([gt_poly[2 * k], gt_poly[2 * k + 1]]) | |
| label_box.append(box_gt) | |
| try: | |
| poly = Polygon(box_gt) | |
| l_area.append(poly.area) | |
| except: | |
| print("invalid detects", gt_poly) | |
| exit(-1) | |
| ol_areas = [] | |
| for i in range(0, len(detect_box)): | |
| ol_areas.append([]) | |
| poly1 = Polygon(detect_box[i]) | |
| for j in range(0, len(label_box)): | |
| poly2 = Polygon(label_box[j]) | |
| try: | |
| ol_area = poly2.intersection(poly1).area | |
| except: | |
| print("invaild pair", detect_box[i], label_box[j]) | |
| ol_areas[i].append(0.0) | |
| else: | |
| ol_areas[i].append(ol_area) | |
| d_ious = [0.0] * len(detect_box) | |
| l_ious = [0.0] * len(label_box) | |
| for i in range(0, len(detect_box)): | |
| for j in range(0, len(label_box)): | |
| if int(label[j]["category_id"]) == int(detect[i]["category_id"]): | |
| iou = min(ol_areas[i][j] / (d_area[i] + 1e-10), ol_areas[i][j] / (l_area[j] + 1e-10)) | |
| else: | |
| iou = 0 | |
| d_ious[i] = max(d_ious[i], iou) | |
| l_ious[j] = max(l_ious[j], iou) | |
| return l_ious, d_ious | |
| def eval(instance_info): | |
| img_name, label_info = instance_info | |
| label = label_info["gt"] | |
| detect = label_info["det"] | |
| l_ious, d_ious = calc_iou(label, detect) | |
| return [img_name, d_ious, l_ious, detect, label] | |
| def static_with_class(rets, iou_thresh=0.7, is_verbose=True, map_info=None, src_image_dir=None, visualization_dir=None): | |
| if is_verbose: | |
| table_head = ["Class_id", "Class_name", "Pre_hit", "Pre_num", "GT_hit", "GT_num", "Precision", "Recall", "F-score"] | |
| else: | |
| table_head = ["Class_id", "Class_name", "Precision", "Recall", "F-score"] | |
| table_body = [] | |
| class_dict = {} | |
| for i in range(len(rets)): | |
| img_name, d_ious, l_ious, detects, labels = rets[i] | |
| item_lv, item_dv, item_dm, item_lm = 0, 0, 0, 0 | |
| for label in labels: | |
| item_lv += 1 | |
| category_id = label["category_id"] | |
| if category_id not in class_dict: | |
| class_dict[category_id] = {} | |
| class_dict[category_id]["dm"] = 0 | |
| class_dict[category_id]["dv"] = 0 | |
| class_dict[category_id]["lm"] = 0 | |
| class_dict[category_id]["lv"] = 0 | |
| class_dict[category_id]["lv"] += 1 | |
| for det in detects: | |
| item_dv += 1 | |
| category_id = det["category_id"] | |
| if category_id not in class_dict: | |
| print("--> category_id not exists in gt: {}".format(category_id)) | |
| continue | |
| class_dict[category_id]["dv"] += 1 | |
| for idx, iou in enumerate(d_ious): | |
| if iou >= iou_thresh: | |
| item_dm += 1 | |
| class_dict[detects[idx]["category_id"]]["dm"] += 1 | |
| for idx, iou in enumerate(l_ious): | |
| if iou >= iou_thresh: | |
| item_lm += 1 | |
| class_dict[labels[idx]["category_id"]]["lm"] += 1 | |
| item_p = item_dm / (item_dv + 1e-6) | |
| item_r = item_lm / (item_lv + 1e-6) | |
| item_f = 2 * item_p * item_r / (item_p + item_r + 1e-6) | |
| if item_f < 0.97 and src_image_dir is not None: | |
| image_path = get_image_path(src_image_dir, os.path.basename(img_name)) | |
| visualization_output = visualization_dir if visualization_dir is not None else "./visualization_badcase" | |
| item_info = "IOU{}, {}, {}, {}".format(iou_thresh, item_r, item_p, item_f) | |
| vis_path = visual_badcase( | |
| image_path, | |
| detects, | |
| labels, | |
| output_dir=visualization_output, | |
| info=item_info, | |
| prefix="{:02d}_".format(int(item_f * 100)), | |
| ) | |
| if is_verbose: | |
| print("--> info: save visualization at: {}".format(vis_path)) | |
| dm, dv, lm, lv = 0, 0, 0, 0 | |
| map_info = {} if map_info is None else map_info | |
| for key in class_dict.keys(): | |
| dm += class_dict[key]["dm"] | |
| dv += class_dict[key]["dv"] | |
| lm += class_dict[key]["lm"] | |
| lv += class_dict[key]["lv"] | |
| p = class_dict[key]["dm"] / (class_dict[key]["dv"] + 1e-6) | |
| r = class_dict[key]["lm"] / (class_dict[key]["lv"] + 1e-6) | |
| fscore = 2 * p * r / (p + r + 1e-6) | |
| if is_verbose: | |
| table_body.append( | |
| ( | |
| key, | |
| map_info.get("primary_map", {}).get(str(key), str(key)), | |
| class_dict[key]["dm"], | |
| class_dict[key]["dv"], | |
| class_dict[key]["lm"], | |
| class_dict[key]["lv"], | |
| p, | |
| r, | |
| fscore, | |
| ) | |
| ) | |
| else: | |
| table_body.append((key, map_info.get(str(key), str(key)), p, r, fscore)) | |
| p = dm / (dv + 1e-6) | |
| r = lm / (lv + 1e-6) | |
| f = 2 * p * r / (p + r + 1e-6) | |
| table_body_sorted = sorted(table_body, key=lambda x: int((x[0]))) | |
| if is_verbose: | |
| table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", dm, dv, lm, lv, p, r, f)) | |
| else: | |
| table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", p, r, f)) | |
| print(tabulate(table_body_sorted, headers=table_head, tablefmt="pipe")) | |
| return [table_head] + table_body_sorted | |
| def multiproc(func, task_list, proc_num=30, retv=True, progress_bar=False): | |
| from multiprocessing import Pool | |
| pool = Pool(proc_num) | |
| rets = [] | |
| if progress_bar: | |
| import tqdm | |
| with tqdm.tqdm(total=len(task_list)) as t: | |
| for ret in pool.imap(func, task_list): | |
| rets.append(ret) | |
| t.update(1) | |
| else: | |
| for ret in pool.imap(func, task_list): | |
| rets.append(ret) | |
| pool.close() | |
| pool.join() | |
| if retv: | |
| return rets | |
| def eval_and_show( | |
| label_dict, detect_dict, output_dir, iou_thresh=0.7, map_info=None, src_image_dir=None, visualization_dir=None | |
| ): | |
| """ """ | |
| evaluation_group_info = {} | |
| for group_name, gt_info in label_dict.items(): | |
| group_pair_list = [] | |
| for file_name, value_list in gt_info.items(): | |
| if file_name not in detect_dict: | |
| print("--> missing pred:", file_name) | |
| continue | |
| group_pair_list.append([file_name, {"gt": gt_info[file_name], "det": detect_dict[file_name]}]) | |
| evaluation_group_info[group_name] = group_pair_list | |
| res_info_all = {} | |
| for group_name, group_pair_list in evaluation_group_info.items(): | |
| print(" ------- group name: {} -----------".format(group_name)) | |
| rets = multiproc(eval, group_pair_list, proc_num=16) | |
| group_name_map_info = map_info.get(group_name, None) if map_info is not None else None | |
| res_info = static_with_class( | |
| rets, | |
| iou_thresh=iou_thresh, | |
| map_info=group_name_map_info, | |
| src_image_dir=src_image_dir, | |
| visualization_dir=visualization_dir, | |
| ) | |
| res_info_all[group_name] = res_info | |
| evaluation_res_info_path = os.path.join(output_dir, "results_val.json") | |
| with open(evaluation_res_info_path, "w") as f: | |
| json.dump(res_info_all, f, ensure_ascii=False, indent=4) | |
| print("--> info: evaluation result is saved at {}".format(evaluation_res_info_path)) | |
| if __name__ == "__main__": | |
| if len(sys.argv) != 5: | |
| print("Usage: python {} gt_json_path pred_json_path output_dir iou_thresh".format(__file__)) | |
| exit(-1) | |
| else: | |
| print("--> info: {}".format(sys.argv)) | |
| gt_json_path, pred_json_path, output_dir, iou_thresh = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] | |
| label_dict = load_gt_from_json(gt_json_path) | |
| with open(pred_json_path, "r") as f: | |
| detect_dict = json.load(f) | |
| src_image_dir = None | |
| eval_and_show( | |
| label_dict, | |
| detect_dict, | |
| output_dir, | |
| iou_thresh=iou_thresh, | |
| map_info=None, | |
| src_image_dir=src_image_dir, | |
| visualization_dir=None, | |
| ) | |