import functools import tensorflow as tf from tensorflow.keras import layers from .attentions import RCAB from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") ConvT_up = functools.partial( layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" ) Conv_down = functools.partial( layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" ) def UNetEncoderBlock( num_channels: int, block_size, grid_size, num_groups: int = 1, lrelu_slope: float = 0.2, block_gmlp_factor: int = 2, grid_gmlp_factor: int = 2, input_proj_factor: int = 2, channels_reduction: int = 4, dropout_rate: float = 0.0, downsample: bool = True, use_global_mlp: bool = True, use_bias: bool = True, use_cross_gating: bool = False, name: str = "unet_encoder", ): """Encoder block in MAXIM.""" def apply(x, skip=None, enc=None, dec=None): if skip is not None: x = tf.concat([x, skip], axis=-1) # convolution-in x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x) shortcut_long = x for i in range(num_groups): if use_global_mlp: x = ResidualSplitHeadMultiAxisGmlpLayer( grid_size=grid_size, block_size=block_size, grid_gmlp_factor=grid_gmlp_factor, block_gmlp_factor=block_gmlp_factor, input_proj_factor=input_proj_factor, use_bias=use_bias, dropout_rate=dropout_rate, name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}", )(x) x = RCAB( num_channels=num_channels, reduction=channels_reduction, lrelu_slope=lrelu_slope, use_bias=use_bias, name=f"{name}_channel_attention_block_1{i}", )(x) x = x + shortcut_long if enc is not None and dec is not None: assert use_cross_gating x, _ = CrossGatingBlock( features=num_channels, block_size=block_size, grid_size=grid_size, dropout_rate=dropout_rate, input_proj_factor=input_proj_factor, upsample_y=False, use_bias=use_bias, name=f"{name}_cross_gating_block", )(x, enc + dec) if downsample: x_down = Conv_down( filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1" )(x) return x_down, x else: return x return apply def UNetDecoderBlock( num_channels: int, block_size, grid_size, num_groups: int = 1, lrelu_slope: float = 0.2, block_gmlp_factor: int = 2, grid_gmlp_factor: int = 2, input_proj_factor: int = 2, channels_reduction: int = 4, dropout_rate: float = 0.0, downsample: bool = True, use_global_mlp: bool = True, use_bias: bool = True, name: str = "unet_decoder", ): """Decoder block in MAXIM.""" def apply(x, bridge=None): x = ConvT_up( filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0" )(x) x = UNetEncoderBlock( num_channels=num_channels, num_groups=num_groups, lrelu_slope=lrelu_slope, block_size=block_size, grid_size=grid_size, block_gmlp_factor=block_gmlp_factor, grid_gmlp_factor=grid_gmlp_factor, channels_reduction=channels_reduction, use_global_mlp=use_global_mlp, dropout_rate=dropout_rate, downsample=False, use_bias=use_bias, name=f"{name}_UNetEncoderBlock_0", )(x, skip=bridge) return x return apply