import tensorflow as tf from tensorflow.keras import layers class WindowAttention(layers.Layer): def __init__( self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0, return_attention_scores=False, **kwargs, ): super().__init__(**kwargs) self.dim = dim self.window_size = window_size self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 self.return_attention_scores = return_attention_scores self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias) self.dropout = layers.Dropout(dropout_rate) self.proj = layers.Dense(dim) def build(self, input_shape): self.relative_position_bias_table = self.add_weight( shape=( (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads, ), initializer="zeros", trainable=True, name="relative_position_bias_table", ) self.relative_position_index = self.get_relative_position_index( self.window_size[0], self.window_size[1] ) super().build(input_shape) def get_relative_position_index(self, window_height, window_width): x_x, y_y = tf.meshgrid(range(window_height), range(window_width)) coords = tf.stack([y_y, x_x], axis=0) coords_flatten = tf.reshape(coords, [2, -1]) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1) y_y = relative_coords[:, :, 1] + window_width - 1 relative_coords = tf.stack([x_x, y_y], axis=-1) return tf.reduce_sum(relative_coords, axis=-1) def call(self, x, mask=None): _, size, channels = x.shape head_dim = channels // self.num_heads x_qkv = self.qkv(x) x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim)) x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4)) q, k, v = x_qkv[0], x_qkv[1], x_qkv[2] q = q * self.scale k = tf.transpose(k, perm=(0, 1, 3, 2)) attn = q @ k relative_position_bias = tf.gather( self.relative_position_bias_table, self.relative_position_index, axis=0, ) relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1]) attn = attn + tf.expand_dims(relative_position_bias, axis=0) if mask is not None: nW = mask.get_shape()[0] mask_float = tf.cast( tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32 ) attn = ( tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size)) + mask_float ) attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size)) attn = tf.nn.softmax(attn, axis=-1) else: attn = tf.nn.softmax(attn, axis=-1) attn = self.dropout(attn) x_qkv = attn @ v x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3)) x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels)) x_qkv = self.proj(x_qkv) x_qkv = self.dropout(x_qkv) if self.return_attention_scores: return x_qkv, attn else: return x_qkv def get_config(self): config = super().get_config() config.update( { "dim": self.dim, "window_size": self.window_size, "num_heads": self.num_heads, "scale": self.scale, } ) return config