add: files.
Browse files- 109fromGOPR1096.MP4.png +0 -0
- 110fromGOPR1087.MP4.png +0 -0
- 1fromGOPR0950.png +0 -0
- 1fromGOPR1096.MP4.png +0 -0
- README.md +5 -5
- app.py +112 -0
- create_maxim_model.py +37 -0
- maxim/__init__.py +0 -0
- maxim/blocks/__init__.py +0 -0
- maxim/blocks/attentions.py +143 -0
- maxim/blocks/block_gating.py +67 -0
- maxim/blocks/bottleneck.py +54 -0
- maxim/blocks/grid_gating.py +68 -0
- maxim/blocks/misc_gating.py +213 -0
- maxim/blocks/others.py +56 -0
- maxim/blocks/unet.py +133 -0
- maxim/configs.py +80 -0
- maxim/layers.py +101 -0
- maxim/maxim.py +320 -0
- requirements.txt +3 -0
109fromGOPR1096.MP4.png
ADDED
110fromGOPR1087.MP4.png
ADDED
1fromGOPR0950.png
ADDED
1fromGOPR1096.MP4.png
ADDED
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
1 |
---
|
2 |
+
title: GoPro Deblurring MAXIM
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.5
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Some preprocessing utilities have been taken from:
|
3 |
+
https://github.com/google-research/maxim/blob/main/maxim/run_eval.py
|
4 |
+
"""
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
from huggingface_hub.keras_mixin import from_pretrained_keras
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from create_maxim_model import Model
|
12 |
+
from maxim.configs import MAXIM_CONFIGS
|
13 |
+
|
14 |
+
CKPT = "sayakpaul/S-3_deblurring_gopro"
|
15 |
+
VARIANT = CKPT.split("/")[-1].split("_")[0]
|
16 |
+
_MODEL = from_pretrained_keras(CKPT)
|
17 |
+
|
18 |
+
|
19 |
+
def mod_padding_symmetric(image, factor=64):
|
20 |
+
"""Padding the image to be divided by factor."""
|
21 |
+
height, width = image.shape[0], image.shape[1]
|
22 |
+
height_pad, width_pad = ((height + factor) // factor) * factor, (
|
23 |
+
(width + factor) // factor
|
24 |
+
) * factor
|
25 |
+
padh = height_pad - height if height % factor != 0 else 0
|
26 |
+
padw = width_pad - width if width % factor != 0 else 0
|
27 |
+
image = tf.pad(
|
28 |
+
image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
|
29 |
+
)
|
30 |
+
return image
|
31 |
+
|
32 |
+
|
33 |
+
def make_shape_even(image):
|
34 |
+
"""Pad the image to have even shapes."""
|
35 |
+
height, width = image.shape[0], image.shape[1]
|
36 |
+
padh = 1 if height % 2 != 0 else 0
|
37 |
+
padw = 1 if width % 2 != 0 else 0
|
38 |
+
image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
|
39 |
+
return image
|
40 |
+
|
41 |
+
|
42 |
+
def process_image(image: Image):
|
43 |
+
input_img = np.asarray(image) / 255.0
|
44 |
+
height, width = input_img.shape[0], input_img.shape[1]
|
45 |
+
|
46 |
+
# Padding images to have even shapes
|
47 |
+
input_img = make_shape_even(input_img)
|
48 |
+
height_even, width_even = input_img.shape[0], input_img.shape[1]
|
49 |
+
|
50 |
+
# padding images to be multiplies of 64
|
51 |
+
input_img = mod_padding_symmetric(input_img, factor=64)
|
52 |
+
input_img = tf.expand_dims(input_img, axis=0)
|
53 |
+
return input_img, height, width, height_even, width_even
|
54 |
+
|
55 |
+
|
56 |
+
def init_new_model(input_img):
|
57 |
+
configs = MAXIM_CONFIGS.get(VARIANT)
|
58 |
+
configs.update(
|
59 |
+
{
|
60 |
+
"variant": VARIANT,
|
61 |
+
"dropout_rate": 0.0,
|
62 |
+
"num_outputs": 3,
|
63 |
+
"use_bias": True,
|
64 |
+
"num_supervision_scales": 3,
|
65 |
+
}
|
66 |
+
)
|
67 |
+
configs.update({"input_resolution": (input_img.shape[1], input_img.shape[2])})
|
68 |
+
new_model = Model(**configs)
|
69 |
+
new_model.set_weights(_MODEL.get_weights())
|
70 |
+
return new_model
|
71 |
+
|
72 |
+
|
73 |
+
def infer(image):
|
74 |
+
preprocessed_image, height, width, height_even, width_even = process_image(image)
|
75 |
+
new_model = init_new_model(preprocessed_image)
|
76 |
+
|
77 |
+
preds = new_model.predict(preprocessed_image)
|
78 |
+
if isinstance(preds, list):
|
79 |
+
preds = preds[-1]
|
80 |
+
if isinstance(preds, list):
|
81 |
+
preds = preds[-1]
|
82 |
+
|
83 |
+
preds = np.array(preds[0], np.float32)
|
84 |
+
|
85 |
+
new_height, new_width = preds.shape[0], preds.shape[1]
|
86 |
+
h_start = new_height // 2 - height_even // 2
|
87 |
+
h_end = h_start + height
|
88 |
+
w_start = new_width // 2 - width_even // 2
|
89 |
+
w_end = w_start + width
|
90 |
+
preds = preds[h_start:h_end, w_start:w_end, :]
|
91 |
+
|
92 |
+
return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
|
93 |
+
|
94 |
+
|
95 |
+
title = "Deblur blurry images."
|
96 |
+
description = f"The underlying model is [this](https://huggingface.co/{CKPT}). You can use the model to deblur blurry images useful for a lot of applications. To quickly try out the model, you can choose from the available sample images below, or you can submit your own image. Not that, internally, the model is re-initialized based on the spatial dimensions of the input image and this process is time-consuming."
|
97 |
+
|
98 |
+
iface = gr.Interface(
|
99 |
+
infer,
|
100 |
+
inputs="image",
|
101 |
+
outputs=gr.Image().style(height=242),
|
102 |
+
title=title,
|
103 |
+
description=description,
|
104 |
+
allow_flagging="never",
|
105 |
+
examples=[
|
106 |
+
["1fromGOPR1096.MP4.png"],
|
107 |
+
["1fromGOPR0950.png"],
|
108 |
+
["109fromGOPR1096.MP4.png"],
|
109 |
+
["110fromGOPR1087.MP4.png"],
|
110 |
+
],
|
111 |
+
)
|
112 |
+
iface.launch(debug=True)
|
create_maxim_model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow import keras
|
2 |
+
|
3 |
+
from maxim import maxim
|
4 |
+
from maxim.configs import MAXIM_CONFIGS
|
5 |
+
|
6 |
+
|
7 |
+
def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
|
8 |
+
"""Factory function to easily create a Model variant like "S".
|
9 |
+
|
10 |
+
Args:
|
11 |
+
variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
|
12 |
+
| 'M-1' | 'M-2' | 'M-3'
|
13 |
+
input_resolution: Size of the input images.
|
14 |
+
**kw: Other UNet config dicts.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
The MAXIM model.
|
18 |
+
"""
|
19 |
+
|
20 |
+
if variant is not None:
|
21 |
+
config = MAXIM_CONFIGS[variant]
|
22 |
+
for k, v in config.items():
|
23 |
+
kw.setdefault(k, v)
|
24 |
+
|
25 |
+
if "variant" in kw:
|
26 |
+
_ = kw.pop("variant")
|
27 |
+
if "input_resolution" in kw:
|
28 |
+
_ = kw.pop("input_resolution")
|
29 |
+
model_name = kw.pop("name")
|
30 |
+
|
31 |
+
maxim_model = maxim.MAXIM(**kw)
|
32 |
+
|
33 |
+
inputs = keras.Input((*input_resolution, 3))
|
34 |
+
outputs = maxim_model(inputs)
|
35 |
+
final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")
|
36 |
+
|
37 |
+
return final_model
|
maxim/__init__.py
ADDED
File without changes
|
maxim/blocks/__init__.py
ADDED
File without changes
|
maxim/blocks/attentions.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers
|
5 |
+
|
6 |
+
from .others import MlpBlock
|
7 |
+
|
8 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
|
11 |
+
|
12 |
+
def CALayer(
|
13 |
+
num_channels: int,
|
14 |
+
reduction: int = 4,
|
15 |
+
use_bias: bool = True,
|
16 |
+
name: str = "channel_attention",
|
17 |
+
):
|
18 |
+
"""Squeeze-and-excitation block for channel attention.
|
19 |
+
|
20 |
+
ref: https://arxiv.org/abs/1709.01507
|
21 |
+
"""
|
22 |
+
|
23 |
+
def apply(x):
|
24 |
+
# 2D global average pooling
|
25 |
+
y = layers.GlobalAvgPool2D(keepdims=True)(x)
|
26 |
+
# Squeeze (in Squeeze-Excitation)
|
27 |
+
y = Conv1x1(
|
28 |
+
filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
|
29 |
+
)(y)
|
30 |
+
y = tf.nn.relu(y)
|
31 |
+
# Excitation (in Squeeze-Excitation)
|
32 |
+
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
|
33 |
+
y = tf.nn.sigmoid(y)
|
34 |
+
return x * y
|
35 |
+
|
36 |
+
return apply
|
37 |
+
|
38 |
+
|
39 |
+
def RCAB(
|
40 |
+
num_channels: int,
|
41 |
+
reduction: int = 4,
|
42 |
+
lrelu_slope: float = 0.2,
|
43 |
+
use_bias: bool = True,
|
44 |
+
name: str = "residual_ca",
|
45 |
+
):
|
46 |
+
"""Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
|
47 |
+
|
48 |
+
def apply(x):
|
49 |
+
shortcut = x
|
50 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
51 |
+
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
|
52 |
+
x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
|
53 |
+
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
|
54 |
+
x = CALayer(
|
55 |
+
num_channels=num_channels,
|
56 |
+
reduction=reduction,
|
57 |
+
use_bias=use_bias,
|
58 |
+
name=f"{name}_channel_attention",
|
59 |
+
)(x)
|
60 |
+
return x + shortcut
|
61 |
+
|
62 |
+
return apply
|
63 |
+
|
64 |
+
|
65 |
+
def RDCAB(
|
66 |
+
num_channels: int,
|
67 |
+
reduction: int = 16,
|
68 |
+
use_bias: bool = True,
|
69 |
+
dropout_rate: float = 0.0,
|
70 |
+
name: str = "rdcab",
|
71 |
+
):
|
72 |
+
"""Residual dense channel attention block. Used in Bottlenecks."""
|
73 |
+
|
74 |
+
def apply(x):
|
75 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
76 |
+
y = MlpBlock(
|
77 |
+
mlp_dim=num_channels,
|
78 |
+
dropout_rate=dropout_rate,
|
79 |
+
use_bias=use_bias,
|
80 |
+
name=f"{name}_channel_mixing",
|
81 |
+
)(y)
|
82 |
+
y = CALayer(
|
83 |
+
num_channels=num_channels,
|
84 |
+
reduction=reduction,
|
85 |
+
use_bias=use_bias,
|
86 |
+
name=f"{name}_channel_attention",
|
87 |
+
)(y)
|
88 |
+
x = x + y
|
89 |
+
return x
|
90 |
+
|
91 |
+
return apply
|
92 |
+
|
93 |
+
|
94 |
+
def SAM(
|
95 |
+
num_channels: int,
|
96 |
+
output_channels: int = 3,
|
97 |
+
use_bias: bool = True,
|
98 |
+
name: str = "sam",
|
99 |
+
):
|
100 |
+
|
101 |
+
"""Supervised attention module for multi-stage training.
|
102 |
+
|
103 |
+
Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
|
104 |
+
"""
|
105 |
+
|
106 |
+
def apply(x, x_image):
|
107 |
+
"""Apply the SAM module to the input and num_channels.
|
108 |
+
Args:
|
109 |
+
x: the output num_channels from UNet decoder with shape (h, w, c)
|
110 |
+
x_image: the input image with shape (h, w, 3)
|
111 |
+
Returns:
|
112 |
+
A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
|
113 |
+
next stage, and (image) is the output restored image at current stage.
|
114 |
+
"""
|
115 |
+
# Get num_channels
|
116 |
+
x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
117 |
+
|
118 |
+
# Output restored image X_s
|
119 |
+
if output_channels == 3:
|
120 |
+
image = (
|
121 |
+
Conv3x3(
|
122 |
+
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
123 |
+
)(x)
|
124 |
+
+ x_image
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
image = Conv3x3(
|
128 |
+
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
129 |
+
)(x)
|
130 |
+
|
131 |
+
# Get attention maps for num_channels
|
132 |
+
x2 = tf.nn.sigmoid(
|
133 |
+
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
|
134 |
+
)
|
135 |
+
|
136 |
+
# Get attended feature maps
|
137 |
+
x1 = x1 * x2
|
138 |
+
|
139 |
+
# Residual connection
|
140 |
+
x1 = x1 + x
|
141 |
+
return x1, image
|
142 |
+
|
143 |
+
return apply
|
maxim/blocks/block_gating.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import backend as K
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
6 |
+
|
7 |
+
|
8 |
+
def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
|
9 |
+
"""A SpatialGatingUnit as defined in the gMLP paper.
|
10 |
+
|
11 |
+
The 'spatial' dim is defined as the **second last**.
|
12 |
+
If applied on other dims, you should swapaxes first.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def apply(x):
|
16 |
+
u, v = tf.split(x, 2, axis=-1)
|
17 |
+
v = layers.LayerNormalization(
|
18 |
+
epsilon=1e-06, name=f"{name}_intermediate_layernorm"
|
19 |
+
)(v)
|
20 |
+
n = K.int_shape(x)[-2] # get spatial dim
|
21 |
+
v = SwapAxes()(v, -1, -2)
|
22 |
+
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
|
23 |
+
v = SwapAxes()(v, -1, -2)
|
24 |
+
return u * (v + 1.0)
|
25 |
+
|
26 |
+
return apply
|
27 |
+
|
28 |
+
|
29 |
+
def BlockGmlpLayer(
|
30 |
+
block_size,
|
31 |
+
use_bias: bool = True,
|
32 |
+
factor: int = 2,
|
33 |
+
dropout_rate: float = 0.0,
|
34 |
+
name: str = "block_gmlp",
|
35 |
+
):
|
36 |
+
"""Block gMLP layer that performs local mixing of tokens."""
|
37 |
+
|
38 |
+
def apply(x):
|
39 |
+
n, h, w, num_channels = (
|
40 |
+
K.int_shape(x)[0],
|
41 |
+
K.int_shape(x)[1],
|
42 |
+
K.int_shape(x)[2],
|
43 |
+
K.int_shape(x)[3],
|
44 |
+
)
|
45 |
+
fh, fw = block_size
|
46 |
+
gh, gw = h // fh, w // fw
|
47 |
+
x = BlockImages()(x, patch_size=(fh, fw))
|
48 |
+
# MLP2: Local (block) mixing part, provides within-block communication.
|
49 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
50 |
+
y = layers.Dense(
|
51 |
+
num_channels * factor,
|
52 |
+
use_bias=use_bias,
|
53 |
+
name=f"{name}_in_project",
|
54 |
+
)(y)
|
55 |
+
y = tf.nn.gelu(y, approximate=True)
|
56 |
+
y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
|
57 |
+
y = layers.Dense(
|
58 |
+
num_channels,
|
59 |
+
use_bias=use_bias,
|
60 |
+
name=f"{name}_out_project",
|
61 |
+
)(y)
|
62 |
+
y = layers.Dropout(dropout_rate)(y)
|
63 |
+
x = x + y
|
64 |
+
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
|
65 |
+
return x
|
66 |
+
|
67 |
+
return apply
|
maxim/blocks/bottleneck.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from .attentions import RDCAB
|
6 |
+
from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
|
7 |
+
|
8 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
9 |
+
|
10 |
+
|
11 |
+
def BottleneckBlock(
|
12 |
+
features: int,
|
13 |
+
block_size,
|
14 |
+
grid_size,
|
15 |
+
num_groups: int = 1,
|
16 |
+
block_gmlp_factor: int = 2,
|
17 |
+
grid_gmlp_factor: int = 2,
|
18 |
+
input_proj_factor: int = 2,
|
19 |
+
channels_reduction: int = 4,
|
20 |
+
dropout_rate: float = 0.0,
|
21 |
+
use_bias: bool = True,
|
22 |
+
name: str = "bottleneck_block",
|
23 |
+
):
|
24 |
+
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
|
25 |
+
|
26 |
+
def apply(x):
|
27 |
+
# input projection
|
28 |
+
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
|
29 |
+
shortcut_long = x
|
30 |
+
|
31 |
+
for i in range(num_groups):
|
32 |
+
x = ResidualSplitHeadMultiAxisGmlpLayer(
|
33 |
+
grid_size=grid_size,
|
34 |
+
block_size=block_size,
|
35 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
36 |
+
block_gmlp_factor=block_gmlp_factor,
|
37 |
+
input_proj_factor=input_proj_factor,
|
38 |
+
use_bias=use_bias,
|
39 |
+
dropout_rate=dropout_rate,
|
40 |
+
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
|
41 |
+
)(x)
|
42 |
+
# Channel-mixing part, which provides within-patch communication.
|
43 |
+
x = RDCAB(
|
44 |
+
num_channels=features,
|
45 |
+
reduction=channels_reduction,
|
46 |
+
use_bias=use_bias,
|
47 |
+
name=f"{name}_channel_attention_block_1_{i}",
|
48 |
+
)(x)
|
49 |
+
|
50 |
+
# long skip-connect
|
51 |
+
x = x + shortcut_long
|
52 |
+
return x
|
53 |
+
|
54 |
+
return apply
|
maxim/blocks/grid_gating.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import backend as K
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
6 |
+
|
7 |
+
|
8 |
+
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
|
9 |
+
"""A SpatialGatingUnit as defined in the gMLP paper.
|
10 |
+
|
11 |
+
The 'spatial' dim is defined as the second last.
|
12 |
+
If applied on other dims, you should swapaxes first.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def apply(x):
|
16 |
+
u, v = tf.split(x, 2, axis=-1)
|
17 |
+
v = layers.LayerNormalization(
|
18 |
+
epsilon=1e-06, name=f"{name}_intermediate_layernorm"
|
19 |
+
)(v)
|
20 |
+
n = K.int_shape(x)[-3] # get spatial dim
|
21 |
+
v = SwapAxes()(v, -1, -3)
|
22 |
+
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
|
23 |
+
v = SwapAxes()(v, -1, -3)
|
24 |
+
return u * (v + 1.0)
|
25 |
+
|
26 |
+
return apply
|
27 |
+
|
28 |
+
|
29 |
+
def GridGmlpLayer(
|
30 |
+
grid_size,
|
31 |
+
use_bias: bool = True,
|
32 |
+
factor: int = 2,
|
33 |
+
dropout_rate: float = 0.0,
|
34 |
+
name: str = "grid_gmlp",
|
35 |
+
):
|
36 |
+
"""Grid gMLP layer that performs global mixing of tokens."""
|
37 |
+
|
38 |
+
def apply(x):
|
39 |
+
n, h, w, num_channels = (
|
40 |
+
K.int_shape(x)[0],
|
41 |
+
K.int_shape(x)[1],
|
42 |
+
K.int_shape(x)[2],
|
43 |
+
K.int_shape(x)[3],
|
44 |
+
)
|
45 |
+
gh, gw = grid_size
|
46 |
+
fh, fw = h // gh, w // gw
|
47 |
+
|
48 |
+
x = BlockImages()(x, patch_size=(fh, fw))
|
49 |
+
# gMLP1: Global (grid) mixing part, provides global grid communication.
|
50 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
51 |
+
y = layers.Dense(
|
52 |
+
num_channels * factor,
|
53 |
+
use_bias=use_bias,
|
54 |
+
name=f"{name}_in_project",
|
55 |
+
)(y)
|
56 |
+
y = tf.nn.gelu(y, approximate=True)
|
57 |
+
y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
|
58 |
+
y = layers.Dense(
|
59 |
+
num_channels,
|
60 |
+
use_bias=use_bias,
|
61 |
+
name=f"{name}_out_project",
|
62 |
+
)(y)
|
63 |
+
y = layers.Dropout(dropout_rate)(y)
|
64 |
+
x = x + y
|
65 |
+
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
|
66 |
+
return x
|
67 |
+
|
68 |
+
return apply
|
maxim/blocks/misc_gating.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
8 |
+
from .block_gating import BlockGmlpLayer
|
9 |
+
from .grid_gating import GridGmlpLayer
|
10 |
+
|
11 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
12 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
13 |
+
ConvT_up = functools.partial(
|
14 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
15 |
+
)
|
16 |
+
Conv_down = functools.partial(
|
17 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def ResidualSplitHeadMultiAxisGmlpLayer(
|
22 |
+
block_size,
|
23 |
+
grid_size,
|
24 |
+
block_gmlp_factor: int = 2,
|
25 |
+
grid_gmlp_factor: int = 2,
|
26 |
+
input_proj_factor: int = 2,
|
27 |
+
use_bias: bool = True,
|
28 |
+
dropout_rate: float = 0.0,
|
29 |
+
name: str = "residual_split_head_maxim",
|
30 |
+
):
|
31 |
+
"""The multi-axis gated MLP block."""
|
32 |
+
|
33 |
+
def apply(x):
|
34 |
+
shortcut = x
|
35 |
+
n, h, w, num_channels = (
|
36 |
+
K.int_shape(x)[0],
|
37 |
+
K.int_shape(x)[1],
|
38 |
+
K.int_shape(x)[2],
|
39 |
+
K.int_shape(x)[3],
|
40 |
+
)
|
41 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
|
42 |
+
|
43 |
+
x = layers.Dense(
|
44 |
+
int(num_channels) * input_proj_factor,
|
45 |
+
use_bias=use_bias,
|
46 |
+
name=f"{name}_in_project",
|
47 |
+
)(x)
|
48 |
+
x = tf.nn.gelu(x, approximate=True)
|
49 |
+
|
50 |
+
u, v = tf.split(x, 2, axis=-1)
|
51 |
+
|
52 |
+
# GridGMLPLayer
|
53 |
+
u = GridGmlpLayer(
|
54 |
+
grid_size=grid_size,
|
55 |
+
factor=grid_gmlp_factor,
|
56 |
+
use_bias=use_bias,
|
57 |
+
dropout_rate=dropout_rate,
|
58 |
+
name=f"{name}_GridGmlpLayer",
|
59 |
+
)(u)
|
60 |
+
|
61 |
+
# BlockGMLPLayer
|
62 |
+
v = BlockGmlpLayer(
|
63 |
+
block_size=block_size,
|
64 |
+
factor=block_gmlp_factor,
|
65 |
+
use_bias=use_bias,
|
66 |
+
dropout_rate=dropout_rate,
|
67 |
+
name=f"{name}_BlockGmlpLayer",
|
68 |
+
)(v)
|
69 |
+
|
70 |
+
x = tf.concat([u, v], axis=-1)
|
71 |
+
|
72 |
+
x = layers.Dense(
|
73 |
+
num_channels,
|
74 |
+
use_bias=use_bias,
|
75 |
+
name=f"{name}_out_project",
|
76 |
+
)(x)
|
77 |
+
x = layers.Dropout(dropout_rate)(x)
|
78 |
+
x = x + shortcut
|
79 |
+
return x
|
80 |
+
|
81 |
+
return apply
|
82 |
+
|
83 |
+
|
84 |
+
def GetSpatialGatingWeights(
|
85 |
+
features: int,
|
86 |
+
block_size,
|
87 |
+
grid_size,
|
88 |
+
input_proj_factor: int = 2,
|
89 |
+
dropout_rate: float = 0.0,
|
90 |
+
use_bias: bool = True,
|
91 |
+
name: str = "spatial_gating",
|
92 |
+
):
|
93 |
+
|
94 |
+
"""Get gating weights for cross-gating MLP block."""
|
95 |
+
|
96 |
+
def apply(x):
|
97 |
+
n, h, w, num_channels = (
|
98 |
+
K.int_shape(x)[0],
|
99 |
+
K.int_shape(x)[1],
|
100 |
+
K.int_shape(x)[2],
|
101 |
+
K.int_shape(x)[3],
|
102 |
+
)
|
103 |
+
|
104 |
+
# input projection
|
105 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
|
106 |
+
x = layers.Dense(
|
107 |
+
num_channels * input_proj_factor,
|
108 |
+
use_bias=use_bias,
|
109 |
+
name=f"{name}_in_project",
|
110 |
+
)(x)
|
111 |
+
x = tf.nn.gelu(x, approximate=True)
|
112 |
+
u, v = tf.split(x, 2, axis=-1)
|
113 |
+
|
114 |
+
# Get grid MLP weights
|
115 |
+
gh, gw = grid_size
|
116 |
+
fh, fw = h // gh, w // gw
|
117 |
+
u = BlockImages()(u, patch_size=(fh, fw))
|
118 |
+
dim_u = K.int_shape(u)[-3]
|
119 |
+
u = SwapAxes()(u, -1, -3)
|
120 |
+
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
|
121 |
+
u = SwapAxes()(u, -1, -3)
|
122 |
+
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
|
123 |
+
|
124 |
+
# Get Block MLP weights
|
125 |
+
fh, fw = block_size
|
126 |
+
gh, gw = h // fh, w // fw
|
127 |
+
v = BlockImages()(v, patch_size=(fh, fw))
|
128 |
+
dim_v = K.int_shape(v)[-2]
|
129 |
+
v = SwapAxes()(v, -1, -2)
|
130 |
+
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
|
131 |
+
v = SwapAxes()(v, -1, -2)
|
132 |
+
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
|
133 |
+
|
134 |
+
x = tf.concat([u, v], axis=-1)
|
135 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
|
136 |
+
x = layers.Dropout(dropout_rate)(x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
return apply
|
140 |
+
|
141 |
+
|
142 |
+
def CrossGatingBlock(
|
143 |
+
features: int,
|
144 |
+
block_size,
|
145 |
+
grid_size,
|
146 |
+
dropout_rate: float = 0.0,
|
147 |
+
input_proj_factor: int = 2,
|
148 |
+
upsample_y: bool = True,
|
149 |
+
use_bias: bool = True,
|
150 |
+
name: str = "cross_gating",
|
151 |
+
):
|
152 |
+
|
153 |
+
"""Cross-gating MLP block."""
|
154 |
+
|
155 |
+
def apply(x, y):
|
156 |
+
# Upscale Y signal, y is the gating signal.
|
157 |
+
if upsample_y:
|
158 |
+
y = ConvT_up(
|
159 |
+
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
|
160 |
+
)(y)
|
161 |
+
|
162 |
+
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
163 |
+
n, h, w, num_channels = (
|
164 |
+
K.int_shape(x)[0],
|
165 |
+
K.int_shape(x)[1],
|
166 |
+
K.int_shape(x)[2],
|
167 |
+
K.int_shape(x)[3],
|
168 |
+
)
|
169 |
+
|
170 |
+
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
|
171 |
+
|
172 |
+
shortcut_x = x
|
173 |
+
shortcut_y = y
|
174 |
+
|
175 |
+
# Get gating weights from X
|
176 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
|
177 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
|
178 |
+
x = tf.nn.gelu(x, approximate=True)
|
179 |
+
gx = GetSpatialGatingWeights(
|
180 |
+
features=num_channels,
|
181 |
+
block_size=block_size,
|
182 |
+
grid_size=grid_size,
|
183 |
+
dropout_rate=dropout_rate,
|
184 |
+
use_bias=use_bias,
|
185 |
+
name=f"{name}_SplitHeadMultiAxisGating_x",
|
186 |
+
)(x)
|
187 |
+
|
188 |
+
# Get gating weights from Y
|
189 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
|
190 |
+
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
|
191 |
+
y = tf.nn.gelu(y, approximate=True)
|
192 |
+
gy = GetSpatialGatingWeights(
|
193 |
+
features=num_channels,
|
194 |
+
block_size=block_size,
|
195 |
+
grid_size=grid_size,
|
196 |
+
dropout_rate=dropout_rate,
|
197 |
+
use_bias=use_bias,
|
198 |
+
name=f"{name}_SplitHeadMultiAxisGating_y",
|
199 |
+
)(y)
|
200 |
+
|
201 |
+
# Apply cross gating: X = X * GY, Y = Y * GX
|
202 |
+
y = y * gx
|
203 |
+
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
|
204 |
+
y = layers.Dropout(dropout_rate)(y)
|
205 |
+
y = y + shortcut_y
|
206 |
+
|
207 |
+
x = x * gy # gating x using y
|
208 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
|
209 |
+
x = layers.Dropout(dropout_rate)(x)
|
210 |
+
x = x + y + shortcut_x # get all aggregated signals
|
211 |
+
return x, y
|
212 |
+
|
213 |
+
return apply
|
maxim/blocks/others.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from ..layers import Resizing
|
8 |
+
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
|
11 |
+
|
12 |
+
def MlpBlock(
|
13 |
+
mlp_dim: int,
|
14 |
+
dropout_rate: float = 0.0,
|
15 |
+
use_bias: bool = True,
|
16 |
+
name: str = "mlp_block",
|
17 |
+
):
|
18 |
+
"""A 1-hidden-layer MLP block, applied over the last dimension."""
|
19 |
+
|
20 |
+
def apply(x):
|
21 |
+
d = K.int_shape(x)[-1]
|
22 |
+
x = layers.Dense(mlp_dim, use_bias=use_bias, name=f"{name}_Dense_0")(x)
|
23 |
+
x = tf.nn.gelu(x, approximate=True)
|
24 |
+
x = layers.Dropout(dropout_rate)(x)
|
25 |
+
x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1")(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
return apply
|
29 |
+
|
30 |
+
|
31 |
+
def UpSampleRatio(
|
32 |
+
num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
|
33 |
+
):
|
34 |
+
"""Upsample features given a ratio > 0."""
|
35 |
+
|
36 |
+
def apply(x):
|
37 |
+
n, h, w, c = (
|
38 |
+
K.int_shape(x)[0],
|
39 |
+
K.int_shape(x)[1],
|
40 |
+
K.int_shape(x)[2],
|
41 |
+
K.int_shape(x)[3],
|
42 |
+
)
|
43 |
+
|
44 |
+
# Following `jax.image.resize()`
|
45 |
+
x = Resizing(
|
46 |
+
height=int(h * ratio),
|
47 |
+
width=int(w * ratio),
|
48 |
+
method="bilinear",
|
49 |
+
antialias=True,
|
50 |
+
name=f"{name}_resizing_{K.get_uid('Resizing')}",
|
51 |
+
)(x)
|
52 |
+
|
53 |
+
x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
return apply
|
maxim/blocks/unet.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers
|
5 |
+
|
6 |
+
from .attentions import RCAB
|
7 |
+
from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer
|
8 |
+
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
11 |
+
ConvT_up = functools.partial(
|
12 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
13 |
+
)
|
14 |
+
Conv_down = functools.partial(
|
15 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def UNetEncoderBlock(
|
20 |
+
num_channels: int,
|
21 |
+
block_size,
|
22 |
+
grid_size,
|
23 |
+
num_groups: int = 1,
|
24 |
+
lrelu_slope: float = 0.2,
|
25 |
+
block_gmlp_factor: int = 2,
|
26 |
+
grid_gmlp_factor: int = 2,
|
27 |
+
input_proj_factor: int = 2,
|
28 |
+
channels_reduction: int = 4,
|
29 |
+
dropout_rate: float = 0.0,
|
30 |
+
downsample: bool = True,
|
31 |
+
use_global_mlp: bool = True,
|
32 |
+
use_bias: bool = True,
|
33 |
+
use_cross_gating: bool = False,
|
34 |
+
name: str = "unet_encoder",
|
35 |
+
):
|
36 |
+
"""Encoder block in MAXIM."""
|
37 |
+
|
38 |
+
def apply(x, skip=None, enc=None, dec=None):
|
39 |
+
if skip is not None:
|
40 |
+
x = tf.concat([x, skip], axis=-1)
|
41 |
+
|
42 |
+
# convolution-in
|
43 |
+
x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
44 |
+
shortcut_long = x
|
45 |
+
|
46 |
+
for i in range(num_groups):
|
47 |
+
if use_global_mlp:
|
48 |
+
x = ResidualSplitHeadMultiAxisGmlpLayer(
|
49 |
+
grid_size=grid_size,
|
50 |
+
block_size=block_size,
|
51 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
52 |
+
block_gmlp_factor=block_gmlp_factor,
|
53 |
+
input_proj_factor=input_proj_factor,
|
54 |
+
use_bias=use_bias,
|
55 |
+
dropout_rate=dropout_rate,
|
56 |
+
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
|
57 |
+
)(x)
|
58 |
+
x = RCAB(
|
59 |
+
num_channels=num_channels,
|
60 |
+
reduction=channels_reduction,
|
61 |
+
lrelu_slope=lrelu_slope,
|
62 |
+
use_bias=use_bias,
|
63 |
+
name=f"{name}_channel_attention_block_1{i}",
|
64 |
+
)(x)
|
65 |
+
|
66 |
+
x = x + shortcut_long
|
67 |
+
|
68 |
+
if enc is not None and dec is not None:
|
69 |
+
assert use_cross_gating
|
70 |
+
x, _ = CrossGatingBlock(
|
71 |
+
features=num_channels,
|
72 |
+
block_size=block_size,
|
73 |
+
grid_size=grid_size,
|
74 |
+
dropout_rate=dropout_rate,
|
75 |
+
input_proj_factor=input_proj_factor,
|
76 |
+
upsample_y=False,
|
77 |
+
use_bias=use_bias,
|
78 |
+
name=f"{name}_cross_gating_block",
|
79 |
+
)(x, enc + dec)
|
80 |
+
|
81 |
+
if downsample:
|
82 |
+
x_down = Conv_down(
|
83 |
+
filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
84 |
+
)(x)
|
85 |
+
return x_down, x
|
86 |
+
else:
|
87 |
+
return x
|
88 |
+
|
89 |
+
return apply
|
90 |
+
|
91 |
+
|
92 |
+
def UNetDecoderBlock(
|
93 |
+
num_channels: int,
|
94 |
+
block_size,
|
95 |
+
grid_size,
|
96 |
+
num_groups: int = 1,
|
97 |
+
lrelu_slope: float = 0.2,
|
98 |
+
block_gmlp_factor: int = 2,
|
99 |
+
grid_gmlp_factor: int = 2,
|
100 |
+
input_proj_factor: int = 2,
|
101 |
+
channels_reduction: int = 4,
|
102 |
+
dropout_rate: float = 0.0,
|
103 |
+
downsample: bool = True,
|
104 |
+
use_global_mlp: bool = True,
|
105 |
+
use_bias: bool = True,
|
106 |
+
name: str = "unet_decoder",
|
107 |
+
):
|
108 |
+
|
109 |
+
"""Decoder block in MAXIM."""
|
110 |
+
|
111 |
+
def apply(x, bridge=None):
|
112 |
+
x = ConvT_up(
|
113 |
+
filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
|
114 |
+
)(x)
|
115 |
+
x = UNetEncoderBlock(
|
116 |
+
num_channels=num_channels,
|
117 |
+
num_groups=num_groups,
|
118 |
+
lrelu_slope=lrelu_slope,
|
119 |
+
block_size=block_size,
|
120 |
+
grid_size=grid_size,
|
121 |
+
block_gmlp_factor=block_gmlp_factor,
|
122 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
123 |
+
channels_reduction=channels_reduction,
|
124 |
+
use_global_mlp=use_global_mlp,
|
125 |
+
dropout_rate=dropout_rate,
|
126 |
+
downsample=False,
|
127 |
+
use_bias=use_bias,
|
128 |
+
name=f"{name}_UNetEncoderBlock_0",
|
129 |
+
)(x, skip=bridge)
|
130 |
+
|
131 |
+
return x
|
132 |
+
|
133 |
+
return apply
|
maxim/configs.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MAXIM_CONFIGS = {
|
2 |
+
# params: 6.108515000000001 M, GFLOPS: 93.163716608
|
3 |
+
"S-1": {
|
4 |
+
"features": 32,
|
5 |
+
"depth": 3,
|
6 |
+
"num_stages": 1,
|
7 |
+
"num_groups": 2,
|
8 |
+
"num_bottleneck_blocks": 2,
|
9 |
+
"block_gmlp_factor": 2,
|
10 |
+
"grid_gmlp_factor": 2,
|
11 |
+
"input_proj_factor": 2,
|
12 |
+
"channels_reduction": 4,
|
13 |
+
"name": "s1",
|
14 |
+
},
|
15 |
+
# params: 13.35383 M, GFLOPS: 206.743273472
|
16 |
+
"S-2": {
|
17 |
+
"features": 32,
|
18 |
+
"depth": 3,
|
19 |
+
"num_stages": 2,
|
20 |
+
"num_groups": 2,
|
21 |
+
"num_bottleneck_blocks": 2,
|
22 |
+
"block_gmlp_factor": 2,
|
23 |
+
"grid_gmlp_factor": 2,
|
24 |
+
"input_proj_factor": 2,
|
25 |
+
"channels_reduction": 4,
|
26 |
+
"name": "s2",
|
27 |
+
},
|
28 |
+
# params: 20.599145 M, GFLOPS: 320.32194560000005
|
29 |
+
"S-3": {
|
30 |
+
"features": 32,
|
31 |
+
"depth": 3,
|
32 |
+
"num_stages": 3,
|
33 |
+
"num_groups": 2,
|
34 |
+
"num_bottleneck_blocks": 2,
|
35 |
+
"block_gmlp_factor": 2,
|
36 |
+
"grid_gmlp_factor": 2,
|
37 |
+
"input_proj_factor": 2,
|
38 |
+
"channels_reduction": 4,
|
39 |
+
"name": "s3",
|
40 |
+
},
|
41 |
+
# params: 19.361219000000002 M, 308.495712256 GFLOPs
|
42 |
+
"M-1": {
|
43 |
+
"features": 64,
|
44 |
+
"depth": 3,
|
45 |
+
"num_stages": 1,
|
46 |
+
"num_groups": 2,
|
47 |
+
"num_bottleneck_blocks": 2,
|
48 |
+
"block_gmlp_factor": 2,
|
49 |
+
"grid_gmlp_factor": 2,
|
50 |
+
"input_proj_factor": 2,
|
51 |
+
"channels_reduction": 4,
|
52 |
+
"name": "m1",
|
53 |
+
},
|
54 |
+
# params: 40.83911 M, 675.25541888 GFLOPs
|
55 |
+
"M-2": {
|
56 |
+
"features": 64,
|
57 |
+
"depth": 3,
|
58 |
+
"num_stages": 2,
|
59 |
+
"num_groups": 2,
|
60 |
+
"num_bottleneck_blocks": 2,
|
61 |
+
"block_gmlp_factor": 2,
|
62 |
+
"grid_gmlp_factor": 2,
|
63 |
+
"input_proj_factor": 2,
|
64 |
+
"channels_reduction": 4,
|
65 |
+
"name": "m2",
|
66 |
+
},
|
67 |
+
# params: 62.317001 M, 1042.014666752 GFLOPs
|
68 |
+
"M-3": {
|
69 |
+
"features": 64,
|
70 |
+
"depth": 3,
|
71 |
+
"num_stages": 3,
|
72 |
+
"num_groups": 2,
|
73 |
+
"num_bottleneck_blocks": 2,
|
74 |
+
"block_gmlp_factor": 2,
|
75 |
+
"grid_gmlp_factor": 2,
|
76 |
+
"input_proj_factor": 2,
|
77 |
+
"channels_reduction": 4,
|
78 |
+
"name": "m3",
|
79 |
+
},
|
80 |
+
}
|
maxim/layers.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.experimental import numpy as tnp
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
|
8 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
9 |
+
class BlockImages(layers.Layer):
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
|
13 |
+
def call(self, x, patch_size):
|
14 |
+
bs, h, w, num_channels = (
|
15 |
+
K.int_shape(x)[0],
|
16 |
+
K.int_shape(x)[1],
|
17 |
+
K.int_shape(x)[2],
|
18 |
+
K.int_shape(x)[3],
|
19 |
+
)
|
20 |
+
|
21 |
+
grid_height, grid_width = h // patch_size[0], w // patch_size[1]
|
22 |
+
|
23 |
+
x = einops.rearrange(
|
24 |
+
x,
|
25 |
+
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
|
26 |
+
gh=grid_height,
|
27 |
+
gw=grid_width,
|
28 |
+
fh=patch_size[0],
|
29 |
+
fw=patch_size[1],
|
30 |
+
)
|
31 |
+
|
32 |
+
return x
|
33 |
+
|
34 |
+
def get_config(self):
|
35 |
+
config = super().get_config().copy()
|
36 |
+
return config
|
37 |
+
|
38 |
+
|
39 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
40 |
+
class UnblockImages(layers.Layer):
|
41 |
+
def __init__(self, **kwargs):
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
|
44 |
+
def call(self, x, grid_size, patch_size):
|
45 |
+
x = einops.rearrange(
|
46 |
+
x,
|
47 |
+
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
|
48 |
+
gh=grid_size[0],
|
49 |
+
gw=grid_size[1],
|
50 |
+
fh=patch_size[0],
|
51 |
+
fw=patch_size[1],
|
52 |
+
)
|
53 |
+
|
54 |
+
return x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
return config
|
59 |
+
|
60 |
+
|
61 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
62 |
+
class SwapAxes(layers.Layer):
|
63 |
+
def __init__(self, **kwargs):
|
64 |
+
super().__init__(**kwargs)
|
65 |
+
|
66 |
+
def call(self, x, axis_one, axis_two):
|
67 |
+
return tnp.swapaxes(x, axis_one, axis_two)
|
68 |
+
|
69 |
+
def get_config(self):
|
70 |
+
config = super().get_config().copy()
|
71 |
+
return config
|
72 |
+
|
73 |
+
|
74 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
75 |
+
class Resizing(layers.Layer):
|
76 |
+
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
|
77 |
+
super().__init__(**kwargs)
|
78 |
+
self.height = height
|
79 |
+
self.width = width
|
80 |
+
self.antialias = antialias
|
81 |
+
self.method = method
|
82 |
+
|
83 |
+
def call(self, x):
|
84 |
+
return tf.image.resize(
|
85 |
+
x,
|
86 |
+
size=(self.height, self.width),
|
87 |
+
antialias=self.antialias,
|
88 |
+
method=self.method,
|
89 |
+
)
|
90 |
+
|
91 |
+
def get_config(self):
|
92 |
+
config = super().get_config().copy()
|
93 |
+
config.update(
|
94 |
+
{
|
95 |
+
"height": self.height,
|
96 |
+
"width": self.width,
|
97 |
+
"antialias": self.antialias,
|
98 |
+
"method": self.method,
|
99 |
+
}
|
100 |
+
)
|
101 |
+
return config
|
maxim/maxim.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from .blocks.attentions import SAM
|
8 |
+
from .blocks.bottleneck import BottleneckBlock
|
9 |
+
from .blocks.misc_gating import CrossGatingBlock
|
10 |
+
from .blocks.others import UpSampleRatio
|
11 |
+
from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
|
12 |
+
from .layers import Resizing
|
13 |
+
|
14 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
15 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
16 |
+
ConvT_up = functools.partial(
|
17 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
18 |
+
)
|
19 |
+
Conv_down = functools.partial(
|
20 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def MAXIM(
|
25 |
+
features: int = 64,
|
26 |
+
depth: int = 3,
|
27 |
+
num_stages: int = 2,
|
28 |
+
num_groups: int = 1,
|
29 |
+
use_bias: bool = True,
|
30 |
+
num_supervision_scales: int = 1,
|
31 |
+
lrelu_slope: float = 0.2,
|
32 |
+
use_global_mlp: bool = True,
|
33 |
+
use_cross_gating: bool = True,
|
34 |
+
high_res_stages: int = 2,
|
35 |
+
block_size_hr=(16, 16),
|
36 |
+
block_size_lr=(8, 8),
|
37 |
+
grid_size_hr=(16, 16),
|
38 |
+
grid_size_lr=(8, 8),
|
39 |
+
num_bottleneck_blocks: int = 1,
|
40 |
+
block_gmlp_factor: int = 2,
|
41 |
+
grid_gmlp_factor: int = 2,
|
42 |
+
input_proj_factor: int = 2,
|
43 |
+
channels_reduction: int = 4,
|
44 |
+
num_outputs: int = 3,
|
45 |
+
dropout_rate: float = 0.0,
|
46 |
+
):
|
47 |
+
"""The MAXIM model function with multi-stage and multi-scale supervision.
|
48 |
+
|
49 |
+
For more model details, please check the CVPR paper:
|
50 |
+
MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
features: initial hidden dimension for the input resolution.
|
54 |
+
depth: the number of downsampling depth for the model.
|
55 |
+
num_stages: how many stages to use. It will also affects the output list.
|
56 |
+
num_groups: how many blocks each stage contains.
|
57 |
+
use_bias: whether to use bias in all the conv/mlp layers.
|
58 |
+
num_supervision_scales: the number of desired supervision scales.
|
59 |
+
lrelu_slope: the negative slope parameter in leaky_relu layers.
|
60 |
+
use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
|
61 |
+
layer.
|
62 |
+
use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
|
63 |
+
skip connections and multi-stage feature fusion layers.
|
64 |
+
high_res_stages: how many stages are specificied as high-res stages. The
|
65 |
+
rest (depth - high_res_stages) are called low_res_stages.
|
66 |
+
block_size_hr: the block_size parameter for high-res stages.
|
67 |
+
block_size_lr: the block_size parameter for low-res stages.
|
68 |
+
grid_size_hr: the grid_size parameter for high-res stages.
|
69 |
+
grid_size_lr: the grid_size parameter for low-res stages.
|
70 |
+
num_bottleneck_blocks: how many bottleneck blocks.
|
71 |
+
block_gmlp_factor: the input projection factor for block_gMLP layers.
|
72 |
+
grid_gmlp_factor: the input projection factor for grid_gMLP layers.
|
73 |
+
input_proj_factor: the input projection factor for the MAB block.
|
74 |
+
channels_reduction: the channel reduction factor for SE layer.
|
75 |
+
num_outputs: the output channels.
|
76 |
+
dropout_rate: Dropout rate.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
The output contains a list of arrays consisting of multi-stage multi-scale
|
80 |
+
outputs. For example, if num_stages = num_supervision_scales = 3 (the
|
81 |
+
model used in the paper), the output specs are: outputs =
|
82 |
+
[[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
|
83 |
+
[output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
|
84 |
+
[output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
|
85 |
+
The final output can be retrieved by outputs[-1][-1].
|
86 |
+
"""
|
87 |
+
|
88 |
+
def apply(x):
|
89 |
+
n, h, w, c = (
|
90 |
+
K.int_shape(x)[0],
|
91 |
+
K.int_shape(x)[1],
|
92 |
+
K.int_shape(x)[2],
|
93 |
+
K.int_shape(x)[3],
|
94 |
+
) # input image shape
|
95 |
+
|
96 |
+
shortcuts = []
|
97 |
+
shortcuts.append(x)
|
98 |
+
|
99 |
+
# Get multi-scale input images
|
100 |
+
for i in range(1, num_supervision_scales):
|
101 |
+
resizing_layer = Resizing(
|
102 |
+
height=h // (2 ** i),
|
103 |
+
width=w // (2 ** i),
|
104 |
+
method="nearest",
|
105 |
+
antialias=True, # Following `jax.image.resize()`.
|
106 |
+
name=f"initial_resizing_{K.get_uid('Resizing')}",
|
107 |
+
)
|
108 |
+
shortcuts.append(resizing_layer(x))
|
109 |
+
|
110 |
+
# store outputs from all stages and all scales
|
111 |
+
# Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs
|
112 |
+
# [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs
|
113 |
+
outputs_all = []
|
114 |
+
sam_features, encs_prev, decs_prev = [], [], []
|
115 |
+
|
116 |
+
for idx_stage in range(num_stages):
|
117 |
+
# Input convolution, get multi-scale input features
|
118 |
+
x_scales = []
|
119 |
+
for i in range(num_supervision_scales):
|
120 |
+
x_scale = Conv3x3(
|
121 |
+
filters=(2 ** i) * features,
|
122 |
+
use_bias=use_bias,
|
123 |
+
name=f"stage_{idx_stage}_input_conv_{i}",
|
124 |
+
)(shortcuts[i])
|
125 |
+
|
126 |
+
# If later stages, fuse input features with SAM features from prev stage
|
127 |
+
if idx_stage > 0:
|
128 |
+
# use larger blocksize at high-res stages
|
129 |
+
if use_cross_gating:
|
130 |
+
block_size = (
|
131 |
+
block_size_hr if i < high_res_stages else block_size_lr
|
132 |
+
)
|
133 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
134 |
+
x_scale, _ = CrossGatingBlock(
|
135 |
+
features=(2 ** i) * features,
|
136 |
+
block_size=block_size,
|
137 |
+
grid_size=grid_size,
|
138 |
+
dropout_rate=dropout_rate,
|
139 |
+
input_proj_factor=input_proj_factor,
|
140 |
+
upsample_y=False,
|
141 |
+
use_bias=use_bias,
|
142 |
+
name=f"stage_{idx_stage}_input_fuse_sam_{i}",
|
143 |
+
)(x_scale, sam_features.pop())
|
144 |
+
else:
|
145 |
+
x_scale = Conv1x1(
|
146 |
+
filters=(2 ** i) * features,
|
147 |
+
use_bias=use_bias,
|
148 |
+
name=f"stage_{idx_stage}_input_catconv_{i}",
|
149 |
+
)(tf.concat([x_scale, sam_features.pop()], axis=-1))
|
150 |
+
|
151 |
+
x_scales.append(x_scale)
|
152 |
+
|
153 |
+
# start encoder blocks
|
154 |
+
encs = []
|
155 |
+
x = x_scales[0] # First full-scale input feature
|
156 |
+
|
157 |
+
for i in range(depth): # 0, 1, 2
|
158 |
+
# use larger blocksize at high-res stages, vice versa.
|
159 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
160 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
161 |
+
use_cross_gating_layer = True if idx_stage > 0 else False
|
162 |
+
|
163 |
+
# Multi-scale input if multi-scale supervision
|
164 |
+
x_scale = x_scales[i] if i < num_supervision_scales else None
|
165 |
+
|
166 |
+
# UNet Encoder block
|
167 |
+
enc_prev = encs_prev.pop() if idx_stage > 0 else None
|
168 |
+
dec_prev = decs_prev.pop() if idx_stage > 0 else None
|
169 |
+
|
170 |
+
x, bridge = UNetEncoderBlock(
|
171 |
+
num_channels=(2 ** i) * features,
|
172 |
+
num_groups=num_groups,
|
173 |
+
downsample=True,
|
174 |
+
lrelu_slope=lrelu_slope,
|
175 |
+
block_size=block_size,
|
176 |
+
grid_size=grid_size,
|
177 |
+
block_gmlp_factor=block_gmlp_factor,
|
178 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
179 |
+
input_proj_factor=input_proj_factor,
|
180 |
+
channels_reduction=channels_reduction,
|
181 |
+
use_global_mlp=use_global_mlp,
|
182 |
+
dropout_rate=dropout_rate,
|
183 |
+
use_bias=use_bias,
|
184 |
+
use_cross_gating=use_cross_gating_layer,
|
185 |
+
name=f"stage_{idx_stage}_encoder_block_{i}",
|
186 |
+
)(x, skip=x_scale, enc=enc_prev, dec=dec_prev)
|
187 |
+
|
188 |
+
# Cache skip signals
|
189 |
+
encs.append(bridge)
|
190 |
+
|
191 |
+
# Global MLP bottleneck blocks
|
192 |
+
for i in range(num_bottleneck_blocks):
|
193 |
+
x = BottleneckBlock(
|
194 |
+
block_size=block_size_lr,
|
195 |
+
grid_size=block_size_lr,
|
196 |
+
features=(2 ** (depth - 1)) * features,
|
197 |
+
num_groups=num_groups,
|
198 |
+
block_gmlp_factor=block_gmlp_factor,
|
199 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
200 |
+
input_proj_factor=input_proj_factor,
|
201 |
+
dropout_rate=dropout_rate,
|
202 |
+
use_bias=use_bias,
|
203 |
+
channels_reduction=channels_reduction,
|
204 |
+
name=f"stage_{idx_stage}_global_block_{i}",
|
205 |
+
)(x)
|
206 |
+
# cache global feature for cross-gating
|
207 |
+
global_feature = x
|
208 |
+
|
209 |
+
# start cross gating. Use multi-scale feature fusion
|
210 |
+
skip_features = []
|
211 |
+
for i in reversed(range(depth)): # 2, 1, 0
|
212 |
+
# use larger blocksize at high-res stages
|
213 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
214 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
215 |
+
|
216 |
+
# get additional multi-scale signals
|
217 |
+
signal = tf.concat(
|
218 |
+
[
|
219 |
+
UpSampleRatio(
|
220 |
+
num_channels=(2 ** i) * features,
|
221 |
+
ratio=2 ** (j - i),
|
222 |
+
use_bias=use_bias,
|
223 |
+
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
|
224 |
+
)(enc)
|
225 |
+
for j, enc in enumerate(encs)
|
226 |
+
],
|
227 |
+
axis=-1,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Use cross-gating to cross modulate features
|
231 |
+
if use_cross_gating:
|
232 |
+
skips, global_feature = CrossGatingBlock(
|
233 |
+
features=(2 ** i) * features,
|
234 |
+
block_size=block_size,
|
235 |
+
grid_size=grid_size,
|
236 |
+
input_proj_factor=input_proj_factor,
|
237 |
+
dropout_rate=dropout_rate,
|
238 |
+
upsample_y=True,
|
239 |
+
use_bias=use_bias,
|
240 |
+
name=f"stage_{idx_stage}_cross_gating_block_{i}",
|
241 |
+
)(signal, global_feature)
|
242 |
+
else:
|
243 |
+
skips = Conv1x1(
|
244 |
+
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
|
245 |
+
)(signal)
|
246 |
+
skips = Conv3x3(
|
247 |
+
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
|
248 |
+
)(skips)
|
249 |
+
|
250 |
+
skip_features.append(skips)
|
251 |
+
|
252 |
+
# start decoder. Multi-scale feature fusion of cross-gated features
|
253 |
+
outputs, decs, sam_features = [], [], []
|
254 |
+
for i in reversed(range(depth)):
|
255 |
+
# use larger blocksize at high-res stages
|
256 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
257 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
258 |
+
|
259 |
+
# get multi-scale skip signals from cross-gating block
|
260 |
+
signal = tf.concat(
|
261 |
+
[
|
262 |
+
UpSampleRatio(
|
263 |
+
num_channels=(2 ** i) * features,
|
264 |
+
ratio=2 ** (depth - j - 1 - i),
|
265 |
+
use_bias=use_bias,
|
266 |
+
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
|
267 |
+
)(skip)
|
268 |
+
for j, skip in enumerate(skip_features)
|
269 |
+
],
|
270 |
+
axis=-1,
|
271 |
+
)
|
272 |
+
|
273 |
+
# Decoder block
|
274 |
+
x = UNetDecoderBlock(
|
275 |
+
num_channels=(2 ** i) * features,
|
276 |
+
num_groups=num_groups,
|
277 |
+
lrelu_slope=lrelu_slope,
|
278 |
+
block_size=block_size,
|
279 |
+
grid_size=grid_size,
|
280 |
+
block_gmlp_factor=block_gmlp_factor,
|
281 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
282 |
+
input_proj_factor=input_proj_factor,
|
283 |
+
channels_reduction=channels_reduction,
|
284 |
+
use_global_mlp=use_global_mlp,
|
285 |
+
dropout_rate=dropout_rate,
|
286 |
+
use_bias=use_bias,
|
287 |
+
name=f"stage_{idx_stage}_decoder_block_{i}",
|
288 |
+
)(x, bridge=signal)
|
289 |
+
|
290 |
+
# Cache decoder features for later-stage's usage
|
291 |
+
decs.append(x)
|
292 |
+
|
293 |
+
# output conv, if not final stage, use supervised-attention-block.
|
294 |
+
if i < num_supervision_scales:
|
295 |
+
if idx_stage < num_stages - 1: # not last stage, apply SAM
|
296 |
+
sam, output = SAM(
|
297 |
+
num_channels=(2 ** i) * features,
|
298 |
+
output_channels=num_outputs,
|
299 |
+
use_bias=use_bias,
|
300 |
+
name=f"stage_{idx_stage}_supervised_attention_module_{i}",
|
301 |
+
)(x, shortcuts[i])
|
302 |
+
outputs.append(output)
|
303 |
+
sam_features.append(sam)
|
304 |
+
else: # Last stage, apply output convolutions
|
305 |
+
output = Conv3x3(
|
306 |
+
num_outputs,
|
307 |
+
use_bias=use_bias,
|
308 |
+
name=f"stage_{idx_stage}_output_conv_{i}",
|
309 |
+
)(x)
|
310 |
+
output = output + shortcuts[i]
|
311 |
+
outputs.append(output)
|
312 |
+
# Cache encoder and decoder features for later-stage's usage
|
313 |
+
encs_prev = encs[::-1]
|
314 |
+
decs_prev = decs
|
315 |
+
|
316 |
+
# Store outputs
|
317 |
+
outputs_all.append(outputs)
|
318 |
+
return outputs_all
|
319 |
+
|
320 |
+
return apply
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow==2.10.0
|
2 |
+
einops
|
3 |
+
numpy
|