DR-App / object_detection /core /anchor_generator.py
pat229988's picture
Upload 653 files
9a393e2
raw
history blame contribute delete
No virus
5.71 kB
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base anchor generator.
The job of the anchor generator is to create (or load) a collection
of bounding boxes to be used as anchors.
Generated anchors are assumed to match some convolutional grid or list of grid
shapes. For example, we might want to generate anchors matching an 8x8
feature map and a 4x4 feature map. If we place 3 anchors per grid location
on the first feature map and 6 anchors per grid location on the second feature
map, then 3*8*8 + 6*4*4 = 288 anchors are generated in total.
To support fully convolutional settings, feature map shapes are passed
dynamically at generation time. The number of anchors to place at each location
is static --- implementations of AnchorGenerator must always be able return
the number of anchors that it uses per location for each feature map.
"""
from abc import ABCMeta
from abc import abstractmethod
import tensorflow as tf
class AnchorGenerator(object):
"""Abstract base class for anchor generators."""
__metaclass__ = ABCMeta
@abstractmethod
def name_scope(self):
"""Name scope.
Must be defined by implementations.
Returns:
a string representing the name scope of the anchor generation operation.
"""
pass
@property
def check_num_anchors(self):
"""Whether to dynamically check the number of anchors generated.
Can be overridden by implementations that would like to disable this
behavior.
Returns:
a boolean controlling whether the Generate function should dynamically
check the number of anchors generated against the mathematically
expected number of anchors.
"""
return True
@abstractmethod
def num_anchors_per_location(self):
"""Returns the number of anchors per spatial location.
Returns:
a list of integers, one for each expected feature map to be passed to
the `generate` function.
"""
pass
def generate(self, feature_map_shape_list, **params):
"""Generates a collection of bounding boxes to be used as anchors.
TODO(rathodv): remove **params from argument list and make stride and
offsets (for multiple_grid_anchor_generator) constructor arguments.
Args:
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with. Pairs can be provided as 1-dimensional
integer tensors of length 2 or simply as tuples of integers.
**params: parameters for anchor generation op
Returns:
boxes_list: a list of BoxLists each holding anchor boxes corresponding to
the input feature map shapes.
Raises:
ValueError: if the number of feature map shapes does not match the length
of NumAnchorsPerLocation.
"""
if self.check_num_anchors and (
len(feature_map_shape_list) != len(self.num_anchors_per_location())):
raise ValueError('Number of feature maps is expected to equal the length '
'of `num_anchors_per_location`.')
with tf.name_scope(self.name_scope()):
anchors_list = self._generate(feature_map_shape_list, **params)
if self.check_num_anchors:
with tf.control_dependencies([
self._assert_correct_number_of_anchors(
anchors_list, feature_map_shape_list)]):
for item in anchors_list:
item.set(tf.identity(item.get()))
return anchors_list
@abstractmethod
def _generate(self, feature_map_shape_list, **params):
"""To be overridden by implementations.
Args:
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with.
**params: parameters for anchor generation op
Returns:
boxes_list: a list of BoxList, each holding a collection of N anchor
boxes.
"""
pass
def _assert_correct_number_of_anchors(self, anchors_list,
feature_map_shape_list):
"""Assert that correct number of anchors was generated.
Args:
anchors_list: A list of box_list.BoxList object holding anchors generated.
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with.
Returns:
Op that raises InvalidArgumentError if the number of anchors does not
match the number of expected anchors.
"""
expected_num_anchors = 0
actual_num_anchors = 0
for num_anchors_per_location, feature_map_shape, anchors in zip(
self.num_anchors_per_location(), feature_map_shape_list, anchors_list):
expected_num_anchors += (num_anchors_per_location
* feature_map_shape[0]
* feature_map_shape[1])
actual_num_anchors += anchors.num_boxes()
return tf.assert_equal(expected_num_anchors, actual_num_anchors)