|
import functools |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras import layers |
|
|
|
from ..layers import BlockImages, SwapAxes, UnblockImages |
|
from .block_gating import BlockGmlpLayer |
|
from .grid_gating import GridGmlpLayer |
|
|
|
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") |
|
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") |
|
ConvT_up = functools.partial( |
|
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" |
|
) |
|
Conv_down = functools.partial( |
|
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" |
|
) |
|
|
|
|
|
def ResidualSplitHeadMultiAxisGmlpLayer( |
|
block_size, |
|
grid_size, |
|
block_gmlp_factor: int = 2, |
|
grid_gmlp_factor: int = 2, |
|
input_proj_factor: int = 2, |
|
use_bias: bool = True, |
|
dropout_rate: float = 0.0, |
|
name: str = "residual_split_head_maxim", |
|
): |
|
"""The multi-axis gated MLP block.""" |
|
|
|
def apply(x): |
|
shortcut = x |
|
n, h, w, num_channels = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) |
|
|
|
x = layers.Dense( |
|
int(num_channels) * input_proj_factor, |
|
use_bias=use_bias, |
|
name=f"{name}_in_project", |
|
)(x) |
|
x = tf.nn.gelu(x, approximate=True) |
|
|
|
u, v = tf.split(x, 2, axis=-1) |
|
|
|
|
|
u = GridGmlpLayer( |
|
grid_size=grid_size, |
|
factor=grid_gmlp_factor, |
|
use_bias=use_bias, |
|
dropout_rate=dropout_rate, |
|
name=f"{name}_GridGmlpLayer", |
|
)(u) |
|
|
|
|
|
v = BlockGmlpLayer( |
|
block_size=block_size, |
|
factor=block_gmlp_factor, |
|
use_bias=use_bias, |
|
dropout_rate=dropout_rate, |
|
name=f"{name}_BlockGmlpLayer", |
|
)(v) |
|
|
|
x = tf.concat([u, v], axis=-1) |
|
|
|
x = layers.Dense( |
|
num_channels, |
|
use_bias=use_bias, |
|
name=f"{name}_out_project", |
|
)(x) |
|
x = layers.Dropout(dropout_rate)(x) |
|
x = x + shortcut |
|
return x |
|
|
|
return apply |
|
|
|
|
|
def GetSpatialGatingWeights( |
|
features: int, |
|
block_size, |
|
grid_size, |
|
input_proj_factor: int = 2, |
|
dropout_rate: float = 0.0, |
|
use_bias: bool = True, |
|
name: str = "spatial_gating", |
|
): |
|
|
|
"""Get gating weights for cross-gating MLP block.""" |
|
|
|
def apply(x): |
|
n, h, w, num_channels = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
|
|
|
|
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) |
|
x = layers.Dense( |
|
num_channels * input_proj_factor, |
|
use_bias=use_bias, |
|
name=f"{name}_in_project", |
|
)(x) |
|
x = tf.nn.gelu(x, approximate=True) |
|
u, v = tf.split(x, 2, axis=-1) |
|
|
|
|
|
gh, gw = grid_size |
|
fh, fw = h // gh, w // gw |
|
u = BlockImages()(u, patch_size=(fh, fw)) |
|
dim_u = K.int_shape(u)[-3] |
|
u = SwapAxes()(u, -1, -3) |
|
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u) |
|
u = SwapAxes()(u, -1, -3) |
|
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw)) |
|
|
|
|
|
fh, fw = block_size |
|
gh, gw = h // fh, w // fw |
|
v = BlockImages()(v, patch_size=(fh, fw)) |
|
dim_v = K.int_shape(v)[-2] |
|
v = SwapAxes()(v, -1, -2) |
|
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v) |
|
v = SwapAxes()(v, -1, -2) |
|
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw)) |
|
|
|
x = tf.concat([u, v], axis=-1) |
|
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x) |
|
x = layers.Dropout(dropout_rate)(x) |
|
return x |
|
|
|
return apply |
|
|
|
|
|
def CrossGatingBlock( |
|
features: int, |
|
block_size, |
|
grid_size, |
|
dropout_rate: float = 0.0, |
|
input_proj_factor: int = 2, |
|
upsample_y: bool = True, |
|
use_bias: bool = True, |
|
name: str = "cross_gating", |
|
): |
|
|
|
"""Cross-gating MLP block.""" |
|
|
|
def apply(x, y): |
|
|
|
if upsample_y: |
|
y = ConvT_up( |
|
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0" |
|
)(y) |
|
|
|
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x) |
|
n, h, w, num_channels = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
|
|
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y) |
|
|
|
shortcut_x = x |
|
shortcut_y = y |
|
|
|
|
|
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x) |
|
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x) |
|
x = tf.nn.gelu(x, approximate=True) |
|
gx = GetSpatialGatingWeights( |
|
features=num_channels, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
name=f"{name}_SplitHeadMultiAxisGating_x", |
|
)(x) |
|
|
|
|
|
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y) |
|
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y) |
|
y = tf.nn.gelu(y, approximate=True) |
|
gy = GetSpatialGatingWeights( |
|
features=num_channels, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
name=f"{name}_SplitHeadMultiAxisGating_y", |
|
)(y) |
|
|
|
|
|
y = y * gx |
|
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y) |
|
y = layers.Dropout(dropout_rate)(y) |
|
y = y + shortcut_y |
|
|
|
x = x * gy |
|
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x) |
|
x = layers.Dropout(dropout_rate)(x) |
|
x = x + y + shortcut_x |
|
return x, y |
|
|
|
return apply |
|
|