File size: 5,038 Bytes
9bd9a8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
from basic_ops import *
from resnet_module import *


"""This script generates the U-Net architecture according to conf_unet.
"""

class UNet(object):
    def __init__(self, conf_unet):
        self.depth = conf_unet['depth']
        self.dimension = conf_unet['dimension']
        self.first_output_filters = conf_unet['first_output_filters']
        self.encoding_block_sizes = conf_unet['encoding_block_sizes']
        self.downsampling = conf_unet['downsampling']
        self.decoding_block_sizes = conf_unet['decoding_block_sizes']
        self.skip_method = conf_unet['skip_method']

    def __call__(self, inputs, training):
        """Add operations to classify a batch of input images.

        Args:
            inputs: A Tensor representing a batch of input images.
            training: A boolean. Set to True to add operations required only when
                training the classifier.

        Returns:
            A logits Tensor with shape [<batch_size>, self.num_classes].
        """

        return self._build_network(inputs, training)


    ################################################################################
    # Composite blocks building the network
    ################################################################################
    def _build_network(self, inputs, training):
        # first_convolution
        if self.dimension == '2D':
            convolution = convolution_2D
        elif self.dimension == '3D':
            convolution = convolution_3D
        inputs = convolution(inputs, self.first_output_filters, 3, 1, False, 'first_convolution')

        # encoding_block_1
        with tf.variable_scope('encoding_block_1'):
            for block_index in range(0, self.encoding_block_sizes[0]):
                inputs = res_block(inputs, self.first_output_filters, training, self.dimension,
                            'res_%d' % block_index)

        # encoding_block_i (down) = downsampling + zero or more res_block, i = 2, 3, ..., depth
        skip_inputs = [] # for identity skip connections
        for i in range(2, self.depth+1):
            skip_inputs.append(inputs)
            with tf.variable_scope('encoding_block_%d' % i):
                output_filters = self.first_output_filters * (2**(i-1))

                # downsampling
                downsampling_func = self._get_downsampling_function(self.downsampling[i-2])
                inputs = downsampling_func(inputs, output_filters, training, self.dimension,
                            'downsampling')

                for block_index in range(0, self.encoding_block_sizes[i-1]):
                    inputs = res_block(inputs, output_filters, training, self.dimension,
                                'res_%d' % block_index)

        # bottom_block = a combination of same_gto and res_block
        with tf.variable_scope('bottom_block'):
            output_filters = self.first_output_filters * (2**(self.depth-1))
            for block_index in range(0, 1):
                current_func = res_block
                inputs = current_func(inputs, output_filters, training, self.dimension,
                            'block_%d' % block_index)

        """
        Note: Identity skip connections are between the output of encoding_block_i and
        the output of upsampling in decoding_block_i, i = 1, 2, ..., depth-1.
        skip_inputs[i] is the output of encoding_block_i now.
        len(skip_inputs) == depth - 1
        skip_inputs[depth-2] should be combined during decoding_block_depth-1
        skip_inputs[0] should be combined during decoding_block_1
        """

        # decoding_block_j (up) = upsampling + zero or more res_block, j = depth-1, depth-2, ..., 1
        for j in range(self.depth-1, 0, -1):
            with tf.variable_scope('decoding_block_%d' % j):
                output_filters = self.first_output_filters * (2**(j-1))

                # upsampling
                upsampling_func = up_transposed_convolution
                inputs = upsampling_func(inputs, output_filters, training, self.dimension,
                            'upsampling')

                # combine with skip connections
                if self.skip_method == 'add':
                    inputs = tf.add(inputs, skip_inputs[j-1])
                elif self.skip_method == 'concat':
                    inputs = tf.concat([inputs, skip_inputs[j-1]], axis=-1)

                for block_index in range(0, self.decoding_block_sizes[self.depth-1-j]):
                    inputs = res_block(inputs, output_filters, training, self.dimension,
                                'res_%d' % block_index)

        return inputs


    def _get_downsampling_function(self, name):
        if name == 'down_res_block':
            return down_res_block
        elif name == 'convolution':
            return down_convolution
        else:
            raise ValueError("Unsupported function: %s." % (name))