Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
from tensorflow.keras import layers | |
from ..layers import BlockImages, SwapAxes, UnblockImages | |
def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"): | |
"""A SpatialGatingUnit as defined in the gMLP paper. | |
The 'spatial' dim is defined as the **second last**. | |
If applied on other dims, you should swapaxes first. | |
""" | |
def apply(x): | |
u, v = tf.split(x, 2, axis=-1) | |
v = layers.LayerNormalization( | |
epsilon=1e-06, name=f"{name}_intermediate_layernorm" | |
)(v) | |
n = K.int_shape(x)[-2] # get spatial dim | |
v = SwapAxes()(v, -1, -2) | |
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) | |
v = SwapAxes()(v, -1, -2) | |
return u * (v + 1.0) | |
return apply | |
def BlockGmlpLayer( | |
block_size, | |
use_bias: bool = True, | |
factor: int = 2, | |
dropout_rate: float = 0.0, | |
name: str = "block_gmlp", | |
): | |
"""Block gMLP layer that performs local mixing of tokens.""" | |
def apply(x): | |
n, h, w, num_channels = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) | |
fh, fw = block_size | |
gh, gw = h // fh, w // fw | |
x = BlockImages()(x, patch_size=(fh, fw)) | |
# MLP2: Local (block) mixing part, provides within-block communication. | |
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | |
y = layers.Dense( | |
num_channels * factor, | |
use_bias=use_bias, | |
name=f"{name}_in_project", | |
)(y) | |
y = tf.nn.gelu(y, approximate=True) | |
y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y) | |
y = layers.Dense( | |
num_channels, | |
use_bias=use_bias, | |
name=f"{name}_out_project", | |
)(y) | |
y = layers.Dropout(dropout_rate)(y) | |
x = x + y | |
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) | |
return x | |
return apply | |