awsaf49 commited on
Commit
4092407
1 Parent(s): 31b7180

fix conflict

Browse files
Files changed (1) hide show
  1. gcvit/models/gcvit.py +1 -24
gcvit/models/gcvit.py CHANGED
@@ -2,25 +2,12 @@ import numpy as np
2
  import tensorflow as tf
3
 
4
  from ..layers import Stem, GCViTLevel, Identity
5
- from ..layers import Stem, GCViTLevel, Identity
6
 
7
 
8
 
9
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
10
  TAG = 'v1.1.1'
11
  NAME2CONFIG = {
12
- 'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
13
- 'dim': 64,
14
- 'depths': (2, 2, 6, 2),
15
- 'num_heads': (2, 4, 8, 16),
16
- 'mlp_ratio': 3.,
17
- 'path_drop': 0.2},
18
- 'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
19
- 'dim': 64,
20
- 'depths': (3, 4, 6, 5),
21
- 'num_heads': (2, 4, 8, 16),
22
- 'mlp_ratio': 3.,
23
- 'path_drop': 0.2},
24
  'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
25
  'dim': 64,
26
  'depths': (2, 2, 6, 2),
@@ -94,7 +81,6 @@ class GCViT(tf.keras.Model):
94
  self.levels = []
95
  for i in range(len(depths)):
96
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
97
- level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
98
  level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
99
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
100
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
@@ -110,17 +96,14 @@ class GCViT(tf.keras.Model):
110
  else:
111
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
112
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
113
- self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
114
 
115
- def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
116
  def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
117
  self.num_classes = num_classes
118
  if global_pool is not None:
119
  self.global_pool = global_pool
120
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
121
  super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
122
- self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
123
- super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
124
 
125
  def forward_features(self, inputs):
126
  x = self.patch_embed(inputs)
@@ -137,7 +120,6 @@ class GCViT(tf.keras.Model):
137
  x = self.pool(x)
138
  if not pre_logits:
139
  x = self.head(x)
140
- x = self.head(x)
141
  return x
142
 
143
  def call(self, inputs, **kwargs):
@@ -153,8 +135,6 @@ class GCViT(tf.keras.Model):
153
  def summary(self, input_shape=(224, 224, 3)):
154
  return self.build_graph(input_shape).summary()
155
 
156
- def summary(self, input_shape=(224, 224, 3)):
157
- return self.build_graph(input_shape).summary()
158
 
159
  # load standard models
160
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
@@ -179,7 +159,6 @@ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
179
  model.load_weights(ckpt_path)
180
  return model
181
 
182
- def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
183
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
184
  name = 'gcvit_xxtiny'
185
  config = NAME2CONFIG[name]
@@ -215,7 +194,6 @@ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **k
215
  model.load_weights(ckpt_path)
216
  return model
217
 
218
- def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
219
  def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
220
  name = 'gcvit_small'
221
  config = NAME2CONFIG[name]
@@ -229,7 +207,6 @@ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
229
  model.load_weights(ckpt_path)
230
  return model
231
 
232
- def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
233
  def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
234
  name = 'gcvit_base'
235
  config = NAME2CONFIG[name]
 
2
  import tensorflow as tf
3
 
4
  from ..layers import Stem, GCViTLevel, Identity
 
5
 
6
 
7
 
8
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
9
  TAG = 'v1.1.1'
10
  NAME2CONFIG = {
 
 
 
 
 
 
 
 
 
 
 
 
11
  'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
12
  'dim': 64,
13
  'depths': (2, 2, 6, 2),
 
81
  self.levels = []
82
  for i in range(len(depths)):
83
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
 
84
  level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
85
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
86
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
 
96
  else:
97
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
98
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
 
99
 
100
+
101
  def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
102
  self.num_classes = num_classes
103
  if global_pool is not None:
104
  self.global_pool = global_pool
105
  self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
106
  super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
 
 
107
 
108
  def forward_features(self, inputs):
109
  x = self.patch_embed(inputs)
 
120
  x = self.pool(x)
121
  if not pre_logits:
122
  x = self.head(x)
 
123
  return x
124
 
125
  def call(self, inputs, **kwargs):
 
135
  def summary(self, input_shape=(224, 224, 3)):
136
  return self.build_graph(input_shape).summary()
137
 
 
 
138
 
139
  # load standard models
140
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
 
159
  model.load_weights(ckpt_path)
160
  return model
161
 
 
162
  def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
163
  name = 'gcvit_xxtiny'
164
  config = NAME2CONFIG[name]
 
194
  model.load_weights(ckpt_path)
195
  return model
196
 
 
197
  def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
198
  name = 'gcvit_small'
199
  config = NAME2CONFIG[name]
 
207
  model.load_weights(ckpt_path)
208
  return model
209
 
 
210
  def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
211
  name = 'gcvit_base'
212
  config = NAME2CONFIG[name]