Oisin Mac Aodha
added bat code
9ace58a
raw
history blame
No virus
12.5 kB
import numpy as np
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import accuracy_score, balanced_accuracy_score
def compute_error_auc(op_str, gt, pred, prob):
# classification error
pred_int = (pred > prob).astype(np.int)
class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve
fpr, tpr, thresholds = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr)
print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc))
#return class_acc, roc_auc
def calc_average_precision(recall, precision):
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0]-2, -1,-1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii+1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1
ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum()
return float(ave_prec)
def calc_recall_at_x(recall, precision, x=0.95):
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
inds = np.where(precision[::-1]>x)[0]
if len(inds) > 0:
return float(recall[::-1][inds[0]])
else:
return 0.0
def compute_affinity_1d(pred_box, gt_boxes, threshold):
# first entry is start time
score = np.abs(pred_box[0] - gt_boxes[:, 0])
valid_detection = np.min(score) <= threshold
return valid_detection, np.argmin(score)
def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end):
"""
Computes precision and recall. Assumes that each file has been exhaustively
annotated. Will not count predicted detection with a start time that is within
ignore_start_end miliseconds of the start or end of the file.
eval_mode == 'detection'
Returns overall detection results (not per class)
eval_mode == 'per_class'
Filters ground truth based on class of interest. This will ignore predictions
assigned to gt with unknown class.
eval_mode = 'top_class'
Turns the problem into a binary one and selects the top predicted class
for each predicted detection
"""
# get predictions and put in array
pred_boxes = []
confidence = []
pred_class = []
file_ids = []
for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file
file_dur = gts[pid]['duration']
valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end))
pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds],
pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T)
if eval_mode == 'detection':
# overall detection
confidence.append(pp['det_probs'][valid_inds])
elif eval_mode == 'per_class':
# per class
confidence.append(pp['class_probs'].T[valid_inds, class_of_interest])
elif eval_mode == 'top_class':
# per class - note that sometimes 'class_probs' can be num_classes+1 in size
top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1)
confidence.append(pp['class_probs'].T[valid_inds, top_class])
pred_class.append(top_class)
# be careful, assuming the order in the list is same as GT
file_ids.append([pid]*valid_inds.sum())
confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(np.int)
pred_boxes = np.vstack(pred_boxes)
if len(pred_class) > 0:
pred_class = np.hstack(pred_class)
# extract relevant ground truth boxes
gt_boxes = []
gt_assigned = []
gt_class = []
gt_generic_class = []
num_positives = 0
for gg in gts:
# filter ground truth calls that are too near the start or end of the file
file_dur = gg['duration']
valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end))
# note, files with the incorrect duration will cause a problem
if (gg['start_times'] > file_dur).sum() > 0:
print('Error: file duration incorrect for', gg['id'])
assert(False)
boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds],
gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T
gen_class = gg['class_ids'][valid_inds] == -1
class_ids = gg['class_ids'][valid_inds]
# keep track of the number of relevant ground truth calls
if eval_mode == 'detection':
# all valid ones
num_positives += len(gg['start_times'][valid_inds])
elif eval_mode == 'per_class':
# all valid ones with class of interest
num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum()
elif eval_mode == 'top_class':
# all valid ones with non generic class
num_positives += (gg['class_ids'][valid_inds] > -1).sum()
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
if eval_mode == 'per_class':
class_inds = (class_ids == class_of_interest) | (class_ids == -1)
boxes = boxes[class_inds, :]
gen_class = gen_class[class_inds]
class_ids = class_ids[class_inds]
gt_assigned.append(np.zeros(boxes.shape[0]))
gt_boxes.append(boxes)
gt_generic_class.append(gen_class)
gt_class.append(class_ids)
# loop through detections and keep track of those that have been assigned
true_pos = np.zeros(confidence.shape[0])
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
sorted_inds = np.argsort(confidence)[::-1] # sort high to low
for ii, ind in enumerate(sorted_inds):
gt_id = file_ids[ind]
valid_det = False
if gt_boxes[gt_id].shape[0] > 0:
# compute overlap
valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id],
threshold)
# valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True
if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]):
# needs to be the same class
count_as_true_pos = False
if count_as_true_pos:
true_pos[ii] = 1
gt_assigned[gt_id][det_ind] = 1
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
# and eval_mode != 'detection', then ignore it
if gt_generic_class[gt_id][det_ind]:
if eval_mode == 'per_class' or eval_mode == 'top_class':
valid_inds[ii] = False
# store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
thresholds_inds[thresholds_inds==0] = -1
# compute precision and recall
true_pos = true_pos[valid_inds]
false_pos_c = np.cumsum(1-true_pos)
true_pos_c = np.cumsum(true_pos)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps)
results = {}
results['recall'] = recall
results['precision'] = precision
results['num_gt'] = num_positives
results['thresholds'] = thresholds
results['thresholds_inds'] = thresholds_inds
if num_positives == 0:
results['avg_prec'] = np.nan
results['rec_at_x'] = np.nan
else:
results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5)
results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5)
return results
def compute_file_accuracy_simple(gts, preds, num_classes):
"""
Evaluates the prediction accuracy at a file level.
Does not include files that have more than one class (or the generic class).
Simply chooses the class per file that has the highest probability overall.
"""
gt_valid = []
pred_valid = []
for ii in range(len(gts)):
gt_class = np.unique(gts[ii]['class_ids'])
if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0])
pred = preds[ii]['class_probs'][:num_classes, :].T
pred_valid.append(np.argmax(pred.mean(0)))
acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
res = {}
res['num_valid_files'] = len(gt_valid)
res['num_total_files'] = len(gts)
res['gt_valid_file'] = gt_valid
res['pred_valid_file'] = pred_valid
res['file_acc'] = np.round(acc, 5)
return res
def compute_file_accuracy(gts, preds, num_classes):
"""
Evaluates the prediction accuracy at a file level.
Does not include files that have more than one class (or the unknown class).
Tries several different detection thresholds and picks the best one.
"""
# compute min and max scoring range - then threshold
min_val = 0
mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0]
if len(mins) > 0:
min_val = np.min(mins)
max_val = 1.0
maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0]
if len(maxes) > 0:
max_val = np.max(maxes)
thresh = np.linspace(min_val, max_val, 11)[:10]
# loop over the files and store the accuracy at different prediction thresholds
# only include gt files that have one valid species
gt_valid = []
pred_valid_all = []
for ii in range(len(gts)):
gt_class = np.unique(gts[ii]['class_ids'])
if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0])
pred = preds[ii]['class_probs'][:num_classes, :].T
p_class = np.zeros(len(thresh))
for tt in range(len(thresh)):
p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax()
pred_valid_all.append(p_class)
# pick the result corresponding to the overall best threshold
pred_valid_all = np.vstack(pred_valid_all)
acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0)
best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
res = {}
res['num_valid_files'] = len(gt_valid)
res['num_total_files'] = len(gts)
res['gt_valid_file'] = gt_valid
res['pred_valid_file'] = pred_valid
res['file_acc'] = np.round(best_acc, 5)
return res
def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0):
"""
Computes metrics derived from the precision and recall.
Assumes that gts and preds are both lists of the same lengths, with ground
truth and predictions contained within.
Returns the overall detection results, and per class results
"""
assert(len(gts) == len(preds))
num_classes = len(class_names)
# evaluate detection on its own i.e. ignoring class
det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end)
top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end)
det_results['top_class'] = top_class
# per class evaluation
det_results['class_pr'] = []
for cc in range(num_classes):
res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end)
res['name'] = class_names[cc]
det_results['class_pr'].append(res)
# ignores classes that are not present in the test set
det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0])
det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5)
# file level evaluation
res_file = compute_file_accuracy(gts, preds, num_classes)
det_results.update(res_file)
return det_results