HybridModel-GradCAM / layers /window_attention.py
innat
init
0f09377
raw history blame
No virus
3.82 kB
import tensorflow as tf
from tensorflow.keras import layers
class WindowAttention(layers.Layer):
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
dropout_rate=0.0,
return_attention_scores=False,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.return_attention_scores = return_attention_scores
self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
self.dropout = layers.Dropout(dropout_rate)
self.proj = layers.Dense(dim)
def build(self, input_shape):
self.relative_position_bias_table = self.add_weight(
shape=(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
self.num_heads,
),
initializer="zeros",
trainable=True,
name="relative_position_bias_table",
)
self.relative_position_index = self.get_relative_position_index(
self.window_size[0], self.window_size[1]
)
super().build(input_shape)
def get_relative_position_index(self, window_height, window_width):
x_x, y_y = tf.meshgrid(range(window_height), range(window_width))
coords = tf.stack([y_y, x_x], axis=0)
coords_flatten = tf.reshape(coords, [2, -1])
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])
x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1)
y_y = relative_coords[:, :, 1] + window_width - 1
relative_coords = tf.stack([x_x, y_y], axis=-1)
return tf.reduce_sum(relative_coords, axis=-1)
def call(self, x, mask=None):
_, size, channels = x.shape
head_dim = channels // self.num_heads
x_qkv = self.qkv(x)
x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
q = q * self.scale
k = tf.transpose(k, perm=(0, 1, 3, 2))
attn = q @ k
relative_position_bias = tf.gather(
self.relative_position_bias_table,
self.relative_position_index,
axis=0,
)
relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])
attn = attn + tf.expand_dims(relative_position_bias, axis=0)
if mask is not None:
nW = mask.get_shape()[0]
mask_float = tf.cast(
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
)
attn = (
tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
+ mask_float
)
attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
attn = tf.nn.softmax(attn, axis=-1)
else:
attn = tf.nn.softmax(attn, axis=-1)
attn = self.dropout(attn)
x_qkv = attn @ v
x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
x_qkv = self.proj(x_qkv)
x_qkv = self.dropout(x_qkv)
if self.return_attention_scores:
return x_qkv, attn
else:
return x_qkv
def get_config(self):
config = super().get_config()
config.update(
{
"dim": self.dim,
"window_size": self.window_size,
"num_heads": self.num_heads,
"scale": self.scale,
}
)
return config