sayakpaul HF staff commited on
Commit
bcbf0c9
1 Parent(s): 541b585

initial files.

Browse files
1.png ADDED
15.png ADDED
55.png ADDED
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Rain13k Deraining Maxim
3
- emoji: 👀
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.7
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Rain13K Deraining MAXIM
3
+ emoji: 💻
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.5
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Some preprocessing utilities have been taken from:
3
+ https://github.com/google-research/maxim/blob/main/maxim/run_eval.py
4
+ """
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from huggingface_hub.keras_mixin import from_pretrained_keras
9
+ from PIL import Image
10
+
11
+ from create_maxim_model import Model
12
+ from maxim.configs import MAXIM_CONFIGS
13
+
14
+ CKPT = "sayakpaul/S-2_deraining_rain13k"
15
+ VARIANT = CKPT.split("/")[-1].split("_")[0]
16
+ _MODEL = from_pretrained_keras(CKPT)
17
+
18
+
19
+ def mod_padding_symmetric(image, factor=64):
20
+ """Padding the image to be divided by factor."""
21
+ height, width = image.shape[0], image.shape[1]
22
+ height_pad, width_pad = ((height + factor) // factor) * factor, (
23
+ (width + factor) // factor
24
+ ) * factor
25
+ padh = height_pad - height if height % factor != 0 else 0
26
+ padw = width_pad - width if width % factor != 0 else 0
27
+ image = tf.pad(
28
+ image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
29
+ )
30
+ return image
31
+
32
+
33
+ def make_shape_even(image):
34
+ """Pad the image to have even shapes."""
35
+ height, width = image.shape[0], image.shape[1]
36
+ padh = 1 if height % 2 != 0 else 0
37
+ padw = 1 if width % 2 != 0 else 0
38
+ image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
39
+ return image
40
+
41
+
42
+ def process_image(image: Image):
43
+ input_img = np.asarray(image) / 255.0
44
+ height, width = input_img.shape[0], input_img.shape[1]
45
+
46
+ # Padding images to have even shapes
47
+ input_img = make_shape_even(input_img)
48
+ height_even, width_even = input_img.shape[0], input_img.shape[1]
49
+
50
+ # padding images to be multiplies of 64
51
+ input_img = mod_padding_symmetric(input_img, factor=64)
52
+ input_img = tf.expand_dims(input_img, axis=0)
53
+ return input_img, height, width, height_even, width_even
54
+
55
+
56
+ def init_new_model(input_img):
57
+ configs = MAXIM_CONFIGS.get(VARIANT)
58
+ configs.update(
59
+ {
60
+ "variant": VARIANT,
61
+ "dropout_rate": 0.0,
62
+ "num_outputs": 3,
63
+ "use_bias": True,
64
+ "num_supervision_scales": 3,
65
+ }
66
+ )
67
+ configs.update({"input_resolution": (input_img.shape[1], input_img.shape[2])})
68
+ new_model = Model(**configs)
69
+ new_model.set_weights(_MODEL.get_weights())
70
+ return new_model
71
+
72
+
73
+ def infer(image):
74
+ preprocessed_image, height, width, height_even, width_even = process_image(image)
75
+ new_model = init_new_model(preprocessed_image)
76
+
77
+ preds = new_model.predict(preprocessed_image)
78
+ if isinstance(preds, list):
79
+ preds = preds[-1]
80
+ if isinstance(preds, list):
81
+ preds = preds[-1]
82
+
83
+ preds = np.array(preds[0], np.float32)
84
+
85
+ new_height, new_width = preds.shape[0], preds.shape[1]
86
+ h_start = new_height // 2 - height_even // 2
87
+ h_end = h_start + height
88
+ w_start = new_width // 2 - width_even // 2
89
+ w_end = w_start + width
90
+ preds = preds[h_start:h_end, w_start:w_end, :]
91
+
92
+ return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
93
+
94
+
95
+ title = "Derain images containing rain drops or stripes."
96
+ description = f"The underlying model is [this](https://huggingface.co/{CKPT}). You can use the model to derain images containing rain drops or stripes. To quickly try out the model, you can choose from the available sample images below, or you can submit your own image. Not that, internally, the model is re-initialized based on the spatial dimensions of the input image and this process is time-consuming."
97
+
98
+ iface = gr.Interface(
99
+ infer,
100
+ inputs="image",
101
+ outputs=gr.Image().style(height=242),
102
+ title=title,
103
+ description=description,
104
+ allow_flagging="never",
105
+ examples=[["1.MP4.png"], ["15.png"], ["55.MP4.png"]],
106
+ )
107
+ iface.launch(debug=True)
create_maxim_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow import keras
2
+
3
+ from maxim import maxim
4
+ from maxim.configs import MAXIM_CONFIGS
5
+
6
+
7
+ def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
8
+ """Factory function to easily create a Model variant like "S".
9
+
10
+ Args:
11
+ variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
12
+ | 'M-1' | 'M-2' | 'M-3'
13
+ input_resolution: Size of the input images.
14
+ **kw: Other UNet config dicts.
15
+
16
+ Returns:
17
+ The MAXIM model.
18
+ """
19
+
20
+ if variant is not None:
21
+ config = MAXIM_CONFIGS[variant]
22
+ for k, v in config.items():
23
+ kw.setdefault(k, v)
24
+
25
+ if "variant" in kw:
26
+ _ = kw.pop("variant")
27
+ if "input_resolution" in kw:
28
+ _ = kw.pop("input_resolution")
29
+ model_name = kw.pop("name")
30
+
31
+ maxim_model = maxim.MAXIM(**kw)
32
+
33
+ inputs = keras.Input((*input_resolution, 3))
34
+ outputs = maxim_model(inputs)
35
+ final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")
36
+
37
+ return final_model
maxim/__init__.py ADDED
File without changes
maxim/blocks/__init__.py ADDED
File without changes
maxim/blocks/attentions.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+
6
+ from .others import MlpBlock
7
+
8
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+
11
+
12
+ def CALayer(
13
+ num_channels: int,
14
+ reduction: int = 4,
15
+ use_bias: bool = True,
16
+ name: str = "channel_attention",
17
+ ):
18
+ """Squeeze-and-excitation block for channel attention.
19
+
20
+ ref: https://arxiv.org/abs/1709.01507
21
+ """
22
+
23
+ def apply(x):
24
+ # 2D global average pooling
25
+ y = layers.GlobalAvgPool2D(keepdims=True)(x)
26
+ # Squeeze (in Squeeze-Excitation)
27
+ y = Conv1x1(
28
+ filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
29
+ )(y)
30
+ y = tf.nn.relu(y)
31
+ # Excitation (in Squeeze-Excitation)
32
+ y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
33
+ y = tf.nn.sigmoid(y)
34
+ return x * y
35
+
36
+ return apply
37
+
38
+
39
+ def RCAB(
40
+ num_channels: int,
41
+ reduction: int = 4,
42
+ lrelu_slope: float = 0.2,
43
+ use_bias: bool = True,
44
+ name: str = "residual_ca",
45
+ ):
46
+ """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
47
+
48
+ def apply(x):
49
+ shortcut = x
50
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
51
+ x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
52
+ x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
53
+ x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
54
+ x = CALayer(
55
+ num_channels=num_channels,
56
+ reduction=reduction,
57
+ use_bias=use_bias,
58
+ name=f"{name}_channel_attention",
59
+ )(x)
60
+ return x + shortcut
61
+
62
+ return apply
63
+
64
+
65
+ def RDCAB(
66
+ num_channels: int,
67
+ reduction: int = 16,
68
+ use_bias: bool = True,
69
+ dropout_rate: float = 0.0,
70
+ name: str = "rdcab",
71
+ ):
72
+ """Residual dense channel attention block. Used in Bottlenecks."""
73
+
74
+ def apply(x):
75
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
76
+ y = MlpBlock(
77
+ mlp_dim=num_channels,
78
+ dropout_rate=dropout_rate,
79
+ use_bias=use_bias,
80
+ name=f"{name}_channel_mixing",
81
+ )(y)
82
+ y = CALayer(
83
+ num_channels=num_channels,
84
+ reduction=reduction,
85
+ use_bias=use_bias,
86
+ name=f"{name}_channel_attention",
87
+ )(y)
88
+ x = x + y
89
+ return x
90
+
91
+ return apply
92
+
93
+
94
+ def SAM(
95
+ num_channels: int,
96
+ output_channels: int = 3,
97
+ use_bias: bool = True,
98
+ name: str = "sam",
99
+ ):
100
+
101
+ """Supervised attention module for multi-stage training.
102
+
103
+ Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
104
+ """
105
+
106
+ def apply(x, x_image):
107
+ """Apply the SAM module to the input and num_channels.
108
+ Args:
109
+ x: the output num_channels from UNet decoder with shape (h, w, c)
110
+ x_image: the input image with shape (h, w, 3)
111
+ Returns:
112
+ A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
113
+ next stage, and (image) is the output restored image at current stage.
114
+ """
115
+ # Get num_channels
116
+ x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
117
+
118
+ # Output restored image X_s
119
+ if output_channels == 3:
120
+ image = (
121
+ Conv3x3(
122
+ filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
123
+ )(x)
124
+ + x_image
125
+ )
126
+ else:
127
+ image = Conv3x3(
128
+ filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
129
+ )(x)
130
+
131
+ # Get attention maps for num_channels
132
+ x2 = tf.nn.sigmoid(
133
+ Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
134
+ )
135
+
136
+ # Get attended feature maps
137
+ x1 = x1 * x2
138
+
139
+ # Residual connection
140
+ x1 = x1 + x
141
+ return x1, image
142
+
143
+ return apply
maxim/blocks/block_gating.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import backend as K
3
+ from tensorflow.keras import layers
4
+
5
+ from ..layers import BlockImages, SwapAxes, UnblockImages
6
+
7
+
8
+ def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
9
+ """A SpatialGatingUnit as defined in the gMLP paper.
10
+
11
+ The 'spatial' dim is defined as the **second last**.
12
+ If applied on other dims, you should swapaxes first.
13
+ """
14
+
15
+ def apply(x):
16
+ u, v = tf.split(x, 2, axis=-1)
17
+ v = layers.LayerNormalization(
18
+ epsilon=1e-06, name=f"{name}_intermediate_layernorm"
19
+ )(v)
20
+ n = K.int_shape(x)[-2] # get spatial dim
21
+ v = SwapAxes()(v, -1, -2)
22
+ v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
23
+ v = SwapAxes()(v, -1, -2)
24
+ return u * (v + 1.0)
25
+
26
+ return apply
27
+
28
+
29
+ def BlockGmlpLayer(
30
+ block_size,
31
+ use_bias: bool = True,
32
+ factor: int = 2,
33
+ dropout_rate: float = 0.0,
34
+ name: str = "block_gmlp",
35
+ ):
36
+ """Block gMLP layer that performs local mixing of tokens."""
37
+
38
+ def apply(x):
39
+ n, h, w, num_channels = (
40
+ K.int_shape(x)[0],
41
+ K.int_shape(x)[1],
42
+ K.int_shape(x)[2],
43
+ K.int_shape(x)[3],
44
+ )
45
+ fh, fw = block_size
46
+ gh, gw = h // fh, w // fw
47
+ x = BlockImages()(x, patch_size=(fh, fw))
48
+ # MLP2: Local (block) mixing part, provides within-block communication.
49
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
50
+ y = layers.Dense(
51
+ num_channels * factor,
52
+ use_bias=use_bias,
53
+ name=f"{name}_in_project",
54
+ )(y)
55
+ y = tf.nn.gelu(y, approximate=True)
56
+ y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
57
+ y = layers.Dense(
58
+ num_channels,
59
+ use_bias=use_bias,
60
+ name=f"{name}_out_project",
61
+ )(y)
62
+ y = layers.Dropout(dropout_rate)(y)
63
+ x = x + y
64
+ x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
65
+ return x
66
+
67
+ return apply
maxim/blocks/bottleneck.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ from tensorflow.keras import layers
4
+
5
+ from .attentions import RDCAB
6
+ from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
7
+
8
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
9
+
10
+
11
+ def BottleneckBlock(
12
+ features: int,
13
+ block_size,
14
+ grid_size,
15
+ num_groups: int = 1,
16
+ block_gmlp_factor: int = 2,
17
+ grid_gmlp_factor: int = 2,
18
+ input_proj_factor: int = 2,
19
+ channels_reduction: int = 4,
20
+ dropout_rate: float = 0.0,
21
+ use_bias: bool = True,
22
+ name: str = "bottleneck_block",
23
+ ):
24
+ """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
25
+
26
+ def apply(x):
27
+ # input projection
28
+ x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
29
+ shortcut_long = x
30
+
31
+ for i in range(num_groups):
32
+ x = ResidualSplitHeadMultiAxisGmlpLayer(
33
+ grid_size=grid_size,
34
+ block_size=block_size,
35
+ grid_gmlp_factor=grid_gmlp_factor,
36
+ block_gmlp_factor=block_gmlp_factor,
37
+ input_proj_factor=input_proj_factor,
38
+ use_bias=use_bias,
39
+ dropout_rate=dropout_rate,
40
+ name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
41
+ )(x)
42
+ # Channel-mixing part, which provides within-patch communication.
43
+ x = RDCAB(
44
+ num_channels=features,
45
+ reduction=channels_reduction,
46
+ use_bias=use_bias,
47
+ name=f"{name}_channel_attention_block_1_{i}",
48
+ )(x)
49
+
50
+ # long skip-connect
51
+ x = x + shortcut_long
52
+ return x
53
+
54
+ return apply
maxim/blocks/grid_gating.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import backend as K
3
+ from tensorflow.keras import layers
4
+
5
+ from ..layers import BlockImages, SwapAxes, UnblockImages
6
+
7
+
8
+ def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
9
+ """A SpatialGatingUnit as defined in the gMLP paper.
10
+
11
+ The 'spatial' dim is defined as the second last.
12
+ If applied on other dims, you should swapaxes first.
13
+ """
14
+
15
+ def apply(x):
16
+ u, v = tf.split(x, 2, axis=-1)
17
+ v = layers.LayerNormalization(
18
+ epsilon=1e-06, name=f"{name}_intermediate_layernorm"
19
+ )(v)
20
+ n = K.int_shape(x)[-3] # get spatial dim
21
+ v = SwapAxes()(v, -1, -3)
22
+ v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
23
+ v = SwapAxes()(v, -1, -3)
24
+ return u * (v + 1.0)
25
+
26
+ return apply
27
+
28
+
29
+ def GridGmlpLayer(
30
+ grid_size,
31
+ use_bias: bool = True,
32
+ factor: int = 2,
33
+ dropout_rate: float = 0.0,
34
+ name: str = "grid_gmlp",
35
+ ):
36
+ """Grid gMLP layer that performs global mixing of tokens."""
37
+
38
+ def apply(x):
39
+ n, h, w, num_channels = (
40
+ K.int_shape(x)[0],
41
+ K.int_shape(x)[1],
42
+ K.int_shape(x)[2],
43
+ K.int_shape(x)[3],
44
+ )
45
+ gh, gw = grid_size
46
+ fh, fw = h // gh, w // gw
47
+
48
+ x = BlockImages()(x, patch_size=(fh, fw))
49
+ # gMLP1: Global (grid) mixing part, provides global grid communication.
50
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
51
+ y = layers.Dense(
52
+ num_channels * factor,
53
+ use_bias=use_bias,
54
+ name=f"{name}_in_project",
55
+ )(y)
56
+ y = tf.nn.gelu(y, approximate=True)
57
+ y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
58
+ y = layers.Dense(
59
+ num_channels,
60
+ use_bias=use_bias,
61
+ name=f"{name}_out_project",
62
+ )(y)
63
+ y = layers.Dropout(dropout_rate)(y)
64
+ x = x + y
65
+ x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
66
+ return x
67
+
68
+ return apply
maxim/blocks/misc_gating.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from ..layers import BlockImages, SwapAxes, UnblockImages
8
+ from .block_gating import BlockGmlpLayer
9
+ from .grid_gating import GridGmlpLayer
10
+
11
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
12
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
13
+ ConvT_up = functools.partial(
14
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
15
+ )
16
+ Conv_down = functools.partial(
17
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
18
+ )
19
+
20
+
21
+ def ResidualSplitHeadMultiAxisGmlpLayer(
22
+ block_size,
23
+ grid_size,
24
+ block_gmlp_factor: int = 2,
25
+ grid_gmlp_factor: int = 2,
26
+ input_proj_factor: int = 2,
27
+ use_bias: bool = True,
28
+ dropout_rate: float = 0.0,
29
+ name: str = "residual_split_head_maxim",
30
+ ):
31
+ """The multi-axis gated MLP block."""
32
+
33
+ def apply(x):
34
+ shortcut = x
35
+ n, h, w, num_channels = (
36
+ K.int_shape(x)[0],
37
+ K.int_shape(x)[1],
38
+ K.int_shape(x)[2],
39
+ K.int_shape(x)[3],
40
+ )
41
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
42
+
43
+ x = layers.Dense(
44
+ int(num_channels) * input_proj_factor,
45
+ use_bias=use_bias,
46
+ name=f"{name}_in_project",
47
+ )(x)
48
+ x = tf.nn.gelu(x, approximate=True)
49
+
50
+ u, v = tf.split(x, 2, axis=-1)
51
+
52
+ # GridGMLPLayer
53
+ u = GridGmlpLayer(
54
+ grid_size=grid_size,
55
+ factor=grid_gmlp_factor,
56
+ use_bias=use_bias,
57
+ dropout_rate=dropout_rate,
58
+ name=f"{name}_GridGmlpLayer",
59
+ )(u)
60
+
61
+ # BlockGMLPLayer
62
+ v = BlockGmlpLayer(
63
+ block_size=block_size,
64
+ factor=block_gmlp_factor,
65
+ use_bias=use_bias,
66
+ dropout_rate=dropout_rate,
67
+ name=f"{name}_BlockGmlpLayer",
68
+ )(v)
69
+
70
+ x = tf.concat([u, v], axis=-1)
71
+
72
+ x = layers.Dense(
73
+ num_channels,
74
+ use_bias=use_bias,
75
+ name=f"{name}_out_project",
76
+ )(x)
77
+ x = layers.Dropout(dropout_rate)(x)
78
+ x = x + shortcut
79
+ return x
80
+
81
+ return apply
82
+
83
+
84
+ def GetSpatialGatingWeights(
85
+ features: int,
86
+ block_size,
87
+ grid_size,
88
+ input_proj_factor: int = 2,
89
+ dropout_rate: float = 0.0,
90
+ use_bias: bool = True,
91
+ name: str = "spatial_gating",
92
+ ):
93
+
94
+ """Get gating weights for cross-gating MLP block."""
95
+
96
+ def apply(x):
97
+ n, h, w, num_channels = (
98
+ K.int_shape(x)[0],
99
+ K.int_shape(x)[1],
100
+ K.int_shape(x)[2],
101
+ K.int_shape(x)[3],
102
+ )
103
+
104
+ # input projection
105
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
106
+ x = layers.Dense(
107
+ num_channels * input_proj_factor,
108
+ use_bias=use_bias,
109
+ name=f"{name}_in_project",
110
+ )(x)
111
+ x = tf.nn.gelu(x, approximate=True)
112
+ u, v = tf.split(x, 2, axis=-1)
113
+
114
+ # Get grid MLP weights
115
+ gh, gw = grid_size
116
+ fh, fw = h // gh, w // gw
117
+ u = BlockImages()(u, patch_size=(fh, fw))
118
+ dim_u = K.int_shape(u)[-3]
119
+ u = SwapAxes()(u, -1, -3)
120
+ u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
121
+ u = SwapAxes()(u, -1, -3)
122
+ u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
123
+
124
+ # Get Block MLP weights
125
+ fh, fw = block_size
126
+ gh, gw = h // fh, w // fw
127
+ v = BlockImages()(v, patch_size=(fh, fw))
128
+ dim_v = K.int_shape(v)[-2]
129
+ v = SwapAxes()(v, -1, -2)
130
+ v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
131
+ v = SwapAxes()(v, -1, -2)
132
+ v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
133
+
134
+ x = tf.concat([u, v], axis=-1)
135
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
136
+ x = layers.Dropout(dropout_rate)(x)
137
+ return x
138
+
139
+ return apply
140
+
141
+
142
+ def CrossGatingBlock(
143
+ features: int,
144
+ block_size,
145
+ grid_size,
146
+ dropout_rate: float = 0.0,
147
+ input_proj_factor: int = 2,
148
+ upsample_y: bool = True,
149
+ use_bias: bool = True,
150
+ name: str = "cross_gating",
151
+ ):
152
+
153
+ """Cross-gating MLP block."""
154
+
155
+ def apply(x, y):
156
+ # Upscale Y signal, y is the gating signal.
157
+ if upsample_y:
158
+ y = ConvT_up(
159
+ filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
160
+ )(y)
161
+
162
+ x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
163
+ n, h, w, num_channels = (
164
+ K.int_shape(x)[0],
165
+ K.int_shape(x)[1],
166
+ K.int_shape(x)[2],
167
+ K.int_shape(x)[3],
168
+ )
169
+
170
+ y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
171
+
172
+ shortcut_x = x
173
+ shortcut_y = y
174
+
175
+ # Get gating weights from X
176
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
177
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
178
+ x = tf.nn.gelu(x, approximate=True)
179
+ gx = GetSpatialGatingWeights(
180
+ features=num_channels,
181
+ block_size=block_size,
182
+ grid_size=grid_size,
183
+ dropout_rate=dropout_rate,
184
+ use_bias=use_bias,
185
+ name=f"{name}_SplitHeadMultiAxisGating_x",
186
+ )(x)
187
+
188
+ # Get gating weights from Y
189
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
190
+ y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
191
+ y = tf.nn.gelu(y, approximate=True)
192
+ gy = GetSpatialGatingWeights(
193
+ features=num_channels,
194
+ block_size=block_size,
195
+ grid_size=grid_size,
196
+ dropout_rate=dropout_rate,
197
+ use_bias=use_bias,
198
+ name=f"{name}_SplitHeadMultiAxisGating_y",
199
+ )(y)
200
+
201
+ # Apply cross gating: X = X * GY, Y = Y * GX
202
+ y = y * gx
203
+ y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
204
+ y = layers.Dropout(dropout_rate)(y)
205
+ y = y + shortcut_y
206
+
207
+ x = x * gy # gating x using y
208
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
209
+ x = layers.Dropout(dropout_rate)(x)
210
+ x = x + y + shortcut_x # get all aggregated signals
211
+ return x, y
212
+
213
+ return apply
maxim/blocks/others.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from ..layers import Resizing
8
+
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+
11
+
12
+ def MlpBlock(
13
+ mlp_dim: int,
14
+ dropout_rate: float = 0.0,
15
+ use_bias: bool = True,
16
+ name: str = "mlp_block",
17
+ ):
18
+ """A 1-hidden-layer MLP block, applied over the last dimension."""
19
+
20
+ def apply(x):
21
+ d = K.int_shape(x)[-1]
22
+ x = layers.Dense(mlp_dim, use_bias=use_bias, name=f"{name}_Dense_0")(x)
23
+ x = tf.nn.gelu(x, approximate=True)
24
+ x = layers.Dropout(dropout_rate)(x)
25
+ x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1")(x)
26
+ return x
27
+
28
+ return apply
29
+
30
+
31
+ def UpSampleRatio(
32
+ num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
33
+ ):
34
+ """Upsample features given a ratio > 0."""
35
+
36
+ def apply(x):
37
+ n, h, w, c = (
38
+ K.int_shape(x)[0],
39
+ K.int_shape(x)[1],
40
+ K.int_shape(x)[2],
41
+ K.int_shape(x)[3],
42
+ )
43
+
44
+ # Following `jax.image.resize()`
45
+ x = Resizing(
46
+ height=int(h * ratio),
47
+ width=int(w * ratio),
48
+ method="bilinear",
49
+ antialias=True,
50
+ name=f"{name}_resizing_{K.get_uid('Resizing')}",
51
+ )(x)
52
+
53
+ x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
54
+ return x
55
+
56
+ return apply
maxim/blocks/unet.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+
6
+ from .attentions import RCAB
7
+ from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer
8
+
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
11
+ ConvT_up = functools.partial(
12
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
13
+ )
14
+ Conv_down = functools.partial(
15
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
16
+ )
17
+
18
+
19
+ def UNetEncoderBlock(
20
+ num_channels: int,
21
+ block_size,
22
+ grid_size,
23
+ num_groups: int = 1,
24
+ lrelu_slope: float = 0.2,
25
+ block_gmlp_factor: int = 2,
26
+ grid_gmlp_factor: int = 2,
27
+ input_proj_factor: int = 2,
28
+ channels_reduction: int = 4,
29
+ dropout_rate: float = 0.0,
30
+ downsample: bool = True,
31
+ use_global_mlp: bool = True,
32
+ use_bias: bool = True,
33
+ use_cross_gating: bool = False,
34
+ name: str = "unet_encoder",
35
+ ):
36
+ """Encoder block in MAXIM."""
37
+
38
+ def apply(x, skip=None, enc=None, dec=None):
39
+ if skip is not None:
40
+ x = tf.concat([x, skip], axis=-1)
41
+
42
+ # convolution-in
43
+ x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
44
+ shortcut_long = x
45
+
46
+ for i in range(num_groups):
47
+ if use_global_mlp:
48
+ x = ResidualSplitHeadMultiAxisGmlpLayer(
49
+ grid_size=grid_size,
50
+ block_size=block_size,
51
+ grid_gmlp_factor=grid_gmlp_factor,
52
+ block_gmlp_factor=block_gmlp_factor,
53
+ input_proj_factor=input_proj_factor,
54
+ use_bias=use_bias,
55
+ dropout_rate=dropout_rate,
56
+ name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
57
+ )(x)
58
+ x = RCAB(
59
+ num_channels=num_channels,
60
+ reduction=channels_reduction,
61
+ lrelu_slope=lrelu_slope,
62
+ use_bias=use_bias,
63
+ name=f"{name}_channel_attention_block_1{i}",
64
+ )(x)
65
+
66
+ x = x + shortcut_long
67
+
68
+ if enc is not None and dec is not None:
69
+ assert use_cross_gating
70
+ x, _ = CrossGatingBlock(
71
+ features=num_channels,
72
+ block_size=block_size,
73
+ grid_size=grid_size,
74
+ dropout_rate=dropout_rate,
75
+ input_proj_factor=input_proj_factor,
76
+ upsample_y=False,
77
+ use_bias=use_bias,
78
+ name=f"{name}_cross_gating_block",
79
+ )(x, enc + dec)
80
+
81
+ if downsample:
82
+ x_down = Conv_down(
83
+ filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1"
84
+ )(x)
85
+ return x_down, x
86
+ else:
87
+ return x
88
+
89
+ return apply
90
+
91
+
92
+ def UNetDecoderBlock(
93
+ num_channels: int,
94
+ block_size,
95
+ grid_size,
96
+ num_groups: int = 1,
97
+ lrelu_slope: float = 0.2,
98
+ block_gmlp_factor: int = 2,
99
+ grid_gmlp_factor: int = 2,
100
+ input_proj_factor: int = 2,
101
+ channels_reduction: int = 4,
102
+ dropout_rate: float = 0.0,
103
+ downsample: bool = True,
104
+ use_global_mlp: bool = True,
105
+ use_bias: bool = True,
106
+ name: str = "unet_decoder",
107
+ ):
108
+
109
+ """Decoder block in MAXIM."""
110
+
111
+ def apply(x, bridge=None):
112
+ x = ConvT_up(
113
+ filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
114
+ )(x)
115
+ x = UNetEncoderBlock(
116
+ num_channels=num_channels,
117
+ num_groups=num_groups,
118
+ lrelu_slope=lrelu_slope,
119
+ block_size=block_size,
120
+ grid_size=grid_size,
121
+ block_gmlp_factor=block_gmlp_factor,
122
+ grid_gmlp_factor=grid_gmlp_factor,
123
+ channels_reduction=channels_reduction,
124
+ use_global_mlp=use_global_mlp,
125
+ dropout_rate=dropout_rate,
126
+ downsample=False,
127
+ use_bias=use_bias,
128
+ name=f"{name}_UNetEncoderBlock_0",
129
+ )(x, skip=bridge)
130
+
131
+ return x
132
+
133
+ return apply
maxim/configs.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MAXIM_CONFIGS = {
2
+ # params: 6.108515000000001 M, GFLOPS: 93.163716608
3
+ "S-1": {
4
+ "features": 32,
5
+ "depth": 3,
6
+ "num_stages": 1,
7
+ "num_groups": 2,
8
+ "num_bottleneck_blocks": 2,
9
+ "block_gmlp_factor": 2,
10
+ "grid_gmlp_factor": 2,
11
+ "input_proj_factor": 2,
12
+ "channels_reduction": 4,
13
+ "name": "s1",
14
+ },
15
+ # params: 13.35383 M, GFLOPS: 206.743273472
16
+ "S-2": {
17
+ "features": 32,
18
+ "depth": 3,
19
+ "num_stages": 2,
20
+ "num_groups": 2,
21
+ "num_bottleneck_blocks": 2,
22
+ "block_gmlp_factor": 2,
23
+ "grid_gmlp_factor": 2,
24
+ "input_proj_factor": 2,
25
+ "channels_reduction": 4,
26
+ "name": "s2",
27
+ },
28
+ # params: 20.599145 M, GFLOPS: 320.32194560000005
29
+ "S-3": {
30
+ "features": 32,
31
+ "depth": 3,
32
+ "num_stages": 3,
33
+ "num_groups": 2,
34
+ "num_bottleneck_blocks": 2,
35
+ "block_gmlp_factor": 2,
36
+ "grid_gmlp_factor": 2,
37
+ "input_proj_factor": 2,
38
+ "channels_reduction": 4,
39
+ "name": "s3",
40
+ },
41
+ # params: 19.361219000000002 M, 308.495712256 GFLOPs
42
+ "M-1": {
43
+ "features": 64,
44
+ "depth": 3,
45
+ "num_stages": 1,
46
+ "num_groups": 2,
47
+ "num_bottleneck_blocks": 2,
48
+ "block_gmlp_factor": 2,
49
+ "grid_gmlp_factor": 2,
50
+ "input_proj_factor": 2,
51
+ "channels_reduction": 4,
52
+ "name": "m1",
53
+ },
54
+ # params: 40.83911 M, 675.25541888 GFLOPs
55
+ "M-2": {
56
+ "features": 64,
57
+ "depth": 3,
58
+ "num_stages": 2,
59
+ "num_groups": 2,
60
+ "num_bottleneck_blocks": 2,
61
+ "block_gmlp_factor": 2,
62
+ "grid_gmlp_factor": 2,
63
+ "input_proj_factor": 2,
64
+ "channels_reduction": 4,
65
+ "name": "m2",
66
+ },
67
+ # params: 62.317001 M, 1042.014666752 GFLOPs
68
+ "M-3": {
69
+ "features": 64,
70
+ "depth": 3,
71
+ "num_stages": 3,
72
+ "num_groups": 2,
73
+ "num_bottleneck_blocks": 2,
74
+ "block_gmlp_factor": 2,
75
+ "grid_gmlp_factor": 2,
76
+ "input_proj_factor": 2,
77
+ "channels_reduction": 4,
78
+ "name": "m3",
79
+ },
80
+ }
maxim/layers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import tensorflow as tf
3
+ from tensorflow.experimental import numpy as tnp
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+
8
+ @tf.keras.utils.register_keras_serializable("maxim")
9
+ class BlockImages(layers.Layer):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+
13
+ def call(self, x, patch_size):
14
+ bs, h, w, num_channels = (
15
+ K.int_shape(x)[0],
16
+ K.int_shape(x)[1],
17
+ K.int_shape(x)[2],
18
+ K.int_shape(x)[3],
19
+ )
20
+
21
+ grid_height, grid_width = h // patch_size[0], w // patch_size[1]
22
+
23
+ x = einops.rearrange(
24
+ x,
25
+ "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
26
+ gh=grid_height,
27
+ gw=grid_width,
28
+ fh=patch_size[0],
29
+ fw=patch_size[1],
30
+ )
31
+
32
+ return x
33
+
34
+ def get_config(self):
35
+ config = super().get_config().copy()
36
+ return config
37
+
38
+
39
+ @tf.keras.utils.register_keras_serializable("maxim")
40
+ class UnblockImages(layers.Layer):
41
+ def __init__(self, **kwargs):
42
+ super().__init__(**kwargs)
43
+
44
+ def call(self, x, grid_size, patch_size):
45
+ x = einops.rearrange(
46
+ x,
47
+ "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
48
+ gh=grid_size[0],
49
+ gw=grid_size[1],
50
+ fh=patch_size[0],
51
+ fw=patch_size[1],
52
+ )
53
+
54
+ return x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ return config
59
+
60
+
61
+ @tf.keras.utils.register_keras_serializable("maxim")
62
+ class SwapAxes(layers.Layer):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(**kwargs)
65
+
66
+ def call(self, x, axis_one, axis_two):
67
+ return tnp.swapaxes(x, axis_one, axis_two)
68
+
69
+ def get_config(self):
70
+ config = super().get_config().copy()
71
+ return config
72
+
73
+
74
+ @tf.keras.utils.register_keras_serializable("maxim")
75
+ class Resizing(layers.Layer):
76
+ def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
77
+ super().__init__(**kwargs)
78
+ self.height = height
79
+ self.width = width
80
+ self.antialias = antialias
81
+ self.method = method
82
+
83
+ def call(self, x):
84
+ return tf.image.resize(
85
+ x,
86
+ size=(self.height, self.width),
87
+ antialias=self.antialias,
88
+ method=self.method,
89
+ )
90
+
91
+ def get_config(self):
92
+ config = super().get_config().copy()
93
+ config.update(
94
+ {
95
+ "height": self.height,
96
+ "width": self.width,
97
+ "antialias": self.antialias,
98
+ "method": self.method,
99
+ }
100
+ )
101
+ return config
maxim/maxim.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from .blocks.attentions import SAM
8
+ from .blocks.bottleneck import BottleneckBlock
9
+ from .blocks.misc_gating import CrossGatingBlock
10
+ from .blocks.others import UpSampleRatio
11
+ from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
12
+ from .layers import Resizing
13
+
14
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
15
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
16
+ ConvT_up = functools.partial(
17
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
18
+ )
19
+ Conv_down = functools.partial(
20
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
21
+ )
22
+
23
+
24
+ def MAXIM(
25
+ features: int = 64,
26
+ depth: int = 3,
27
+ num_stages: int = 2,
28
+ num_groups: int = 1,
29
+ use_bias: bool = True,
30
+ num_supervision_scales: int = 1,
31
+ lrelu_slope: float = 0.2,
32
+ use_global_mlp: bool = True,
33
+ use_cross_gating: bool = True,
34
+ high_res_stages: int = 2,
35
+ block_size_hr=(16, 16),
36
+ block_size_lr=(8, 8),
37
+ grid_size_hr=(16, 16),
38
+ grid_size_lr=(8, 8),
39
+ num_bottleneck_blocks: int = 1,
40
+ block_gmlp_factor: int = 2,
41
+ grid_gmlp_factor: int = 2,
42
+ input_proj_factor: int = 2,
43
+ channels_reduction: int = 4,
44
+ num_outputs: int = 3,
45
+ dropout_rate: float = 0.0,
46
+ ):
47
+ """The MAXIM model function with multi-stage and multi-scale supervision.
48
+
49
+ For more model details, please check the CVPR paper:
50
+ MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
51
+
52
+ Attributes:
53
+ features: initial hidden dimension for the input resolution.
54
+ depth: the number of downsampling depth for the model.
55
+ num_stages: how many stages to use. It will also affects the output list.
56
+ num_groups: how many blocks each stage contains.
57
+ use_bias: whether to use bias in all the conv/mlp layers.
58
+ num_supervision_scales: the number of desired supervision scales.
59
+ lrelu_slope: the negative slope parameter in leaky_relu layers.
60
+ use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
61
+ layer.
62
+ use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
63
+ skip connections and multi-stage feature fusion layers.
64
+ high_res_stages: how many stages are specificied as high-res stages. The
65
+ rest (depth - high_res_stages) are called low_res_stages.
66
+ block_size_hr: the block_size parameter for high-res stages.
67
+ block_size_lr: the block_size parameter for low-res stages.
68
+ grid_size_hr: the grid_size parameter for high-res stages.
69
+ grid_size_lr: the grid_size parameter for low-res stages.
70
+ num_bottleneck_blocks: how many bottleneck blocks.
71
+ block_gmlp_factor: the input projection factor for block_gMLP layers.
72
+ grid_gmlp_factor: the input projection factor for grid_gMLP layers.
73
+ input_proj_factor: the input projection factor for the MAB block.
74
+ channels_reduction: the channel reduction factor for SE layer.
75
+ num_outputs: the output channels.
76
+ dropout_rate: Dropout rate.
77
+
78
+ Returns:
79
+ The output contains a list of arrays consisting of multi-stage multi-scale
80
+ outputs. For example, if num_stages = num_supervision_scales = 3 (the
81
+ model used in the paper), the output specs are: outputs =
82
+ [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
83
+ [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
84
+ [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
85
+ The final output can be retrieved by outputs[-1][-1].
86
+ """
87
+
88
+ def apply(x):
89
+ n, h, w, c = (
90
+ K.int_shape(x)[0],
91
+ K.int_shape(x)[1],
92
+ K.int_shape(x)[2],
93
+ K.int_shape(x)[3],
94
+ ) # input image shape
95
+
96
+ shortcuts = []
97
+ shortcuts.append(x)
98
+
99
+ # Get multi-scale input images
100
+ for i in range(1, num_supervision_scales):
101
+ resizing_layer = Resizing(
102
+ height=h // (2 ** i),
103
+ width=w // (2 ** i),
104
+ method="nearest",
105
+ antialias=True, # Following `jax.image.resize()`.
106
+ name=f"initial_resizing_{K.get_uid('Resizing')}",
107
+ )
108
+ shortcuts.append(resizing_layer(x))
109
+
110
+ # store outputs from all stages and all scales
111
+ # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs
112
+ # [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs
113
+ outputs_all = []
114
+ sam_features, encs_prev, decs_prev = [], [], []
115
+
116
+ for idx_stage in range(num_stages):
117
+ # Input convolution, get multi-scale input features
118
+ x_scales = []
119
+ for i in range(num_supervision_scales):
120
+ x_scale = Conv3x3(
121
+ filters=(2 ** i) * features,
122
+ use_bias=use_bias,
123
+ name=f"stage_{idx_stage}_input_conv_{i}",
124
+ )(shortcuts[i])
125
+
126
+ # If later stages, fuse input features with SAM features from prev stage
127
+ if idx_stage > 0:
128
+ # use larger blocksize at high-res stages
129
+ if use_cross_gating:
130
+ block_size = (
131
+ block_size_hr if i < high_res_stages else block_size_lr
132
+ )
133
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
134
+ x_scale, _ = CrossGatingBlock(
135
+ features=(2 ** i) * features,
136
+ block_size=block_size,
137
+ grid_size=grid_size,
138
+ dropout_rate=dropout_rate,
139
+ input_proj_factor=input_proj_factor,
140
+ upsample_y=False,
141
+ use_bias=use_bias,
142
+ name=f"stage_{idx_stage}_input_fuse_sam_{i}",
143
+ )(x_scale, sam_features.pop())
144
+ else:
145
+ x_scale = Conv1x1(
146
+ filters=(2 ** i) * features,
147
+ use_bias=use_bias,
148
+ name=f"stage_{idx_stage}_input_catconv_{i}",
149
+ )(tf.concat([x_scale, sam_features.pop()], axis=-1))
150
+
151
+ x_scales.append(x_scale)
152
+
153
+ # start encoder blocks
154
+ encs = []
155
+ x = x_scales[0] # First full-scale input feature
156
+
157
+ for i in range(depth): # 0, 1, 2
158
+ # use larger blocksize at high-res stages, vice versa.
159
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
160
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
161
+ use_cross_gating_layer = True if idx_stage > 0 else False
162
+
163
+ # Multi-scale input if multi-scale supervision
164
+ x_scale = x_scales[i] if i < num_supervision_scales else None
165
+
166
+ # UNet Encoder block
167
+ enc_prev = encs_prev.pop() if idx_stage > 0 else None
168
+ dec_prev = decs_prev.pop() if idx_stage > 0 else None
169
+
170
+ x, bridge = UNetEncoderBlock(
171
+ num_channels=(2 ** i) * features,
172
+ num_groups=num_groups,
173
+ downsample=True,
174
+ lrelu_slope=lrelu_slope,
175
+ block_size=block_size,
176
+ grid_size=grid_size,
177
+ block_gmlp_factor=block_gmlp_factor,
178
+ grid_gmlp_factor=grid_gmlp_factor,
179
+ input_proj_factor=input_proj_factor,
180
+ channels_reduction=channels_reduction,
181
+ use_global_mlp=use_global_mlp,
182
+ dropout_rate=dropout_rate,
183
+ use_bias=use_bias,
184
+ use_cross_gating=use_cross_gating_layer,
185
+ name=f"stage_{idx_stage}_encoder_block_{i}",
186
+ )(x, skip=x_scale, enc=enc_prev, dec=dec_prev)
187
+
188
+ # Cache skip signals
189
+ encs.append(bridge)
190
+
191
+ # Global MLP bottleneck blocks
192
+ for i in range(num_bottleneck_blocks):
193
+ x = BottleneckBlock(
194
+ block_size=block_size_lr,
195
+ grid_size=block_size_lr,
196
+ features=(2 ** (depth - 1)) * features,
197
+ num_groups=num_groups,
198
+ block_gmlp_factor=block_gmlp_factor,
199
+ grid_gmlp_factor=grid_gmlp_factor,
200
+ input_proj_factor=input_proj_factor,
201
+ dropout_rate=dropout_rate,
202
+ use_bias=use_bias,
203
+ channels_reduction=channels_reduction,
204
+ name=f"stage_{idx_stage}_global_block_{i}",
205
+ )(x)
206
+ # cache global feature for cross-gating
207
+ global_feature = x
208
+
209
+ # start cross gating. Use multi-scale feature fusion
210
+ skip_features = []
211
+ for i in reversed(range(depth)): # 2, 1, 0
212
+ # use larger blocksize at high-res stages
213
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
214
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
215
+
216
+ # get additional multi-scale signals
217
+ signal = tf.concat(
218
+ [
219
+ UpSampleRatio(
220
+ num_channels=(2 ** i) * features,
221
+ ratio=2 ** (j - i),
222
+ use_bias=use_bias,
223
+ name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
224
+ )(enc)
225
+ for j, enc in enumerate(encs)
226
+ ],
227
+ axis=-1,
228
+ )
229
+
230
+ # Use cross-gating to cross modulate features
231
+ if use_cross_gating:
232
+ skips, global_feature = CrossGatingBlock(
233
+ features=(2 ** i) * features,
234
+ block_size=block_size,
235
+ grid_size=grid_size,
236
+ input_proj_factor=input_proj_factor,
237
+ dropout_rate=dropout_rate,
238
+ upsample_y=True,
239
+ use_bias=use_bias,
240
+ name=f"stage_{idx_stage}_cross_gating_block_{i}",
241
+ )(signal, global_feature)
242
+ else:
243
+ skips = Conv1x1(
244
+ filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
245
+ )(signal)
246
+ skips = Conv3x3(
247
+ filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
248
+ )(skips)
249
+
250
+ skip_features.append(skips)
251
+
252
+ # start decoder. Multi-scale feature fusion of cross-gated features
253
+ outputs, decs, sam_features = [], [], []
254
+ for i in reversed(range(depth)):
255
+ # use larger blocksize at high-res stages
256
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
257
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
258
+
259
+ # get multi-scale skip signals from cross-gating block
260
+ signal = tf.concat(
261
+ [
262
+ UpSampleRatio(
263
+ num_channels=(2 ** i) * features,
264
+ ratio=2 ** (depth - j - 1 - i),
265
+ use_bias=use_bias,
266
+ name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
267
+ )(skip)
268
+ for j, skip in enumerate(skip_features)
269
+ ],
270
+ axis=-1,
271
+ )
272
+
273
+ # Decoder block
274
+ x = UNetDecoderBlock(
275
+ num_channels=(2 ** i) * features,
276
+ num_groups=num_groups,
277
+ lrelu_slope=lrelu_slope,
278
+ block_size=block_size,
279
+ grid_size=grid_size,
280
+ block_gmlp_factor=block_gmlp_factor,
281
+ grid_gmlp_factor=grid_gmlp_factor,
282
+ input_proj_factor=input_proj_factor,
283
+ channels_reduction=channels_reduction,
284
+ use_global_mlp=use_global_mlp,
285
+ dropout_rate=dropout_rate,
286
+ use_bias=use_bias,
287
+ name=f"stage_{idx_stage}_decoder_block_{i}",
288
+ )(x, bridge=signal)
289
+
290
+ # Cache decoder features for later-stage's usage
291
+ decs.append(x)
292
+
293
+ # output conv, if not final stage, use supervised-attention-block.
294
+ if i < num_supervision_scales:
295
+ if idx_stage < num_stages - 1: # not last stage, apply SAM
296
+ sam, output = SAM(
297
+ num_channels=(2 ** i) * features,
298
+ output_channels=num_outputs,
299
+ use_bias=use_bias,
300
+ name=f"stage_{idx_stage}_supervised_attention_module_{i}",
301
+ )(x, shortcuts[i])
302
+ outputs.append(output)
303
+ sam_features.append(sam)
304
+ else: # Last stage, apply output convolutions
305
+ output = Conv3x3(
306
+ num_outputs,
307
+ use_bias=use_bias,
308
+ name=f"stage_{idx_stage}_output_conv_{i}",
309
+ )(x)
310
+ output = output + shortcuts[i]
311
+ outputs.append(output)
312
+ # Cache encoder and decoder features for later-stage's usage
313
+ encs_prev = encs[::-1]
314
+ decs_prev = decs
315
+
316
+ # Store outputs
317
+ outputs_all.append(outputs)
318
+ return outputs_all
319
+
320
+ return apply
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow==2.10.0
2
+ einops
3
+ numpy