|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for axial_blocks.""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model.layers import axial_blocks |
|
|
|
|
|
class AxialBlocksTest(tf.test.TestCase): |
|
|
|
def test_conv_basic_block_correct_output_shape(self): |
|
layer = axial_blocks.AxialBlock( |
|
filters_list=[256, 256], |
|
strides=2) |
|
float_training_tensor = tf.constant(0.0, dtype=tf.float32) |
|
output = layer((tf.zeros([2, 65, 65, 32]), |
|
float_training_tensor))[1] |
|
self.assertListEqual(output.get_shape().as_list(), [2, 33, 33, 256]) |
|
|
|
def test_conv_bottleneck_block_correct_output_shape(self): |
|
layer = axial_blocks.AxialBlock( |
|
filters_list=[64, 64, 256], |
|
strides=1) |
|
float_training_tensor = tf.constant(0.0, dtype=tf.float32) |
|
output = layer((tf.zeros([2, 65, 65, 32]), |
|
float_training_tensor))[0] |
|
self.assertListEqual(output.get_shape().as_list(), [2, 65, 65, 256]) |
|
|
|
def test_axial_block_correct_output_shape(self): |
|
layer = axial_blocks.AxialBlock( |
|
filters_list=[128, 64, 256], |
|
strides=2, |
|
attention_type='axial') |
|
float_training_tensor = tf.constant(0.0, dtype=tf.float32) |
|
output = layer((tf.zeros([2, 65, 65, 32]), |
|
float_training_tensor))[1] |
|
self.assertListEqual(output.get_shape().as_list(), [2, 33, 33, 256]) |
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|