Spaces:
Runtime error
Runtime error
File size: 4,090 Bytes
3f9d71f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import functools
import tensorflow as tf
from tensorflow.keras import layers
from .others import MlpBlock
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
def CALayer(
num_channels: int,
reduction: int = 4,
use_bias: bool = True,
name: str = "channel_attention",
):
"""Squeeze-and-excitation block for channel attention.
ref: https://arxiv.org/abs/1709.01507
"""
def apply(x):
# 2D global average pooling
y = layers.GlobalAvgPool2D(keepdims=True)(x)
# Squeeze (in Squeeze-Excitation)
y = Conv1x1(
filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
)(y)
y = tf.nn.relu(y)
# Excitation (in Squeeze-Excitation)
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
y = tf.nn.sigmoid(y)
return x * y
return apply
def RCAB(
num_channels: int,
reduction: int = 4,
lrelu_slope: float = 0.2,
use_bias: bool = True,
name: str = "residual_ca",
):
"""Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
def apply(x):
shortcut = x
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
x = CALayer(
num_channels=num_channels,
reduction=reduction,
use_bias=use_bias,
name=f"{name}_channel_attention",
)(x)
return x + shortcut
return apply
def RDCAB(
num_channels: int,
reduction: int = 16,
use_bias: bool = True,
dropout_rate: float = 0.0,
name: str = "rdcab",
):
"""Residual dense channel attention block. Used in Bottlenecks."""
def apply(x):
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
y = MlpBlock(
mlp_dim=num_channels,
dropout_rate=dropout_rate,
use_bias=use_bias,
name=f"{name}_channel_mixing",
)(y)
y = CALayer(
num_channels=num_channels,
reduction=reduction,
use_bias=use_bias,
name=f"{name}_channel_attention",
)(y)
x = x + y
return x
return apply
def SAM(
num_channels: int,
output_channels: int = 3,
use_bias: bool = True,
name: str = "sam",
):
"""Supervised attention module for multi-stage training.
Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
"""
def apply(x, x_image):
"""Apply the SAM module to the input and num_channels.
Args:
x: the output num_channels from UNet decoder with shape (h, w, c)
x_image: the input image with shape (h, w, 3)
Returns:
A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
next stage, and (image) is the output restored image at current stage.
"""
# Get num_channels
x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
# Output restored image X_s
if output_channels == 3:
image = (
Conv3x3(
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
)(x)
+ x_image
)
else:
image = Conv3x3(
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
)(x)
# Get attention maps for num_channels
x2 = tf.nn.sigmoid(
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
)
# Get attended feature maps
x1 = x1 * x2
# Residual connection
x1 = x1 + x
return x1, image
return apply
|