File size: 4,545 Bytes
9a393e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for object_detection.core.freezable_batch_norm."""
import numpy as np
import tensorflow as tf

from object_detection.core import freezable_batch_norm


class FreezableBatchNormTest(tf.test.TestCase):
  """Tests for FreezableBatchNorm operations."""

  def _build_model(self, training=None):
    model = tf.keras.models.Sequential()
    norm = freezable_batch_norm.FreezableBatchNorm(training=training,
                                                   input_shape=(10,),
                                                   momentum=0.8)
    model.add(norm)
    return model, norm

  def _train_freezable_batch_norm(self, training_mean, training_var):
    model, _ = self._build_model()
    model.compile(loss='mse', optimizer='sgd')

    # centered on training_mean, variance training_var
    train_data = np.random.normal(
        loc=training_mean,
        scale=training_var,
        size=(1000, 10))
    model.fit(train_data, train_data, epochs=4, verbose=0)
    return model.weights

  def test_batchnorm_freezing_training_true(self):
    with self.test_session():
      training_mean = 5.0
      training_var = 10.0

      testing_mean = -10.0
      testing_var = 5.0

      # Initially train the batch norm, and save the weights
      trained_weights = self._train_freezable_batch_norm(training_mean,
                                                         training_var)

      # Load the batch norm weights, freezing training to True.
      # Apply the batch norm layer to testing data and ensure it is normalized
      # according to the batch statistics.
      model, norm = self._build_model(training=True)
      for trained_weight, blank_weight in zip(trained_weights, model.weights):
        weight_copy = blank_weight.assign(tf.keras.backend.eval(trained_weight))
        tf.keras.backend.eval(weight_copy)

      # centered on testing_mean, variance testing_var
      test_data = np.random.normal(
          loc=testing_mean,
          scale=testing_var,
          size=(1000, 10))

      out_tensor = norm(tf.convert_to_tensor(test_data, dtype=tf.float32))
      out = tf.keras.backend.eval(out_tensor)

      out -= tf.keras.backend.eval(norm.beta)
      out /= tf.keras.backend.eval(norm.gamma)

      np.testing.assert_allclose(out.mean(), 0.0, atol=1.5e-1)
      np.testing.assert_allclose(out.std(), 1.0, atol=1.5e-1)

  def test_batchnorm_freezing_training_false(self):
    with self.test_session():
      training_mean = 5.0
      training_var = 10.0

      testing_mean = -10.0
      testing_var = 5.0

      # Initially train the batch norm, and save the weights
      trained_weights = self._train_freezable_batch_norm(training_mean,
                                                         training_var)

      # Load the batch norm back up, freezing training to False.
      # Apply the batch norm layer to testing data and ensure it is normalized
      # according to the training data's statistics.
      model, norm = self._build_model(training=False)
      for trained_weight, blank_weight in zip(trained_weights, model.weights):
        weight_copy = blank_weight.assign(tf.keras.backend.eval(trained_weight))
        tf.keras.backend.eval(weight_copy)

      # centered on testing_mean, variance testing_var
      test_data = np.random.normal(
          loc=testing_mean,
          scale=testing_var,
          size=(1000, 10))

      out_tensor = norm(tf.convert_to_tensor(test_data, dtype=tf.float32))
      out = tf.keras.backend.eval(out_tensor)

      out -= tf.keras.backend.eval(norm.beta)
      out /= tf.keras.backend.eval(norm.gamma)

      out *= training_var
      out += (training_mean - testing_mean)
      out /= testing_var

      np.testing.assert_allclose(out.mean(), 0.0, atol=1.5e-1)
      np.testing.assert_allclose(out.std(), 1.0, atol=1.5e-1)

if __name__ == '__main__':
  tf.test.main()