File size: 4,748 Bytes
3126b1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import tensorflow as tf
import tensorflow_addons as tfa
@tf.keras.utils.register_keras_serializable(package="gcvit")
class WindowAttention(tf.keras.layers.Layer):
def __init__(self, window_size, num_heads, global_query, qkv_bias=True, qk_scale=None, attn_dropout=0., proj_dropout=0.,
**kwargs):
super().__init__(**kwargs)
window_size = (window_size,window_size)
self.window_size = window_size
self.num_heads = num_heads
self.global_query = global_query
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
def build(self, input_shape):
dim = input_shape[0][-1]
head_dim = dim // self.num_heads
self.scale = self.qk_scale or head_dim ** -0.5
self.qkv_size = 3 - int(self.global_query)
self.qkv = tf.keras.layers.Dense(dim * self.qkv_size, use_bias=self.qkv_bias, name='qkv')
self.relative_position_bias_table = self.add_weight(
'relative_position_bias_table',
shape=[(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads],
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
trainable=True,
dtype=self.dtype)
self.attn_drop = tf.keras.layers.Dropout(self.attn_dropout, name='attn_drop')
self.proj = tf.keras.layers.Dense(dim, name='proj')
self.proj_drop = tf.keras.layers.Dropout(self.proj_dropout, name='proj_drop')
self.softmax = tf.keras.layers.Activation('softmax', name='softmax')
self.relative_position_index = self.get_relative_position_index()
super().build(input_shape)
def get_relative_position_index(self):
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing='ij'), axis=0)
coords_flatten = tf.reshape(coords, [2, -1])
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])
relative_coords_xx = (relative_coords[:, :, 0] + self.window_size[0] - 1)
relative_coords_yy = (relative_coords[:, :, 1] + self.window_size[1] - 1)
relative_coords_xx = relative_coords_xx * (2 * self.window_size[1] - 1)
relative_position_index = (relative_coords_xx + relative_coords_yy)
return relative_position_index
def call(self, inputs, **kwargs):
if self.global_query:
inputs, q_global = inputs
B = tf.shape(q_global)[0] # B, N, C
else:
inputs = inputs[0]
B_, N, C = tf.unstack(tf.shape(inputs), num=3) # B*num_window, num_tokens, channels
qkv = self.qkv(inputs)
qkv = tf.reshape(qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads])
qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
if self.global_query:
k, v = tf.unstack(qkv, num=2, axis=0) # for unknown shame num=None will throw error
q_global = tf.repeat(q_global, repeats=B_//B, axis=0) # num_windows = B_//B => q_global same for all windows in a img
q = tf.reshape(q_global, shape=[B_, N, self.num_heads, C // self.num_heads])
q = tf.transpose(q, perm=[0, 2, 1, 3])
else:
q, k, v = tf.unstack(qkv, num=3, axis=0)
q = q * self.scale
attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2]))
relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(self.relative_position_index, shape=[-1]))
relative_position_bias = tf.reshape(relative_position_bias,
shape=[self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1])
relative_position_bias = tf.transpose(relative_position_bias, perm=[2, 0, 1])
attn = attn + relative_position_bias[tf.newaxis,]
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = tf.transpose((attn @ v), perm=[0, 2, 1, 3]) # B_, num_tokens, num_heads, channels_per_head
x = tf.reshape(x, shape=[B_, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
def get_config(self):
config = super().get_config()
config.update({
'window_size': self.window_size,
'num_heads': self.num_heads,
'global_query': self.global_query,
'qkv_bias': self.qkv_bias,
'qk_scale': self.qk_scale,
'attn_dropout': self.attn_dropout,
'proj_dropout': self.proj_dropout
})
return config |