| |
| |
| |
|
|
| import torch |
| import numpy as np |
| from itertools import product |
|
|
|
|
| def compute_prf1(count, miss, fp): |
| """ |
| Code modified from https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/evaluation/RH_evaluation/evaluation.py#L90 |
| """ |
| if count == 0: |
| return 0, 0, 0 |
| all_tp = count - miss |
| all_fp = fp |
| all_fn = miss |
| if all_tp == 0: |
| return 0.0, 0.0, 0.0 |
| all_f1_score = round(all_tp / (all_tp + 0.5 * (all_fp + all_fn)), 2) |
| all_recall = round(all_tp / (all_tp + all_fn), 2) |
| all_precision = round(all_tp / (all_tp + all_fp), 2) |
| return 100.0 * all_precision, 100.0 * all_recall, 100.0 * all_f1_score |
|
|
|
|
| def match_2d_greedy( |
| pred_kps, |
| gtkp, |
| valid_mask, |
| imgPath=None, |
| baseline=None, |
| iou_thresh=0.05, |
| valid=None, |
| ind=-1, |
| ): |
| """ |
| Code modified from: https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/trace2/evaluation/eval_3DPW.py#L232 |
| matches groundtruth keypoints to the detection by considering all possible matchings. |
| :return: best possible matching, a list of tuples, where each tuple corresponds to one match of pred_person.to gt_person. |
| the order within one tuple is as follows (idx_pred_kps, idx_gt_kps) |
| """ |
| predList = np.arange(len(pred_kps)) |
| gtList = np.arange(len(gtkp)) |
| |
| |
| combs = list(product(predList, gtList)) |
|
|
| errors_per_pair = {} |
| errors_per_pair_list = [] |
| for comb in combs: |
| vmask = valid_mask[comb[1]] |
| assert vmask.sum() > 0, print("no valid points") |
| errors_per_pair[str(comb)] = np.linalg.norm( |
| pred_kps[comb[0]][vmask, :2] - gtkp[comb[1]][vmask, :2], 2 |
| ) |
| errors_per_pair_list.append(errors_per_pair[str(comb)]) |
|
|
| gtAssigned = np.zeros((len(gtkp),), dtype=bool) |
| opAssigned = np.zeros((len(pred_kps),), dtype=bool) |
| errors_per_pair_list = np.array(errors_per_pair_list) |
|
|
| bestMatch = [] |
| excludedGtBecauseInvalid = [] |
| falsePositiveCounter = 0 |
| while np.sum(gtAssigned) < len(gtAssigned) and np.sum( |
| opAssigned |
| ) + falsePositiveCounter < len(pred_kps): |
| found = False |
| falsePositive = False |
| while not (found): |
| if sum(np.inf == errors_per_pair_list) == len(errors_per_pair_list): |
| print("something went wrong here") |
|
|
| minIdx = np.argmin(errors_per_pair_list) |
| minComb = combs[minIdx] |
| |
| iou = get_bbx_overlap( |
| pred_kps[minComb[0]], gtkp[minComb[1]] |
| ) |
| |
| |
| if ( |
| not (opAssigned[minComb[0]]) |
| and not (gtAssigned[minComb[1]]) |
| and iou >= iou_thresh |
| ): |
| |
| found = True |
| errors_per_pair_list[minIdx] = np.inf |
| else: |
| errors_per_pair_list[minIdx] = np.inf |
| |
| |
| if iou < iou_thresh: |
| |
| |
| found = True |
| falsePositive = True |
| falsePositiveCounter += 1 |
|
|
| |
| |
| if not (valid is None): |
| if valid[minComb[1]]: |
| if not falsePositive: |
| bestMatch.append(minComb) |
| opAssigned[minComb[0]] = True |
| gtAssigned[minComb[1]] = True |
| else: |
| gtAssigned[minComb[1]] = True |
| excludedGtBecauseInvalid.append(minComb[1]) |
|
|
| elif not falsePositive: |
| |
| bestMatch.append(minComb) |
| opAssigned[minComb[0]] = True |
| gtAssigned[minComb[1]] = True |
|
|
| bestMatch = np.array(bestMatch) |
| |
| |
| opAssigned = [] |
| gtAssigned = [] |
| for pair in bestMatch: |
| opAssigned.append(pair[0]) |
| gtAssigned.append(pair[1]) |
| opAssigned.sort() |
| gtAssigned.sort() |
|
|
| falsePositives = [] |
| misses = [] |
|
|
| |
| opIds = np.arange(len(pred_kps)) |
| |
| notAssignedIds = np.setdiff1d(opIds, opAssigned) |
| for notAssignedId in notAssignedIds: |
| falsePositives.append(notAssignedId) |
| gtIds = np.arange(len(gtList)) |
| |
| notAssignedIdsGt = np.setdiff1d(gtIds, gtAssigned) |
|
|
| |
| for notAssignedIdGt in notAssignedIdsGt: |
| if not (valid is None): |
| if valid[notAssignedIdGt]: |
| |
| misses.append(notAssignedIdGt) |
| else: |
| excludedGtBecauseInvalid.append(notAssignedIdGt) |
| else: |
| |
| misses.append(notAssignedIdGt) |
|
|
| return bestMatch, falsePositives, misses |
|
|
|
|
| def get_bbx_overlap(p1, p2): |
| """ |
| Code modifed from https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/simple_romp/trace2/evaluation/eval_3DPW.py#L185 |
| """ |
| min_p1 = np.min(p1, axis=0) |
| min_p2 = np.min(p2, axis=0) |
| max_p1 = np.max(p1, axis=0) |
| max_p2 = np.max(p2, axis=0) |
|
|
| bb1 = {} |
| bb2 = {} |
|
|
| bb1["x1"] = min_p1[0] |
| bb1["x2"] = max_p1[0] |
| bb1["y1"] = min_p1[1] |
| bb1["y2"] = max_p1[1] |
| bb2["x1"] = min_p2[0] |
| bb2["x2"] = max_p2[0] |
| bb2["y1"] = min_p2[1] |
| bb2["y2"] = max_p2[1] |
|
|
| assert bb1["x1"] < bb1["x2"] |
| assert bb1["y1"] < bb1["y2"] |
| assert bb2["x1"] < bb2["x2"] |
| assert bb2["y1"] < bb2["y2"] |
| |
| x_left = max(bb1["x1"], bb2["x1"]) |
| y_top = max(bb1["y1"], bb2["y1"]) |
| x_right = min(bb1["x2"], bb2["x2"]) |
| y_bottom = min(bb1["y2"], bb2["y2"]) |
|
|
| |
| |
| intersection_area = max(0, x_right - x_left + 1) * max(0, y_bottom - y_top + 1) |
|
|
| |
| bb1_area = (bb1["x2"] - bb1["x1"] + 1) * (bb1["y2"] - bb1["y1"] + 1) |
| bb2_area = (bb2["x2"] - bb2["x1"] + 1) * (bb2["y2"] - bb2["y1"] + 1) |
|
|
| |
| |
| |
| iou = intersection_area / float(bb1_area + bb2_area - intersection_area) |
|
|
| return iou |
|
|
|
|
| class AverageMeter(object): |
| """ |
| Code mofied from https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423 |
| Computes and stores the average and current value |
| """ |
|
|
| def __init__(self, name, fmt=":f"): |
| self.name = name |
| self.fmt = fmt |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| if type(val) == torch.Tensor: |
| val = val.detach() |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def __str__(self): |
| fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
| return fmtstr.format(**self.__dict__) |
|
|