gheinrich commited on
Commit
49f3b67
1 Parent(s): 2fa687c

Upload model

Browse files
Files changed (6) hide show
  1. block.py +309 -0
  2. config.json +198 -0
  3. conv.py +339 -0
  4. hf_model.py +84 -0
  5. model.py +1341 -0
  6. pytorch_model.bin +3 -0
block.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Block modules
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from timm.models.layers import DropPath
10
+
11
+ from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
12
+ # from .transformer import TransformerBlock
13
+
14
+ __all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
15
+ 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3')
16
+
17
+
18
+ class DFL(nn.Module):
19
+ """
20
+ Integral module of Distribution Focal Loss (DFL).
21
+ Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
22
+ """
23
+
24
+ def __init__(self, c1=16):
25
+ """Initialize a convolutional layer with a given number of input channels."""
26
+ super().__init__()
27
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
28
+ x = torch.arange(c1, dtype=torch.float)
29
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
30
+ self.c1 = c1
31
+
32
+ def forward(self, x):
33
+ """Applies a transformer layer on input tensor 'x' and returns a tensor."""
34
+ b, c, a = x.shape # batch, channels, anchors
35
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
36
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
37
+
38
+
39
+ class Proto(nn.Module):
40
+ """YOLOv8 mask Proto module for segmentation models."""
41
+
42
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
43
+ super().__init__()
44
+ self.cv1 = Conv(c1, c_, k=3)
45
+ self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
46
+ self.cv2 = Conv(c_, c_, k=3)
47
+ self.cv3 = Conv(c_, c2)
48
+
49
+ def forward(self, x):
50
+ """Performs a forward pass through layers using an upsampled input image."""
51
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
52
+
53
+
54
+ class HGStem(nn.Module):
55
+ """StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
56
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
57
+ """
58
+
59
+ def __init__(self, c1, cm, c2):
60
+ super().__init__()
61
+ self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
62
+ self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
63
+ self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
64
+ self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
65
+ self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
66
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
67
+
68
+ def forward(self, x):
69
+ """Forward pass of a PPHGNetV2 backbone layer."""
70
+ x = self.stem1(x)
71
+ x = F.pad(x, [0, 1, 0, 1])
72
+ x2 = self.stem2a(x)
73
+ x2 = F.pad(x2, [0, 1, 0, 1])
74
+ x2 = self.stem2b(x2)
75
+ x1 = self.pool(x)
76
+ x = torch.cat([x1, x2], dim=1)
77
+ x = self.stem3(x)
78
+ x = self.stem4(x)
79
+ return x
80
+
81
+
82
+ class HGBlock(nn.Module):
83
+ """HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
84
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
85
+ """
86
+
87
+ def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
88
+ super().__init__()
89
+ block = LightConv if lightconv else Conv
90
+ self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
91
+ self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
92
+ self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
93
+ self.add = shortcut and c1 == c2
94
+
95
+ def forward(self, x):
96
+ """Forward pass of a PPHGNetV2 backbone layer."""
97
+ y = [x]
98
+ y.extend(m(y[-1]) for m in self.m)
99
+ y = self.ec(self.sc(torch.cat(y, 1)))
100
+ return y + x if self.add else y
101
+
102
+
103
+ class SPP(nn.Module):
104
+ """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
105
+
106
+ def __init__(self, c1, c2, k=(5, 9, 13)):
107
+ """Initialize the SPP layer with input/output channels and pooling kernel sizes."""
108
+ super().__init__()
109
+ c_ = c1 // 2 # hidden channels
110
+ self.cv1 = Conv(c1, c_, 1, 1)
111
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
112
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
113
+
114
+ def forward(self, x):
115
+ """Forward pass of the SPP layer, performing spatial pyramid pooling."""
116
+ x = self.cv1(x)
117
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
118
+
119
+
120
+ class SPPF(nn.Module):
121
+ """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
122
+
123
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
124
+ super().__init__()
125
+ c_ = c1 // 2 # hidden channels
126
+ self.cv1 = Conv(c1, c_, 1, 1)
127
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
128
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
129
+
130
+ def forward(self, x):
131
+ """Forward pass through Ghost Convolution block."""
132
+ x = self.cv1(x)
133
+ y1 = self.m(x)
134
+ y2 = self.m(y1)
135
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
136
+
137
+
138
+ class C1(nn.Module):
139
+ """CSP Bottleneck with 1 convolution."""
140
+
141
+ def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
142
+ super().__init__()
143
+ self.cv1 = Conv(c1, c2, 1, 1)
144
+ self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
145
+
146
+ def forward(self, x):
147
+ """Applies cross-convolutions to input in the C3 module."""
148
+ y = self.cv1(x)
149
+ return self.m(y) + y
150
+
151
+
152
+ class C2(nn.Module):
153
+ """CSP Bottleneck with 2 convolutions."""
154
+
155
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
156
+ super().__init__()
157
+ self.c = int(c2 * e) # hidden channels
158
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
159
+ self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
160
+ # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
161
+ self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
162
+
163
+ def forward(self, x):
164
+ """Forward pass through the CSP bottleneck with 2 convolutions."""
165
+ a, b = self.cv1(x).chunk(2, 1)
166
+ return self.cv2(torch.cat((self.m(a), b), 1))
167
+
168
+
169
+ class C2f(nn.Module):
170
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
171
+
172
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
173
+ super().__init__()
174
+ if drop_path is None:
175
+ drop_path = [0.0] * n
176
+
177
+ self.c = int(c2 * e) # hidden channels
178
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
179
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
180
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
181
+
182
+ def forward(self, x):
183
+ """Forward pass through C2f layer."""
184
+ y = list(self.cv1(x).chunk(2, 1))
185
+ y.extend(m(y[-1]) for m in self.m)
186
+ return self.cv2(torch.cat(y, 1))
187
+
188
+ def forward_split(self, x):
189
+ """Forward pass using split() instead of chunk()."""
190
+ y = list(self.cv1(x).split((self.c, self.c), 1))
191
+ y.extend(m(y[-1]) for m in self.m)
192
+ return self.cv2(torch.cat(y, 1))
193
+
194
+
195
+ class C3(nn.Module):
196
+ """CSP Bottleneck with 3 convolutions."""
197
+
198
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
199
+ super().__init__()
200
+ c_ = int(c2 * e) # hidden channels
201
+ self.cv1 = Conv(c1, c_, 1, 1)
202
+ self.cv2 = Conv(c1, c_, 1, 1)
203
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
204
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
205
+
206
+ def forward(self, x):
207
+ """Forward pass through the CSP bottleneck with 2 convolutions."""
208
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
209
+
210
+
211
+ class C3x(C3):
212
+ """C3 module with cross-convolutions."""
213
+
214
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
215
+ """Initialize C3TR instance and set default parameters."""
216
+ super().__init__(c1, c2, n, shortcut, g, e)
217
+ self.c_ = int(c2 * e)
218
+ self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
219
+
220
+
221
+ class RepC3(nn.Module):
222
+ """Rep C3."""
223
+
224
+ def __init__(self, c1, c2, n=3, e=1.0):
225
+ super().__init__()
226
+ c_ = int(c2 * e) # hidden channels
227
+ self.cv1 = Conv(c1, c2, 1, 1)
228
+ self.cv2 = Conv(c1, c2, 1, 1)
229
+ self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
230
+ self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
231
+
232
+ def forward(self, x):
233
+ """Forward pass of RT-DETR neck layer."""
234
+ return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
235
+
236
+
237
+ class C3TR(C3):
238
+ """C3 module with TransformerBlock()."""
239
+
240
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
241
+ """Initialize C3Ghost module with GhostBottleneck()."""
242
+ super().__init__(c1, c2, n, shortcut, g, e)
243
+ c_ = int(c2 * e)
244
+ self.m = TransformerBlock(c_, c_, 4, n)
245
+
246
+
247
+ class C3Ghost(C3):
248
+ """C3 module with GhostBottleneck()."""
249
+
250
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
251
+ """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
252
+ super().__init__(c1, c2, n, shortcut, g, e)
253
+ c_ = int(c2 * e) # hidden channels
254
+ self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
255
+
256
+
257
+ class GhostBottleneck(nn.Module):
258
+ """Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
259
+
260
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
261
+ super().__init__()
262
+ c_ = c2 // 2
263
+ self.conv = nn.Sequential(
264
+ GhostConv(c1, c_, 1, 1), # pw
265
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
266
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
267
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
268
+ act=False)) if s == 2 else nn.Identity()
269
+
270
+ def forward(self, x):
271
+ """Applies skip connection and concatenation to input tensor."""
272
+ return self.conv(x) + self.shortcut(x)
273
+
274
+
275
+ class Bottleneck(nn.Module):
276
+ """Standard bottleneck."""
277
+
278
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
279
+ super().__init__()
280
+ c_ = int(c2 * e) # hidden channels
281
+ self.cv1 = Conv(c1, c_, k[0], 1)
282
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
283
+ self.add = shortcut and c1 == c2
284
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
285
+
286
+ def forward(self, x):
287
+ """'forward()' applies the YOLOv5 FPN to input data."""
288
+ return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
289
+
290
+
291
+ class BottleneckCSP(nn.Module):
292
+ """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
293
+
294
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
295
+ super().__init__()
296
+ c_ = int(c2 * e) # hidden channels
297
+ self.cv1 = Conv(c1, c_, 1, 1)
298
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
299
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
300
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
301
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
302
+ self.act = nn.SiLU()
303
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
304
+
305
+ def forward(self, x):
306
+ """Applies a CSP bottleneck with 3 convolutions."""
307
+ y1 = self.cv3(self.m(self.cv1(x)))
308
+ y2 = self.cv2(x)
309
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
config.json ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ERADIOModel"
4
+ ],
5
+ "args": {
6
+ "aa": null,
7
+ "amp": true,
8
+ "amp_dtype": "bfloat16",
9
+ "amp_impl": "native",
10
+ "aug_repeats": 0,
11
+ "aug_splits": 0,
12
+ "batch_size": 32,
13
+ "bn_eps": null,
14
+ "bn_momentum": null,
15
+ "cache": "/lustre/fs3/portfolios/llmservice/users/gheinrich/cache/",
16
+ "cache_dir": null,
17
+ "channels_last": false,
18
+ "checkpoint_hist": 3,
19
+ "class_map": "",
20
+ "clip_grad": null,
21
+ "clip_mode": "norm",
22
+ "coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
23
+ "coco_image_dir": "/datasets/coco2017-adlsa/val2017",
24
+ "color_jitter": 0.4,
25
+ "cooldown_epochs": 0,
26
+ "cpe_max_size": null,
27
+ "crd_loss": false,
28
+ "crd_loss_weight": 0.8,
29
+ "crop_pct": null,
30
+ "cutmix": 0.0,
31
+ "cutmix_minmax": null,
32
+ "data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-21k/webdataset",
33
+ "dataset": "nvgpt4",
34
+ "dataset_download": false,
35
+ "debug_full_knn": false,
36
+ "decay_epochs": 90,
37
+ "decay_milestones": [
38
+ 90,
39
+ 180,
40
+ 270
41
+ ],
42
+ "decay_rate": 0.1,
43
+ "device": "cuda:0",
44
+ "dist_bn": "reduce",
45
+ "distributed": true,
46
+ "drop": 0.0,
47
+ "drop_block": null,
48
+ "drop_connect": null,
49
+ "drop_path": null,
50
+ "epoch_repeats": 0.0,
51
+ "epochs": 300,
52
+ "eval_metric": "knn_top1",
53
+ "eval_teacher": false,
54
+ "eval_teacher_only": false,
55
+ "eval_throughput": false,
56
+ "experiment": "checkpoints",
57
+ "fast_norm": false,
58
+ "feature_summarizer": "cls_token",
59
+ "feature_upscale_factor": null,
60
+ "fuser": "",
61
+ "gp": null,
62
+ "grad_accum_steps": 1,
63
+ "grad_checkpointing": false,
64
+ "head_init_bias": null,
65
+ "head_init_scale": null,
66
+ "hflip": 0.5,
67
+ "img_size": null,
68
+ "in_chans": 3,
69
+ "initial_checkpoint": "",
70
+ "input_size": null,
71
+ "interpolation": "",
72
+ "layer_decay": null,
73
+ "local_rank": 0,
74
+ "log_interval": 50,
75
+ "log_mlflow": false,
76
+ "log_wandb": true,
77
+ "loss": "cosine",
78
+ "lr": 0.001,
79
+ "lr_base": 0.1,
80
+ "lr_base_scale": "",
81
+ "lr_base_size": 256,
82
+ "lr_cycle_decay": 0.5,
83
+ "lr_cycle_limit": 1,
84
+ "lr_cycle_mul": 1.0,
85
+ "lr_k_decay": 1.0,
86
+ "lr_noise": null,
87
+ "lr_noise_pct": 0.67,
88
+ "lr_noise_std": 1.0,
89
+ "mean": null,
90
+ "min_lr": 0,
91
+ "mixup": 0.0,
92
+ "mixup_mode": "batch",
93
+ "mixup_off_epoch": 0,
94
+ "mixup_prob": 1.0,
95
+ "mixup_switch_prob": 0.5,
96
+ "mlp_hidden_size": 1024,
97
+ "model": "fastervit2_large_fullres",
98
+ "model_ema": false,
99
+ "model_ema_decay": 0.9998,
100
+ "model_ema_force_cpu": false,
101
+ "model_kwargs": {
102
+ "return_full_features": true
103
+ },
104
+ "momentum": 0.9,
105
+ "no_aug": false,
106
+ "no_ddp_bb": false,
107
+ "no_prefetcher": false,
108
+ "no_resume_opt": false,
109
+ "num_classes": 0,
110
+ "opt": "fusedlamb",
111
+ "opt_betas": null,
112
+ "opt_eps": null,
113
+ "opt_kwargs": {},
114
+ "output": "/lustre/fs3/portfolios/llmservice/users/gheinrich/results/evfm/19-11-23-fastervit2-l-fullres",
115
+ "patience_epochs": 10,
116
+ "pin_mem": false,
117
+ "prefetcher": false,
118
+ "pretrained": false,
119
+ "rank": 0,
120
+ "ratio": [
121
+ 0.75,
122
+ 1.3333333333333333
123
+ ],
124
+ "recount": 1,
125
+ "recovery_interval": 0,
126
+ "remode": "pixel",
127
+ "reprob": 0.0,
128
+ "resplit": false,
129
+ "resume": "/lustre/fs3/portfolios/llmservice/users/gheinrich/results/evfm/19-11-23-fastervit2-l-fullres/checkpoints/last.pth.tar",
130
+ "return_full_features": true,
131
+ "save_images": false,
132
+ "scale": [
133
+ 0.5,
134
+ 1.0
135
+ ],
136
+ "sched": "cosine",
137
+ "sched_on_updates": true,
138
+ "seed": 42,
139
+ "smoothing": 0.1,
140
+ "split_bn": false,
141
+ "start_epoch": null,
142
+ "std": null,
143
+ "steps_per_epoch": 2000,
144
+ "sync_bn": false,
145
+ "synchronize_step": false,
146
+ "teachers": [
147
+ {
148
+ "batch_size": 32,
149
+ "config": "open_clip_vit-h-14_res224.yaml",
150
+ "fd_loss_weight": 1.0,
151
+ "feature_distillation": true,
152
+ "sample_rate": 8,
153
+ "summary_loss_weight": 1.0
154
+ },
155
+ {
156
+ "batch_size": 32,
157
+ "config": "dinov2_vit-g-14_res224.yaml",
158
+ "fd_loss_weight": 4.0,
159
+ "feature_distillation": true,
160
+ "sample_rate": 8,
161
+ "summary_loss_weight": 1.0
162
+ }
163
+ ],
164
+ "torchcompile": null,
165
+ "torchscript": false,
166
+ "train_interpolation": "random",
167
+ "train_split": "train",
168
+ "tta": 0,
169
+ "use_coco": false,
170
+ "use_multi_epochs_loader": false,
171
+ "val_data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-1k/webdataset",
172
+ "val_img_size": 224,
173
+ "val_split": "val",
174
+ "validation_batch_size": 32,
175
+ "vflip": 0.0,
176
+ "wandb_entity": "",
177
+ "wandb_group": "backbones",
178
+ "wandb_job_type": "",
179
+ "wandb_name": "",
180
+ "wandb_project": "",
181
+ "warmup_epochs": 2.5,
182
+ "warmup_lr": 1e-05,
183
+ "warmup_prefix": false,
184
+ "weight_decay": 2e-05,
185
+ "worker_seeding": "all",
186
+ "workers": 4,
187
+ "world_size": 32
188
+ },
189
+ "auto_map": {
190
+ "AutoConfig": "hf_model.ERADIOConfig",
191
+ "AutoModel": "hf_model.ERADIOModel"
192
+ },
193
+ "return_spatial_features": true,
194
+ "return_summary": true,
195
+ "torch_dtype": "float32",
196
+ "transformers_version": "4.29.0",
197
+ "version": "v1"
198
+ }
conv.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Convolution modules
4
+ """
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ __all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
13
+ 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
14
+
15
+
16
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
17
+ """Pad to 'same' shape outputs."""
18
+ if d > 1:
19
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
20
+ if p is None:
21
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
22
+ return p
23
+
24
+ # Pavlo's implementation with switch to deploy
25
+ class Conv(nn.Module):
26
+ default_act = nn.SiLU() # default activation
27
+
28
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
29
+ super().__init__()
30
+
31
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
32
+ if 1:
33
+ self.bn = torch.nn.BatchNorm2d(b)
34
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
35
+ torch.nn.init.constant_(self.bn.bias, 0)
36
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
37
+
38
+
39
+ def forward(self,x):
40
+ x = self.conv(x)
41
+ x = self.bn(x)
42
+ x = self.act(x)
43
+ return x
44
+
45
+ @torch.no_grad()
46
+ def switch_to_deploy(self):
47
+ if not isinstance(self.bn, nn.Identity):
48
+ # return 1
49
+ c, bn = self.conv, self.bn
50
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
51
+ w = c.weight * w[:, None, None, None]
52
+ b = bn.bias - bn.running_mean * bn.weight / \
53
+ (bn.running_var + bn.eps)**0.5
54
+ # m = torch.nn.Conv2d(w.size(1) * c.groups,
55
+ # w.size(0),
56
+ # w.shape[2:],
57
+ # stride=c.stride,
58
+ # padding=c.padding,
59
+ # dilation=c.dilation,
60
+ # groups=c.groups)
61
+ self.conv.weight.data.copy_(w)
62
+ self.conv.bias = nn.Parameter(b)
63
+ # self.conv.bias.data.copy_(b)
64
+ # self.conv = m.to(c.weight.device)
65
+ self.bn = nn.Identity()
66
+
67
+ # class Conv(nn.Module):
68
+ # """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
69
+ # default_act = nn.SiLU() # default activation
70
+
71
+ # def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
72
+ # """Initialize Conv layer with given arguments including activation."""
73
+ # super().__init__()
74
+ # self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
75
+ # self.bn = nn.BatchNorm2d(c2)
76
+ # self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
77
+
78
+ # def forward(self, x):
79
+ # """Apply convolution, batch normalization and activation to input tensor."""
80
+ # return self.act(self.bn(self.conv(x)))
81
+
82
+ # def forward_fuse(self, x):
83
+ # """Perform transposed convolution of 2D data."""
84
+ # return self.act(self.conv(x))
85
+
86
+
87
+ class Conv2(Conv):
88
+ """Simplified RepConv module with Conv fusing."""
89
+
90
+ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
91
+ """Initialize Conv layer with given arguments including activation."""
92
+ super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
93
+ self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
94
+
95
+ def forward(self, x):
96
+ """Apply convolution, batch normalization and activation to input tensor."""
97
+ return self.act(self.bn(self.conv(x) + self.cv2(x)))
98
+
99
+ def fuse_convs(self):
100
+ """Fuse parallel convolutions."""
101
+ w = torch.zeros_like(self.conv.weight.data)
102
+ i = [x // 2 for x in w.shape[2:]]
103
+ w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
104
+ self.conv.weight.data += w
105
+ self.__delattr__('cv2')
106
+
107
+
108
+ class LightConv(nn.Module):
109
+ """Light convolution with args(ch_in, ch_out, kernel).
110
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
111
+ """
112
+
113
+ def __init__(self, c1, c2, k=1, act=nn.ReLU()):
114
+ """Initialize Conv layer with given arguments including activation."""
115
+ super().__init__()
116
+ self.conv1 = Conv(c1, c2, 1, act=False)
117
+ self.conv2 = DWConv(c2, c2, k, act=act)
118
+
119
+ def forward(self, x):
120
+ """Apply 2 convolutions to input tensor."""
121
+ return self.conv2(self.conv1(x))
122
+
123
+
124
+ class DWConv(Conv):
125
+ """Depth-wise convolution."""
126
+
127
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
128
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
129
+
130
+
131
+ class DWConvTranspose2d(nn.ConvTranspose2d):
132
+ """Depth-wise transpose convolution."""
133
+
134
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
135
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
136
+
137
+
138
+ class ConvTranspose(nn.Module):
139
+ """Convolution transpose 2d layer."""
140
+ default_act = nn.SiLU() # default activation
141
+
142
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
143
+ """Initialize ConvTranspose2d layer with batch normalization and activation function."""
144
+ super().__init__()
145
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
146
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
147
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
148
+
149
+ def forward(self, x):
150
+ """Applies transposed convolutions, batch normalization and activation to input."""
151
+ return self.act(self.bn(self.conv_transpose(x)))
152
+
153
+ def forward_fuse(self, x):
154
+ """Applies activation and convolution transpose operation to input."""
155
+ return self.act(self.conv_transpose(x))
156
+
157
+
158
+ class Focus(nn.Module):
159
+ """Focus wh information into c-space."""
160
+
161
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
162
+ super().__init__()
163
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
164
+ # self.contract = Contract(gain=2)
165
+
166
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
167
+ return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
168
+ # return self.conv(self.contract(x))
169
+
170
+
171
+ class GhostConv(nn.Module):
172
+ """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
173
+
174
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
175
+ super().__init__()
176
+ c_ = c2 // 2 # hidden channels
177
+ self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
178
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
179
+
180
+ def forward(self, x):
181
+ """Forward propagation through a Ghost Bottleneck layer with skip connection."""
182
+ y = self.cv1(x)
183
+ return torch.cat((y, self.cv2(y)), 1)
184
+
185
+
186
+ class RepConv(nn.Module):
187
+ """RepConv is a basic rep-style block, including training and deploy status
188
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
189
+ """
190
+ default_act = nn.SiLU() # default activation
191
+
192
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
193
+ super().__init__()
194
+ assert k == 3 and p == 1
195
+ self.g = g
196
+ self.c1 = c1
197
+ self.c2 = c2
198
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
199
+
200
+ self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
201
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
202
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
203
+
204
+ def forward_fuse(self, x):
205
+ """Forward process"""
206
+ return self.act(self.conv(x))
207
+
208
+ def forward(self, x):
209
+ """Forward process"""
210
+ id_out = 0 if self.bn is None else self.bn(x)
211
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
212
+
213
+ def get_equivalent_kernel_bias(self):
214
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
215
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
216
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
217
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
218
+
219
+ def _avg_to_3x3_tensor(self, avgp):
220
+ channels = self.c1
221
+ groups = self.g
222
+ kernel_size = avgp.kernel_size
223
+ input_dim = channels // groups
224
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
225
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
226
+ return k
227
+
228
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
229
+ if kernel1x1 is None:
230
+ return 0
231
+ else:
232
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
233
+
234
+ def _fuse_bn_tensor(self, branch):
235
+ if branch is None:
236
+ return 0, 0
237
+ if isinstance(branch, Conv):
238
+ kernel = branch.conv.weight
239
+ running_mean = branch.bn.running_mean
240
+ running_var = branch.bn.running_var
241
+ gamma = branch.bn.weight
242
+ beta = branch.bn.bias
243
+ eps = branch.bn.eps
244
+ elif isinstance(branch, nn.BatchNorm2d):
245
+ if not hasattr(self, 'id_tensor'):
246
+ input_dim = self.c1 // self.g
247
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
248
+ for i in range(self.c1):
249
+ kernel_value[i, i % input_dim, 1, 1] = 1
250
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
251
+ kernel = self.id_tensor
252
+ running_mean = branch.running_mean
253
+ running_var = branch.running_var
254
+ gamma = branch.weight
255
+ beta = branch.bias
256
+ eps = branch.eps
257
+ std = (running_var + eps).sqrt()
258
+ t = (gamma / std).reshape(-1, 1, 1, 1)
259
+ return kernel * t, beta - running_mean * gamma / std
260
+
261
+ def fuse_convs(self):
262
+ if hasattr(self, 'conv'):
263
+ return
264
+ kernel, bias = self.get_equivalent_kernel_bias()
265
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
266
+ out_channels=self.conv1.conv.out_channels,
267
+ kernel_size=self.conv1.conv.kernel_size,
268
+ stride=self.conv1.conv.stride,
269
+ padding=self.conv1.conv.padding,
270
+ dilation=self.conv1.conv.dilation,
271
+ groups=self.conv1.conv.groups,
272
+ bias=True).requires_grad_(False)
273
+ self.conv.weight.data = kernel
274
+ self.conv.bias.data = bias
275
+ for para in self.parameters():
276
+ para.detach_()
277
+ self.__delattr__('conv1')
278
+ self.__delattr__('conv2')
279
+ if hasattr(self, 'nm'):
280
+ self.__delattr__('nm')
281
+ if hasattr(self, 'bn'):
282
+ self.__delattr__('bn')
283
+ if hasattr(self, 'id_tensor'):
284
+ self.__delattr__('id_tensor')
285
+
286
+
287
+ class ChannelAttention(nn.Module):
288
+ """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
289
+
290
+ def __init__(self, channels: int) -> None:
291
+ super().__init__()
292
+ self.pool = nn.AdaptiveAvgPool2d(1)
293
+ self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
294
+ self.act = nn.Sigmoid()
295
+
296
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
297
+ return x * self.act(self.fc(self.pool(x)))
298
+
299
+
300
+ class SpatialAttention(nn.Module):
301
+ """Spatial-attention module."""
302
+
303
+ def __init__(self, kernel_size=7):
304
+ """Initialize Spatial-attention module with kernel size argument."""
305
+ super().__init__()
306
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
307
+ padding = 3 if kernel_size == 7 else 1
308
+ self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
309
+ self.act = nn.Sigmoid()
310
+
311
+ def forward(self, x):
312
+ """Apply channel and spatial attention on input for feature recalibration."""
313
+ return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
314
+
315
+
316
+ class CBAM(nn.Module):
317
+ """Convolutional Block Attention Module."""
318
+
319
+ def __init__(self, c1, kernel_size=7): # ch_in, kernels
320
+ super().__init__()
321
+ self.channel_attention = ChannelAttention(c1)
322
+ self.spatial_attention = SpatialAttention(kernel_size)
323
+
324
+ def forward(self, x):
325
+ """Applies the forward pass through C1 module."""
326
+ return self.spatial_attention(self.channel_attention(x))
327
+
328
+
329
+ class Concat(nn.Module):
330
+ """Concatenate a list of tensors along dimension."""
331
+
332
+ def __init__(self, dimension=1):
333
+ """Concatenates a list of tensors along a specified dimension."""
334
+ super().__init__()
335
+ self.d = dimension
336
+
337
+ def forward(self, x):
338
+ """Forward pass for the YOLOv8 mask Proto module."""
339
+ return torch.cat(x, self.d)
hf_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import namedtuple
15
+ from typing import Optional
16
+
17
+ from einops import rearrange
18
+ import torch
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+
21
+ #from radio.model import create_model_from_args
22
+ from radio.input_conditioner import get_default_conditioner, InputConditioner
23
+ from .model import eradio
24
+
25
+
26
+ class ERADIOConfig(PretrainedConfig):
27
+ """Pretrained Hugging Face configuration for ERADIO models."""
28
+
29
+ def __init__(
30
+ self,
31
+ args: Optional[dict] = None,
32
+ version: Optional[str] = "v1",
33
+ return_summary: Optional[bool] = True,
34
+ return_spatial_features: Optional[bool] = True,
35
+ **kwargs,
36
+ ):
37
+ self.args = args
38
+ self.version = version
39
+ self.return_summary = return_summary
40
+ self.return_spatial_features = return_spatial_features
41
+ super().__init__(**kwargs)
42
+
43
+
44
+ class ERADIOModel(PreTrainedModel):
45
+ """Pretrained Hugging Face model for ERADIO.
46
+
47
+ This class inherits from PreTrainedModel, which provides
48
+ HuggingFace's functionality for loading and saving models.
49
+ """
50
+
51
+ config_class = ERADIOConfig
52
+
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+
56
+ config.args["in_chans"] = 3
57
+ config.args["num_classes"] = 0
58
+ config.args["return_full_features"] = config.return_spatial_features
59
+
60
+ self.config = config
61
+ model = eradio(**config.args)
62
+ self.input_conditioner: InputConditioner = get_default_conditioner()
63
+ self.return_summary = config.return_summary
64
+ self.return_spatial_features = config.return_spatial_features
65
+ self.model = model
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ x = self.input_conditioner(x)
69
+ y = self.model.forward_features(x)
70
+ summary, features = self.model.forward_features(x)
71
+
72
+ if isinstance(y, tuple):
73
+ summary, features = y
74
+ # ERADIO features are spatial tokens.
75
+ features = rearrange(features, 'b c h w -> b (h w) c')
76
+ else:
77
+ summary = y
78
+ features = None
79
+
80
+ if self.return_summary and self.return_spatial_features:
81
+ return summary, features
82
+ elif self.return_summary:
83
+ return summary
84
+ return features
model.py ADDED
@@ -0,0 +1,1341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ # Created by Pavlo Molchanov, LPR - DL Efficiency Research team
12
+ # based on Fastervit1 from LPR
13
+
14
+ import timm
15
+ import torch
16
+ import torch.nn as nn
17
+ from timm.models.registry import register_model
18
+
19
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
20
+ import numpy as np
21
+ import torch.nn.functional as F
22
+ from .block import C2f
23
+ TRT = False # should help for TRT
24
+
25
+ import pickle
26
+ global bias_indx
27
+ bias_indx = -1
28
+ DEBUG = False
29
+
30
+
31
+
32
+ def pixel_unshuffle(data, factor=2):
33
+ # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
34
+ B, C, H, W = data.shape
35
+ return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
36
+
37
+ class SwiGLU(nn.Module):
38
+ # should be more advanced, but doesnt improve results so far
39
+ def forward(self, x):
40
+ x, gate = x.chunk(2, dim=-1)
41
+ return F.silu(gate) * x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, C, H, W)
48
+ window_size: window size
49
+ Returns:
50
+ windows - local window features (num_windows*B, window_size*window_size, C)
51
+ (Hp, Wp) - the size of the padded image
52
+ """
53
+ B, C, H, W = x.shape
54
+
55
+ if window_size == 0 or (window_size==H and window_size==W):
56
+ windows = x.flatten(2).transpose(1, 2)
57
+ Hp, Wp = H, W
58
+ else:
59
+ pad_h = (window_size - H % window_size) % window_size
60
+ pad_w = (window_size - W % window_size) % window_size
61
+ if pad_h > 0 or pad_w > 0:
62
+ x = F.pad(x, (0, pad_w, 0, pad_h, 0, 0, 0, 0))
63
+ Hp, Wp = H + pad_h, W + pad_w
64
+
65
+ x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
66
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
67
+
68
+ return windows, (Hp, Wp)
69
+
70
+ class Conv2d_BN(nn.Module):
71
+ '''
72
+ Conv2d + BN layer with folding capability to speed up inference
73
+ '''
74
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
75
+ super().__init__()
76
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
77
+ if 1:
78
+ self.bn = torch.nn.BatchNorm2d(b)
79
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
80
+ torch.nn.init.constant_(self.bn.bias, 0)
81
+
82
+ def forward(self,x):
83
+ x = self.conv(x)
84
+ x = self.bn(x)
85
+ return x
86
+
87
+ @torch.no_grad()
88
+ def switch_to_deploy(self):
89
+
90
+ # return 1
91
+ if not isinstance(self.bn, nn.Identity):
92
+ c, bn = self.conv, self.bn
93
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
94
+ w = c.weight * w[:, None, None, None]
95
+ b = bn.bias - bn.running_mean * bn.weight / \
96
+ (bn.running_var + bn.eps)**0.5
97
+ self.conv.weight.data.copy_(w)
98
+ self.conv.bias = nn.Parameter(b)
99
+ self.bn = nn.Identity()
100
+
101
+
102
+
103
+ def window_reverse(windows, window_size, H, W, pad_hw):
104
+ """
105
+ Args:
106
+ windows: local window features (num_windows*B, window_size, window_size, C)
107
+ window_size: Window size
108
+ H: Height of image
109
+ W: Width of image
110
+ pad_w - a tuple of image passing used in windowing step
111
+ Returns:
112
+ x: (B, C, H, W)
113
+
114
+ """
115
+ # print(f"window_reverse, windows.shape {windows.shape}")
116
+ Hp, Wp = pad_hw
117
+ if window_size == 0 or (window_size==H and window_size==W):
118
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
119
+ x = windows.transpose(1, 2).view(B, -1, H, W)
120
+ else:
121
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
122
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
123
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
124
+
125
+ if Hp > H or Wp > W:
126
+ x = x[:, :, :H, :W, ].contiguous()
127
+
128
+ return x
129
+
130
+
131
+
132
+ class PosEmbMLPSwinv2D(nn.Module):
133
+ def __init__(self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False):
134
+ super().__init__()
135
+ self.window_size = window_size
136
+ self.num_heads = num_heads
137
+ # mlp to generate continuous relative position bias
138
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
139
+ nn.ReLU(inplace=True),
140
+ nn.Linear(512, num_heads, bias=False))
141
+
142
+ # get relative_coords_table
143
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
144
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
145
+ relative_coords_table = torch.stack(
146
+ torch.meshgrid([relative_coords_h,
147
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
148
+ if pretrained_window_size[0] > 0:
149
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
150
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
151
+ else:
152
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
153
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
154
+
155
+ if not no_log:
156
+ relative_coords_table *= 8 # normalize to -8, 8
157
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
158
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
159
+
160
+ self.register_buffer("relative_coords_table", relative_coords_table)
161
+
162
+ # get pair-wise relative position index for each token inside the window
163
+ coords_h = torch.arange(self.window_size[0])
164
+ coords_w = torch.arange(self.window_size[1])
165
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
166
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
167
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
168
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
169
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
170
+ relative_coords[:, :, 1] += self.window_size[1] - 1
171
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
172
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
173
+ self.register_buffer("relative_position_index", relative_position_index)
174
+
175
+ self.grid_exists = False
176
+
177
+ self.deploy = False
178
+
179
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
180
+ self.seq_length = seq_length
181
+ self.register_buffer("relative_bias", relative_bias) #for EMA
182
+
183
+ def switch_to_deploy(self):
184
+ self.deploy = True
185
+ self.grid_exists = True
186
+
187
+ def forward(self, input_tensor):
188
+ # for efficiency, we want this forward to be folded into a single operation (sum)
189
+ # if resolution stays the same, then we dont need to recompute MLP layers
190
+ #
191
+ # to dynamically adjust patch size over the step
192
+ # if not (input_tensor.shape[1:] == self.relative_bias.shape[1:]):
193
+ # self.grid_exists = False
194
+
195
+ if self.training: self.grid_exists = False
196
+
197
+ if self.deploy and self.grid_exists:
198
+ input_tensor += self.relative_bias
199
+ return input_tensor
200
+
201
+ if not self.grid_exists:
202
+ self.grid_exists = True
203
+
204
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
205
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
206
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
207
+ -1) # Wh*Ww,Wh*Ww,nH
208
+
209
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
210
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
211
+
212
+ self.relative_bias = relative_position_bias.unsqueeze(0)
213
+
214
+ input_tensor += self.relative_bias
215
+ return input_tensor
216
+
217
+
218
+
219
+ class GRAAttentionBlock(nn.Module):
220
+ def __init__(self, window_size, dim_in, dim_out,
221
+ num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
222
+ norm_layer=nn.LayerNorm, layer_scale=None,
223
+ use_swiglu=True,
224
+ subsample_ratio=1, dim_ratio=1, conv_base=False,
225
+ do_windowing=True, multi_query=False) -> None:
226
+ super().__init__()
227
+
228
+ dim = dim_in
229
+ # conv_base = True
230
+ SHUFFLE = True
231
+ SHUFFLE = False
232
+ self.do_windowing = do_windowing
233
+
234
+ if do_windowing:
235
+ if SHUFFLE:
236
+ self.downsample_op = torch.nn.PixelUnshuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
237
+ self.downsample_mixer = nn.Conv2d(dim_in * (subsample_ratio * subsample_ratio), dim_in * (dim_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
238
+ else:
239
+ if conv_base:
240
+ self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
241
+ self.downsample_mixer = nn.Identity()
242
+ else:
243
+ self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
244
+ self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
245
+
246
+
247
+ if do_windowing:
248
+ if SHUFFLE:
249
+ self.upsample_mixer =nn.Conv2d(dim_in * dim_ratio, dim_in * (subsample_ratio * subsample_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
250
+ self.upsample_op = torch.nn.PixelShuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
251
+ else:
252
+ if conv_base:
253
+ self.upsample_mixer = nn.Identity()
254
+ self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
255
+ else:
256
+ self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
257
+ self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
258
+
259
+ self.window_size = window_size
260
+
261
+ self.norm1 = norm_layer(dim_in)
262
+ if DEBUG:
263
+ print(f"GRAAttentionBlock: input_resolution: , window_size: {window_size}, dim_in: {dim_in}, dim_out: {dim_out}, num_heads: {num_heads}, drop_path: {drop_path}, qk_scale: {qk_scale}, qkv_bias: {qkv_bias}, layer_scale: {layer_scale}")
264
+
265
+
266
+ self.attn = WindowAttention(
267
+ dim_in,
268
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
269
+ resolution=window_size,
270
+ seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query)
271
+ if DEBUG:
272
+ print(f"Attention: dim_in: {dim_in}, num_heads: {num_heads}, qkv_bias: {qkv_bias}, qk_scale: {qk_scale}, resolution: {window_size}, seq_length: {window_size**2}, dim_out: {dim_in}")
273
+ print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
274
+
275
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
276
+
277
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
278
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
279
+
280
+ ### mlp layer
281
+ mlp_ratio = 4
282
+ self.norm2 = norm_layer(dim_in)
283
+ mlp_hidden_dim = int(dim_in * mlp_ratio)
284
+
285
+ activation = nn.GELU if not use_swiglu else SwiGLU
286
+ mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
287
+
288
+ self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
289
+
290
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
291
+ self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
292
+ if DEBUG:
293
+ print(f"MLP layer: dim_in: {dim_in}, dim_out: {dim_in}, mlp_hidden_dim: {mlp_hidden_dim}")
294
+ print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
295
+
296
+
297
+ def forward(self, x):
298
+ skip_connection = x
299
+
300
+ if self.do_windowing:
301
+ # performing windowing if required
302
+ x = self.downsample_op(x)
303
+ x = self.downsample_mixer(x)
304
+
305
+ if self.window_size>0:
306
+ H, W = x.shape[2], x.shape[3]
307
+
308
+ x, pad_hw = window_partition(x, self.window_size)
309
+
310
+ # window attention
311
+ x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x)))
312
+ # mlp layer
313
+ x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
314
+
315
+ if self.do_windowing:
316
+ if self.window_size > 0:
317
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
318
+
319
+ x = self.upsample_mixer(x)
320
+ x = self.upsample_op(x)
321
+
322
+
323
+ if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
324
+ x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]))
325
+ # need to add skip connection because downsampling and upsampling will break residual connection
326
+ # 0.5 is needed to make sure that the skip connection is not too strong
327
+ # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
328
+ x = 0.5 * x + 0.5 * skip_connection
329
+
330
+ return x
331
+
332
+
333
+
334
+
335
+ class MultiResolutionAttention(nn.Module):
336
+ """
337
+ MultiResolutionAttention (MRA) module
338
+ The idea is to use multiple attention blocks with different resolution
339
+ Feature maps are downsampled / upsampled for each attention block on different blocks
340
+ Every attention block supports
341
+
342
+ """
343
+
344
+ def __init__(self, window_size, sr_ratio,
345
+ dim, dim_ratio, num_heads,
346
+ do_windowing=True,
347
+ layer_scale=1e-5, norm_layer=nn.LayerNorm,
348
+ drop_path = 0, qkv_bias=False, qk_scale=1.0,
349
+ use_swiglu=True, multi_query=False, conv_base=False) -> None:
350
+ """
351
+ Args:
352
+ input_resolution: input image resolution
353
+ window_size: window size
354
+ compression_ratio: compression ratio
355
+ max_depth: maximum depth of the GRA module
356
+ """
357
+ super().__init__()
358
+
359
+ depth = len(sr_ratio)
360
+
361
+
362
+ self.attention_blocks = nn.ModuleList()
363
+
364
+
365
+ for i in range(depth):
366
+ subsample_ratio = sr_ratio[i]
367
+ if len(window_size) > i:
368
+ window_size_local = window_size[i]
369
+ else:
370
+ window_size_local = window_size[0]
371
+
372
+ self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
373
+ dim_in=dim, dim_out=dim, num_heads=num_heads,
374
+ qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
375
+ layer_scale=layer_scale, drop_path=drop_path,
376
+ use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
377
+ do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base),
378
+ )
379
+
380
+
381
+
382
+ def forward(self, x):
383
+
384
+ for attention_block in self.attention_blocks:
385
+ x = attention_block(x)
386
+
387
+ return x
388
+
389
+
390
+
391
+ class Mlp(nn.Module):
392
+ """
393
+ Multi-Layer Perceptron (MLP) block
394
+ """
395
+
396
+ def __init__(self,
397
+ in_features,
398
+ hidden_features=None,
399
+ out_features=None,
400
+ act_layer=nn.GELU,
401
+ use_swiglu=True,
402
+ drop=0.):
403
+ """
404
+ Args:
405
+ in_features: input features dimension.
406
+ hidden_features: hidden features dimension.
407
+ out_features: output features dimension.
408
+ act_layer: activation function.
409
+ drop: dropout rate.
410
+ """
411
+
412
+ super().__init__()
413
+ out_features = out_features or in_features
414
+ hidden_features = hidden_features or in_features
415
+ self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
416
+ self.act = act_layer()
417
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
418
+ # self.drop = GaussianDropout(drop)
419
+
420
+ def forward(self, x):
421
+ x_size = x.size()
422
+ x = x.view(-1, x_size[-1])
423
+ x = self.fc1(x)
424
+ x = self.act(x)
425
+ # x = self.drop(x)
426
+ x = self.fc2(x)
427
+ # x = self.drop(x)
428
+ x = x.view(x_size)
429
+ return x
430
+
431
+ class Downsample(nn.Module):
432
+ """
433
+ Down-sampling block
434
+
435
+ Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
436
+ """
437
+
438
+ def __init__(self,
439
+ dim,
440
+ shuffle = False,
441
+ ):
442
+ """
443
+ Args:
444
+ dim: feature size dimension.
445
+ shuffle: idea with
446
+ keep_dim: bool argument for maintaining the resolution.
447
+ """
448
+
449
+ super().__init__()
450
+ dim_out = 2 * dim
451
+
452
+ if shuffle:
453
+ self.norm = lambda x: pixel_unshuffle(x, factor=2)
454
+ self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
455
+ else:
456
+ #removed layer norm for better, in this formulation we are getting 10% better speed
457
+ # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
458
+ self.norm = nn.Identity()
459
+ self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
460
+
461
+
462
+ def forward(self, x):
463
+ x = self.norm(x)
464
+ x = self.reduction(x)
465
+ return x
466
+
467
+
468
+ class PatchEmbed(nn.Module):
469
+ """
470
+ Patch embedding block
471
+ """
472
+
473
+ def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
474
+ """
475
+ Args:
476
+ in_chans: number of input channels.
477
+ in_dim: intermediate feature size dimension to speed up stem.
478
+ dim: final stem channel number
479
+ shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field
480
+ """
481
+
482
+ super().__init__()
483
+ # shuffle_down = False
484
+ if not shuffle_down:
485
+ self.proj = nn.Identity()
486
+ self.conv_down = nn.Sequential(
487
+ Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
488
+ nn.ReLU(),
489
+ Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
490
+ nn.ReLU()
491
+ )
492
+ else:
493
+ self.proj = lambda x: pixel_unshuffle(x, factor=4)
494
+
495
+ # self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, in_dim, 3, 1, 1),
496
+ # nn.SiLU(),
497
+ # Conv2d_BN(in_dim, dim, 3, 1, 1),
498
+ # nn.SiLU(),
499
+ # )
500
+ self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
501
+ nn.ReLU(),
502
+ )
503
+
504
+ def forward(self, x):
505
+ x = self.proj(x)
506
+ x = self.conv_down(x)
507
+ return x
508
+
509
+
510
+
511
+ class ConvBlock(nn.Module):
512
+ """
513
+ Convolutional block, used in first couple of stages
514
+ Experimented with plan resnet-18 like modules, they are the best in terms of throughput
515
+ Experimented with RepVGG, dont see significant improvement in accuracy
516
+ Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
517
+ """
518
+ def __init__(self, dim,
519
+ drop_path=0.,
520
+ layer_scale=None,
521
+ kernel_size=3,
522
+ rep_vgg=False):
523
+ super().__init__()
524
+ self.rep_vgg = rep_vgg
525
+ if not rep_vgg:
526
+ self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
527
+ self.act1 = nn.GELU()
528
+ else:
529
+ self.conv1 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
530
+
531
+
532
+ if not rep_vgg:
533
+ self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
534
+ else:
535
+ self.conv2 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
536
+
537
+ self.layer_scale = layer_scale
538
+ if layer_scale is not None and type(layer_scale) in [int, float]:
539
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
540
+ self.layer_scale = True
541
+ else:
542
+ self.layer_scale = False
543
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
544
+
545
+ def forward(self, x):
546
+ input = x
547
+ if not self.rep_vgg:
548
+ x = self.conv1(x)
549
+ x = self.act1(x)
550
+ x = self.conv2(x)
551
+ else:
552
+ x = self.conv1(x)
553
+ x = self.conv2(x)
554
+ if self.layer_scale:
555
+ x = x * self.gamma.view(1, -1, 1, 1)
556
+ x = input + self.drop_path(x)
557
+ return x
558
+
559
+
560
+ class WindowAttention(nn.Module):
561
+ # Windowed Attention from SwinV2
562
+ # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
563
+ # tested multi-querry attention, but it is not as good as full attention:
564
+ # look into palm: https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/palm_pytorch.py
565
+ # single kv attention, mlp in parallel (didnt improve speed)
566
+
567
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
568
+ seq_length=0, dim_out=None, multi_query=False):
569
+ # taken from EdgeViT and tweaked with attention bias.
570
+ super().__init__()
571
+ if not dim_out: dim_out = dim
572
+ self.multi_query = multi_query
573
+ self.num_heads = num_heads
574
+ head_dim = dim // num_heads
575
+ self.head_dim = dim // num_heads
576
+
577
+ self.dim_internal = dim
578
+
579
+ self.scale = qk_scale or head_dim ** -0.5
580
+ if not multi_query:
581
+ if TRT:
582
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
583
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
584
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
585
+ else:
586
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
587
+ else:
588
+ self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
589
+
590
+ self.proj = nn.Linear(dim, dim_out, bias=False)
591
+ # attention positional bias
592
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
593
+ pretrained_window_size=[resolution, resolution],
594
+ num_heads=num_heads,
595
+ seq_length=seq_length)
596
+
597
+ self.resolution = resolution
598
+
599
+ def forward(self, x):
600
+ B, N, C = x.shape
601
+
602
+ if not self.multi_query:
603
+ if TRT:
604
+ q = self.q(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
605
+ k = self.k(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
606
+ v = self.v(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
607
+ else:
608
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
609
+ q, k, v = qkv[0], qkv[1], qkv[2]
610
+ else:
611
+ qkv = self.qkv(x)
612
+ (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
613
+
614
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
615
+ k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
616
+ v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
617
+
618
+ attn = (q @ k.transpose(-2, -1)) * self.scale
619
+
620
+ attn = self.pos_emb_funct(attn)
621
+
622
+ attn = attn.softmax(dim=-1)
623
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
624
+ x = self.proj(x)
625
+ return x
626
+
627
+
628
+
629
+ class FasterViTLayer(nn.Module):
630
+ """
631
+ fastervitlayer
632
+ """
633
+
634
+ def __init__(self,
635
+ dim,
636
+ depth,
637
+ num_heads,
638
+ window_size,
639
+ conv=False,
640
+ downsample=True,
641
+ mlp_ratio=4.,
642
+ qkv_bias=False,
643
+ qk_scale=None,
644
+ norm_layer=nn.LayerNorm,
645
+ drop_path=0.,
646
+ layer_scale=None,
647
+ layer_scale_conv=None,
648
+ sr_dim_ratio=1,
649
+ sr_ratio=1,
650
+ multi_query=False,
651
+ use_swiglu=True,
652
+ rep_vgg=False,
653
+ yolo_arch=False,
654
+ downsample_shuffle=False,
655
+ conv_base=False,
656
+
657
+ ):
658
+ """
659
+ Args:
660
+ dim: feature size dimension.
661
+ depth: number of layers in each stage.
662
+ input_resolution: input image resolution.
663
+ window_size: window size in each stage.
664
+ downsample: bool argument for down-sampling.
665
+ mlp_ratio: MLP ratio.
666
+ num_heads: number of heads in each stage.
667
+ qkv_bias: bool argument for query, key, value learnable bias.
668
+ qk_scale: bool argument to scaling query, key.
669
+ drop: dropout rate.
670
+ attn_drop: attention dropout rate.
671
+ drop_path: drop path rate.
672
+ norm_layer: normalization layer.
673
+ layer_scale: layer scaling coefficient.
674
+ """
675
+
676
+ super().__init__()
677
+ self.conv = conv
678
+ self.yolo_arch=False
679
+ if conv:
680
+ if not yolo_arch:
681
+ self.blocks = nn.ModuleList([
682
+ ConvBlock(dim=dim,
683
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
684
+ layer_scale=layer_scale_conv, rep_vgg=rep_vgg)
685
+ for i in range(depth)])
686
+ else:
687
+ self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
688
+ self.yolo_arch=True
689
+ else:
690
+ if not isinstance(window_size, list): window_size = [window_size]
691
+ self.window_size = window_size[0]
692
+ self.do_single_windowing = True
693
+ if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
694
+ if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
695
+ self.do_single_windowing = False
696
+ do_windowing = True
697
+ else:
698
+ self.do_single_windowing = True
699
+ do_windowing = False
700
+
701
+ self.blocks = nn.ModuleList()
702
+ for i in range(depth):
703
+
704
+ self.blocks.append(
705
+ MultiResolutionAttention(window_size=window_size,
706
+ sr_ratio=sr_ratio,
707
+ dim=dim,
708
+ dim_ratio = sr_dim_ratio,
709
+ num_heads=num_heads,
710
+ norm_layer=norm_layer,
711
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
712
+ layer_scale=layer_scale,
713
+ qkv_bias=qkv_bias,
714
+ qk_scale=qk_scale,
715
+ use_swiglu=use_swiglu,
716
+ do_windowing=do_windowing,
717
+ multi_query=multi_query,
718
+ conv_base=conv_base,
719
+ ))
720
+
721
+ self.transformer = not conv
722
+
723
+
724
+ self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
725
+
726
+
727
+
728
+
729
+ def forward(self, x):
730
+ B, C, H, W = x.shape
731
+
732
+ if self.transformer and self.do_single_windowing:
733
+ H, W = x.shape[2], x.shape[3]
734
+ x, pad_hw = window_partition(x, self.window_size)
735
+
736
+ if not self.yolo_arch:
737
+ for bn, blk in enumerate(self.blocks):
738
+ x = blk(x)
739
+ else:
740
+ x = self.blocks(x)
741
+
742
+ if self.transformer and self.do_single_windowing:
743
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
744
+
745
+
746
+ if self.downsample is None:
747
+ return x, x
748
+
749
+ return self.downsample(x), x #changing to output pre downsampled features
750
+
751
+
752
+ class FasterViT(nn.Module):
753
+ """
754
+ FasterViT
755
+ """
756
+
757
+ def __init__(self,
758
+ dim,
759
+ in_dim,
760
+ depths,
761
+ window_size,
762
+ mlp_ratio,
763
+ num_heads,
764
+ drop_path_rate=0.2,
765
+ in_chans=3,
766
+ num_classes=1000,
767
+ qkv_bias=False,
768
+ qk_scale=None,
769
+ layer_scale=None,
770
+ layer_scale_conv=None,
771
+ layer_norm_last=False,
772
+ sr_ratio = [1, 1, 1, 1],
773
+ max_depth = -1,
774
+ conv_base=False,
775
+ use_swiglu=False,
776
+ multi_query=False,
777
+ norm_layer=nn.LayerNorm,
778
+ rep_vgg=False,
779
+ drop_uniform=False,
780
+ yolo_arch=False,
781
+ shuffle_down=False,
782
+ downsample_shuffle=False,
783
+ return_full_features=False,
784
+ full_features_head_dim=128,
785
+ neck_start_stage=1,
786
+ use_neck=False,
787
+ **kwargs):
788
+ """
789
+ Args:
790
+ dim: feature size dimension.
791
+ depths: number of layers in each stage.
792
+ window_size: window size in each stage.
793
+ mlp_ratio: MLP ratio.
794
+ num_heads: number of heads in each stage.
795
+ drop_path_rate: drop path rate.
796
+ in_chans: number of input channels.
797
+ num_classes: number of classes.
798
+ qkv_bias: bool argument for query, key, value learnable bias.
799
+ qk_scale: bool argument to scaling query, key.
800
+ drop_rate: dropout rate.
801
+ attn_drop_rate: attention dropout rate.
802
+ norm_layer: normalization layer.
803
+ layer_scale: layer scaling coefficient.
804
+ return_full_features: output dense features as well as logits
805
+ full_features_head_dim: number of channels in the dense features head
806
+ neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0
807
+ for 224 resolution, the output of the stage before downsample:
808
+ stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
809
+ use_neck: even for summarization embedding use neck
810
+ """
811
+ super().__init__()
812
+
813
+ num_features = int(dim * 2 ** (len(depths) - 1))
814
+ self.num_classes = num_classes
815
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
816
+ # set return_full_features true if we want to return full features from all stages
817
+ self.return_full_features = return_full_features
818
+ self.use_neck = use_neck
819
+
820
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
821
+ if drop_uniform:
822
+ dpr = [drop_path_rate for x in range(sum(depths))]
823
+
824
+ if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
825
+
826
+ self.levels = nn.ModuleList()
827
+ for i in range(len(depths)):
828
+ conv = True if (i == 0 or i == 1) else False
829
+
830
+ level = FasterViTLayer(dim=int(dim * 2 ** i),
831
+ depth=depths[i],
832
+ num_heads=num_heads[i],
833
+ window_size=window_size[i],
834
+ mlp_ratio=mlp_ratio,
835
+ qkv_bias=qkv_bias,
836
+ qk_scale=qk_scale,
837
+ conv=conv,
838
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
839
+ downsample=(i < 3),
840
+ layer_scale=layer_scale,
841
+ layer_scale_conv=layer_scale_conv,
842
+ sr_ratio=sr_ratio[i],
843
+ use_swiglu=use_swiglu,
844
+ multi_query=multi_query,
845
+ norm_layer=norm_layer,
846
+ rep_vgg=rep_vgg,
847
+ yolo_arch=yolo_arch,
848
+ downsample_shuffle=downsample_shuffle,
849
+ conv_base=conv_base)
850
+
851
+ self.levels.append(level)
852
+
853
+ if self.return_full_features or self.use_neck:
854
+ # create feature projection layers for segmentation output
855
+ self.neck_features_proj = nn.ModuleList()
856
+ self.neck_start_stage = neck_start_stage
857
+ upsample_ratio = 1
858
+ for i in range(len(depths)):
859
+ level_n_features_output = int(dim * 2 ** i)
860
+
861
+ if self.neck_start_stage > i: continue
862
+
863
+ if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
864
+ feature_projection = nn.Sequential()
865
+ # feature_projection.add_module("norm",LayerNorm2d(level_n_features_output)) #slow, but better
866
+
867
+
868
+ if 0 :
869
+ # Train: 0 [1900/10009 ( 19%)] Loss: 6.113 (6.57) Time: 0.548s, 233.40/s (0.549s, 233.04/s) LR: 1.000e-05 Data: 0.015 (0.013)
870
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
871
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
872
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
873
+ else:
874
+ # pixel shuffle based upsampling
875
+ # Train: 0 [1950/10009 ( 19%)] Loss: 6.190 (6.55) Time: 0.540s, 236.85/s (0.548s, 233.38/s) LR: 1.000e-05 Data: 0.015 (0.013)
876
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
877
+ feature_projection.add_module("conv", nn.Conv2d(level_n_features_output,
878
+ full_features_head_dim*upsample_ratio*upsample_ratio, kernel_size=1, stride=1))
879
+ feature_projection.add_module("upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio))
880
+
881
+ else:
882
+ feature_projection = nn.Sequential()
883
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
884
+
885
+
886
+ self.neck_features_proj.append(feature_projection)
887
+
888
+ if i>0 and self.levels[i-1].downsample is not None:
889
+ upsample_ratio *= 2
890
+
891
+
892
+ num_features = full_features_head_dim if (self.return_full_features or self.use_neck) else num_features
893
+
894
+ self.num_features = num_features
895
+
896
+ self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
897
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
898
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
899
+ self.apply(self._init_weights)
900
+ # pass
901
+
902
+ def _init_weights(self, m):
903
+ if isinstance(m, nn.Linear):
904
+ trunc_normal_(m.weight, std=.02)
905
+ if isinstance(m, nn.Linear) and m.bias is not None:
906
+ nn.init.constant_(m.bias, 0)
907
+ elif isinstance(m, nn.LayerNorm):
908
+ nn.init.constant_(m.bias, 0)
909
+ nn.init.constant_(m.weight, 1.0)
910
+ elif isinstance(m, LayerNorm2d):
911
+ nn.init.constant_(m.bias, 0)
912
+ nn.init.constant_(m.weight, 1.0)
913
+ elif isinstance(m, nn.BatchNorm2d):
914
+ nn.init.ones_(m.weight)
915
+ nn.init.zeros_(m.bias)
916
+
917
+ @torch.jit.ignore
918
+ def no_weight_decay_keywords(self):
919
+ return {'rpb'}
920
+
921
+ def forward_features(self, x):
922
+ x = self.patch_embed(x)
923
+ full_features = None
924
+ for il, level in enumerate(self.levels):
925
+ x, pre_downsample_x = level(x)
926
+
927
+ if self.return_full_features or self.use_neck:
928
+ if self.neck_start_stage > il: continue
929
+ if full_features is None:
930
+ full_features = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
931
+ else:
932
+ #upsample torch tensor x to match full_features size, and add to full_features
933
+ feature_projection = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
934
+ if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
935
+ feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
936
+ full_features += feature_projection
937
+
938
+ # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
939
+ x = self.norm(x) # new version for
940
+ x = self.avgpool(x)
941
+ x = torch.flatten(x, 1)
942
+
943
+ if not self.return_full_features:
944
+ return x, None
945
+
946
+ return x, full_features
947
+
948
+ def forward(self, x):
949
+ x, full_features = self.forward_features(x)
950
+ x = self.head(x)
951
+ if full_features is not None:
952
+ return x, full_features
953
+ return x
954
+
955
+ def switch_to_deploy(self):
956
+ '''
957
+ A method to perform model self-compression
958
+ merges BN into conv layers
959
+ converts MLP relative positional bias into precomputed buffers
960
+ '''
961
+ for level in [self.patch_embed, self.levels, self.head]:
962
+ for module in level.modules():
963
+ if hasattr(module, 'switch_to_deploy'):
964
+ module.switch_to_deploy()
965
+
966
+ @register_model
967
+ def fastervit2_small(pretrained=False, **kwargs): #,
968
+ model = FasterViT(depths=[3, 3, 5, 5],
969
+ num_heads=[2, 4, 8, 16],
970
+ window_size=[8, 8, [7, 7], 7],
971
+ dim=96,
972
+ in_dim=64,
973
+ mlp_ratio=4,
974
+ drop_path_rate=0.2,
975
+ sr_ratio=[1, 1, [1, 2], 1],
976
+ use_swiglu=False,
977
+ downsample_shuffle=False,
978
+ yolo_arch=True,
979
+ shuffle_down=False,
980
+ **kwargs)
981
+ if pretrained:
982
+ model.load_state_dict(torch.load(pretrained))
983
+ return model
984
+
985
+ @register_model
986
+ def fastervit2_tiny(pretrained=False, **kwargs): #,
987
+ model = FasterViT(depths=[1, 3, 4, 5],
988
+ num_heads=[2, 4, 8, 16],
989
+ window_size=[8, 8, [7, 7], 7],
990
+ dim=80,
991
+ in_dim=64,
992
+ mlp_ratio=4,
993
+ drop_path_rate=0.2,
994
+ sr_ratio=[1, 1, [2, 1], 1],
995
+ use_swiglu=False,
996
+ downsample_shuffle=False,
997
+ yolo_arch=True,
998
+ shuffle_down=False,
999
+ **kwargs)
1000
+ if pretrained:
1001
+ model.load_state_dict(torch.load(pretrained))
1002
+ return model
1003
+
1004
+ @register_model
1005
+ def fastervit2_base(pretrained=False, **kwargs):
1006
+ model = FasterViT(depths=[3, 3, 5, 5],
1007
+ num_heads=[2, 4, 8, 16],
1008
+ window_size=[8, 8, [7, 7], 7],
1009
+ dim=128,
1010
+ in_dim=64,
1011
+ mlp_ratio=4,
1012
+ drop_path_rate=0.2,
1013
+ sr_ratio=[1, 1, [2, 1], 1],
1014
+ use_swiglu=False,
1015
+ yolo_arch=True,
1016
+ shuffle_down=False,
1017
+ conv_base=True,
1018
+ **kwargs)
1019
+ if pretrained:
1020
+ model.load_state_dict(torch.load(pretrained))
1021
+ return model
1022
+
1023
+ @register_model
1024
+ def fastervit2_base_fullres1(pretrained=False, **kwargs):
1025
+ model = FasterViT(depths=[3, 3, 5, 5],
1026
+ num_heads=[2, 4, 8, 16],
1027
+ window_size=[8, 8, [7, 7], 7],
1028
+ dim=128,
1029
+ in_dim=64,
1030
+ mlp_ratio=4,
1031
+ drop_path_rate=0.2,
1032
+ sr_ratio=[1, 1, [2, 1], 1],
1033
+ use_swiglu=False,
1034
+ yolo_arch=True,
1035
+ shuffle_down=False,
1036
+ conv_base=True,
1037
+ use_neck=True,
1038
+ full_features_head_dim=1024,
1039
+ neck_start_stage=2,
1040
+ **kwargs)
1041
+ if pretrained:
1042
+ model.load_state_dict(torch.load(pretrained))
1043
+ return model
1044
+
1045
+ @register_model
1046
+ def fastervit2_base_fullres2(pretrained=False, **kwargs):
1047
+ model = FasterViT(depths=[3, 3, 5, 5],
1048
+ num_heads=[2, 4, 8, 16],
1049
+ window_size=[8, 8, [7, 7], 7],
1050
+ dim=128,
1051
+ in_dim=64,
1052
+ mlp_ratio=4,
1053
+ drop_path_rate=0.2,
1054
+ sr_ratio=[1, 1, [2, 1], 1],
1055
+ use_swiglu=False,
1056
+ yolo_arch=True,
1057
+ shuffle_down=False,
1058
+ conv_base=True,
1059
+ use_neck=True,
1060
+ full_features_head_dim=512,
1061
+ neck_start_stage=1,
1062
+ **kwargs)
1063
+ if pretrained:
1064
+ model.load_state_dict(torch.load(pretrained))
1065
+ return model
1066
+
1067
+ @register_model
1068
+ def fastervit2_base_fullres3(pretrained=False, **kwargs):
1069
+ model = FasterViT(depths=[3, 3, 5, 5],
1070
+ num_heads=[2, 4, 8, 16],
1071
+ window_size=[8, 8, [7, 7], 7],
1072
+ dim=128,
1073
+ in_dim=64,
1074
+ mlp_ratio=4,
1075
+ drop_path_rate=0.2,
1076
+ sr_ratio=[1, 1, [2, 1], 1],
1077
+ use_swiglu=False,
1078
+ yolo_arch=True,
1079
+ shuffle_down=False,
1080
+ conv_base=True,
1081
+ use_neck=True,
1082
+ full_features_head_dim=256,
1083
+ neck_start_stage=1,
1084
+ **kwargs)
1085
+ if pretrained:
1086
+ model.load_state_dict(torch.load(pretrained))
1087
+ return model
1088
+
1089
+ @register_model
1090
+ def fastervit2_base_fullres4(pretrained=False, **kwargs):
1091
+ model = FasterViT(depths=[3, 3, 5, 5],
1092
+ num_heads=[2, 4, 8, 16],
1093
+ window_size=[8, 8, [7, 7], 7],
1094
+ dim=128,
1095
+ in_dim=64,
1096
+ mlp_ratio=4,
1097
+ drop_path_rate=0.2,
1098
+ sr_ratio=[1, 1, [2, 1], 1],
1099
+ use_swiglu=False,
1100
+ yolo_arch=True,
1101
+ shuffle_down=False,
1102
+ conv_base=True,
1103
+ use_neck=True,
1104
+ full_features_head_dim=256,
1105
+ neck_start_stage=2,
1106
+ **kwargs)
1107
+ if pretrained:
1108
+ model.load_state_dict(torch.load(pretrained))
1109
+ return model
1110
+
1111
+ @register_model
1112
+ def fastervit2_base_fullres5(pretrained=False, **kwargs):
1113
+ model = FasterViT(depths=[3, 3, 5, 5],
1114
+ num_heads=[2, 4, 8, 16],
1115
+ window_size=[8, 8, [7, 7], 7],
1116
+ dim=128,
1117
+ in_dim=64,
1118
+ mlp_ratio=4,
1119
+ drop_path_rate=0.2,
1120
+ sr_ratio=[1, 1, [2, 1], 1],
1121
+ use_swiglu=False,
1122
+ yolo_arch=True,
1123
+ shuffle_down=False,
1124
+ conv_base=True,
1125
+ use_neck=True,
1126
+ full_features_head_dim=512,
1127
+ neck_start_stage=2,
1128
+ **kwargs)
1129
+ if pretrained:
1130
+ model.load_state_dict(torch.load(pretrained))
1131
+ return model
1132
+
1133
+ #pyt: 1934, 4202 TRT
1134
+ @register_model
1135
+ def fastervit2_large(pretrained=False, **kwargs):
1136
+ model = FasterViT(depths=[3, 3, 5, 5],
1137
+ num_heads=[2, 4, 8, 16],
1138
+ window_size=[8, 8, [7, 7], 7],
1139
+ dim=128+64,
1140
+ in_dim=64,
1141
+ mlp_ratio=4,
1142
+ drop_path_rate=0.2,
1143
+ sr_ratio=[1, 1, [2, 1], 1],
1144
+ use_swiglu=False,
1145
+ yolo_arch=True,
1146
+ shuffle_down=False,
1147
+ **kwargs)
1148
+ if pretrained:
1149
+ model.load_state_dict(torch.load(pretrained))
1150
+ return model
1151
+
1152
+ @register_model
1153
+ def fastervit2_large_fullres(pretrained=False, **kwargs):
1154
+ model = FasterViT(depths=[3, 3, 5, 5],
1155
+ num_heads=[2, 4, 8, 16],
1156
+ window_size=[None, None, [7, 7], 7],
1157
+ dim=192,
1158
+ in_dim=64,
1159
+ mlp_ratio=4,
1160
+ drop_path_rate=0.,
1161
+ sr_ratio=[1, 1, [2, 1], 1],
1162
+ use_swiglu=False,
1163
+ yolo_arch=True,
1164
+ shuffle_down=False,
1165
+ conv_base=True,
1166
+ use_neck=True,
1167
+ full_features_head_dim=1536,
1168
+ neck_start_stage=2,
1169
+ **kwargs)
1170
+ if pretrained:
1171
+ model.load_state_dict(torch.load(pretrained))
1172
+ return model
1173
+
1174
+ @register_model
1175
+ def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1176
+ model = FasterViT(depths=[3, 3, 5, 5],
1177
+ num_heads=[2, 4, 8, 16],
1178
+ window_size=[None, None, [8, 8], 8],
1179
+ dim=192,
1180
+ in_dim=64,
1181
+ mlp_ratio=4,
1182
+ drop_path_rate=0.,
1183
+ sr_ratio=[1, 1, [2, 1], 1],
1184
+ use_swiglu=False,
1185
+ yolo_arch=True,
1186
+ shuffle_down=False,
1187
+ conv_base=True,
1188
+ use_neck=True,
1189
+ full_features_head_dim=1536,
1190
+ neck_start_stage=2,
1191
+ **kwargs)
1192
+ if pretrained:
1193
+ model.load_state_dict(torch.load(pretrained))
1194
+ return model
1195
+
1196
+ @register_model
1197
+ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1198
+ model = FasterViT(depths=[3, 3, 5, 5],
1199
+ num_heads=[2, 4, 8, 16],
1200
+ window_size=[None, None, [16, 16], 16],
1201
+ dim=192,
1202
+ in_dim=64,
1203
+ mlp_ratio=4,
1204
+ drop_path_rate=0.,
1205
+ sr_ratio=[1, 1, [2, 1], 1],
1206
+ use_swiglu=False,
1207
+ yolo_arch=True,
1208
+ shuffle_down=False,
1209
+ conv_base=True,
1210
+ use_neck=True,
1211
+ full_features_head_dim=1536,
1212
+ neck_start_stage=2,
1213
+ **kwargs)
1214
+ if pretrained:
1215
+ model.load_state_dict(torch.load(pretrained))
1216
+ return model
1217
+
1218
+ @register_model
1219
+ def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1220
+ model = FasterViT(depths=[3, 3, 5, 5],
1221
+ num_heads=[2, 4, 8, 16],
1222
+ window_size=[None, None, [32, 32], 32],
1223
+ dim=192,
1224
+ in_dim=64,
1225
+ mlp_ratio=4,
1226
+ drop_path_rate=0.,
1227
+ sr_ratio=[1, 1, [2, 1], 1],
1228
+ use_swiglu=False,
1229
+ yolo_arch=True,
1230
+ shuffle_down=False,
1231
+ conv_base=True,
1232
+ use_neck=True,
1233
+ full_features_head_dim=1536,
1234
+ neck_start_stage=2,
1235
+ **kwargs)
1236
+ if pretrained:
1237
+ model.load_state_dict(torch.load(pretrained))
1238
+ return model
1239
+
1240
+ #pyt: 897
1241
+ @register_model
1242
+ def fastervit2_xlarge(pretrained=False, **kwargs):
1243
+ model = FasterViT(depths=[3, 3, 5, 5],
1244
+ num_heads=[2, 4, 8, 16],
1245
+ window_size=[8, 8, [7, 7], 7],
1246
+ dim=128+128+64,
1247
+ in_dim=64,
1248
+ mlp_ratio=4,
1249
+ drop_path_rate=0.2,
1250
+ sr_ratio=[1, 1, [2, 1], 1],
1251
+ use_swiglu=False,
1252
+ yolo_arch=True,
1253
+ shuffle_down=False,
1254
+ **kwargs)
1255
+ if pretrained:
1256
+ model.load_state_dict(torch.load(pretrained))
1257
+ return model
1258
+
1259
+
1260
+ #pyt:
1261
+ @register_model
1262
+ def fastervit2_huge(pretrained=False, **kwargs):
1263
+ model = FasterViT(depths=[3, 3, 5, 5],
1264
+ num_heads=[2, 4, 8, 16],
1265
+ window_size=[8, 8, [7, 7], 7],
1266
+ dim=128+128+128+64,
1267
+ in_dim=64,
1268
+ mlp_ratio=4,
1269
+ drop_path_rate=0.2,
1270
+ sr_ratio=[1, 1, [2, 1], 1],
1271
+ use_swiglu=False,
1272
+ yolo_arch=True,
1273
+ shuffle_down=False,
1274
+ **kwargs)
1275
+ if pretrained:
1276
+ model.load_state_dict(torch.load(pretrained))
1277
+ return model
1278
+
1279
+
1280
+ @register_model
1281
+ def fastervit2_xtiny(pretrained=False, **kwargs): #,
1282
+ model = FasterViT(depths=[1, 3, 4, 5],
1283
+ num_heads=[2, 4, 8, 16],
1284
+ window_size=[8, 8, [7, 7], 7],
1285
+ dim=64,
1286
+ in_dim=64,
1287
+ mlp_ratio=4,
1288
+ drop_path_rate=0.1,
1289
+ sr_ratio=[1, 1, [2, 1], 1],
1290
+ use_swiglu=False,
1291
+ downsample_shuffle=False,
1292
+ yolo_arch=True,
1293
+ shuffle_down=False,
1294
+ **kwargs)
1295
+ if pretrained:
1296
+ model.load_state_dict(torch.load(pretrained))
1297
+ return model
1298
+
1299
+
1300
+ @register_model
1301
+ def fastervit2_xxtiny_5(pretrained=False, **kwargs): #,
1302
+ model = FasterViT(depths=[1, 3, 4, 5],
1303
+ num_heads=[2, 4, 8, 16],
1304
+ window_size=[8, 8, [7, 7], 7],
1305
+ dim=48,
1306
+ in_dim=64,
1307
+ mlp_ratio=4,
1308
+ drop_path_rate=0.05,
1309
+ sr_ratio=[1, 1, [2, 1], 1],
1310
+ use_swiglu=False,
1311
+ downsample_shuffle=False,
1312
+ yolo_arch=True,
1313
+ shuffle_down=False,
1314
+ **kwargs)
1315
+ if pretrained:
1316
+ model.load_state_dict(torch.load(pretrained))
1317
+ return model
1318
+
1319
+ @register_model
1320
+ def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1321
+ model = FasterViT(depths=[1, 3, 4, 5],
1322
+ num_heads=[2, 4, 8, 16],
1323
+ window_size=[8, 8, [7, 7], 7],
1324
+ dim=32,
1325
+ in_dim=32,
1326
+ mlp_ratio=4,
1327
+ drop_path_rate=0.0,
1328
+ sr_ratio=[1, 1, [2, 1], 1],
1329
+ use_swiglu=False,
1330
+ downsample_shuffle=False,
1331
+ yolo_arch=True,
1332
+ shuffle_down=False,
1333
+ **kwargs)
1334
+ if pretrained:
1335
+ model.load_state_dict(torch.load(pretrained))
1336
+ return model
1337
+
1338
+
1339
+ @register_model
1340
+ def eradio(pretrained=False, **kwargs):
1341
+ return fastervit2_large_fullres(pretrained=pretrained, **kwargs)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:115b8f54d0d4999c180718ce138f8078127af9815b6cb507b253e5db10a5723c
3
+ size 1057766065