Spaces:
Runtime error
Runtime error
File size: 4,684 Bytes
0f09377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
try:
from jax import numpy as jnp
except ModuleNotFoundError:
# jax doesn't support windows os yet.
import numpy as jnp
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from layers.window_attention import WindowAttention
from utils.drop_path import DropPath
from utils.swin_window import window_partition
from utils.swin_window import window_reverse
class SwinTransformer(layers.Layer):
def __init__(
self,
dim,
num_patch,
num_heads,
window_size=7,
shift_size=0,
num_mlp=1024,
qkv_bias=True,
dropout_rate=0.0,
**kwargs,
):
super(SwinTransformer, self).__init__(**kwargs)
self.dim = dim
self.num_patch = num_patch
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.num_mlp = num_mlp
self.norm1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = WindowAttention(
dim,
window_size=(self.window_size, self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout_rate=dropout_rate,
)
self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
self.norm2 = layers.LayerNormalization(epsilon=1e-5)
self.mlp = keras.Sequential(
[
layers.Dense(num_mlp),
layers.Activation(keras.activations.gelu),
layers.Dropout(dropout_rate),
layers.Dense(dim),
layers.Dropout(dropout_rate),
]
)
if min(self.num_patch) < self.window_size:
self.shift_size = 0
self.window_size = min(self.num_patch)
def build(self, input_shape):
if self.shift_size == 0:
self.attn_mask = None
else:
height, width = self.num_patch
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
mask_array = jnp.zeros((1, height, width, 1))
count = 0
for h in h_slices:
for w in w_slices:
mask_array[:, h, w, :] = count
count += 1
mask_array = tf.convert_to_tensor(mask_array)
# mask array to windows
mask_windows = window_partition(mask_array, self.window_size)
mask_windows = tf.reshape(
mask_windows, shape=[-1, self.window_size * self.window_size]
)
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
mask_windows, axis=2
)
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)
def call(self, x):
height, width = self.num_patch
_, num_patches_before, channels = x.shape
x_skip = x
x = self.norm1(x)
x = tf.reshape(x, shape=(-1, height, width, channels))
if self.shift_size > 0:
shifted_x = tf.roll(
x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
)
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size)
x_windows = tf.reshape(
x_windows, shape=(-1, self.window_size * self.window_size, channels)
)
attn_windows = self.attn(x_windows, mask=self.attn_mask)
attn_windows = tf.reshape(
attn_windows, shape=(-1, self.window_size, self.window_size, channels)
)
shifted_x = window_reverse(
attn_windows, self.window_size, height, width, channels
)
if self.shift_size > 0:
x = tf.roll(
shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
)
else:
x = shifted_x
x = tf.reshape(x, shape=(-1, height * width, channels))
x = self.drop_path(x)
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
x_skip = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
return x
|