try: from jax import numpy as jnp except ModuleNotFoundError: # jax doesn't support windows os yet. import numpy as jnp import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from layers.window_attention import WindowAttention from utils.drop_path import DropPath from utils.swin_window import window_partition from utils.swin_window import window_reverse class SwinTransformer(layers.Layer): def __init__( self, dim, num_patch, num_heads, window_size=7, shift_size=0, num_mlp=1024, qkv_bias=True, dropout_rate=0.0, **kwargs, ): super(SwinTransformer, self).__init__(**kwargs) self.dim = dim self.num_patch = num_patch self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.num_mlp = num_mlp self.norm1 = layers.LayerNormalization(epsilon=1e-5) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, dropout_rate=dropout_rate, ) self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity self.norm2 = layers.LayerNormalization(epsilon=1e-5) self.mlp = keras.Sequential( [ layers.Dense(num_mlp), layers.Activation(keras.activations.gelu), layers.Dropout(dropout_rate), layers.Dense(dim), layers.Dropout(dropout_rate), ] ) if min(self.num_patch) < self.window_size: self.shift_size = 0 self.window_size = min(self.num_patch) def build(self, input_shape): if self.shift_size == 0: self.attn_mask = None else: height, width = self.num_patch h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) mask_array = jnp.zeros((1, height, width, 1)) count = 0 for h in h_slices: for w in w_slices: mask_array[:, h, w, :] = count count += 1 mask_array = tf.convert_to_tensor(mask_array) # mask array to windows mask_windows = window_partition(mask_array, self.window_size) mask_windows = tf.reshape( mask_windows, shape=[-1, self.window_size * self.window_size] ) attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims( mask_windows, axis=2 ) attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask) self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False) def call(self, x): height, width = self.num_patch _, num_patches_before, channels = x.shape x_skip = x x = self.norm1(x) x = tf.reshape(x, shape=(-1, height, width, channels)) if self.shift_size > 0: shifted_x = tf.roll( x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2] ) else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size) x_windows = tf.reshape( x_windows, shape=(-1, self.window_size * self.window_size, channels) ) attn_windows = self.attn(x_windows, mask=self.attn_mask) attn_windows = tf.reshape( attn_windows, shape=(-1, self.window_size, self.window_size, channels) ) shifted_x = window_reverse( attn_windows, self.window_size, height, width, channels ) if self.shift_size > 0: x = tf.roll( shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2] ) else: x = shifted_x x = tf.reshape(x, shape=(-1, height * width, channels)) x = self.drop_path(x) x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) x_skip = x x = self.norm2(x) x = self.mlp(x) x = self.drop_path(x) x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) return x