HybridModel-GradCAM / layers /swin_blocks.py
innat
init
0f09377
try:
from jax import numpy as jnp
except ModuleNotFoundError:
# jax doesn't support windows os yet.
import numpy as jnp
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from layers.window_attention import WindowAttention
from utils.drop_path import DropPath
from utils.swin_window import window_partition
from utils.swin_window import window_reverse
class SwinTransformer(layers.Layer):
def __init__(
self,
dim,
num_patch,
num_heads,
window_size=7,
shift_size=0,
num_mlp=1024,
qkv_bias=True,
dropout_rate=0.0,
**kwargs,
):
super(SwinTransformer, self).__init__(**kwargs)
self.dim = dim
self.num_patch = num_patch
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.num_mlp = num_mlp
self.norm1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = WindowAttention(
dim,
window_size=(self.window_size, self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout_rate=dropout_rate,
)
self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
self.norm2 = layers.LayerNormalization(epsilon=1e-5)
self.mlp = keras.Sequential(
[
layers.Dense(num_mlp),
layers.Activation(keras.activations.gelu),
layers.Dropout(dropout_rate),
layers.Dense(dim),
layers.Dropout(dropout_rate),
]
)
if min(self.num_patch) < self.window_size:
self.shift_size = 0
self.window_size = min(self.num_patch)
def build(self, input_shape):
if self.shift_size == 0:
self.attn_mask = None
else:
height, width = self.num_patch
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
mask_array = jnp.zeros((1, height, width, 1))
count = 0
for h in h_slices:
for w in w_slices:
mask_array[:, h, w, :] = count
count += 1
mask_array = tf.convert_to_tensor(mask_array)
# mask array to windows
mask_windows = window_partition(mask_array, self.window_size)
mask_windows = tf.reshape(
mask_windows, shape=[-1, self.window_size * self.window_size]
)
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
mask_windows, axis=2
)
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)
def call(self, x):
height, width = self.num_patch
_, num_patches_before, channels = x.shape
x_skip = x
x = self.norm1(x)
x = tf.reshape(x, shape=(-1, height, width, channels))
if self.shift_size > 0:
shifted_x = tf.roll(
x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
)
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size)
x_windows = tf.reshape(
x_windows, shape=(-1, self.window_size * self.window_size, channels)
)
attn_windows = self.attn(x_windows, mask=self.attn_mask)
attn_windows = tf.reshape(
attn_windows, shape=(-1, self.window_size, self.window_size, channels)
)
shifted_x = window_reverse(
attn_windows, self.window_size, height, width, channels
)
if self.shift_size > 0:
x = tf.roll(
shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
)
else:
x = shifted_x
x = tf.reshape(x, shape=(-1, height * width, channels))
x = self.drop_path(x)
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
x_skip = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
return x