Spaces:
Running
Running
| import functools | |
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| from .attentions import RCAB | |
| from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer | |
| Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
| Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | |
| ConvT_up = functools.partial( | |
| layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" | |
| ) | |
| Conv_down = functools.partial( | |
| layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" | |
| ) | |
| def UNetEncoderBlock( | |
| num_channels: int, | |
| block_size, | |
| grid_size, | |
| num_groups: int = 1, | |
| lrelu_slope: float = 0.2, | |
| block_gmlp_factor: int = 2, | |
| grid_gmlp_factor: int = 2, | |
| input_proj_factor: int = 2, | |
| channels_reduction: int = 4, | |
| dropout_rate: float = 0.0, | |
| downsample: bool = True, | |
| use_global_mlp: bool = True, | |
| use_bias: bool = True, | |
| use_cross_gating: bool = False, | |
| name: str = "unet_encoder", | |
| ): | |
| """Encoder block in MAXIM.""" | |
| def apply(x, skip=None, enc=None, dec=None): | |
| if skip is not None: | |
| x = tf.concat([x, skip], axis=-1) | |
| # convolution-in | |
| x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x) | |
| shortcut_long = x | |
| for i in range(num_groups): | |
| if use_global_mlp: | |
| x = ResidualSplitHeadMultiAxisGmlpLayer( | |
| grid_size=grid_size, | |
| block_size=block_size, | |
| grid_gmlp_factor=grid_gmlp_factor, | |
| block_gmlp_factor=block_gmlp_factor, | |
| input_proj_factor=input_proj_factor, | |
| use_bias=use_bias, | |
| dropout_rate=dropout_rate, | |
| name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}", | |
| )(x) | |
| x = RCAB( | |
| num_channels=num_channels, | |
| reduction=channels_reduction, | |
| lrelu_slope=lrelu_slope, | |
| use_bias=use_bias, | |
| name=f"{name}_channel_attention_block_1{i}", | |
| )(x) | |
| x = x + shortcut_long | |
| if enc is not None and dec is not None: | |
| assert use_cross_gating | |
| x, _ = CrossGatingBlock( | |
| features=num_channels, | |
| block_size=block_size, | |
| grid_size=grid_size, | |
| dropout_rate=dropout_rate, | |
| input_proj_factor=input_proj_factor, | |
| upsample_y=False, | |
| use_bias=use_bias, | |
| name=f"{name}_cross_gating_block", | |
| )(x, enc + dec) | |
| if downsample: | |
| x_down = Conv_down( | |
| filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1" | |
| )(x) | |
| return x_down, x | |
| else: | |
| return x | |
| return apply | |
| def UNetDecoderBlock( | |
| num_channels: int, | |
| block_size, | |
| grid_size, | |
| num_groups: int = 1, | |
| lrelu_slope: float = 0.2, | |
| block_gmlp_factor: int = 2, | |
| grid_gmlp_factor: int = 2, | |
| input_proj_factor: int = 2, | |
| channels_reduction: int = 4, | |
| dropout_rate: float = 0.0, | |
| downsample: bool = True, | |
| use_global_mlp: bool = True, | |
| use_bias: bool = True, | |
| name: str = "unet_decoder", | |
| ): | |
| """Decoder block in MAXIM.""" | |
| def apply(x, bridge=None): | |
| x = ConvT_up( | |
| filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0" | |
| )(x) | |
| x = UNetEncoderBlock( | |
| num_channels=num_channels, | |
| num_groups=num_groups, | |
| lrelu_slope=lrelu_slope, | |
| block_size=block_size, | |
| grid_size=grid_size, | |
| block_gmlp_factor=block_gmlp_factor, | |
| grid_gmlp_factor=grid_gmlp_factor, | |
| channels_reduction=channels_reduction, | |
| use_global_mlp=use_global_mlp, | |
| dropout_rate=dropout_rate, | |
| downsample=False, | |
| use_bias=use_bias, | |
| name=f"{name}_UNetEncoderBlock_0", | |
| )(x, skip=bridge) | |
| return x | |
| return apply | |