awsaf49 commited on
Commit
31b7180
2 Parent(s): 28f10f6 21fa3d7

Merge branch 'main' of https://huggingface.co/spaces/awsaf49/gcvit-tf into main

Browse files
.gitattributes CHANGED
@@ -1,31 +1,31 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
- *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.model filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
- *.onnx filter=lfs diff=lfs merge=lfs -text
15
- *.ot filter=lfs diff=lfs merge=lfs -text
16
- *.parquet filter=lfs diff=lfs merge=lfs -text
17
- *.pickle filter=lfs diff=lfs merge=lfs -text
18
- *.pkl filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pt filter=lfs diff=lfs merge=lfs -text
21
- *.pth filter=lfs diff=lfs merge=lfs -text
22
- *.rar filter=lfs diff=lfs merge=lfs -text
23
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
- *.tar.* filter=lfs diff=lfs merge=lfs -text
25
- *.tflite filter=lfs diff=lfs merge=lfs -text
26
- *.tgz filter=lfs diff=lfs merge=lfs -text
27
- *.wasm filter=lfs diff=lfs merge=lfs -text
28
- *.xz filter=lfs diff=lfs merge=lfs -text
29
- *.zip filter=lfs diff=lfs merge=lfs -text
30
- *.zstandard filter=lfs diff=lfs merge=lfs -text
31
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Gcvit Tf
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Gcvit Tf
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
example/Standing_jaguar.jpg ADDED
gcvit/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (228 Bytes)
 
gcvit/layers/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- from .window import window_partition, window_reverse
2
- from .attention import WindowAttention
3
- from .drop import DropPath, Identity
4
- from .embedding import PatchEmbed
5
- from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
6
- from .block import GCViTBlock
7
- from .level import GCViTLayer
 
1
+ from .window import window_partition, window_reverse
2
+ from .attention import WindowAttention
3
+ from .drop import DropPath, Identity
4
+ from .embedding import Stem
5
+ from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
6
+ from .block import GCViTBlock
7
+ from .level import GCViTLevel
gcvit/layers/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (530 Bytes)
 
gcvit/layers/__pycache__/attention.cpython-38.pyc DELETED
Binary file (3.58 kB)
 
gcvit/layers/__pycache__/block.cpython-38.pyc DELETED
Binary file (3 kB)
 
gcvit/layers/__pycache__/drop.cpython-38.pyc DELETED
Binary file (1.8 kB)
 
gcvit/layers/__pycache__/embedding.cpython-38.pyc DELETED
Binary file (1.39 kB)
 
gcvit/layers/__pycache__/feature.cpython-38.pyc DELETED
Binary file (5.5 kB)
 
gcvit/layers/__pycache__/level.cpython-38.pyc DELETED
Binary file (3 kB)
 
gcvit/layers/__pycache__/window.cpython-38.pyc DELETED
Binary file (801 Bytes)
 
gcvit/layers/block.py CHANGED
@@ -1,99 +1,99 @@
1
- import tensorflow as tf
2
-
3
- from .attention import WindowAttention
4
- from .drop import DropPath
5
- from .window import window_partition, window_reverse
6
- from .feature import Mlp, FeatExtract
7
-
8
-
9
- @tf.keras.utils.register_keras_serializable(package="gcvit")
10
- class GCViTBlock(tf.keras.layers.Layer):
11
- def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
12
- attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
13
- super().__init__(**kwargs)
14
- self.window_size = window_size
15
- self.num_heads = num_heads
16
- self.global_query = global_query
17
- self.mlp_ratio = mlp_ratio
18
- self.qkv_bias = qkv_bias
19
- self.qk_scale = qk_scale
20
- self.drop = drop
21
- self.attn_drop = attn_drop
22
- self.path_drop = path_drop
23
- self.act_layer = act_layer
24
- self.layer_scale = layer_scale
25
-
26
- def build(self, input_shape):
27
- B, H, W, C = input_shape[0]
28
- self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
29
- self.attn = WindowAttention(window_size=self.window_size,
30
- num_heads=self.num_heads,
31
- global_query=self.global_query,
32
- qkv_bias=self.qkv_bias,
33
- qk_scale=self.qk_scale,
34
- attn_dropout=self.attn_drop,
35
- proj_dropout=self.drop,
36
- name='attn')
37
- self.drop_path1 = DropPath(self.path_drop)
38
- self.drop_path2 = DropPath(self.path_drop)
39
- self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
40
- self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
41
- if self.layer_scale is not None:
42
- self.gamma1 = self.add_weight(
43
- 'gamma1',
44
- shape=[C],
45
- initializer=tf.keras.initializers.Constant(self.layer_scale),
46
- trainable=True,
47
- dtype=self.dtype)
48
- self.gamma2 = self.add_weight(
49
- 'gamma2',
50
- shape=[C],
51
- initializer=tf.keras.initializers.Constant(self.layer_scale),
52
- trainable=True,
53
- dtype=self.dtype)
54
- else:
55
- self.gamma1 = 1.0
56
- self.gamma2 = 1.0
57
- self.num_windows = int(H // self.window_size) * int(W // self.window_size)
58
- super().build(input_shape)
59
-
60
- def call(self, inputs, **kwargs):
61
- if self.global_query:
62
- inputs, q_global = inputs
63
- else:
64
- inputs = inputs[0]
65
- B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
66
- x = self.norm1(inputs)
67
- # create windows and concat them in batch axis
68
- x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
69
- # flatten patch
70
- x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
71
- # attention
72
- if self.global_query:
73
- x = self.attn([x, q_global])
74
- else:
75
- x = self.attn([x])
76
- # reverse window partition
77
- x = window_reverse(x, self.window_size, H, W, C)
78
- # FFN
79
- x = inputs + self.drop_path1(x * self.gamma1)
80
- x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
81
- return x
82
-
83
- def get_config(self):
84
- config = super().get_config()
85
- config.update({
86
- 'window_size': self.window_size,
87
- 'num_heads': self.num_heads,
88
- 'global_query': self.global_query,
89
- 'mlp_ratio': self.mlp_ratio,
90
- 'qkv_bias': self.qkv_bias,
91
- 'qk_scale': self.qk_scale,
92
- 'drop': self.drop,
93
- 'attn_drop': self.attn_drop,
94
- 'path_drop': self.path_drop,
95
- 'act_layer': self.act_layer,
96
- 'layer_scale': self.layer_scale,
97
- 'num_windows': self.num_windows,
98
- })
99
  return config
 
1
+ import tensorflow as tf
2
+
3
+ from .attention import WindowAttention
4
+ from .drop import DropPath
5
+ from .window import window_partition, window_reverse
6
+ from .feature import Mlp, FeatExtract
7
+
8
+
9
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
10
+ class GCViTBlock(tf.keras.layers.Layer):
11
+ def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
12
+ attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.window_size = window_size
15
+ self.num_heads = num_heads
16
+ self.global_query = global_query
17
+ self.mlp_ratio = mlp_ratio
18
+ self.qkv_bias = qkv_bias
19
+ self.qk_scale = qk_scale
20
+ self.drop = drop
21
+ self.attn_drop = attn_drop
22
+ self.path_drop = path_drop
23
+ self.act_layer = act_layer
24
+ self.layer_scale = layer_scale
25
+
26
+ def build(self, input_shape):
27
+ B, H, W, C = input_shape[0]
28
+ self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
29
+ self.attn = WindowAttention(window_size=self.window_size,
30
+ num_heads=self.num_heads,
31
+ global_query=self.global_query,
32
+ qkv_bias=self.qkv_bias,
33
+ qk_scale=self.qk_scale,
34
+ attn_dropout=self.attn_drop,
35
+ proj_dropout=self.drop,
36
+ name='attn')
37
+ self.drop_path1 = DropPath(self.path_drop)
38
+ self.drop_path2 = DropPath(self.path_drop)
39
+ self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
40
+ self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
41
+ if self.layer_scale is not None:
42
+ self.gamma1 = self.add_weight(
43
+ 'gamma1',
44
+ shape=[C],
45
+ initializer=tf.keras.initializers.Constant(self.layer_scale),
46
+ trainable=True,
47
+ dtype=self.dtype)
48
+ self.gamma2 = self.add_weight(
49
+ 'gamma2',
50
+ shape=[C],
51
+ initializer=tf.keras.initializers.Constant(self.layer_scale),
52
+ trainable=True,
53
+ dtype=self.dtype)
54
+ else:
55
+ self.gamma1 = 1.0
56
+ self.gamma2 = 1.0
57
+ self.num_windows = int(H // self.window_size) * int(W // self.window_size)
58
+ super().build(input_shape)
59
+
60
+ def call(self, inputs, **kwargs):
61
+ if self.global_query:
62
+ inputs, q_global = inputs
63
+ else:
64
+ inputs = inputs[0]
65
+ B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
66
+ x = self.norm1(inputs)
67
+ # create windows and concat them in batch axis
68
+ x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
69
+ # flatten patch
70
+ x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
71
+ # attention
72
+ if self.global_query:
73
+ x = self.attn([x, q_global])
74
+ else:
75
+ x = self.attn([x])
76
+ # reverse window partition
77
+ x = window_reverse(x, self.window_size, H, W, C)
78
+ # FFN
79
+ x = inputs + self.drop_path1(x * self.gamma1)
80
+ x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
81
+ return x
82
+
83
+ def get_config(self):
84
+ config = super().get_config()
85
+ config.update({
86
+ 'window_size': self.window_size,
87
+ 'num_heads': self.num_heads,
88
+ 'global_query': self.global_query,
89
+ 'mlp_ratio': self.mlp_ratio,
90
+ 'qkv_bias': self.qkv_bias,
91
+ 'qk_scale': self.qk_scale,
92
+ 'drop': self.drop,
93
+ 'attn_drop': self.attn_drop,
94
+ 'path_drop': self.path_drop,
95
+ 'act_layer': self.act_layer,
96
+ 'layer_scale': self.layer_scale,
97
+ 'num_windows': self.num_windows,
98
+ })
99
  return config
gcvit/layers/embedding.py CHANGED
@@ -4,7 +4,7 @@ from .feature import ReduceSize
4
 
5
 
6
  @tf.keras.utils.register_keras_serializable(package="gcvit")
7
- class PatchEmbed(tf.keras.layers.Layer):
8
  def __init__(self, dim, **kwargs):
9
  super().__init__(**kwargs)
10
  self.dim = dim
 
4
 
5
 
6
  @tf.keras.utils.register_keras_serializable(package="gcvit")
7
+ class Stem(tf.keras.layers.Layer):
8
  def __init__(self, dim, **kwargs):
9
  super().__init__(**kwargs)
10
  self.dim = dim
gcvit/layers/feature.py CHANGED
@@ -1,202 +1,255 @@
1
- import tensorflow as tf
2
- import tensorflow_addons as tfa
3
-
4
- H_AXIS = -3
5
- W_AXIS = -2
6
-
7
- @tf.keras.utils.register_keras_serializable(package="gcvit")
8
- class Mlp(tf.keras.layers.Layer):
9
- def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
10
- super().__init__(**kwargs)
11
- self.hidden_features = hidden_features
12
- self.out_features = out_features
13
- self.act_layer = act_layer
14
- self.dropout = dropout
15
-
16
- def build(self, input_shape):
17
- self.in_features = input_shape[-1]
18
- self.hidden_features = self.hidden_features or self.in_features
19
- self.out_features = self.out_features or self.in_features
20
- self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
21
- self.act = tf.keras.layers.Activation(self.act_layer, name="act")
22
- self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
23
- self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
24
- self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
25
- super().build(input_shape)
26
-
27
- def call(self, inputs, **kwargs):
28
- x = self.fc1(inputs)
29
- x = self.act(x)
30
- x = self.drop1(x)
31
- x = self.fc2(x)
32
- x = self.drop2(x)
33
- return x
34
-
35
- def get_config(self):
36
- config = super().get_config()
37
- config.update({
38
- "hidden_features":self.hidden_features,
39
- "out_features":self.out_features,
40
- "act_layer":self.act_layer,
41
- "dropout":self.dropout
42
- })
43
- return config
44
-
45
- @tf.keras.utils.register_keras_serializable(package="gcvit")
46
- class SE(tf.keras.layers.Layer):
47
- def __init__(self, oup=None, expansion=0.25, **kwargs):
48
- super().__init__(**kwargs)
49
- self.expansion = expansion
50
- self.oup = oup
51
-
52
- def build(self, input_shape):
53
- inp = input_shape[-1]
54
- self.oup = self.oup or inp
55
- self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
56
- self.fc = [
57
- tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
58
- tf.keras.layers.Activation('gelu', name='fc/1'),
59
- tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
60
- tf.keras.layers.Activation('sigmoid', name='fc/3')
61
- ]
62
- super().build(input_shape)
63
-
64
- def call(self, inputs, **kwargs):
65
- b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
66
- x = tf.reshape(self.avg_pool(inputs), (b, c))
67
- for layer in self.fc:
68
- x = layer(x)
69
- x = tf.reshape(x, (b, 1, 1, c))
70
- return x*inputs
71
-
72
- def get_config(self):
73
- config = super().get_config()
74
- config.update({
75
- 'expansion': self.expansion,
76
- 'oup': self.oup,
77
- })
78
- return config
79
-
80
- @tf.keras.utils.register_keras_serializable(package="gcvit")
81
- class ReduceSize(tf.keras.layers.Layer):
82
- def __init__(self, keep_dim=False, **kwargs):
83
- super().__init__(**kwargs)
84
- self.keep_dim = keep_dim
85
-
86
- def build(self, input_shape):
87
- dim = input_shape[-1]
88
- dim_out = dim if self.keep_dim else 2*dim
89
- self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
90
- self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
91
- self.conv = [
92
- tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
93
- tf.keras.layers.Activation('gelu', name='conv/1'),
94
- SE(name='conv/2'),
95
- tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
96
- ]
97
- self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
98
- name='reduction')
99
- self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
100
- self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
101
- super().build(input_shape)
102
-
103
- def call(self, inputs, **kwargs):
104
- x = self.norm1(inputs)
105
- xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
106
- for layer in self.conv:
107
- xr = layer(xr)
108
- x = x + xr
109
- x = self.pad2(x)
110
- x = self.reduction(x)
111
- x = self.norm2(x)
112
- return x
113
-
114
- def get_config(self):
115
- config = super().get_config()
116
- config.update({
117
- "keep_dim":self.keep_dim,
118
- })
119
- return config
120
-
121
- @tf.keras.utils.register_keras_serializable(package="gcvit")
122
- class FeatExtract(tf.keras.layers.Layer):
123
- def __init__(self, keep_dim=False, **kwargs):
124
- super().__init__(**kwargs)
125
- self.keep_dim = keep_dim
126
-
127
- def build(self, input_shape):
128
- dim = input_shape[-1]
129
- self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
130
- self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
131
- self.conv = [
132
- tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
133
- tf.keras.layers.Activation('gelu', name='conv/1'),
134
- SE(name='conv/2'),
135
- tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
136
- ]
137
- if not self.keep_dim:
138
- self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
139
- # else:
140
- # self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
141
- super().build(input_shape)
142
-
143
- def call(self, inputs, **kwargs):
144
- x = inputs
145
- xr = self.pad1(x)
146
- for layer in self.conv:
147
- xr = layer(xr)
148
- x = x + xr # if pad had weights it would've thrown error with .save_weights()
149
- if not self.keep_dim:
150
- x = self.pad2(x)
151
- x = self.pool(x)
152
- return x
153
-
154
- def get_config(self):
155
- config = super().get_config()
156
- config.update({
157
- "keep_dim":self.keep_dim,
158
- })
159
- return config
160
-
161
- @tf.keras.utils.register_keras_serializable(package="gcvit")
162
- class Resizing(tf.keras.layers.Layer):
163
- def __init__(self,
164
- height,
165
- width,
166
- interpolation='bilinear',
167
- **kwargs):
168
- self.height = height
169
- self.width = width
170
- self.interpolation = interpolation
171
- super().__init__(**kwargs)
172
-
173
- def call(self, inputs):
174
- # tf.image.resize will always output float32 and operate more efficiently on
175
- # float32 unless interpolation is nearest, in which case ouput type matches
176
- # input type.
177
- if self.interpolation == 'nearest':
178
- input_dtype = self.compute_dtype
179
- else:
180
- input_dtype = tf.float32
181
- inputs = tf.cast(inputs, dtype=input_dtype)
182
- size = [self.height, self.width]
183
- outputs = tf.image.resize(
184
- inputs,
185
- size=size,
186
- method=self.interpolation)
187
- return tf.cast(outputs, self.compute_dtype)
188
-
189
- def compute_output_shape(self, input_shape):
190
- input_shape = tf.TensorShape(input_shape).as_list()
191
- input_shape[H_AXIS] = self.height
192
- input_shape[W_AXIS] = self.width
193
- return tf.TensorShape(input_shape)
194
-
195
- def get_config(self):
196
- config = super().get_config()
197
- config.update({
198
- 'height': self.height,
199
- 'width': self.width,
200
- 'interpolation': self.interpolation,
201
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  return config
 
1
+ import tensorflow as tf
2
+ import tensorflow_addons as tfa
3
+
4
+ H_AXIS = -3
5
+ W_AXIS = -2
6
+
7
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
8
+ class Mlp(tf.keras.layers.Layer):
9
+ def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.hidden_features = hidden_features
12
+ self.out_features = out_features
13
+ self.act_layer = act_layer
14
+ self.dropout = dropout
15
+
16
+ def build(self, input_shape):
17
+ self.in_features = input_shape[-1]
18
+ self.hidden_features = self.hidden_features or self.in_features
19
+ self.out_features = self.out_features or self.in_features
20
+ self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
21
+ self.act = tf.keras.layers.Activation(self.act_layer, name="act")
22
+ self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
23
+ self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
24
+ self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
25
+ super().build(input_shape)
26
+
27
+ def call(self, inputs, **kwargs):
28
+ x = self.fc1(inputs)
29
+ x = self.act(x)
30
+ x = self.drop1(x)
31
+ x = self.fc2(x)
32
+ x = self.drop2(x)
33
+ return x
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update({
38
+ "hidden_features":self.hidden_features,
39
+ "out_features":self.out_features,
40
+ "act_layer":self.act_layer,
41
+ "dropout":self.dropout
42
+ })
43
+ return config
44
+
45
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
46
+ class SE(tf.keras.layers.Layer):
47
+ def __init__(self, oup=None, expansion=0.25, **kwargs):
48
+ super().__init__(**kwargs)
49
+ self.expansion = expansion
50
+ self.oup = oup
51
+
52
+ def build(self, input_shape):
53
+ inp = input_shape[-1]
54
+ self.oup = self.oup or inp
55
+ self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
56
+ self.fc = [
57
+ tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
58
+ tf.keras.layers.Activation('gelu', name='fc/1'),
59
+ tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
60
+ tf.keras.layers.Activation('sigmoid', name='fc/3')
61
+ ]
62
+ super().build(input_shape)
63
+
64
+ def call(self, inputs, **kwargs):
65
+ b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
66
+ x = tf.reshape(self.avg_pool(inputs), (b, c))
67
+ for layer in self.fc:
68
+ x = layer(x)
69
+ x = tf.reshape(x, (b, 1, 1, c))
70
+ return x*inputs
71
+
72
+ def get_config(self):
73
+ config = super().get_config()
74
+ config.update({
75
+ 'expansion': self.expansion,
76
+ 'oup': self.oup,
77
+ })
78
+ return config
79
+
80
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
81
+ class ReduceSize(tf.keras.layers.Layer):
82
+ def __init__(self, keep_dim=False, **kwargs):
83
+ super().__init__(**kwargs)
84
+ self.keep_dim = keep_dim
85
+
86
+ def build(self, input_shape):
87
+ dim = input_shape[-1]
88
+ dim_out = dim if self.keep_dim else 2*dim
89
+ self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
90
+ self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
91
+ self.conv = [
92
+ tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
93
+ tf.keras.layers.Activation('gelu', name='conv/1'),
94
+ SE(name='conv/2'),
95
+ tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
96
+ ]
97
+ self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
98
+ name='reduction')
99
+ self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
100
+ self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
101
+ super().build(input_shape)
102
+
103
+ def call(self, inputs, **kwargs):
104
+ x = self.norm1(inputs)
105
+ xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
106
+ for layer in self.conv:
107
+ xr = layer(xr)
108
+ x = x + xr
109
+ x = self.pad2(x)
110
+ x = self.reduction(x)
111
+ x = self.norm2(x)
112
+ return x
113
+
114
+ def get_config(self):
115
+ config = super().get_config()
116
+ config.update({
117
+ "keep_dim":self.keep_dim,
118
+ })
119
+ return config
120
+
121
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
122
+ class FeatExtract(tf.keras.layers.Layer):
123
+ def __init__(self, keep_dim=False, **kwargs):
124
+ super().__init__(**kwargs)
125
+ self.keep_dim = keep_dim
126
+
127
+ def build(self, input_shape):
128
+ dim = input_shape[-1]
129
+ self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
130
+ self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
131
+ self.conv = [
132
+ tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
133
+ tf.keras.layers.Activation('gelu', name='conv/1'),
134
+ SE(name='conv/2'),
135
+ tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
136
+ ]
137
+ if not self.keep_dim:
138
+ self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
139
+ # else:
140
+ # self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
141
+ super().build(input_shape)
142
+
143
+ def call(self, inputs, **kwargs):
144
+ x = inputs
145
+ xr = self.pad1(x)
146
+ for layer in self.conv:
147
+ xr = layer(xr)
148
+ x = x + xr # if pad had weights it would've thrown error with .save_weights()
149
+ if not self.keep_dim:
150
+ x = self.pad2(x)
151
+ x = self.pool(x)
152
+ return x
153
+
154
+ def get_config(self):
155
+ config = super().get_config()
156
+ config.update({
157
+ "keep_dim":self.keep_dim,
158
+ })
159
+ return config
160
+
161
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
162
+ class GlobalQueryGen(tf.keras.layers.Layer):
163
+ """
164
+ Global query generator based on: "Hatamizadeh et al.,
165
+ Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
166
+ """
167
+ def __init__(self, keep_dims=False, **kwargs):
168
+ super().__init__(**kwargs)
169
+ self.keep_dims = keep_dims
170
+
171
+ def build(self, input_shape):
172
+ self.to_q_global = [FeatExtract(keep_dim, name=f'to_q_global/{i}') \
173
+ for i, keep_dim in enumerate(self.keep_dims)]
174
+ super().build(input_shape)
175
+
176
+ def call(self, inputs, **kwargs):
177
+ x = inputs
178
+ for layer in self.to_q_global:
179
+ x = layer(x)
180
+ return x
181
+
182
+ def get_config(self):
183
+ config = super().get_config()
184
+ config.update({
185
+ "keep_dims":self.keep_dims,
186
+ })
187
+ return config
188
+
189
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
190
+ class Resizing(tf.keras.layers.Layer):
191
+ def __init__(self,
192
+ height,
193
+ width,
194
+ interpolation='bilinear',
195
+ **kwargs):
196
+ self.height = height
197
+ self.width = width
198
+ self.interpolation = interpolation
199
+ super().__init__(**kwargs)
200
+
201
+ def call(self, inputs):
202
+ # tf.image.resize will always output float32 and operate more efficiently on
203
+ # float32 unless interpolation is nearest, in which case ouput type matches
204
+ # input type.
205
+ if self.interpolation == 'nearest':
206
+ input_dtype = self.compute_dtype
207
+ else:
208
+ input_dtype = tf.float32
209
+ inputs = tf.cast(inputs, dtype=input_dtype)
210
+ size = [self.height, self.width]
211
+ outputs = tf.image.resize(
212
+ inputs,
213
+ size=size,
214
+ method=self.interpolation)
215
+ return tf.cast(outputs, self.compute_dtype)
216
+
217
+ def compute_output_shape(self, input_shape):
218
+ input_shape = tf.TensorShape(input_shape).as_list()
219
+ input_shape[H_AXIS] = self.height
220
+ input_shape[W_AXIS] = self.width
221
+ return tf.TensorShape(input_shape)
222
+
223
+ def get_config(self):
224
+ config = super().get_config()
225
+ config.update({
226
+ 'height': self.height,
227
+ 'width': self.width,
228
+ 'interpolation': self.interpolation,
229
+ })
230
+ return config
231
+
232
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
233
+ class FitWindow(tf.keras.layers.Layer):
234
+ "Pad feature to fit window"
235
+ def __init__(self, window_size, **kwargs):
236
+ super().__init__(**kwargs)
237
+ self.window_size = window_size
238
+
239
+ def call(self, inputs):
240
+ B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
241
+ # pad to multiple of window_size
242
+ h_pad = (self.window_size - H % self.window_size) % self.window_size
243
+ w_pad = (self.window_size - W % self.window_size) % self.window_size
244
+ x = tf.pad(inputs, [[0, 0],
245
+ [h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
246
+ [w_pad//2, (w_pad//2 + w_pad%2)],
247
+ [0, 0]])
248
+ return x
249
+
250
+ def get_config(self):
251
+ config = super().get_config()
252
+ config.update({
253
+ 'window_size': self.window_size,
254
+ })
255
  return config
gcvit/layers/level.py CHANGED
@@ -1,93 +1,85 @@
1
- import tensorflow as tf
2
-
3
- from .feature import FeatExtract, ReduceSize, Resizing
4
- from .block import GCViTBlock
5
-
6
- @tf.keras.utils.register_keras_serializable(package="gcvit")
7
- class GCViTLayer(tf.keras.layers.Layer):
8
- def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
9
- qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
10
- super().__init__(**kwargs)
11
- self.depth = depth
12
- self.num_heads = num_heads
13
- self.window_size = window_size
14
- self.keep_dims = keep_dims
15
- self.downsample = downsample
16
- self.mlp_ratio = mlp_ratio
17
- self.qkv_bias = qkv_bias
18
- self.qk_scale = qk_scale
19
- self.drop = drop
20
- self.attn_drop = attn_drop
21
- self.path_drop = path_drop
22
- self.layer_scale = layer_scale
23
- self.resize_query = resize_query
24
-
25
- def build(self, input_shape):
26
- path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
27
- self.blocks = [
28
- GCViTBlock(window_size=self.window_size,
29
- num_heads=self.num_heads,
30
- global_query=bool(i % 2),
31
- mlp_ratio=self.mlp_ratio,
32
- qkv_bias=self.qkv_bias,
33
- qk_scale=self.qk_scale,
34
- drop=self.drop,
35
- attn_drop=self.attn_drop,
36
- path_drop=path_drop[i],
37
- layer_scale=self.layer_scale,
38
- name=f'blocks/{i}')
39
- for i in range(self.depth)]
40
- self.down = ReduceSize(keep_dim=False, name='downsample')
41
- self.to_q_global = [
42
- FeatExtract(keep_dim, name=f'to_q_global/{i}')
43
- for i, keep_dim in enumerate(self.keep_dims)]
44
- self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
45
- super().build(input_shape)
46
-
47
- def call(self, inputs, **kwargs):
48
- height, width = tf.unstack(tf.shape(inputs)[1:3], num=2)
49
- # pad to multiple of window_size
50
- h_pad = (self.window_size - height % self.window_size) % self.window_size
51
- w_pad = (self.window_size - width % self.window_size) % self.window_size
52
- x = tf.pad(inputs, [[0, 0],
53
- [h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
54
- [w_pad//2, (w_pad//2 + w_pad%2)],
55
- [0, 0]])
56
- # generate global query
57
- q_global = x # (B, H, W, C)
58
- for layer in self.to_q_global:
59
- q_global = layer(q_global) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
60
- # resize query to fit key-value, but result in poor score with official weights?
61
- if self.resize_query:
62
- q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
63
- # feature_map -> windows -> window_attention -> feature_map
64
- for i, blk in enumerate(self.blocks):
65
- if i % 2:
66
- x = blk([x, q_global])
67
- else:
68
- x = blk([x])
69
- x = x[:, :height, :width, :] # https://github.com/NVlabs/GCVit/issues/9
70
- # set shape for [B, ?, ?, C]
71
- x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
72
- # downsample
73
- if self.downsample:
74
- x = self.down(x)
75
- return x
76
-
77
- def get_config(self):
78
- config = super().get_config()
79
- config.update({
80
- 'depth': self.depth,
81
- 'num_heads': self.num_heads,
82
- 'window_size': self.window_size,
83
- 'keep_dims': self.keep_dims,
84
- 'downsample': self.downsample,
85
- 'mlp_ratio': self.mlp_ratio,
86
- 'qkv_bias': self.qkv_bias,
87
- 'qk_scale': self.qk_scale,
88
- 'drop': self.drop,
89
- 'attn_drop': self.attn_drop,
90
- 'path_drop': self.path_drop,
91
- 'layer_scale': self.layer_scale
92
- })
93
  return config
 
1
+ import tensorflow as tf
2
+
3
+ from .feature import GlobalQueryGen, ReduceSize, Resizing, FitWindow
4
+ from .block import GCViTBlock
5
+
6
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
7
+ class GCViTLevel(tf.keras.layers.Layer):
8
+ def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
9
+ qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.depth = depth
12
+ self.num_heads = num_heads
13
+ self.window_size = window_size
14
+ self.keep_dims = keep_dims
15
+ self.downsample = downsample
16
+ self.mlp_ratio = mlp_ratio
17
+ self.qkv_bias = qkv_bias
18
+ self.qk_scale = qk_scale
19
+ self.drop = drop
20
+ self.attn_drop = attn_drop
21
+ self.path_drop = path_drop
22
+ self.layer_scale = layer_scale
23
+ self.resize_query = resize_query
24
+
25
+ def build(self, input_shape):
26
+ path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
27
+ self.blocks = [
28
+ GCViTBlock(window_size=self.window_size,
29
+ num_heads=self.num_heads,
30
+ global_query=bool(i % 2),
31
+ mlp_ratio=self.mlp_ratio,
32
+ qkv_bias=self.qkv_bias,
33
+ qk_scale=self.qk_scale,
34
+ drop=self.drop,
35
+ attn_drop=self.attn_drop,
36
+ path_drop=path_drop[i],
37
+ layer_scale=self.layer_scale,
38
+ name=f'blocks/{i}')
39
+ for i in range(self.depth)]
40
+ self.down = ReduceSize(keep_dim=False, name='downsample')
41
+ self.q_global_gen = GlobalQueryGen(self.keep_dims, name='q_global_gen')
42
+ self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
43
+ self.fit_window = FitWindow(self.window_size)
44
+ super().build(input_shape)
45
+
46
+ def call(self, inputs, **kwargs):
47
+ H, W = tf.unstack(tf.shape(inputs)[1:3], num=2)
48
+ # pad to fit window_size
49
+ x = self.fit_window(inputs)
50
+ # generate global query
51
+ q_global = self.q_global_gen(x) # (B, H, W, C) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
52
+ # resize query to fit key-value, but result in poor score with official weights?
53
+ if self.resize_query:
54
+ q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
55
+ # feature_map -> windows -> window_attention -> feature_map
56
+ for i, blk in enumerate(self.blocks):
57
+ if i % 2:
58
+ x = blk([x, q_global])
59
+ else:
60
+ x = blk([x])
61
+ x = x[:, :H, :W, :] # https://github.com/NVlabs/GCVit/issues/9
62
+ # set shape for [B, ?, ?, C]
63
+ x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
64
+ # downsample
65
+ if self.downsample:
66
+ x = self.down(x)
67
+ return x
68
+
69
+ def get_config(self):
70
+ config = super().get_config()
71
+ config.update({
72
+ 'depth': self.depth,
73
+ 'num_heads': self.num_heads,
74
+ 'window_size': self.window_size,
75
+ 'keep_dims': self.keep_dims,
76
+ 'downsample': self.downsample,
77
+ 'mlp_ratio': self.mlp_ratio,
78
+ 'qkv_bias': self.qkv_bias,
79
+ 'qk_scale': self.qk_scale,
80
+ 'drop': self.drop,
81
+ 'attn_drop': self.attn_drop,
82
+ 'path_drop': self.path_drop,
83
+ 'layer_scale': self.layer_scale
84
+ })
 
 
 
 
 
 
 
 
85
  return config
gcvit/models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (234 Bytes)
 
gcvit/models/__pycache__/gcvit.cpython-38.pyc DELETED
Binary file (4.08 kB)
 
gcvit/models/gcvit.py CHANGED
@@ -2,11 +2,25 @@ import numpy as np
2
  import tensorflow as tf
3
 
4
  from ..layers import Stem, GCViTLevel, Identity
 
5
 
6
 
 
7
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
8
  TAG = 'v1.1.1'
9
  NAME2CONFIG = {
 
 
 
 
 
 
 
 
 
 
 
 
10
  'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
11
  'dim': 64,
12
  'depths': (2, 2, 6, 2),
@@ -24,6 +38,8 @@ NAME2CONFIG = {
24
  'depths': (3, 4, 19, 5),
25
  'num_heads': (2, 4, 8, 16),
26
  'mlp_ratio': 3.,
 
 
27
  'path_drop': 0.2,},
28
  'gcvit_small': {'window_size': (7, 7, 14, 7),
29
  'dim': 96,
@@ -70,6 +86,7 @@ class GCViT(tf.keras.Model):
70
  self.num_classes = num_classes
71
  self.head_act = head_act
72
 
 
73
  self.patch_embed = Stem(dim=dim, name='patch_embed')
74
  self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
75
  path_drops = np.linspace(0., path_drop, sum(depths))
@@ -77,6 +94,7 @@ class GCViT(tf.keras.Model):
77
  self.levels = []
78
  for i in range(len(depths)):
79
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
 
80
  level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
81
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
82
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
@@ -92,13 +110,17 @@ class GCViT(tf.keras.Model):
92
  else:
93
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
94
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
 
95
 
 
96
  def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
97
  self.num_classes = num_classes
98
  if global_pool is not None:
99
  self.global_pool = global_pool
100
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
101
  super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
 
 
102
 
103
  def forward_features(self, inputs):
104
  x = self.patch_embed(inputs)
@@ -115,6 +137,7 @@ class GCViT(tf.keras.Model):
115
  x = self.pool(x)
116
  if not pre_logits:
117
  x = self.head(x)
 
118
  return x
119
 
120
  def call(self, inputs, **kwargs):
@@ -130,6 +153,9 @@ class GCViT(tf.keras.Model):
130
  def summary(self, input_shape=(224, 224, 3)):
131
  return self.build_graph(input_shape).summary()
132
 
 
 
 
133
  # load standard models
134
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
135
  name = 'gcvit_xxtiny'
@@ -153,28 +179,57 @@ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
153
  model.load_weights(ckpt_path)
154
  return model
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
157
  name = 'gcvit_tiny'
158
  config = NAME2CONFIG[name]
159
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
160
  model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
161
  model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
 
 
162
  if pretrain:
163
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
164
  model.load_weights(ckpt_path)
165
  return model
166
 
 
167
  def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
168
  name = 'gcvit_small'
169
  config = NAME2CONFIG[name]
170
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
171
  model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
172
  model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
 
 
173
  if pretrain:
174
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
175
  model.load_weights(ckpt_path)
176
  return model
177
 
 
178
  def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
179
  name = 'gcvit_base'
180
  config = NAME2CONFIG[name]
 
2
  import tensorflow as tf
3
 
4
  from ..layers import Stem, GCViTLevel, Identity
5
+ from ..layers import Stem, GCViTLevel, Identity
6
 
7
 
8
+
9
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
10
  TAG = 'v1.1.1'
11
  NAME2CONFIG = {
12
+ 'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
13
+ 'dim': 64,
14
+ 'depths': (2, 2, 6, 2),
15
+ 'num_heads': (2, 4, 8, 16),
16
+ 'mlp_ratio': 3.,
17
+ 'path_drop': 0.2},
18
+ 'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
19
+ 'dim': 64,
20
+ 'depths': (3, 4, 6, 5),
21
+ 'num_heads': (2, 4, 8, 16),
22
+ 'mlp_ratio': 3.,
23
+ 'path_drop': 0.2},
24
  'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
25
  'dim': 64,
26
  'depths': (2, 2, 6, 2),
 
38
  'depths': (3, 4, 19, 5),
39
  'num_heads': (2, 4, 8, 16),
40
  'mlp_ratio': 3.,
41
+ 'num_heads': (2, 4, 8, 16),
42
+ 'mlp_ratio': 3.,
43
  'path_drop': 0.2,},
44
  'gcvit_small': {'window_size': (7, 7, 14, 7),
45
  'dim': 96,
 
86
  self.num_classes = num_classes
87
  self.head_act = head_act
88
 
89
+ self.patch_embed = Stem(dim=dim, name='patch_embed')
90
  self.patch_embed = Stem(dim=dim, name='patch_embed')
91
  self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
92
  path_drops = np.linspace(0., path_drop, sum(depths))
 
94
  self.levels = []
95
  for i in range(len(depths)):
96
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
97
+ level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
98
  level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
99
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
100
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
 
110
  else:
111
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
112
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
113
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
114
 
115
+ def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
116
  def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
117
  self.num_classes = num_classes
118
  if global_pool is not None:
119
  self.global_pool = global_pool
120
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
121
  super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
122
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
123
+ super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
124
 
125
  def forward_features(self, inputs):
126
  x = self.patch_embed(inputs)
 
137
  x = self.pool(x)
138
  if not pre_logits:
139
  x = self.head(x)
140
+ x = self.head(x)
141
  return x
142
 
143
  def call(self, inputs, **kwargs):
 
153
  def summary(self, input_shape=(224, 224, 3)):
154
  return self.build_graph(input_shape).summary()
155
 
156
+ def summary(self, input_shape=(224, 224, 3)):
157
+ return self.build_graph(input_shape).summary()
158
+
159
  # load standard models
160
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
161
  name = 'gcvit_xxtiny'
 
179
  model.load_weights(ckpt_path)
180
  return model
181
 
182
+ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
183
+ def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
184
+ name = 'gcvit_xxtiny'
185
+ config = NAME2CONFIG[name]
186
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
187
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
188
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
189
+ if pretrain:
190
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
191
+ model.load_weights(ckpt_path)
192
+ return model
193
+
194
+ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
195
+ name = 'gcvit_xtiny'
196
+ config = NAME2CONFIG[name]
197
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
198
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
199
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
200
+ if pretrain:
201
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
202
+ model.load_weights(ckpt_path)
203
+ return model
204
+
205
  def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
206
  name = 'gcvit_tiny'
207
  config = NAME2CONFIG[name]
208
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
209
  model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
210
  model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
211
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
212
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
213
  if pretrain:
214
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
215
  model.load_weights(ckpt_path)
216
  return model
217
 
218
+ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
219
  def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
220
  name = 'gcvit_small'
221
  config = NAME2CONFIG[name]
222
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
223
  model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
224
  model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
225
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
226
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
227
  if pretrain:
228
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
229
  model.load_weights(ckpt_path)
230
  return model
231
 
232
+ def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
233
  def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
234
  name = 'gcvit_base'
235
  config = NAME2CONFIG[name]
gcvit/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.0.3"
 
1
+ __version__ = "1.0.9"
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- tensorflow==2.4.1
2
- tensorflow_addons==0.14.0
3
- gradio==3.1.0
4
- numpy
5
  matplotlib
 
1
+ tensorflow==2.4.1
2
+ tensorflow_addons==0.14.0
3
+ gradio==3.1.0
4
+ numpy
5
  matplotlib
setup.py DELETED
@@ -1,50 +0,0 @@
1
- from setuptools import setup, find_packages
2
- from codecs import open
3
- from os import path
4
-
5
- here = path.abspath(path.dirname(__file__))
6
-
7
- # Get the long description from the README file
8
- with open(path.join(here, "README.md"), encoding="utf-8") as f:
9
- long_description = f.read()
10
-
11
- with open(path.join(here, 'requirements.txt')) as f:
12
- install_requires = [x for x in f.read().splitlines() if len(x)]
13
-
14
- exec(open("gcvit/version.py").read())
15
-
16
- setup(
17
- name="gcvit",
18
- version=__version__,
19
- description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
20
- long_description=long_description,
21
- long_description_content_type="text/markdown",
22
- url="https://github.com/awsaf49/gcvit-tf",
23
- author="Awsaf",
24
- author_email="awsaf49@gmail.com",
25
- classifiers=[
26
- # How mature is this project? Common values are
27
- # 3 - Alpha
28
- # 4 - Beta
29
- # 5 - Production/Stable
30
- "Development Status :: 3 - Alpha",
31
- "Intended Audience :: Developers",
32
- "Intended Audience :: Science/Research",
33
- "License :: OSI Approved :: Apache Software License",
34
- "Programming Language :: Python :: 3.6",
35
- "Programming Language :: Python :: 3.7",
36
- "Programming Language :: Python :: 3.8",
37
- "Topic :: Scientific/Engineering",
38
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
39
- "Topic :: Software Development",
40
- "Topic :: Software Development :: Libraries",
41
- "Topic :: Software Development :: Libraries :: Python Modules",
42
- ],
43
- # Note that this is a string of words separated by whitespace, not a list.
44
- keywords="tensorflow computer_vision image classification transformer",
45
- packages=find_packages(exclude=["tests"]),
46
- include_package_data=True,
47
- install_requires=install_requires,
48
- python_requires=">=3.6",
49
- license="MIT",
50
- )