|
import tensorflow as tf |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras import layers |
|
|
|
from ..layers import BlockImages, SwapAxes, UnblockImages |
|
|
|
|
|
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): |
|
"""A SpatialGatingUnit as defined in the gMLP paper. |
|
|
|
The 'spatial' dim is defined as the second last. |
|
If applied on other dims, you should swapaxes first. |
|
""" |
|
|
|
def apply(x): |
|
u, v = tf.split(x, 2, axis=-1) |
|
v = layers.LayerNormalization( |
|
epsilon=1e-06, name=f"{name}_intermediate_layernorm" |
|
)(v) |
|
n = K.int_shape(x)[-3] |
|
v = SwapAxes()(v, -1, -3) |
|
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) |
|
v = SwapAxes()(v, -1, -3) |
|
return u * (v + 1.0) |
|
|
|
return apply |
|
|
|
|
|
def GridGmlpLayer( |
|
grid_size, |
|
use_bias: bool = True, |
|
factor: int = 2, |
|
dropout_rate: float = 0.0, |
|
name: str = "grid_gmlp", |
|
): |
|
"""Grid gMLP layer that performs global mixing of tokens.""" |
|
|
|
def apply(x): |
|
n, h, w, num_channels = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
gh, gw = grid_size |
|
fh, fw = h // gh, w // gw |
|
|
|
x = BlockImages()(x, patch_size=(fh, fw)) |
|
|
|
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) |
|
y = layers.Dense( |
|
num_channels * factor, |
|
use_bias=use_bias, |
|
name=f"{name}_in_project", |
|
)(y) |
|
y = tf.nn.gelu(y, approximate=True) |
|
y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y) |
|
y = layers.Dense( |
|
num_channels, |
|
use_bias=use_bias, |
|
name=f"{name}_out_project", |
|
)(y) |
|
y = layers.Dropout(dropout_rate)(y) |
|
x = x + y |
|
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) |
|
return x |
|
|
|
return apply |
|
|