|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Builds the Shake-Shake Model.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import custom_ops as ops |
|
import tensorflow as tf |
|
|
|
|
|
def _shake_shake_skip_connection(x, output_filters, stride): |
|
"""Adds a residual connection to the filter x for the shake-shake model.""" |
|
curr_filters = int(x.shape[3]) |
|
if curr_filters == output_filters: |
|
return x |
|
stride_spec = ops.stride_arr(stride, stride) |
|
|
|
path1 = tf.nn.avg_pool( |
|
x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') |
|
path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv') |
|
|
|
|
|
|
|
pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] |
|
path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :] |
|
concat_axis = 3 |
|
|
|
path2 = tf.nn.avg_pool( |
|
path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC') |
|
path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv') |
|
|
|
|
|
final_path = tf.concat(values=[path1, path2], axis=concat_axis) |
|
final_path = ops.batch_norm(final_path, scope='final_path_bn') |
|
return final_path |
|
|
|
|
|
def _shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward, |
|
is_training): |
|
"""Building a 2 branching convnet.""" |
|
x = tf.nn.relu(x) |
|
x = ops.conv2d(x, output_filters, 3, stride=stride, scope='conv1') |
|
x = ops.batch_norm(x, scope='bn1') |
|
x = tf.nn.relu(x) |
|
x = ops.conv2d(x, output_filters, 3, scope='conv2') |
|
x = ops.batch_norm(x, scope='bn2') |
|
if is_training: |
|
x = x * rand_backward + tf.stop_gradient(x * rand_forward - |
|
x * rand_backward) |
|
else: |
|
x *= 1.0 / 2 |
|
return x |
|
|
|
|
|
def _shake_shake_block(x, output_filters, stride, is_training): |
|
"""Builds a full shake-shake sub layer.""" |
|
batch_size = tf.shape(x)[0] |
|
|
|
|
|
rand_forward = [ |
|
tf.random_uniform( |
|
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) |
|
for _ in range(2) |
|
] |
|
rand_backward = [ |
|
tf.random_uniform( |
|
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32) |
|
for _ in range(2) |
|
] |
|
|
|
total_forward = tf.add_n(rand_forward) |
|
total_backward = tf.add_n(rand_backward) |
|
rand_forward = [samp / total_forward for samp in rand_forward] |
|
rand_backward = [samp / total_backward for samp in rand_backward] |
|
zipped_rand = zip(rand_forward, rand_backward) |
|
|
|
branches = [] |
|
for branch, (r_forward, r_backward) in enumerate(zipped_rand): |
|
with tf.variable_scope('branch_{}'.format(branch)): |
|
b = _shake_shake_branch(x, output_filters, stride, r_forward, r_backward, |
|
is_training) |
|
branches.append(b) |
|
res = _shake_shake_skip_connection(x, output_filters, stride) |
|
return res + tf.add_n(branches) |
|
|
|
|
|
def _shake_shake_layer(x, output_filters, num_blocks, stride, |
|
is_training): |
|
"""Builds many sub layers into one full layer.""" |
|
for block_num in range(num_blocks): |
|
curr_stride = stride if (block_num == 0) else 1 |
|
with tf.variable_scope('layer_{}'.format(block_num)): |
|
x = _shake_shake_block(x, output_filters, curr_stride, |
|
is_training) |
|
return x |
|
|
|
|
|
def build_shake_shake_model(images, num_classes, hparams, is_training): |
|
"""Builds the Shake-Shake model. |
|
|
|
Build the Shake-Shake model from https://arxiv.org/abs/1705.07485. |
|
|
|
Args: |
|
images: Tensor of images that will be fed into the Wide ResNet Model. |
|
num_classes: Number of classed that the model needs to predict. |
|
hparams: tf.HParams object that contains additional hparams needed to |
|
construct the model. In this case it is the `shake_shake_widen_factor` |
|
that is used to determine how many filters the model has. |
|
is_training: Is the model training or not. |
|
|
|
Returns: |
|
The logits of the Shake-Shake model. |
|
""" |
|
depth = 26 |
|
k = hparams.shake_shake_widen_factor |
|
n = int((depth - 2) / 6) |
|
x = images |
|
|
|
x = ops.conv2d(x, 16, 3, scope='init_conv') |
|
x = ops.batch_norm(x, scope='init_bn') |
|
with tf.variable_scope('L1'): |
|
x = _shake_shake_layer(x, 16 * k, n, 1, is_training) |
|
with tf.variable_scope('L2'): |
|
x = _shake_shake_layer(x, 32 * k, n, 2, is_training) |
|
with tf.variable_scope('L3'): |
|
x = _shake_shake_layer(x, 64 * k, n, 2, is_training) |
|
x = tf.nn.relu(x) |
|
x = ops.global_avg_pool(x) |
|
|
|
|
|
logits = ops.fc(x, num_classes) |
|
return logits |
|
|