File size: 4,218 Bytes
4a0cabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3126b1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf

from .attention import WindowAttention
from .drop import DropPath
from .window import window_partition, window_reverse
from .feature import Mlp, FeatExtract


@tf.keras.utils.register_keras_serializable(package="gcvit")
class GCViTBlock(tf.keras.layers.Layer):
    def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,

                 attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
        super().__init__(**kwargs)
        self.window_size = window_size
        self.num_heads = num_heads
        self.global_query = global_query
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.drop = drop
        self.attn_drop = attn_drop
        self.path_drop = path_drop
        self.act_layer = act_layer
        self.layer_scale = layer_scale

    def build(self, input_shape):
        B, H, W, C = input_shape[0]
        self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
        self.attn = WindowAttention(window_size=self.window_size, 
                                   num_heads=self.num_heads,
                                   global_query=self.global_query,
                                   qkv_bias=self.qkv_bias, 
                                   qk_scale=self.qk_scale, 
                                   attn_dropout=self.attn_drop, 
                                   proj_dropout=self.drop,
                                   name='attn')
        self.drop_path1 = DropPath(self.path_drop)
        self.drop_path2 = DropPath(self.path_drop)
        self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
        self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
        if self.layer_scale is not None:
            self.gamma1 = self.add_weight(
                'gamma1',
                shape=[C],
                initializer=tf.keras.initializers.Constant(self.layer_scale),
                trainable=True,
                dtype=self.dtype)
            self.gamma2 = self.add_weight(
                'gamma2',
                shape=[C],
                initializer=tf.keras.initializers.Constant(self.layer_scale),
                trainable=True,
                dtype=self.dtype)
        else:
            self.gamma1 = 1.0
            self.gamma2 = 1.0
        self.num_windows = int(H // self.window_size) * int(W // self.window_size)
        super().build(input_shape)

    def call(self, inputs, **kwargs):
        if self.global_query:
            inputs, q_global = inputs
        else:
            inputs = inputs[0]
        B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
        x = self.norm1(inputs)
        # create windows and concat them in batch axis
        x = window_partition(x, self.window_size)  # (B_, win_h, win_w, C)
        # flatten patch
        x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C])  # (B_, N, C) => (batch*num_win, num_token, feature)
        # attention
        if self.global_query:
            x = self.attn([x, q_global])
        else:
            x = self.attn([x])
        # reverse window partition
        x = window_reverse(x, self.window_size, H, W, C)
        # FFN
        x = inputs + self.drop_path1(x * self.gamma1)
        x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
        return x

    def get_config(self):
        config = super().get_config()
        config.update({
            'window_size': self.window_size,
            'num_heads': self.num_heads,
            'global_query': self.global_query,
            'mlp_ratio': self.mlp_ratio,
            'qkv_bias': self.qkv_bias,
            'qk_scale': self.qk_scale,
            'drop': self.drop,
            'attn_drop': self.attn_drop,
            'path_drop': self.path_drop,
            'act_layer': self.act_layer,
            'layer_scale': self.layer_scale,
            'num_windows': self.num_windows,
        })
        return config