|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for axial_layers.""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model.layers import axial_layers |
|
|
|
|
|
class AxialLayersTest(tf.test.TestCase): |
|
|
|
def test_default_axial_attention_layer_output_shape(self): |
|
layer = axial_layers.AxialAttention() |
|
output = layer(tf.zeros([10, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [10, 5, 1024]) |
|
|
|
def test_axial_attention_2d_layer_output_shape(self): |
|
layer = axial_layers.AxialAttention2D() |
|
output = layer(tf.zeros([2, 5, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [2, 5, 5, 1024]) |
|
|
|
def test_change_filters_output_shape(self): |
|
layer = axial_layers.AxialAttention2D(filters=32) |
|
output = layer(tf.zeros([2, 5, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [2, 5, 5, 64]) |
|
|
|
def test_value_expansion_output_shape(self): |
|
layer = axial_layers.AxialAttention2D(value_expansion=1) |
|
output = layer(tf.zeros([2, 5, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [2, 5, 5, 512]) |
|
|
|
def test_global_attention_output_shape(self): |
|
layer = axial_layers.GlobalAttention2D() |
|
output = layer(tf.zeros([2, 5, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [2, 5, 5, 1024]) |
|
|
|
def test_stride_two_output_shape(self): |
|
layer = axial_layers.AxialAttention2D(strides=2) |
|
output = layer(tf.zeros([2, 5, 5, 32])) |
|
self.assertListEqual(output.get_shape().as_list(), [2, 3, 3, 1024]) |
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|