File size: 7,634 Bytes
3126b1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import tensorflow as tf
import tensorflow_addons as tfa
H_AXIS = -3
W_AXIS = -2
@tf.keras.utils.register_keras_serializable(package="gcvit")
class Mlp(tf.keras.layers.Layer):
def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
super().__init__(**kwargs)
self.hidden_features = hidden_features
self.out_features = out_features
self.act_layer = act_layer
self.dropout = dropout
def build(self, input_shape):
self.in_features = input_shape[-1]
self.hidden_features = self.hidden_features or self.in_features
self.out_features = self.out_features or self.in_features
self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
self.act = tf.keras.layers.Activation(self.act_layer, name="act")
self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
super().build(input_shape)
def call(self, inputs, **kwargs):
x = self.fc1(inputs)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"hidden_features":self.hidden_features,
"out_features":self.out_features,
"act_layer":self.act_layer,
"dropout":self.dropout
})
return config
@tf.keras.utils.register_keras_serializable(package="gcvit")
class SE(tf.keras.layers.Layer):
def __init__(self, oup=None, expansion=0.25, **kwargs):
super().__init__(**kwargs)
self.expansion = expansion
self.oup = oup
def build(self, input_shape):
inp = input_shape[-1]
self.oup = self.oup or inp
self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
self.fc = [
tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
tf.keras.layers.Activation('gelu', name='fc/1'),
tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
tf.keras.layers.Activation('sigmoid', name='fc/3')
]
super().build(input_shape)
def call(self, inputs, **kwargs):
b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
x = tf.reshape(self.avg_pool(inputs), (b, c))
for layer in self.fc:
x = layer(x)
x = tf.reshape(x, (b, 1, 1, c))
return x*inputs
def get_config(self):
config = super().get_config()
config.update({
'expansion': self.expansion,
'oup': self.oup,
})
return config
@tf.keras.utils.register_keras_serializable(package="gcvit")
class ReduceSize(tf.keras.layers.Layer):
def __init__(self, keep_dim=False, **kwargs):
super().__init__(**kwargs)
self.keep_dim = keep_dim
def build(self, input_shape):
dim = input_shape[-1]
dim_out = dim if self.keep_dim else 2*dim
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
self.conv = [
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
tf.keras.layers.Activation('gelu', name='conv/1'),
SE(name='conv/2'),
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
]
self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
name='reduction')
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
super().build(input_shape)
def call(self, inputs, **kwargs):
x = self.norm1(inputs)
xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
for layer in self.conv:
xr = layer(xr)
x = x + xr
x = self.pad2(x)
x = self.reduction(x)
x = self.norm2(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"keep_dim":self.keep_dim,
})
return config
@tf.keras.utils.register_keras_serializable(package="gcvit")
class FeatExtract(tf.keras.layers.Layer):
def __init__(self, keep_dim=False, **kwargs):
super().__init__(**kwargs)
self.keep_dim = keep_dim
def build(self, input_shape):
dim = input_shape[-1]
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
self.conv = [
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
tf.keras.layers.Activation('gelu', name='conv/1'),
SE(name='conv/2'),
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
]
if not self.keep_dim:
self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
# else:
# self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
super().build(input_shape)
def call(self, inputs, **kwargs):
x = inputs
xr = self.pad1(x)
for layer in self.conv:
xr = layer(xr)
x = x + xr # if pad had weights it would've thrown error with .save_weights()
if not self.keep_dim:
x = self.pad2(x)
x = self.pool(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"keep_dim":self.keep_dim,
})
return config
@tf.keras.utils.register_keras_serializable(package="gcvit")
class Resizing(tf.keras.layers.Layer):
def __init__(self,
height,
width,
interpolation='bilinear',
**kwargs):
self.height = height
self.width = width
self.interpolation = interpolation
super().__init__(**kwargs)
def call(self, inputs):
# tf.image.resize will always output float32 and operate more efficiently on
# float32 unless interpolation is nearest, in which case ouput type matches
# input type.
if self.interpolation == 'nearest':
input_dtype = self.compute_dtype
else:
input_dtype = tf.float32
inputs = tf.cast(inputs, dtype=input_dtype)
size = [self.height, self.width]
outputs = tf.image.resize(
inputs,
size=size,
method=self.interpolation)
return tf.cast(outputs, self.compute_dtype)
def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
input_shape[H_AXIS] = self.height
input_shape[W_AXIS] = self.width
return tf.TensorShape(input_shape)
def get_config(self):
config = super().get_config()
config.update({
'height': self.height,
'width': self.width,
'interpolation': self.interpolation,
})
return config |