shark_detection / metrics.py
piperod91's picture
fixing border and adding logo
690e199
raw
history blame
9.98 kB
import numpy as np
import matplotlib.pyplot as plt
def get_top_predictions(prediction = None, threshold = 0.7):
if prediction is None:
return None, None
else:
sorted_scores_ids = prediction.pred_instances.scores.argsort()[::-1]
sorted_scores = prediction.pred_instances.scores[sorted_scores_ids]
sorted_predictions = prediction.pred_instances.labels[sorted_scores_ids]
return {'pred_above_thresh': sorted_predictions[sorted_scores > threshold],
'pred_above_thresh_id': sorted_scores_ids[sorted_scores > threshold],
'pred_above_thresh_scores': sorted_scores[sorted_scores > threshold],
'pred_above_thresh_bboxes': prediction.pred_instances['bboxes'][sorted_scores_ids][sorted_scores > threshold]}
def add_class_labels(top_pred = {}, class_labels = None):
if class_labels == None:
print('No class labels provided, returning original dictionary')
return top_pred
else:
top_pred['pred_above_thresh_labels'] = [class_labels[x].lower() for x in top_pred['pred_above_thresh']]
top_pred['any_detection'] = len(top_pred['pred_above_thresh_labels']) > 0
if top_pred['any_detection']:
# Get shark / human / unknown vectors
top_pred['is_shark'] = np.array([1 if 'shark' in x else 0 for x in top_pred['pred_above_thresh_labels']])
top_pred['is_human'] = np.array([1 if 'person' in x else 1 if 'surfer' in x else 0 for x in top_pred['pred_above_thresh_labels']])
top_pred['is_unknown'] = np.array([1 if 'unidentifiable' in x else 0 for x in top_pred['pred_above_thresh_labels']])
# Get shark / human / unknown numbers of detections
top_pred['shark_n'] = np.sum(top_pred['is_shark'])
top_pred['human_n'] = np.sum(top_pred['is_human'])
top_pred['unknown_n'] = np.sum(top_pred['is_unknown'])
else:
# Get shark / human / unknown vectors
top_pred['is_shark'] = None
top_pred['is_human'] = None
top_pred['is_unknown'] = None
# Get shark / human / unknown numbers of detections
top_pred['shark_n'] = 0
top_pred['human_n'] = 0
top_pred['unknown_n'] = 0
return top_pred
def add_class_sizes(top_pred = {}, class_sizes = None):
size_list = []
shark_size_list = []
if top_pred['any_detection']:
for tmp_pred in top_pred['pred_above_thresh_labels']:
tmp_class_sizes = class_sizes[tmp_pred.lower()]
if tmp_class_sizes == None:
size_list.append(None)
continue
else:
size_list.append(tmp_class_sizes['feet'])
if 'shark' in tmp_pred.lower():
shark_size_list.append(np.mean(tmp_class_sizes['feet']))
top_pred['pred_above_thresh_sizes'] = size_list
if top_pred['shark_n'] > 0:
top_pred['biggest_shark_size'] = np.max(shark_size_list)
else:
top_pred['biggest_shark_size'] = None
else:
top_pred['pred_above_thresh_sizes'] = None
top_pred['biggest_shark_size'] = None
return top_pred
def add_class_weights(top_pred = {}, class_weights = None):
weight_list = []
shark_weight_list = []
if top_pred['any_detection']:
for tmp_pred in top_pred['pred_above_thresh_labels']:
tmp_class_weights = class_weights[tmp_pred.lower()]
if tmp_class_weights == None:
weight_list.append(None)
continue
else:
weight_list.append(tmp_class_weights['pounds'])
if 'shark' in tmp_pred.lower():
shark_weight_list.append(np.mean(tmp_class_weights['pounds']))
top_pred['pred_above_thresh_weights'] = weight_list
if top_pred['shark_n'] > 0:
top_pred['biggest_shark_weight'] = np.max(shark_weight_list)
else:
top_pred['biggest_shark_weight'] = None
else:
top_pred['pred_above_thresh_weights'] = None
top_pred['biggest_shark_weight'] = None
return top_pred
# Sizes
def get_min_distance_shark_person(top_pred, class_sizes = None, dangerous_distance = 100):
min_dist = 99999
dist_calculated = False
# Calculate distance for every pairing of human and shark
# and accumulate the min distance
for i, tmp_shark in enumerate(top_pred['is_shark']):
for j, tmp_person in enumerate(top_pred['is_human']):
if tmp_shark == 1 and tmp_person == 1:
dist_calculated = True
#print(top_pred['pred_above_thresh_bboxes'][i])
#print(top_pred['pred_above_thresh_bboxes'][j])
tmp_dist_feed = _calculate_dist_estimate(top_pred['pred_above_thresh_bboxes'][i],
top_pred['pred_above_thresh_bboxes'][j],
[top_pred['pred_above_thresh_labels'][i], top_pred['pred_above_thresh_labels'][j]],
class_sizes,
measurement = 'feet')
#print(tmp_dist_feed)
min_dist = min(min_dist, tmp_dist_feed)
else:
pass
return {'min_dist': str(round(min_dist,1)) + ' feet' if dist_calculated else '',
'any_dist_calculated': dist_calculated,
'dangerous_dist': min_dist < dangerous_distance}
def _calculate_dist_estimate(bbox1, bbox2, labels, class_sizes = None, measurement = 'feet'):
if class_sizes[labels[0]] == None or class_sizes[labels[1]] == None:
return 9999
class_feet_size_mean = np.array([class_sizes[labels[0]][measurement][0],
class_sizes[labels[1]][measurement][0]]).mean()
box_pixel_size_mean = np.array([np.linalg.norm(bbox1[[0, 1]] - bbox1[[2, 3]]),
np.linalg.norm(bbox2[[0, 1]] - bbox2[[2, 3]])]).mean()
# Calculate the max size of the two boxes
box_center_1 = np.array([(bbox1[2] - bbox1[0])/2 + bbox1[0],
(bbox1[3] - bbox1[1])/2 + bbox1[1]])
box_center_2 = np.array([(bbox2[2] - bbox2[0])/2 + bbox2[0],
(bbox2[3] - bbox2[1])/2 + bbox2[1]])
# Return ratio distance
return np.linalg.norm(box_center_1 - box_center_2) / box_pixel_size_mean * class_feet_size_mean
# bboxes info!
# 1 x1 (left, lower pixel number)
# 2 y1 (top , lower pixel number)
# 3 x2 (right, higher pixel number)
# 4 y2 (bottom, higher pixel number)
def process_results_for_plot(predictions = None, threshold = 0.5, classes = None,
class_sizes = None, dangerous_distance = 100):
top_pred = get_top_predictions(predictions, threshold = threshold)
top_pred = add_class_labels(top_pred, class_labels = classes)
top_pred = add_class_sizes(top_pred, class_sizes = class_sizes)
top_pred = add_class_weights(top_pred, class_weights = class_sizes)
if len(top_pred['pred_above_thresh']) > 0:
min_dist = get_min_distance_shark_person(top_pred, class_sizes = class_sizes)
else:
min_dist = {'any_dist_calculated': False,
'min_dist': '',
'dangerous_dist': False}
return {'min_dist_str': min_dist['min_dist'],
'shark_sighted': top_pred['shark_n'] > 0,
'human_sighted': top_pred['human_n'] > 0,
'shark_n': top_pred['shark_n'],
'human_n': top_pred['human_n'],
'human_and_shark': (top_pred['shark_n'] > 0) and (top_pred['human_n'] > 0),
'dangerous_dist': min_dist['dangerous_dist'],
'dist_calculated': min_dist['any_dist_calculated'],
'biggest_shark_size': '' if top_pred['biggest_shark_size'] == None else str(round(top_pred['biggest_shark_size'],1)) + ' feet',
'biggest_shark_weight': '' if top_pred['biggest_shark_weight'] == None else str(round(top_pred['biggest_shark_weight'],1)) + ' pounds',
}
def prediction_dashboard(top_pred = None):
# Bullet points:
shark_sighted = 'Shark Detected: ' + str(top_pred['shark_sighted'])
human_sighted = 'Number of Humans: ' + str(top_pred['human_n'])
shark_size_estimate = 'Biggest shark size: ' + str(top_pred['biggest_shark_size'])
shark_weight_estimate = 'Biggest shark weight: ' + str(top_pred['biggest_shark_weight'])
danger_level = 'Danger Level: '
danger_level += 'High' if top_pred['dangerous_dist'] else 'Low'
danger_color = 'orangered' if top_pred['dangerous_dist'] else 'yellowgreen'
# Create a list of strings to plot
strings = [shark_sighted, human_sighted, shark_size_estimate, shark_weight_estimate, danger_level]
# Create a figure and axis
fig, ax = plt.subplots()
fig.set_facecolor((35/255,40/255,54/255))
# Hide axes
ax.axis('off')
# Position for starting to place text, starting from top
y_pos = 0.7
# Iterate through list and place each item as text on the plot
for s in strings:
if 'danger' in s.lower():
ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=danger_color)
else:
ax.text(0.05, y_pos, s, transform=ax.transAxes, fontsize=16, color=(0, 204/255, 153/255))
y_pos -= 0.1 # move down for next item
# plt.tight_layout()
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw();
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
#plt.savefig('tmp.png', format='png')
return data #plt.show()