import logging, os logging.disable(logging.WARNING) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf from basic_ops import * from resnet_module import * """This script generates the U-Net architecture according to conf_unet. """ class UNet(object): def __init__(self, conf_unet): self.depth = conf_unet['depth'] self.dimension = conf_unet['dimension'] self.first_output_filters = conf_unet['first_output_filters'] self.encoding_block_sizes = conf_unet['encoding_block_sizes'] self.downsampling = conf_unet['downsampling'] self.decoding_block_sizes = conf_unet['decoding_block_sizes'] self.skip_method = conf_unet['skip_method'] def __call__(self, inputs, training): """Add operations to classify a batch of input images. Args: inputs: A Tensor representing a batch of input images. training: A boolean. Set to True to add operations required only when training the classifier. Returns: A logits Tensor with shape [, self.num_classes]. """ return self._build_network(inputs, training) ################################################################################ # Composite blocks building the network ################################################################################ def _build_network(self, inputs, training): # first_convolution if self.dimension == '2D': convolution = convolution_2D elif self.dimension == '3D': convolution = convolution_3D inputs = convolution(inputs, self.first_output_filters, 3, 1, False, 'first_convolution') # encoding_block_1 with tf.variable_scope('encoding_block_1'): for block_index in range(0, self.encoding_block_sizes[0]): inputs = res_block(inputs, self.first_output_filters, training, self.dimension, 'res_%d' % block_index) # encoding_block_i (down) = downsampling + zero or more res_block, i = 2, 3, ..., depth skip_inputs = [] # for identity skip connections for i in range(2, self.depth+1): skip_inputs.append(inputs) with tf.variable_scope('encoding_block_%d' % i): output_filters = self.first_output_filters * (2**(i-1)) # downsampling downsampling_func = self._get_downsampling_function(self.downsampling[i-2]) inputs = downsampling_func(inputs, output_filters, training, self.dimension, 'downsampling') for block_index in range(0, self.encoding_block_sizes[i-1]): inputs = res_block(inputs, output_filters, training, self.dimension, 'res_%d' % block_index) # bottom_block = a combination of same_gto and res_block with tf.variable_scope('bottom_block'): output_filters = self.first_output_filters * (2**(self.depth-1)) for block_index in range(0, 1): current_func = res_block inputs = current_func(inputs, output_filters, training, self.dimension, 'block_%d' % block_index) """ Note: Identity skip connections are between the output of encoding_block_i and the output of upsampling in decoding_block_i, i = 1, 2, ..., depth-1. skip_inputs[i] is the output of encoding_block_i now. len(skip_inputs) == depth - 1 skip_inputs[depth-2] should be combined during decoding_block_depth-1 skip_inputs[0] should be combined during decoding_block_1 """ # decoding_block_j (up) = upsampling + zero or more res_block, j = depth-1, depth-2, ..., 1 for j in range(self.depth-1, 0, -1): with tf.variable_scope('decoding_block_%d' % j): output_filters = self.first_output_filters * (2**(j-1)) # upsampling upsampling_func = up_transposed_convolution inputs = upsampling_func(inputs, output_filters, training, self.dimension, 'upsampling') # combine with skip connections if self.skip_method == 'add': inputs = tf.add(inputs, skip_inputs[j-1]) elif self.skip_method == 'concat': inputs = tf.concat([inputs, skip_inputs[j-1]], axis=-1) for block_index in range(0, self.decoding_block_sizes[self.depth-1-j]): inputs = res_block(inputs, output_filters, training, self.dimension, 'res_%d' % block_index) return inputs def _get_downsampling_function(self, name): if name == 'down_res_block': return down_res_block elif name == 'convolution': return down_convolution else: raise ValueError("Unsupported function: %s." % (name))