|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for squeeze_and_excite.py.""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model.layers import squeeze_and_excite |
|
|
|
|
|
class SqueezeAndExciteTest(tf.test.TestCase): |
|
|
|
def test_simpliefied_squeeze_and_excite_input_output_shape(self): |
|
|
|
channels = 32 |
|
input_tensor = tf.random.uniform(shape=(3, 65, 65, channels)) |
|
layer_op = squeeze_and_excite.SimplifiedSqueezeAndExcite( |
|
channels) |
|
output_tensor = layer_op(input_tensor) |
|
self.assertListEqual(input_tensor.get_shape().as_list(), |
|
output_tensor.get_shape().as_list()) |
|
|
|
def test_squeeze_and_excite_input_output_shape(self): |
|
|
|
channels = 32 |
|
input_tensor = tf.random.uniform(shape=(3, 65, 65, channels)) |
|
layer_op = squeeze_and_excite.SqueezeAndExcite( |
|
in_filters=channels, |
|
out_filters=channels, |
|
se_ratio=8, |
|
name='se') |
|
output_tensor = layer_op(input_tensor) |
|
self.assertListEqual(input_tensor.get_shape().as_list(), |
|
output_tensor.get_shape().as_list()) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|