|
from typing import List, Tuple |
|
|
|
import numpy as np |
|
from numpy import ndarray |
|
|
|
|
|
def get_mAP( |
|
preds: ndarray, |
|
gt_file: str, |
|
taglist: List[str] |
|
) -> Tuple[float, ndarray]: |
|
assert preds.shape[1] == len(taglist) |
|
|
|
|
|
|
|
|
|
|
|
tag2idxs = {} |
|
for idx, tag in enumerate(taglist): |
|
if tag not in tag2idxs: |
|
tag2idxs[tag] = [] |
|
tag2idxs[tag].append(idx) |
|
|
|
|
|
targets = np.zeros_like(preds) |
|
with open(gt_file, "r") as f: |
|
lines = [line.strip("\n").split(",") for line in f.readlines()] |
|
assert len(lines) == targets.shape[0] |
|
for i, line in enumerate(lines): |
|
for tag in line[1:]: |
|
targets[i, tag2idxs[tag]] = 1.0 |
|
|
|
|
|
APs = np.zeros(preds.shape[1]) |
|
for k in range(preds.shape[1]): |
|
APs[k] = _average_precision(preds[:, k], targets[:, k]) |
|
|
|
return APs.mean(), APs |
|
|
|
|
|
def _average_precision(output: ndarray, target: ndarray) -> float: |
|
epsilon = 1e-8 |
|
|
|
|
|
indices = output.argsort()[::-1] |
|
|
|
total_count_ = np.cumsum(np.ones((len(output), 1))) |
|
|
|
target_ = target[indices] |
|
ind = target_ == 1 |
|
pos_count_ = np.cumsum(ind) |
|
total = pos_count_[-1] |
|
pos_count_[np.logical_not(ind)] = 0 |
|
pp = pos_count_ / total_count_ |
|
precision_at_i_ = np.sum(pp) |
|
precision_at_i = precision_at_i_ / (total + epsilon) |
|
|
|
return precision_at_i |
|
|
|
|
|
def get_PR( |
|
pred_file: str, |
|
gt_file: str, |
|
taglist: List[str] |
|
) -> Tuple[float, float, ndarray, ndarray]: |
|
|
|
|
|
|
|
|
|
tag2idxs = {} |
|
for idx, tag in enumerate(taglist): |
|
if tag not in tag2idxs: |
|
tag2idxs[tag] = [] |
|
tag2idxs[tag].append(idx) |
|
|
|
|
|
with open(pred_file, "r", encoding="utf-8") as f: |
|
lines = [line.strip().split(",") for line in f.readlines()] |
|
preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) |
|
for i, line in enumerate(lines): |
|
for tag in line[1:]: |
|
preds[i, tag2idxs[tag]] = True |
|
|
|
|
|
with open(gt_file, "r", encoding="utf-8") as f: |
|
lines = [line.strip().split(",") for line in f.readlines()] |
|
targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) |
|
for i, line in enumerate(lines): |
|
for tag in line[1:]: |
|
targets[i, tag2idxs[tag]] = True |
|
|
|
assert preds.shape == targets.shape |
|
|
|
|
|
TPs = ( preds & targets).sum(axis=0) |
|
FPs = ( preds & ~targets).sum(axis=0) |
|
FNs = (~preds & targets).sum(axis=0) |
|
eps = 1.e-9 |
|
Ps = TPs / (TPs + FPs + eps) |
|
Rs = TPs / (TPs + FNs + eps) |
|
|
|
return Ps.mean(), Rs.mean(), Ps, Rs |
|
|