import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import layers from ..layers import BlockImages, SwapAxes, UnblockImages def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"): """A SpatialGatingUnit as defined in the gMLP paper. The 'spatial' dim is defined as the **second last**. If applied on other dims, you should swapaxes first. """ def apply(x): u, v = tf.split(x, 2, axis=-1) v = layers.LayerNormalization( epsilon=1e-06, name=f"{name}_intermediate_layernorm" )(v) n = K.int_shape(x)[-2] # get spatial dim v = SwapAxes()(v, -1, -2) v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) v = SwapAxes()(v, -1, -2) return u * (v + 1.0) return apply def BlockGmlpLayer( block_size, use_bias: bool = True, factor: int = 2, dropout_rate: float = 0.0, name: str = "block_gmlp", ): """Block gMLP layer that performs local mixing of tokens.""" def apply(x): n, h, w, num_channels = ( K.int_shape(x)[0], K.int_shape(x)[1], K.int_shape(x)[2], K.int_shape(x)[3], ) fh, fw = block_size gh, gw = h // fh, w // fw x = BlockImages()(x, patch_size=(fh, fw)) # MLP2: Local (block) mixing part, provides within-block communication. y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) y = layers.Dense( num_channels * factor, use_bias=use_bias, name=f"{name}_in_project", )(y) y = tf.nn.gelu(y, approximate=True) y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y) y = layers.Dense( num_channels, use_bias=use_bias, name=f"{name}_out_project", )(y) y = layers.Dropout(dropout_rate)(y) x = x + y x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) return x return apply