import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import layers from ..layers import BlockImages, SwapAxes, UnblockImages def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): """A SpatialGatingUnit as defined in the gMLP paper. The 'spatial' dim is defined as the second last. If applied on other dims, you should swapaxes first. """ def apply(x): u, v = tf.split(x, 2, axis=-1) v = layers.LayerNormalization( epsilon=1e-06, name=f"{name}_intermediate_layernorm" )(v) n = K.int_shape(x)[-3] # get spatial dim v = SwapAxes()(v, -1, -3) v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) v = SwapAxes()(v, -1, -3) return u * (v + 1.0) return apply def GridGmlpLayer( grid_size, use_bias: bool = True, factor: int = 2, dropout_rate: float = 0.0, name: str = "grid_gmlp", ): """Grid gMLP layer that performs global mixing of tokens.""" def apply(x): n, h, w, num_channels = ( K.int_shape(x)[0], K.int_shape(x)[1], K.int_shape(x)[2], K.int_shape(x)[3], ) gh, gw = grid_size fh, fw = h // gh, w // gw x = BlockImages()(x, patch_size=(fh, fw)) # gMLP1: Global (grid) mixing part, provides global grid communication. y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) y = layers.Dense( num_channels * factor, use_bias=use_bias, name=f"{name}_in_project", )(y) y = tf.nn.gelu(y, approximate=True) y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y) y = layers.Dense( num_channels, use_bias=use_bias, name=f"{name}_out_project", )(y) y = layers.Dropout(dropout_rate)(y) x = x + y x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) return x return apply