|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Detection model evaluator. |
|
|
|
This file provides a generic evaluation method that can be used to evaluate a |
|
DetectionModel. |
|
""" |
|
|
|
import logging |
|
import tensorflow as tf |
|
|
|
from object_detection import eval_util |
|
from object_detection.core import prefetcher |
|
from object_detection.core import standard_fields as fields |
|
from object_detection.metrics import coco_evaluation |
|
from object_detection.utils import object_detection_evaluation |
|
|
|
|
|
|
|
|
|
EVAL_METRICS_CLASS_DICT = { |
|
'pascal_voc_detection_metrics': |
|
object_detection_evaluation.PascalDetectionEvaluator, |
|
'weighted_pascal_voc_detection_metrics': |
|
object_detection_evaluation.WeightedPascalDetectionEvaluator, |
|
'pascal_voc_instance_segmentation_metrics': |
|
object_detection_evaluation.PascalInstanceSegmentationEvaluator, |
|
'weighted_pascal_voc_instance_segmentation_metrics': |
|
object_detection_evaluation.WeightedPascalInstanceSegmentationEvaluator, |
|
'oid_V2_detection_metrics': |
|
object_detection_evaluation.OpenImagesDetectionEvaluator, |
|
|
|
'open_images_V2_detection_metrics': |
|
object_detection_evaluation.OpenImagesDetectionEvaluator, |
|
'coco_detection_metrics': |
|
coco_evaluation.CocoDetectionEvaluator, |
|
'coco_mask_metrics': |
|
coco_evaluation.CocoMaskEvaluator, |
|
'oid_challenge_detection_metrics': |
|
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, |
|
|
|
'oid_challenge_object_detection_metrics': |
|
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, |
|
} |
|
|
|
EVAL_DEFAULT_METRIC = 'pascal_voc_detection_metrics' |
|
|
|
|
|
def _extract_predictions_and_losses(model, |
|
create_input_dict_fn, |
|
ignore_groundtruth=False): |
|
"""Constructs tensorflow detection graph and returns output tensors. |
|
|
|
Args: |
|
model: model to perform predictions with. |
|
create_input_dict_fn: function to create input tensor dictionaries. |
|
ignore_groundtruth: whether groundtruth should be ignored. |
|
|
|
Returns: |
|
prediction_groundtruth_dict: A dictionary with postprocessed tensors (keyed |
|
by standard_fields.DetectionResultsFields) and optional groundtruth |
|
tensors (keyed by standard_fields.InputDataFields). |
|
losses_dict: A dictionary containing detection losses. This is empty when |
|
ignore_groundtruth is true. |
|
""" |
|
input_dict = create_input_dict_fn() |
|
prefetch_queue = prefetcher.prefetch(input_dict, capacity=500) |
|
input_dict = prefetch_queue.dequeue() |
|
original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0) |
|
preprocessed_image, true_image_shapes = model.preprocess( |
|
tf.to_float(original_image)) |
|
prediction_dict = model.predict(preprocessed_image, true_image_shapes) |
|
detections = model.postprocess(prediction_dict, true_image_shapes) |
|
|
|
groundtruth = None |
|
losses_dict = {} |
|
if not ignore_groundtruth: |
|
groundtruth = { |
|
fields.InputDataFields.groundtruth_boxes: |
|
input_dict[fields.InputDataFields.groundtruth_boxes], |
|
fields.InputDataFields.groundtruth_classes: |
|
input_dict[fields.InputDataFields.groundtruth_classes], |
|
fields.InputDataFields.groundtruth_area: |
|
input_dict[fields.InputDataFields.groundtruth_area], |
|
fields.InputDataFields.groundtruth_is_crowd: |
|
input_dict[fields.InputDataFields.groundtruth_is_crowd], |
|
fields.InputDataFields.groundtruth_difficult: |
|
input_dict[fields.InputDataFields.groundtruth_difficult] |
|
} |
|
if fields.InputDataFields.groundtruth_group_of in input_dict: |
|
groundtruth[fields.InputDataFields.groundtruth_group_of] = ( |
|
input_dict[fields.InputDataFields.groundtruth_group_of]) |
|
groundtruth_masks_list = None |
|
if fields.DetectionResultFields.detection_masks in detections: |
|
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = ( |
|
input_dict[fields.InputDataFields.groundtruth_instance_masks]) |
|
groundtruth_masks_list = [ |
|
input_dict[fields.InputDataFields.groundtruth_instance_masks]] |
|
groundtruth_keypoints_list = None |
|
if fields.DetectionResultFields.detection_keypoints in detections: |
|
groundtruth[fields.InputDataFields.groundtruth_keypoints] = ( |
|
input_dict[fields.InputDataFields.groundtruth_keypoints]) |
|
groundtruth_keypoints_list = [ |
|
input_dict[fields.InputDataFields.groundtruth_keypoints]] |
|
label_id_offset = 1 |
|
model.provide_groundtruth( |
|
[input_dict[fields.InputDataFields.groundtruth_boxes]], |
|
[tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes] |
|
- label_id_offset, depth=model.num_classes)], |
|
groundtruth_masks_list, groundtruth_keypoints_list) |
|
losses_dict.update(model.loss(prediction_dict, true_image_shapes)) |
|
|
|
result_dict = eval_util.result_dict_for_single_example( |
|
original_image, |
|
input_dict[fields.InputDataFields.source_id], |
|
detections, |
|
groundtruth, |
|
class_agnostic=( |
|
fields.DetectionResultFields.detection_classes not in detections), |
|
scale_to_absolute=True) |
|
return result_dict, losses_dict |
|
|
|
|
|
def get_evaluators(eval_config, categories): |
|
"""Returns the evaluator class according to eval_config, valid for categories. |
|
|
|
Args: |
|
eval_config: evaluation configurations. |
|
categories: a list of categories to evaluate. |
|
Returns: |
|
An list of instances of DetectionEvaluator. |
|
|
|
Raises: |
|
ValueError: if metric is not in the metric class dictionary. |
|
""" |
|
eval_metric_fn_keys = eval_config.metrics_set |
|
if not eval_metric_fn_keys: |
|
eval_metric_fn_keys = [EVAL_DEFAULT_METRIC] |
|
evaluators_list = [] |
|
for eval_metric_fn_key in eval_metric_fn_keys: |
|
if eval_metric_fn_key not in EVAL_METRICS_CLASS_DICT: |
|
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key)) |
|
if eval_metric_fn_key == 'oid_challenge_object_detection_metrics': |
|
logging.warning( |
|
'oid_challenge_object_detection_metrics is deprecated; ' |
|
'use oid_challenge_detection_metrics instead' |
|
) |
|
if eval_metric_fn_key == 'oid_V2_detection_metrics': |
|
logging.warning( |
|
'open_images_V2_detection_metrics is deprecated; ' |
|
'use oid_V2_detection_metrics instead' |
|
) |
|
evaluators_list.append( |
|
EVAL_METRICS_CLASS_DICT[eval_metric_fn_key](categories=categories)) |
|
return evaluators_list |
|
|
|
|
|
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, |
|
checkpoint_dir, eval_dir, graph_hook_fn=None, evaluator_list=None): |
|
"""Evaluation function for detection models. |
|
|
|
Args: |
|
create_input_dict_fn: a function to create a tensor input dictionary. |
|
create_model_fn: a function that creates a DetectionModel. |
|
eval_config: a eval_pb2.EvalConfig protobuf. |
|
categories: a list of category dictionaries. Each dict in the list should |
|
have an integer 'id' field and string 'name' field. |
|
checkpoint_dir: directory to load the checkpoints to evaluate from. |
|
eval_dir: directory to write evaluation metrics summary to. |
|
graph_hook_fn: Optional function that is called after the training graph is |
|
completely built. This is helpful to perform additional changes to the |
|
training graph such as optimizing batchnorm. The function should modify |
|
the default graph. |
|
evaluator_list: Optional list of instances of DetectionEvaluator. If not |
|
given, this list of metrics is created according to the eval_config. |
|
|
|
Returns: |
|
metrics: A dictionary containing metric names and values from the latest |
|
run. |
|
""" |
|
|
|
model = create_model_fn() |
|
|
|
if eval_config.ignore_groundtruth and not eval_config.export_path: |
|
logging.fatal('If ignore_groundtruth=True then an export_path is ' |
|
'required. Aborting!!!') |
|
|
|
tensor_dict, losses_dict = _extract_predictions_and_losses( |
|
model=model, |
|
create_input_dict_fn=create_input_dict_fn, |
|
ignore_groundtruth=eval_config.ignore_groundtruth) |
|
|
|
def _process_batch(tensor_dict, sess, batch_index, counters, |
|
losses_dict=None): |
|
"""Evaluates tensors in tensor_dict, losses_dict and visualizes examples. |
|
|
|
This function calls sess.run on tensor_dict, evaluating the original_image |
|
tensor only on the first K examples and visualizing detections overlaid |
|
on this original_image. |
|
|
|
Args: |
|
tensor_dict: a dictionary of tensors |
|
sess: tensorflow session |
|
batch_index: the index of the batch amongst all batches in the run. |
|
counters: a dictionary holding 'success' and 'skipped' fields which can |
|
be updated to keep track of number of successful and failed runs, |
|
respectively. If these fields are not updated, then the success/skipped |
|
counter values shown at the end of evaluation will be incorrect. |
|
losses_dict: Optional dictonary of scalar loss tensors. |
|
|
|
Returns: |
|
result_dict: a dictionary of numpy arrays |
|
result_losses_dict: a dictionary of scalar losses. This is empty if input |
|
losses_dict is None. |
|
""" |
|
try: |
|
if not losses_dict: |
|
losses_dict = {} |
|
result_dict, result_losses_dict = sess.run([tensor_dict, losses_dict]) |
|
counters['success'] += 1 |
|
except tf.errors.InvalidArgumentError: |
|
logging.info('Skipping image') |
|
counters['skipped'] += 1 |
|
return {}, {} |
|
global_step = tf.train.global_step(sess, tf.train.get_global_step()) |
|
if batch_index < eval_config.num_visualizations: |
|
tag = 'image-{}'.format(batch_index) |
|
eval_util.visualize_detection_results( |
|
result_dict, |
|
tag, |
|
global_step, |
|
categories=categories, |
|
summary_dir=eval_dir, |
|
export_dir=eval_config.visualization_export_dir, |
|
show_groundtruth=eval_config.visualize_groundtruth_boxes, |
|
groundtruth_box_visualization_color=eval_config. |
|
groundtruth_box_visualization_color, |
|
min_score_thresh=eval_config.min_score_threshold, |
|
max_num_predictions=eval_config.max_num_boxes_to_visualize, |
|
skip_scores=eval_config.skip_scores, |
|
skip_labels=eval_config.skip_labels, |
|
keep_image_id_for_visualization_export=eval_config. |
|
keep_image_id_for_visualization_export) |
|
return result_dict, result_losses_dict |
|
|
|
if graph_hook_fn: graph_hook_fn() |
|
|
|
variables_to_restore = tf.global_variables() |
|
global_step = tf.train.get_or_create_global_step() |
|
variables_to_restore.append(global_step) |
|
|
|
if eval_config.use_moving_averages: |
|
variable_averages = tf.train.ExponentialMovingAverage(0.0) |
|
variables_to_restore = variable_averages.variables_to_restore() |
|
saver = tf.train.Saver(variables_to_restore) |
|
|
|
def _restore_latest_checkpoint(sess): |
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) |
|
saver.restore(sess, latest_checkpoint) |
|
|
|
if not evaluator_list: |
|
evaluator_list = get_evaluators(eval_config, categories) |
|
|
|
metrics = eval_util.repeated_checkpoint_run( |
|
tensor_dict=tensor_dict, |
|
summary_dir=eval_dir, |
|
evaluators=evaluator_list, |
|
batch_processor=_process_batch, |
|
checkpoint_dirs=[checkpoint_dir], |
|
variables_to_restore=None, |
|
restore_fn=_restore_latest_checkpoint, |
|
num_batches=eval_config.num_examples, |
|
eval_interval_secs=eval_config.eval_interval_secs, |
|
max_number_of_evaluations=(1 if eval_config.ignore_groundtruth else |
|
eval_config.max_evals |
|
if eval_config.max_evals else None), |
|
master=eval_config.eval_master, |
|
save_graph=eval_config.save_graph, |
|
save_graph_dir=(eval_dir if eval_config.save_graph else ''), |
|
losses_dict=losses_dict, |
|
eval_export_path=eval_config.export_path) |
|
|
|
return metrics |
|
|