File size: 25,937 Bytes
9a393e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 |
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to generate a list of feature maps based on image features.
Provides several feature map generators that can be used to build object
detection feature extractors.
Object detection feature extractors usually are built by stacking two components
- A base feature extractor such as Inception V3 and a feature map generator.
Feature map generators build on the base feature extractors and produce a list
of final feature maps.
"""
import collections
import functools
import tensorflow as tf
from object_detection.utils import ops
slim = tf.contrib.slim
# Activation bound used for TPU v1. Activations will be clipped to
# [-ACTIVATION_BOUND, ACTIVATION_BOUND] when training with
# use_bounded_activations enabled.
ACTIVATION_BOUND = 6.0
def get_depth_fn(depth_multiplier, min_depth):
"""Builds a callable to compute depth (output channels) of conv filters.
Args:
depth_multiplier: a multiplier for the nominal depth.
min_depth: a lower bound on the depth of filters.
Returns:
A callable that takes in a nominal depth and returns the depth to use.
"""
def multiply_depth(depth):
new_depth = int(depth * depth_multiplier)
return max(new_depth, min_depth)
return multiply_depth
class KerasMultiResolutionFeatureMaps(tf.keras.Model):
"""Generates multi resolution feature maps from input image features.
A Keras model that generates multi-scale feature maps for detection as in the
SSD papers by Liu et al: https://arxiv.org/pdf/1512.02325v2.pdf, See Sec 2.1.
More specifically, when called on inputs it performs the following two tasks:
1) If a layer name is provided in the configuration, returns that layer as a
feature map.
2) If a layer name is left as an empty string, constructs a new feature map
based on the spatial shape and depth configuration. Note that the current
implementation only supports generating new layers using convolution of
stride 2 resulting in a spatial resolution reduction by a factor of 2.
By default convolution kernel size is set to 3, and it can be customized
by caller.
An example of the configuration for Inception V3:
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
When this feature generator object is called on input image_features:
Args:
image_features: A dictionary of handles to activation tensors from the
base feature extractor.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
"""
def __init__(self,
feature_map_layout,
depth_multiplier,
min_depth,
insert_1x1_conv,
is_training,
conv_hyperparams,
freeze_batchnorm,
name=None):
"""Constructor.
Args:
feature_map_layout: Dictionary of specifications for the feature map
layouts in the following format (Inception V2/V3 respectively):
{
'from_layer': ['Mixed_3c', 'Mixed_4c', 'Mixed_5c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
or
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
If 'from_layer' is specified, the specified feature map is directly used
as a box predictor layer, and the layer_depth is directly infered from
the feature map (instead of using the provided 'layer_depth' parameter).
In this case, our convention is to set 'layer_depth' to -1 for clarity.
Otherwise, if 'from_layer' is an empty string, then the box predictor
layer will be built from the previous layer using convolution
operations. Note that the current implementation only supports
generating new layers using convolutions of stride 2 (resulting in a
spatial resolution reduction by a factor of 2), and will be extended to
a more flexible design. Convolution kernel size is set to 3 by default,
and can be customized by 'conv_kernel_size' parameter (similarily,
'conv_kernel_size' should be set to -1 if 'from_layer' is specified).
The created convolution operation will be a normal 2D convolution by
default, and a depthwise convolution followed by 1x1 convolution if
'use_depthwise' is set to True.
depth_multiplier: Depth multiplier for convolutional layers.
min_depth: Minimum depth for convolutional layers.
insert_1x1_conv: A boolean indicating whether an additional 1x1
convolution should be inserted before shrinking the feature map.
is_training: Indicates whether the feature generator is in training mode.
conv_hyperparams: A `hyperparams_builder.KerasLayerHyperparams` object
containing hyperparameters for convolution ops.
freeze_batchnorm: Bool. Whether to freeze batch norm parameters during
training or not. When training with a small batch size (e.g. 1), it is
desirable to freeze batch norm update and use pretrained batch norm
params.
name: A string name scope to assign to the model. If 'None', Keras
will auto-generate one from the class name.
"""
super(KerasMultiResolutionFeatureMaps, self).__init__(name=name)
self.feature_map_layout = feature_map_layout
self.convolutions = []
depth_fn = get_depth_fn(depth_multiplier, min_depth)
base_from_layer = ''
use_explicit_padding = False
if 'use_explicit_padding' in feature_map_layout:
use_explicit_padding = feature_map_layout['use_explicit_padding']
use_depthwise = False
if 'use_depthwise' in feature_map_layout:
use_depthwise = feature_map_layout['use_depthwise']
for index, from_layer in enumerate(feature_map_layout['from_layer']):
net = []
layer_depth = feature_map_layout['layer_depth'][index]
conv_kernel_size = 3
if 'conv_kernel_size' in feature_map_layout:
conv_kernel_size = feature_map_layout['conv_kernel_size'][index]
if from_layer:
base_from_layer = from_layer
else:
if insert_1x1_conv:
layer_name = '{}_1_Conv2d_{}_1x1_{}'.format(
base_from_layer, index, depth_fn(layer_depth / 2))
net.append(tf.keras.layers.Conv2D(depth_fn(layer_depth / 2),
[1, 1],
padding='SAME',
strides=1,
name=layer_name + '_conv',
**conv_hyperparams.params()))
net.append(
conv_hyperparams.build_batch_norm(
training=(is_training and not freeze_batchnorm),
name=layer_name + '_batchnorm'))
net.append(
conv_hyperparams.build_activation_layer(
name=layer_name))
layer_name = '{}_2_Conv2d_{}_{}x{}_s2_{}'.format(
base_from_layer, index, conv_kernel_size, conv_kernel_size,
depth_fn(layer_depth))
stride = 2
padding = 'SAME'
if use_explicit_padding:
padding = 'VALID'
# We define this function here while capturing the value of
# conv_kernel_size, to avoid holding a reference to the loop variable
# conv_kernel_size inside of a lambda function
def fixed_padding(features, kernel_size=conv_kernel_size):
return ops.fixed_padding(features, kernel_size)
net.append(tf.keras.layers.Lambda(fixed_padding))
# TODO(rathodv): Add some utilities to simplify the creation of
# Depthwise & non-depthwise convolutions w/ normalization & activations
if use_depthwise:
net.append(tf.keras.layers.DepthwiseConv2D(
[conv_kernel_size, conv_kernel_size],
depth_multiplier=1,
padding=padding,
strides=stride,
name=layer_name + '_depthwise_conv',
**conv_hyperparams.params()))
net.append(
conv_hyperparams.build_batch_norm(
training=(is_training and not freeze_batchnorm),
name=layer_name + '_depthwise_batchnorm'))
net.append(
conv_hyperparams.build_activation_layer(
name=layer_name + '_depthwise'))
net.append(tf.keras.layers.Conv2D(depth_fn(layer_depth), [1, 1],
padding='SAME',
strides=1,
name=layer_name + '_conv',
**conv_hyperparams.params()))
net.append(
conv_hyperparams.build_batch_norm(
training=(is_training and not freeze_batchnorm),
name=layer_name + '_batchnorm'))
net.append(
conv_hyperparams.build_activation_layer(
name=layer_name))
else:
net.append(tf.keras.layers.Conv2D(
depth_fn(layer_depth),
[conv_kernel_size, conv_kernel_size],
padding=padding,
strides=stride,
name=layer_name + '_conv',
**conv_hyperparams.params()))
net.append(
conv_hyperparams.build_batch_norm(
training=(is_training and not freeze_batchnorm),
name=layer_name + '_batchnorm'))
net.append(
conv_hyperparams.build_activation_layer(
name=layer_name))
# Until certain bugs are fixed in checkpointable lists,
# this net must be appended only once it's been filled with layers
self.convolutions.append(net)
def call(self, image_features):
"""Generate the multi-resolution feature maps.
Executed when calling the `.__call__` method on input.
Args:
image_features: A dictionary of handles to activation tensors from the
base feature extractor.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
"""
feature_maps = []
feature_map_keys = []
for index, from_layer in enumerate(self.feature_map_layout['from_layer']):
if from_layer:
feature_map = image_features[from_layer]
feature_map_keys.append(from_layer)
else:
feature_map = feature_maps[-1]
for layer in self.convolutions[index]:
feature_map = layer(feature_map)
layer_name = self.convolutions[index][-1].name
feature_map_keys.append(layer_name)
feature_maps.append(feature_map)
return collections.OrderedDict(
[(x, y) for (x, y) in zip(feature_map_keys, feature_maps)])
def multi_resolution_feature_maps(feature_map_layout, depth_multiplier,
min_depth, insert_1x1_conv, image_features,
pool_residual=False):
"""Generates multi resolution feature maps from input image features.
Generates multi-scale feature maps for detection as in the SSD papers by
Liu et al: https://arxiv.org/pdf/1512.02325v2.pdf, See Sec 2.1.
More specifically, it performs the following two tasks:
1) If a layer name is provided in the configuration, returns that layer as a
feature map.
2) If a layer name is left as an empty string, constructs a new feature map
based on the spatial shape and depth configuration. Note that the current
implementation only supports generating new layers using convolution of
stride 2 resulting in a spatial resolution reduction by a factor of 2.
By default convolution kernel size is set to 3, and it can be customized
by caller.
An example of the configuration for Inception V3:
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
Args:
feature_map_layout: Dictionary of specifications for the feature map
layouts in the following format (Inception V2/V3 respectively):
{
'from_layer': ['Mixed_3c', 'Mixed_4c', 'Mixed_5c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
or
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128]
}
If 'from_layer' is specified, the specified feature map is directly used
as a box predictor layer, and the layer_depth is directly infered from the
feature map (instead of using the provided 'layer_depth' parameter). In
this case, our convention is to set 'layer_depth' to -1 for clarity.
Otherwise, if 'from_layer' is an empty string, then the box predictor
layer will be built from the previous layer using convolution operations.
Note that the current implementation only supports generating new layers
using convolutions of stride 2 (resulting in a spatial resolution
reduction by a factor of 2), and will be extended to a more flexible
design. Convolution kernel size is set to 3 by default, and can be
customized by 'conv_kernel_size' parameter (similarily, 'conv_kernel_size'
should be set to -1 if 'from_layer' is specified). The created convolution
operation will be a normal 2D convolution by default, and a depthwise
convolution followed by 1x1 convolution if 'use_depthwise' is set to True.
depth_multiplier: Depth multiplier for convolutional layers.
min_depth: Minimum depth for convolutional layers.
insert_1x1_conv: A boolean indicating whether an additional 1x1 convolution
should be inserted before shrinking the feature map.
image_features: A dictionary of handles to activation tensors from the
base feature extractor.
pool_residual: Whether to add an average pooling layer followed by a
residual connection between subsequent feature maps when the channel
depth match. For example, with option 'layer_depth': [-1, 512, 256, 256],
a pooling and residual layer is added between the third and forth feature
map. This option is better used with Weight Shared Convolution Box
Predictor when all feature maps have the same channel depth to encourage
more consistent features across multi-scale feature maps.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
Raises:
ValueError: if the number entries in 'from_layer' and
'layer_depth' do not match.
ValueError: if the generated layer does not have the same resolution
as specified.
"""
depth_fn = get_depth_fn(depth_multiplier, min_depth)
feature_map_keys = []
feature_maps = []
base_from_layer = ''
use_explicit_padding = False
if 'use_explicit_padding' in feature_map_layout:
use_explicit_padding = feature_map_layout['use_explicit_padding']
use_depthwise = False
if 'use_depthwise' in feature_map_layout:
use_depthwise = feature_map_layout['use_depthwise']
for index, from_layer in enumerate(feature_map_layout['from_layer']):
layer_depth = feature_map_layout['layer_depth'][index]
conv_kernel_size = 3
if 'conv_kernel_size' in feature_map_layout:
conv_kernel_size = feature_map_layout['conv_kernel_size'][index]
if from_layer:
feature_map = image_features[from_layer]
base_from_layer = from_layer
feature_map_keys.append(from_layer)
else:
pre_layer = feature_maps[-1]
pre_layer_depth = pre_layer.get_shape().as_list()[3]
intermediate_layer = pre_layer
if insert_1x1_conv:
layer_name = '{}_1_Conv2d_{}_1x1_{}'.format(
base_from_layer, index, depth_fn(layer_depth / 2))
intermediate_layer = slim.conv2d(
pre_layer,
depth_fn(layer_depth / 2), [1, 1],
padding='SAME',
stride=1,
scope=layer_name)
layer_name = '{}_2_Conv2d_{}_{}x{}_s2_{}'.format(
base_from_layer, index, conv_kernel_size, conv_kernel_size,
depth_fn(layer_depth))
stride = 2
padding = 'SAME'
if use_explicit_padding:
padding = 'VALID'
intermediate_layer = ops.fixed_padding(
intermediate_layer, conv_kernel_size)
if use_depthwise:
feature_map = slim.separable_conv2d(
intermediate_layer,
None, [conv_kernel_size, conv_kernel_size],
depth_multiplier=1,
padding=padding,
stride=stride,
scope=layer_name + '_depthwise')
feature_map = slim.conv2d(
feature_map,
depth_fn(layer_depth), [1, 1],
padding='SAME',
stride=1,
scope=layer_name)
if pool_residual and pre_layer_depth == depth_fn(layer_depth):
feature_map += slim.avg_pool2d(
pre_layer, [3, 3],
padding='SAME',
stride=2,
scope=layer_name + '_pool')
else:
feature_map = slim.conv2d(
intermediate_layer,
depth_fn(layer_depth), [conv_kernel_size, conv_kernel_size],
padding=padding,
stride=stride,
scope=layer_name)
feature_map_keys.append(layer_name)
feature_maps.append(feature_map)
return collections.OrderedDict(
[(x, y) for (x, y) in zip(feature_map_keys, feature_maps)])
def fpn_top_down_feature_maps(image_features,
depth,
use_depthwise=False,
use_explicit_padding=False,
use_bounded_activations=False,
scope=None,
use_native_resize_op=False):
"""Generates `top-down` feature maps for Feature Pyramid Networks.
See https://arxiv.org/abs/1612.03144 for details.
Args:
image_features: list of tuples of (tensor_name, image_feature_tensor).
Spatial resolutions of succesive tensors must reduce exactly by a factor
of 2.
depth: depth of output feature maps.
use_depthwise: whether to use depthwise separable conv instead of regular
conv.
use_explicit_padding: whether to use explicit padding.
use_bounded_activations: Whether or not to clip activations to range
[-ACTIVATION_BOUND, ACTIVATION_BOUND]. Bounded activations better lend
themselves to quantized inference.
scope: A scope name to wrap this op under.
use_native_resize_op: If True, uses tf.image.resize_nearest_neighbor op for
the upsampling process instead of reshape and broadcasting implementation.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
"""
with tf.name_scope(scope, 'top_down'):
num_levels = len(image_features)
output_feature_maps_list = []
output_feature_map_keys = []
padding = 'VALID' if use_explicit_padding else 'SAME'
kernel_size = 3
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d], padding=padding, stride=1):
top_down = slim.conv2d(
image_features[-1][1],
depth, [1, 1], activation_fn=None, normalizer_fn=None,
scope='projection_%d' % num_levels)
if use_bounded_activations:
top_down = tf.clip_by_value(top_down, -ACTIVATION_BOUND,
ACTIVATION_BOUND)
output_feature_maps_list.append(top_down)
output_feature_map_keys.append(
'top_down_%s' % image_features[-1][0])
for level in reversed(range(num_levels - 1)):
if use_native_resize_op:
with tf.name_scope('nearest_neighbor_upsampling'):
top_down_shape = top_down.shape.as_list()
top_down = tf.image.resize_nearest_neighbor(
top_down, [top_down_shape[1] * 2, top_down_shape[2] * 2])
else:
top_down = ops.nearest_neighbor_upsampling(top_down, scale=2)
residual = slim.conv2d(
image_features[level][1], depth, [1, 1],
activation_fn=None, normalizer_fn=None,
scope='projection_%d' % (level + 1))
if use_bounded_activations:
residual = tf.clip_by_value(residual, -ACTIVATION_BOUND,
ACTIVATION_BOUND)
if use_explicit_padding:
# slice top_down to the same shape as residual
residual_shape = tf.shape(residual)
top_down = top_down[:, :residual_shape[1], :residual_shape[2], :]
top_down += residual
if use_bounded_activations:
top_down = tf.clip_by_value(top_down, -ACTIVATION_BOUND,
ACTIVATION_BOUND)
if use_depthwise:
conv_op = functools.partial(slim.separable_conv2d, depth_multiplier=1)
else:
conv_op = slim.conv2d
if use_explicit_padding:
top_down = ops.fixed_padding(top_down, kernel_size)
output_feature_maps_list.append(conv_op(
top_down,
depth, [kernel_size, kernel_size],
scope='smoothing_%d' % (level + 1)))
output_feature_map_keys.append('top_down_%s' % image_features[level][0])
return collections.OrderedDict(reversed(
list(zip(output_feature_map_keys, output_feature_maps_list))))
def pooling_pyramid_feature_maps(base_feature_map_depth, num_layers,
image_features, replace_pool_with_conv=False):
"""Generates pooling pyramid feature maps.
The pooling pyramid feature maps is motivated by
multi_resolution_feature_maps. The main difference are that it is simpler and
reduces the number of free parameters.
More specifically:
- Instead of using convolutions to shrink the feature map, it uses max
pooling, therefore totally gets rid of the parameters in convolution.
- By pooling feature from larger map up to a single cell, it generates
features in the same feature space.
- Instead of independently making box predictions from individual maps, it
shares the same classifier across different feature maps, therefore reduces
the "mis-calibration" across different scales.
See go/ppn-detection for more details.
Args:
base_feature_map_depth: Depth of the base feature before the max pooling.
num_layers: Number of layers used to make predictions. They are pooled
from the base feature.
image_features: A dictionary of handles to activation tensors from the
feature extractor.
replace_pool_with_conv: Whether or not to replace pooling operations with
convolutions in the PPN. Default is False.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
Raises:
ValueError: image_features does not contain exactly one entry
"""
if len(image_features) != 1:
raise ValueError('image_features should be a dictionary of length 1.')
image_features = image_features[image_features.keys()[0]]
feature_map_keys = []
feature_maps = []
feature_map_key = 'Base_Conv2d_1x1_%d' % base_feature_map_depth
if base_feature_map_depth > 0:
image_features = slim.conv2d(
image_features,
base_feature_map_depth,
[1, 1], # kernel size
padding='SAME', stride=1, scope=feature_map_key)
# Add a 1x1 max-pooling node (a no op node) immediately after the conv2d for
# TPU v1 compatibility. Without the following dummy op, TPU runtime
# compiler will combine the convolution with one max-pooling below into a
# single cycle, so getting the conv2d feature becomes impossible.
image_features = slim.max_pool2d(
image_features, [1, 1], padding='SAME', stride=1, scope=feature_map_key)
feature_map_keys.append(feature_map_key)
feature_maps.append(image_features)
feature_map = image_features
if replace_pool_with_conv:
with slim.arg_scope([slim.conv2d], padding='SAME', stride=2):
for i in range(num_layers - 1):
feature_map_key = 'Conv2d_{}_3x3_s2_{}'.format(i,
base_feature_map_depth)
feature_map = slim.conv2d(
feature_map, base_feature_map_depth, [3, 3], scope=feature_map_key)
feature_map_keys.append(feature_map_key)
feature_maps.append(feature_map)
else:
with slim.arg_scope([slim.max_pool2d], padding='SAME', stride=2):
for i in range(num_layers - 1):
feature_map_key = 'MaxPool2d_%d_2x2' % i
feature_map = slim.max_pool2d(
feature_map, [2, 2], padding='SAME', scope=feature_map_key)
feature_map_keys.append(feature_map_key)
feature_maps.append(feature_map)
return collections.OrderedDict(
[(x, y) for (x, y) in zip(feature_map_keys, feature_maps)])
|