gheinrich commited on
Commit
00797c3
1 Parent(s): 49f3b67

Upload model

Browse files
cls_token.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class ClsToken(nn.Module):
14
+ def __init__(self, ndim: int,
15
+ num_tokens: int = 1,
16
+ enabled: bool = True,
17
+ register_multiple: int = 0,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.ndim = ndim
22
+ self.enabled = enabled
23
+ self.num_registers = 0
24
+ self.num_tokens = num_tokens
25
+ if enabled:
26
+ if register_multiple > 0:
27
+ self.num_registers = register_multiple - (num_tokens % register_multiple)
28
+
29
+ scale = ndim ** -0.5
30
+ self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
31
+ else:
32
+ self.token = None
33
+
34
+ self.num_patches = self.num_tokens + self.num_registers
35
+
36
+ def disable(self):
37
+ self.token = None
38
+ self.enabled = False
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ if self.token is None:
42
+ return x
43
+
44
+ token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
45
+ x = torch.cat([
46
+ token,
47
+ x,
48
+ ], dim=1)
49
+
50
+ return x
51
+
52
+ def no_weight_decay(self):
53
+ return [
54
+ 'token',
55
+ ]
enable_cpe_support.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Union, Tuple
10
+ from types import MethodType
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from timm.models import VisionTransformer, checkpoint_seq
16
+
17
+ from .vit_patch_generator import ViTPatchGenerator
18
+
19
+
20
+ def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
21
+ x = self.patch_generator(x)
22
+ if self.grad_checkpointing and not torch.jit.is_scripting():
23
+ x = checkpoint_seq(self.blocks, x)
24
+ else:
25
+ x = self.blocks(x)
26
+ x = self.norm(x)
27
+ return x
28
+
29
+
30
+ def enable_cpe(model: nn.Module,
31
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
32
+ num_cls_tokens: int = 1,
33
+ pos_dropout: float = 0.1,
34
+ register_multiple: int = 0,
35
+ ):
36
+ if not isinstance(model, VisionTransformer):
37
+ raise ValueError("CPE only support for VisionTransformer models!")
38
+
39
+ patch_size = model.patch_embed.patch_size[0]
40
+ embed_dim = model.embed_dim
41
+ input_dims = model.patch_embed.img_size
42
+ normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
43
+ cls_token = model.cls_token is not None
44
+
45
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
46
+
47
+ patch_generator = ViTPatchGenerator(
48
+ patch_size=patch_size,
49
+ embed_dim=embed_dim,
50
+ input_dims=input_dims,
51
+ normalize_patches=normalize_patches,
52
+ cls_token=cls_token,
53
+ max_input_dims=max_img_size,
54
+ pos_dropout=pos_dropout,
55
+ num_cls_tokens=num_cls_tokens,
56
+ register_multiple=register_multiple,
57
+ )
58
+
59
+ model.patch_generator = patch_generator
60
+ model.patch_embed = None
61
+ model.cls_token = None
62
+ model.pos_embed = None
63
+ model.pos_drop = None
64
+ model.num_cls_tokens = num_cls_tokens
65
+ model.num_registers = patch_generator.num_registers
66
+
67
+ model.forward_features = MethodType(_forward_cpe, model)
eradio_model.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ # Created by Pavlo Molchanov, LPR - DL Efficiency Research team
12
+ # based on Fastervit1 from LPR
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from timm.models.registry import register_model
17
+
18
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
19
+ import numpy as np
20
+ import torch.nn.functional as F
21
+ from .block import C2f
22
+ TRT = False # should help for TRT
23
+
24
+ import pickle
25
+ global bias_indx
26
+ bias_indx = -1
27
+ DEBUG = False
28
+
29
+
30
+
31
+ def pixel_unshuffle(data, factor=2):
32
+ # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
33
+ B, C, H, W = data.shape
34
+ return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
35
+
36
+ class SwiGLU(nn.Module):
37
+ # should be more advanced, but doesnt improve results so far
38
+ def forward(self, x):
39
+ x, gate = x.chunk(2, dim=-1)
40
+ return F.silu(gate) * x
41
+
42
+
43
+ def window_partition(x, window_size):
44
+ """
45
+ Args:
46
+ x: (B, C, H, W)
47
+ window_size: window size
48
+ Returns:
49
+ windows - local window features (num_windows*B, window_size*window_size, C)
50
+ (Hp, Wp) - the size of the padded image
51
+ """
52
+ B, C, H, W = x.shape
53
+
54
+ if window_size == 0 or (window_size==H and window_size==W):
55
+ windows = x.flatten(2).transpose(1, 2)
56
+ Hp, Wp = H, W
57
+ else:
58
+ pad_h = (window_size - H % window_size) % window_size
59
+ pad_w = (window_size - W % window_size) % window_size
60
+ if pad_h > 0 or pad_w > 0:
61
+ x = F.pad(x, (0, pad_w, 0, pad_h, 0, 0, 0, 0))
62
+ Hp, Wp = H + pad_h, W + pad_w
63
+
64
+ x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
65
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
66
+
67
+ return windows, (Hp, Wp)
68
+
69
+ class Conv2d_BN(nn.Module):
70
+ '''
71
+ Conv2d + BN layer with folding capability to speed up inference
72
+ '''
73
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
74
+ super().__init__()
75
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
76
+ if 1:
77
+ self.bn = torch.nn.BatchNorm2d(b)
78
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
79
+ torch.nn.init.constant_(self.bn.bias, 0)
80
+
81
+ def forward(self,x):
82
+ x = self.conv(x)
83
+ x = self.bn(x)
84
+ return x
85
+
86
+ @torch.no_grad()
87
+ def switch_to_deploy(self):
88
+
89
+ # return 1
90
+ if not isinstance(self.bn, nn.Identity):
91
+ c, bn = self.conv, self.bn
92
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
93
+ w = c.weight * w[:, None, None, None]
94
+ b = bn.bias - bn.running_mean * bn.weight / \
95
+ (bn.running_var + bn.eps)**0.5
96
+ self.conv.weight.data.copy_(w)
97
+ self.conv.bias = nn.Parameter(b)
98
+ self.bn = nn.Identity()
99
+
100
+
101
+
102
+ def window_reverse(windows, window_size, H, W, pad_hw):
103
+ """
104
+ Args:
105
+ windows: local window features (num_windows*B, window_size, window_size, C)
106
+ window_size: Window size
107
+ H: Height of image
108
+ W: Width of image
109
+ pad_w - a tuple of image passing used in windowing step
110
+ Returns:
111
+ x: (B, C, H, W)
112
+
113
+ """
114
+ # print(f"window_reverse, windows.shape {windows.shape}")
115
+ Hp, Wp = pad_hw
116
+ if window_size == 0 or (window_size==H and window_size==W):
117
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
118
+ x = windows.transpose(1, 2).view(B, -1, H, W)
119
+ else:
120
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
121
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
122
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
123
+
124
+ if Hp > H or Wp > W:
125
+ x = x[:, :, :H, :W, ].contiguous()
126
+
127
+ return x
128
+
129
+
130
+
131
+ class PosEmbMLPSwinv2D(nn.Module):
132
+ def __init__(self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False):
133
+ super().__init__()
134
+ self.window_size = window_size
135
+ self.num_heads = num_heads
136
+ # mlp to generate continuous relative position bias
137
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
138
+ nn.ReLU(inplace=True),
139
+ nn.Linear(512, num_heads, bias=False))
140
+
141
+ # get relative_coords_table
142
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
143
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
144
+ relative_coords_table = torch.stack(
145
+ torch.meshgrid([relative_coords_h,
146
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
147
+ if pretrained_window_size[0] > 0:
148
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
149
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
150
+ else:
151
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
152
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
153
+
154
+ if not no_log:
155
+ relative_coords_table *= 8 # normalize to -8, 8
156
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
157
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
158
+
159
+ self.register_buffer("relative_coords_table", relative_coords_table)
160
+
161
+ # get pair-wise relative position index for each token inside the window
162
+ coords_h = torch.arange(self.window_size[0])
163
+ coords_w = torch.arange(self.window_size[1])
164
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
165
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
166
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
167
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
168
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
169
+ relative_coords[:, :, 1] += self.window_size[1] - 1
170
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
171
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
172
+ self.register_buffer("relative_position_index", relative_position_index)
173
+
174
+ self.grid_exists = False
175
+
176
+ self.deploy = False
177
+
178
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
179
+ self.seq_length = seq_length
180
+ self.register_buffer("relative_bias", relative_bias) #for EMA
181
+
182
+ def switch_to_deploy(self):
183
+ self.deploy = True
184
+ self.grid_exists = True
185
+
186
+ def forward(self, input_tensor):
187
+ # for efficiency, we want this forward to be folded into a single operation (sum)
188
+ # if resolution stays the same, then we dont need to recompute MLP layers
189
+ #
190
+ # to dynamically adjust patch size over the step
191
+ # if not (input_tensor.shape[1:] == self.relative_bias.shape[1:]):
192
+ # self.grid_exists = False
193
+
194
+ if self.training: self.grid_exists = False
195
+
196
+ if self.deploy and self.grid_exists:
197
+ input_tensor += self.relative_bias
198
+ return input_tensor
199
+
200
+ if not self.grid_exists:
201
+ self.grid_exists = True
202
+
203
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
204
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
205
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
206
+ -1) # Wh*Ww,Wh*Ww,nH
207
+
208
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
209
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
210
+
211
+ self.relative_bias = relative_position_bias.unsqueeze(0)
212
+
213
+ input_tensor += self.relative_bias
214
+ return input_tensor
215
+
216
+
217
+
218
+ class GRAAttentionBlock(nn.Module):
219
+ def __init__(self, window_size, dim_in, dim_out,
220
+ num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
221
+ norm_layer=nn.LayerNorm, layer_scale=None,
222
+ use_swiglu=True,
223
+ subsample_ratio=1, dim_ratio=1, conv_base=False,
224
+ do_windowing=True, multi_query=False) -> None:
225
+ super().__init__()
226
+
227
+ dim = dim_in
228
+ # conv_base = True
229
+ SHUFFLE = True
230
+ SHUFFLE = False
231
+ self.do_windowing = do_windowing
232
+
233
+ if do_windowing:
234
+ if SHUFFLE:
235
+ self.downsample_op = torch.nn.PixelUnshuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
236
+ self.downsample_mixer = nn.Conv2d(dim_in * (subsample_ratio * subsample_ratio), dim_in * (dim_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
237
+ else:
238
+ if conv_base:
239
+ self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
240
+ self.downsample_mixer = nn.Identity()
241
+ else:
242
+ self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
243
+ self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
244
+
245
+
246
+ if do_windowing:
247
+ if SHUFFLE:
248
+ self.upsample_mixer =nn.Conv2d(dim_in * dim_ratio, dim_in * (subsample_ratio * subsample_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
249
+ self.upsample_op = torch.nn.PixelShuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
250
+ else:
251
+ if conv_base:
252
+ self.upsample_mixer = nn.Identity()
253
+ self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
254
+ else:
255
+ self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
256
+ self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
257
+
258
+ self.window_size = window_size
259
+
260
+ self.norm1 = norm_layer(dim_in)
261
+ if DEBUG:
262
+ print(f"GRAAttentionBlock: input_resolution: , window_size: {window_size}, dim_in: {dim_in}, dim_out: {dim_out}, num_heads: {num_heads}, drop_path: {drop_path}, qk_scale: {qk_scale}, qkv_bias: {qkv_bias}, layer_scale: {layer_scale}")
263
+
264
+
265
+ self.attn = WindowAttention(
266
+ dim_in,
267
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
268
+ resolution=window_size,
269
+ seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query)
270
+ if DEBUG:
271
+ print(f"Attention: dim_in: {dim_in}, num_heads: {num_heads}, qkv_bias: {qkv_bias}, qk_scale: {qk_scale}, resolution: {window_size}, seq_length: {window_size**2}, dim_out: {dim_in}")
272
+ print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
273
+
274
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
275
+
276
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
277
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
278
+
279
+ ### mlp layer
280
+ mlp_ratio = 4
281
+ self.norm2 = norm_layer(dim_in)
282
+ mlp_hidden_dim = int(dim_in * mlp_ratio)
283
+
284
+ activation = nn.GELU if not use_swiglu else SwiGLU
285
+ mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
286
+
287
+ self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
288
+
289
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
290
+ self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
291
+ if DEBUG:
292
+ print(f"MLP layer: dim_in: {dim_in}, dim_out: {dim_in}, mlp_hidden_dim: {mlp_hidden_dim}")
293
+ print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
294
+
295
+
296
+ def forward(self, x):
297
+ skip_connection = x
298
+
299
+ if self.do_windowing:
300
+ # performing windowing if required
301
+ x = self.downsample_op(x)
302
+ x = self.downsample_mixer(x)
303
+
304
+ if self.window_size>0:
305
+ H, W = x.shape[2], x.shape[3]
306
+
307
+ x, pad_hw = window_partition(x, self.window_size)
308
+
309
+ # window attention
310
+ x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x)))
311
+ # mlp layer
312
+ x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
313
+
314
+ if self.do_windowing:
315
+ if self.window_size > 0:
316
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
317
+
318
+ x = self.upsample_mixer(x)
319
+ x = self.upsample_op(x)
320
+
321
+
322
+ if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
323
+ x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]))
324
+ # need to add skip connection because downsampling and upsampling will break residual connection
325
+ # 0.5 is needed to make sure that the skip connection is not too strong
326
+ # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
327
+ x = 0.5 * x + 0.5 * skip_connection
328
+
329
+ return x
330
+
331
+
332
+
333
+
334
+ class MultiResolutionAttention(nn.Module):
335
+ """
336
+ MultiResolutionAttention (MRA) module
337
+ The idea is to use multiple attention blocks with different resolution
338
+ Feature maps are downsampled / upsampled for each attention block on different blocks
339
+ Every attention block supports
340
+
341
+ """
342
+
343
+ def __init__(self, window_size, sr_ratio,
344
+ dim, dim_ratio, num_heads,
345
+ do_windowing=True,
346
+ layer_scale=1e-5, norm_layer=nn.LayerNorm,
347
+ drop_path = 0, qkv_bias=False, qk_scale=1.0,
348
+ use_swiglu=True, multi_query=False, conv_base=False) -> None:
349
+ """
350
+ Args:
351
+ input_resolution: input image resolution
352
+ window_size: window size
353
+ compression_ratio: compression ratio
354
+ max_depth: maximum depth of the GRA module
355
+ """
356
+ super().__init__()
357
+
358
+ depth = len(sr_ratio)
359
+
360
+
361
+ self.attention_blocks = nn.ModuleList()
362
+
363
+
364
+ for i in range(depth):
365
+ subsample_ratio = sr_ratio[i]
366
+ if len(window_size) > i:
367
+ window_size_local = window_size[i]
368
+ else:
369
+ window_size_local = window_size[0]
370
+
371
+ self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
372
+ dim_in=dim, dim_out=dim, num_heads=num_heads,
373
+ qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
374
+ layer_scale=layer_scale, drop_path=drop_path,
375
+ use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
376
+ do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base),
377
+ )
378
+
379
+
380
+
381
+ def forward(self, x):
382
+
383
+ for attention_block in self.attention_blocks:
384
+ x = attention_block(x)
385
+
386
+ return x
387
+
388
+
389
+
390
+ class Mlp(nn.Module):
391
+ """
392
+ Multi-Layer Perceptron (MLP) block
393
+ """
394
+
395
+ def __init__(self,
396
+ in_features,
397
+ hidden_features=None,
398
+ out_features=None,
399
+ act_layer=nn.GELU,
400
+ use_swiglu=True,
401
+ drop=0.):
402
+ """
403
+ Args:
404
+ in_features: input features dimension.
405
+ hidden_features: hidden features dimension.
406
+ out_features: output features dimension.
407
+ act_layer: activation function.
408
+ drop: dropout rate.
409
+ """
410
+
411
+ super().__init__()
412
+ out_features = out_features or in_features
413
+ hidden_features = hidden_features or in_features
414
+ self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
415
+ self.act = act_layer()
416
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
417
+ # self.drop = GaussianDropout(drop)
418
+
419
+ def forward(self, x):
420
+ x_size = x.size()
421
+ x = x.view(-1, x_size[-1])
422
+ x = self.fc1(x)
423
+ x = self.act(x)
424
+ # x = self.drop(x)
425
+ x = self.fc2(x)
426
+ # x = self.drop(x)
427
+ x = x.view(x_size)
428
+ return x
429
+
430
+ class Downsample(nn.Module):
431
+ """
432
+ Down-sampling block
433
+
434
+ Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
435
+ """
436
+
437
+ def __init__(self,
438
+ dim,
439
+ shuffle = False,
440
+ ):
441
+ """
442
+ Args:
443
+ dim: feature size dimension.
444
+ shuffle: idea with
445
+ keep_dim: bool argument for maintaining the resolution.
446
+ """
447
+
448
+ super().__init__()
449
+ dim_out = 2 * dim
450
+
451
+ if shuffle:
452
+ self.norm = lambda x: pixel_unshuffle(x, factor=2)
453
+ self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
454
+ else:
455
+ #removed layer norm for better, in this formulation we are getting 10% better speed
456
+ # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
457
+ self.norm = nn.Identity()
458
+ self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
459
+
460
+
461
+ def forward(self, x):
462
+ x = self.norm(x)
463
+ x = self.reduction(x)
464
+ return x
465
+
466
+
467
+ class PatchEmbed(nn.Module):
468
+ """
469
+ Patch embedding block
470
+ """
471
+
472
+ def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
473
+ """
474
+ Args:
475
+ in_chans: number of input channels.
476
+ in_dim: intermediate feature size dimension to speed up stem.
477
+ dim: final stem channel number
478
+ shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field
479
+ """
480
+
481
+ super().__init__()
482
+ # shuffle_down = False
483
+ if not shuffle_down:
484
+ self.proj = nn.Identity()
485
+ self.conv_down = nn.Sequential(
486
+ Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
487
+ nn.ReLU(),
488
+ Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
489
+ nn.ReLU()
490
+ )
491
+ else:
492
+ self.proj = lambda x: pixel_unshuffle(x, factor=4)
493
+
494
+ # self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, in_dim, 3, 1, 1),
495
+ # nn.SiLU(),
496
+ # Conv2d_BN(in_dim, dim, 3, 1, 1),
497
+ # nn.SiLU(),
498
+ # )
499
+ self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
500
+ nn.ReLU(),
501
+ )
502
+
503
+ def forward(self, x):
504
+ x = self.proj(x)
505
+ x = self.conv_down(x)
506
+ return x
507
+
508
+
509
+
510
+ class ConvBlock(nn.Module):
511
+ """
512
+ Convolutional block, used in first couple of stages
513
+ Experimented with plan resnet-18 like modules, they are the best in terms of throughput
514
+ Experimented with RepVGG, dont see significant improvement in accuracy
515
+ Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
516
+ """
517
+ def __init__(self, dim,
518
+ drop_path=0.,
519
+ layer_scale=None,
520
+ kernel_size=3,
521
+ rep_vgg=False):
522
+ super().__init__()
523
+ self.rep_vgg = rep_vgg
524
+ if not rep_vgg:
525
+ self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
526
+ self.act1 = nn.GELU()
527
+ else:
528
+ self.conv1 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
529
+
530
+
531
+ if not rep_vgg:
532
+ self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
533
+ else:
534
+ self.conv2 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
535
+
536
+ self.layer_scale = layer_scale
537
+ if layer_scale is not None and type(layer_scale) in [int, float]:
538
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
539
+ self.layer_scale = True
540
+ else:
541
+ self.layer_scale = False
542
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
543
+
544
+ def forward(self, x):
545
+ input = x
546
+ if not self.rep_vgg:
547
+ x = self.conv1(x)
548
+ x = self.act1(x)
549
+ x = self.conv2(x)
550
+ else:
551
+ x = self.conv1(x)
552
+ x = self.conv2(x)
553
+ if self.layer_scale:
554
+ x = x * self.gamma.view(1, -1, 1, 1)
555
+ x = input + self.drop_path(x)
556
+ return x
557
+
558
+
559
+ class WindowAttention(nn.Module):
560
+ # Windowed Attention from SwinV2
561
+ # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
562
+ # tested multi-querry attention, but it is not as good as full attention:
563
+ # look into palm: https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/palm_pytorch.py
564
+ # single kv attention, mlp in parallel (didnt improve speed)
565
+
566
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
567
+ seq_length=0, dim_out=None, multi_query=False):
568
+ # taken from EdgeViT and tweaked with attention bias.
569
+ super().__init__()
570
+ if not dim_out: dim_out = dim
571
+ self.multi_query = multi_query
572
+ self.num_heads = num_heads
573
+ head_dim = dim // num_heads
574
+ self.head_dim = dim // num_heads
575
+
576
+ self.dim_internal = dim
577
+
578
+ self.scale = qk_scale or head_dim ** -0.5
579
+ if not multi_query:
580
+ if TRT:
581
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
582
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
583
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
584
+ else:
585
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
586
+ else:
587
+ self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
588
+
589
+ self.proj = nn.Linear(dim, dim_out, bias=False)
590
+ # attention positional bias
591
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
592
+ pretrained_window_size=[resolution, resolution],
593
+ num_heads=num_heads,
594
+ seq_length=seq_length)
595
+
596
+ self.resolution = resolution
597
+
598
+ def forward(self, x):
599
+ B, N, C = x.shape
600
+
601
+ if not self.multi_query:
602
+ if TRT:
603
+ q = self.q(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
604
+ k = self.k(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
605
+ v = self.v(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
606
+ else:
607
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
608
+ q, k, v = qkv[0], qkv[1], qkv[2]
609
+ else:
610
+ qkv = self.qkv(x)
611
+ (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
612
+
613
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
614
+ k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
615
+ v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
616
+
617
+ attn = (q @ k.transpose(-2, -1)) * self.scale
618
+
619
+ attn = self.pos_emb_funct(attn)
620
+
621
+ attn = attn.softmax(dim=-1)
622
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
623
+ x = self.proj(x)
624
+ return x
625
+
626
+
627
+
628
+ class FasterViTLayer(nn.Module):
629
+ """
630
+ fastervitlayer
631
+ """
632
+
633
+ def __init__(self,
634
+ dim,
635
+ depth,
636
+ num_heads,
637
+ window_size,
638
+ conv=False,
639
+ downsample=True,
640
+ mlp_ratio=4.,
641
+ qkv_bias=False,
642
+ qk_scale=None,
643
+ norm_layer=nn.LayerNorm,
644
+ drop_path=0.,
645
+ layer_scale=None,
646
+ layer_scale_conv=None,
647
+ sr_dim_ratio=1,
648
+ sr_ratio=1,
649
+ multi_query=False,
650
+ use_swiglu=True,
651
+ rep_vgg=False,
652
+ yolo_arch=False,
653
+ downsample_shuffle=False,
654
+ conv_base=False,
655
+
656
+ ):
657
+ """
658
+ Args:
659
+ dim: feature size dimension.
660
+ depth: number of layers in each stage.
661
+ input_resolution: input image resolution.
662
+ window_size: window size in each stage.
663
+ downsample: bool argument for down-sampling.
664
+ mlp_ratio: MLP ratio.
665
+ num_heads: number of heads in each stage.
666
+ qkv_bias: bool argument for query, key, value learnable bias.
667
+ qk_scale: bool argument to scaling query, key.
668
+ drop: dropout rate.
669
+ attn_drop: attention dropout rate.
670
+ drop_path: drop path rate.
671
+ norm_layer: normalization layer.
672
+ layer_scale: layer scaling coefficient.
673
+ """
674
+
675
+ super().__init__()
676
+ self.conv = conv
677
+ self.yolo_arch=False
678
+ if conv:
679
+ if not yolo_arch:
680
+ self.blocks = nn.ModuleList([
681
+ ConvBlock(dim=dim,
682
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
683
+ layer_scale=layer_scale_conv, rep_vgg=rep_vgg)
684
+ for i in range(depth)])
685
+ else:
686
+ self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
687
+ self.yolo_arch=True
688
+ else:
689
+ if not isinstance(window_size, list): window_size = [window_size]
690
+ self.window_size = window_size[0]
691
+ self.do_single_windowing = True
692
+ if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
693
+ if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
694
+ self.do_single_windowing = False
695
+ do_windowing = True
696
+ else:
697
+ self.do_single_windowing = True
698
+ do_windowing = False
699
+
700
+ self.blocks = nn.ModuleList()
701
+ for i in range(depth):
702
+
703
+ self.blocks.append(
704
+ MultiResolutionAttention(window_size=window_size,
705
+ sr_ratio=sr_ratio,
706
+ dim=dim,
707
+ dim_ratio = sr_dim_ratio,
708
+ num_heads=num_heads,
709
+ norm_layer=norm_layer,
710
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
711
+ layer_scale=layer_scale,
712
+ qkv_bias=qkv_bias,
713
+ qk_scale=qk_scale,
714
+ use_swiglu=use_swiglu,
715
+ do_windowing=do_windowing,
716
+ multi_query=multi_query,
717
+ conv_base=conv_base,
718
+ ))
719
+
720
+ self.transformer = not conv
721
+
722
+
723
+ self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
724
+
725
+
726
+
727
+
728
+ def forward(self, x):
729
+ B, C, H, W = x.shape
730
+
731
+ if self.transformer and self.do_single_windowing:
732
+ H, W = x.shape[2], x.shape[3]
733
+ x, pad_hw = window_partition(x, self.window_size)
734
+
735
+ if not self.yolo_arch:
736
+ for bn, blk in enumerate(self.blocks):
737
+ x = blk(x)
738
+ else:
739
+ x = self.blocks(x)
740
+
741
+ if self.transformer and self.do_single_windowing:
742
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
743
+
744
+
745
+ if self.downsample is None:
746
+ return x, x
747
+
748
+ return self.downsample(x), x #changing to output pre downsampled features
749
+
750
+
751
+ class FasterViT(nn.Module):
752
+ """
753
+ FasterViT
754
+ """
755
+
756
+ def __init__(self,
757
+ dim,
758
+ in_dim,
759
+ depths,
760
+ window_size,
761
+ mlp_ratio,
762
+ num_heads,
763
+ drop_path_rate=0.2,
764
+ in_chans=3,
765
+ num_classes=1000,
766
+ qkv_bias=False,
767
+ qk_scale=None,
768
+ layer_scale=None,
769
+ layer_scale_conv=None,
770
+ layer_norm_last=False,
771
+ sr_ratio = [1, 1, 1, 1],
772
+ max_depth = -1,
773
+ conv_base=False,
774
+ use_swiglu=False,
775
+ multi_query=False,
776
+ norm_layer=nn.LayerNorm,
777
+ rep_vgg=False,
778
+ drop_uniform=False,
779
+ yolo_arch=False,
780
+ shuffle_down=False,
781
+ downsample_shuffle=False,
782
+ return_full_features=False,
783
+ full_features_head_dim=128,
784
+ neck_start_stage=1,
785
+ use_neck=False,
786
+ **kwargs):
787
+ """
788
+ Args:
789
+ dim: feature size dimension.
790
+ depths: number of layers in each stage.
791
+ window_size: window size in each stage.
792
+ mlp_ratio: MLP ratio.
793
+ num_heads: number of heads in each stage.
794
+ drop_path_rate: drop path rate.
795
+ in_chans: number of input channels.
796
+ num_classes: number of classes.
797
+ qkv_bias: bool argument for query, key, value learnable bias.
798
+ qk_scale: bool argument to scaling query, key.
799
+ drop_rate: dropout rate.
800
+ attn_drop_rate: attention dropout rate.
801
+ norm_layer: normalization layer.
802
+ layer_scale: layer scaling coefficient.
803
+ return_full_features: output dense features as well as logits
804
+ full_features_head_dim: number of channels in the dense features head
805
+ neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0
806
+ for 224 resolution, the output of the stage before downsample:
807
+ stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
808
+ use_neck: even for summarization embedding use neck
809
+ """
810
+ super().__init__()
811
+
812
+ num_features = int(dim * 2 ** (len(depths) - 1))
813
+ self.num_classes = num_classes
814
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
815
+ # set return_full_features true if we want to return full features from all stages
816
+ self.return_full_features = return_full_features
817
+ self.use_neck = use_neck
818
+
819
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
820
+ if drop_uniform:
821
+ dpr = [drop_path_rate for x in range(sum(depths))]
822
+
823
+ if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
824
+
825
+ self.levels = nn.ModuleList()
826
+ for i in range(len(depths)):
827
+ conv = True if (i == 0 or i == 1) else False
828
+
829
+ level = FasterViTLayer(dim=int(dim * 2 ** i),
830
+ depth=depths[i],
831
+ num_heads=num_heads[i],
832
+ window_size=window_size[i],
833
+ mlp_ratio=mlp_ratio,
834
+ qkv_bias=qkv_bias,
835
+ qk_scale=qk_scale,
836
+ conv=conv,
837
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
838
+ downsample=(i < 3),
839
+ layer_scale=layer_scale,
840
+ layer_scale_conv=layer_scale_conv,
841
+ sr_ratio=sr_ratio[i],
842
+ use_swiglu=use_swiglu,
843
+ multi_query=multi_query,
844
+ norm_layer=norm_layer,
845
+ rep_vgg=rep_vgg,
846
+ yolo_arch=yolo_arch,
847
+ downsample_shuffle=downsample_shuffle,
848
+ conv_base=conv_base)
849
+
850
+ self.levels.append(level)
851
+
852
+ if self.return_full_features or self.use_neck:
853
+ # create feature projection layers for segmentation output
854
+ self.neck_features_proj = nn.ModuleList()
855
+ self.neck_start_stage = neck_start_stage
856
+ upsample_ratio = 1
857
+ for i in range(len(depths)):
858
+ level_n_features_output = int(dim * 2 ** i)
859
+
860
+ if self.neck_start_stage > i: continue
861
+
862
+ if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
863
+ feature_projection = nn.Sequential()
864
+ # feature_projection.add_module("norm",LayerNorm2d(level_n_features_output)) #slow, but better
865
+
866
+
867
+ if 0 :
868
+ # Train: 0 [1900/10009 ( 19%)] Loss: 6.113 (6.57) Time: 0.548s, 233.40/s (0.549s, 233.04/s) LR: 1.000e-05 Data: 0.015 (0.013)
869
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
870
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
871
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
872
+ else:
873
+ # pixel shuffle based upsampling
874
+ # Train: 0 [1950/10009 ( 19%)] Loss: 6.190 (6.55) Time: 0.540s, 236.85/s (0.548s, 233.38/s) LR: 1.000e-05 Data: 0.015 (0.013)
875
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
876
+ feature_projection.add_module("conv", nn.Conv2d(level_n_features_output,
877
+ full_features_head_dim*upsample_ratio*upsample_ratio, kernel_size=1, stride=1))
878
+ feature_projection.add_module("upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio))
879
+
880
+ else:
881
+ feature_projection = nn.Sequential()
882
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
883
+
884
+
885
+ self.neck_features_proj.append(feature_projection)
886
+
887
+ if i>0 and self.levels[i-1].downsample is not None:
888
+ upsample_ratio *= 2
889
+
890
+
891
+ num_features = full_features_head_dim if (self.return_full_features or self.use_neck) else num_features
892
+
893
+ self.num_features = num_features
894
+
895
+ self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
896
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
897
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
898
+ self.apply(self._init_weights)
899
+ # pass
900
+
901
+ def _init_weights(self, m):
902
+ if isinstance(m, nn.Linear):
903
+ trunc_normal_(m.weight, std=.02)
904
+ if isinstance(m, nn.Linear) and m.bias is not None:
905
+ nn.init.constant_(m.bias, 0)
906
+ elif isinstance(m, nn.LayerNorm):
907
+ nn.init.constant_(m.bias, 0)
908
+ nn.init.constant_(m.weight, 1.0)
909
+ elif isinstance(m, LayerNorm2d):
910
+ nn.init.constant_(m.bias, 0)
911
+ nn.init.constant_(m.weight, 1.0)
912
+ elif isinstance(m, nn.BatchNorm2d):
913
+ nn.init.ones_(m.weight)
914
+ nn.init.zeros_(m.bias)
915
+
916
+ @torch.jit.ignore
917
+ def no_weight_decay_keywords(self):
918
+ return {'rpb'}
919
+
920
+ def forward_features(self, x):
921
+ x = self.patch_embed(x)
922
+ full_features = None
923
+ for il, level in enumerate(self.levels):
924
+ x, pre_downsample_x = level(x)
925
+
926
+ if self.return_full_features or self.use_neck:
927
+ if self.neck_start_stage > il: continue
928
+ if full_features is None:
929
+ full_features = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
930
+ else:
931
+ #upsample torch tensor x to match full_features size, and add to full_features
932
+ feature_projection = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
933
+ if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
934
+ feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
935
+ full_features += feature_projection
936
+
937
+ # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
938
+ x = self.norm(x) # new version for
939
+ x = self.avgpool(x)
940
+ x = torch.flatten(x, 1)
941
+
942
+ if not self.return_full_features:
943
+ return x, None
944
+
945
+ return x, full_features
946
+
947
+ def forward(self, x):
948
+ x, full_features = self.forward_features(x)
949
+ x = self.head(x)
950
+ if full_features is not None:
951
+ return x, full_features
952
+ return x
953
+
954
+ def switch_to_deploy(self):
955
+ '''
956
+ A method to perform model self-compression
957
+ merges BN into conv layers
958
+ converts MLP relative positional bias into precomputed buffers
959
+ '''
960
+ for level in [self.patch_embed, self.levels, self.head]:
961
+ for module in level.modules():
962
+ if hasattr(module, 'switch_to_deploy'):
963
+ module.switch_to_deploy()
964
+
965
+ @register_model
966
+ def fastervit2_small(pretrained=False, **kwargs): #,
967
+ model = FasterViT(depths=[3, 3, 5, 5],
968
+ num_heads=[2, 4, 8, 16],
969
+ window_size=[8, 8, [7, 7], 7],
970
+ dim=96,
971
+ in_dim=64,
972
+ mlp_ratio=4,
973
+ drop_path_rate=0.2,
974
+ sr_ratio=[1, 1, [1, 2], 1],
975
+ use_swiglu=False,
976
+ downsample_shuffle=False,
977
+ yolo_arch=True,
978
+ shuffle_down=False,
979
+ **kwargs)
980
+ if pretrained:
981
+ model.load_state_dict(torch.load(pretrained))
982
+ return model
983
+
984
+ @register_model
985
+ def fastervit2_tiny(pretrained=False, **kwargs): #,
986
+ model = FasterViT(depths=[1, 3, 4, 5],
987
+ num_heads=[2, 4, 8, 16],
988
+ window_size=[8, 8, [7, 7], 7],
989
+ dim=80,
990
+ in_dim=64,
991
+ mlp_ratio=4,
992
+ drop_path_rate=0.2,
993
+ sr_ratio=[1, 1, [2, 1], 1],
994
+ use_swiglu=False,
995
+ downsample_shuffle=False,
996
+ yolo_arch=True,
997
+ shuffle_down=False,
998
+ **kwargs)
999
+ if pretrained:
1000
+ model.load_state_dict(torch.load(pretrained))
1001
+ return model
1002
+
1003
+ @register_model
1004
+ def fastervit2_base(pretrained=False, **kwargs):
1005
+ model = FasterViT(depths=[3, 3, 5, 5],
1006
+ num_heads=[2, 4, 8, 16],
1007
+ window_size=[8, 8, [7, 7], 7],
1008
+ dim=128,
1009
+ in_dim=64,
1010
+ mlp_ratio=4,
1011
+ drop_path_rate=0.2,
1012
+ sr_ratio=[1, 1, [2, 1], 1],
1013
+ use_swiglu=False,
1014
+ yolo_arch=True,
1015
+ shuffle_down=False,
1016
+ conv_base=True,
1017
+ **kwargs)
1018
+ if pretrained:
1019
+ model.load_state_dict(torch.load(pretrained))
1020
+ return model
1021
+
1022
+ @register_model
1023
+ def fastervit2_base_fullres1(pretrained=False, **kwargs):
1024
+ model = FasterViT(depths=[3, 3, 5, 5],
1025
+ num_heads=[2, 4, 8, 16],
1026
+ window_size=[8, 8, [7, 7], 7],
1027
+ dim=128,
1028
+ in_dim=64,
1029
+ mlp_ratio=4,
1030
+ drop_path_rate=0.2,
1031
+ sr_ratio=[1, 1, [2, 1], 1],
1032
+ use_swiglu=False,
1033
+ yolo_arch=True,
1034
+ shuffle_down=False,
1035
+ conv_base=True,
1036
+ use_neck=True,
1037
+ full_features_head_dim=1024,
1038
+ neck_start_stage=2,
1039
+ **kwargs)
1040
+ if pretrained:
1041
+ model.load_state_dict(torch.load(pretrained))
1042
+ return model
1043
+
1044
+ @register_model
1045
+ def fastervit2_base_fullres2(pretrained=False, **kwargs):
1046
+ model = FasterViT(depths=[3, 3, 5, 5],
1047
+ num_heads=[2, 4, 8, 16],
1048
+ window_size=[8, 8, [7, 7], 7],
1049
+ dim=128,
1050
+ in_dim=64,
1051
+ mlp_ratio=4,
1052
+ drop_path_rate=0.2,
1053
+ sr_ratio=[1, 1, [2, 1], 1],
1054
+ use_swiglu=False,
1055
+ yolo_arch=True,
1056
+ shuffle_down=False,
1057
+ conv_base=True,
1058
+ use_neck=True,
1059
+ full_features_head_dim=512,
1060
+ neck_start_stage=1,
1061
+ **kwargs)
1062
+ if pretrained:
1063
+ model.load_state_dict(torch.load(pretrained))
1064
+ return model
1065
+
1066
+ @register_model
1067
+ def fastervit2_base_fullres3(pretrained=False, **kwargs):
1068
+ model = FasterViT(depths=[3, 3, 5, 5],
1069
+ num_heads=[2, 4, 8, 16],
1070
+ window_size=[8, 8, [7, 7], 7],
1071
+ dim=128,
1072
+ in_dim=64,
1073
+ mlp_ratio=4,
1074
+ drop_path_rate=0.2,
1075
+ sr_ratio=[1, 1, [2, 1], 1],
1076
+ use_swiglu=False,
1077
+ yolo_arch=True,
1078
+ shuffle_down=False,
1079
+ conv_base=True,
1080
+ use_neck=True,
1081
+ full_features_head_dim=256,
1082
+ neck_start_stage=1,
1083
+ **kwargs)
1084
+ if pretrained:
1085
+ model.load_state_dict(torch.load(pretrained))
1086
+ return model
1087
+
1088
+ @register_model
1089
+ def fastervit2_base_fullres4(pretrained=False, **kwargs):
1090
+ model = FasterViT(depths=[3, 3, 5, 5],
1091
+ num_heads=[2, 4, 8, 16],
1092
+ window_size=[8, 8, [7, 7], 7],
1093
+ dim=128,
1094
+ in_dim=64,
1095
+ mlp_ratio=4,
1096
+ drop_path_rate=0.2,
1097
+ sr_ratio=[1, 1, [2, 1], 1],
1098
+ use_swiglu=False,
1099
+ yolo_arch=True,
1100
+ shuffle_down=False,
1101
+ conv_base=True,
1102
+ use_neck=True,
1103
+ full_features_head_dim=256,
1104
+ neck_start_stage=2,
1105
+ **kwargs)
1106
+ if pretrained:
1107
+ model.load_state_dict(torch.load(pretrained))
1108
+ return model
1109
+
1110
+ @register_model
1111
+ def fastervit2_base_fullres5(pretrained=False, **kwargs):
1112
+ model = FasterViT(depths=[3, 3, 5, 5],
1113
+ num_heads=[2, 4, 8, 16],
1114
+ window_size=[8, 8, [7, 7], 7],
1115
+ dim=128,
1116
+ in_dim=64,
1117
+ mlp_ratio=4,
1118
+ drop_path_rate=0.2,
1119
+ sr_ratio=[1, 1, [2, 1], 1],
1120
+ use_swiglu=False,
1121
+ yolo_arch=True,
1122
+ shuffle_down=False,
1123
+ conv_base=True,
1124
+ use_neck=True,
1125
+ full_features_head_dim=512,
1126
+ neck_start_stage=2,
1127
+ **kwargs)
1128
+ if pretrained:
1129
+ model.load_state_dict(torch.load(pretrained))
1130
+ return model
1131
+
1132
+ #pyt: 1934, 4202 TRT
1133
+ @register_model
1134
+ def fastervit2_large(pretrained=False, **kwargs):
1135
+ model = FasterViT(depths=[3, 3, 5, 5],
1136
+ num_heads=[2, 4, 8, 16],
1137
+ window_size=[8, 8, [7, 7], 7],
1138
+ dim=128+64,
1139
+ in_dim=64,
1140
+ mlp_ratio=4,
1141
+ drop_path_rate=0.2,
1142
+ sr_ratio=[1, 1, [2, 1], 1],
1143
+ use_swiglu=False,
1144
+ yolo_arch=True,
1145
+ shuffle_down=False,
1146
+ **kwargs)
1147
+ if pretrained:
1148
+ model.load_state_dict(torch.load(pretrained))
1149
+ return model
1150
+
1151
+ @register_model
1152
+ def fastervit2_large_fullres(pretrained=False, **kwargs):
1153
+ model = FasterViT(depths=[3, 3, 5, 5],
1154
+ num_heads=[2, 4, 8, 16],
1155
+ window_size=[None, None, [7, 7], 7],
1156
+ dim=192,
1157
+ in_dim=64,
1158
+ mlp_ratio=4,
1159
+ drop_path_rate=0.,
1160
+ sr_ratio=[1, 1, [2, 1], 1],
1161
+ use_swiglu=False,
1162
+ yolo_arch=True,
1163
+ shuffle_down=False,
1164
+ conv_base=True,
1165
+ use_neck=True,
1166
+ full_features_head_dim=1536,
1167
+ neck_start_stage=2,
1168
+ **kwargs)
1169
+ if pretrained:
1170
+ model.load_state_dict(torch.load(pretrained))
1171
+ return model
1172
+
1173
+ @register_model
1174
+ def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1175
+ model = FasterViT(depths=[3, 3, 5, 5],
1176
+ num_heads=[2, 4, 8, 16],
1177
+ window_size=[None, None, [8, 8], 8],
1178
+ dim=192,
1179
+ in_dim=64,
1180
+ mlp_ratio=4,
1181
+ drop_path_rate=0.,
1182
+ sr_ratio=[1, 1, [2, 1], 1],
1183
+ use_swiglu=False,
1184
+ yolo_arch=True,
1185
+ shuffle_down=False,
1186
+ conv_base=True,
1187
+ use_neck=True,
1188
+ full_features_head_dim=1536,
1189
+ neck_start_stage=2,
1190
+ **kwargs)
1191
+ if pretrained:
1192
+ model.load_state_dict(torch.load(pretrained))
1193
+ return model
1194
+
1195
+ @register_model
1196
+ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1197
+ model = FasterViT(depths=[3, 3, 5, 5],
1198
+ num_heads=[2, 4, 8, 16],
1199
+ window_size=[None, None, [16, 16], 16],
1200
+ dim=192,
1201
+ in_dim=64,
1202
+ mlp_ratio=4,
1203
+ drop_path_rate=0.,
1204
+ sr_ratio=[1, 1, [2, 1], 1],
1205
+ use_swiglu=False,
1206
+ yolo_arch=True,
1207
+ shuffle_down=False,
1208
+ conv_base=True,
1209
+ use_neck=True,
1210
+ full_features_head_dim=1536,
1211
+ neck_start_stage=2,
1212
+ **kwargs)
1213
+ if pretrained:
1214
+ model.load_state_dict(torch.load(pretrained))
1215
+ return model
1216
+
1217
+ @register_model
1218
+ def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1219
+ model = FasterViT(depths=[3, 3, 5, 5],
1220
+ num_heads=[2, 4, 8, 16],
1221
+ window_size=[None, None, [32, 32], 32],
1222
+ dim=192,
1223
+ in_dim=64,
1224
+ mlp_ratio=4,
1225
+ drop_path_rate=0.,
1226
+ sr_ratio=[1, 1, [2, 1], 1],
1227
+ use_swiglu=False,
1228
+ yolo_arch=True,
1229
+ shuffle_down=False,
1230
+ conv_base=True,
1231
+ use_neck=True,
1232
+ full_features_head_dim=1536,
1233
+ neck_start_stage=2,
1234
+ **kwargs)
1235
+ if pretrained:
1236
+ model.load_state_dict(torch.load(pretrained))
1237
+ return model
1238
+
1239
+ #pyt: 897
1240
+ @register_model
1241
+ def fastervit2_xlarge(pretrained=False, **kwargs):
1242
+ model = FasterViT(depths=[3, 3, 5, 5],
1243
+ num_heads=[2, 4, 8, 16],
1244
+ window_size=[8, 8, [7, 7], 7],
1245
+ dim=128+128+64,
1246
+ in_dim=64,
1247
+ mlp_ratio=4,
1248
+ drop_path_rate=0.2,
1249
+ sr_ratio=[1, 1, [2, 1], 1],
1250
+ use_swiglu=False,
1251
+ yolo_arch=True,
1252
+ shuffle_down=False,
1253
+ **kwargs)
1254
+ if pretrained:
1255
+ model.load_state_dict(torch.load(pretrained))
1256
+ return model
1257
+
1258
+
1259
+ #pyt:
1260
+ @register_model
1261
+ def fastervit2_huge(pretrained=False, **kwargs):
1262
+ model = FasterViT(depths=[3, 3, 5, 5],
1263
+ num_heads=[2, 4, 8, 16],
1264
+ window_size=[8, 8, [7, 7], 7],
1265
+ dim=128+128+128+64,
1266
+ in_dim=64,
1267
+ mlp_ratio=4,
1268
+ drop_path_rate=0.2,
1269
+ sr_ratio=[1, 1, [2, 1], 1],
1270
+ use_swiglu=False,
1271
+ yolo_arch=True,
1272
+ shuffle_down=False,
1273
+ **kwargs)
1274
+ if pretrained:
1275
+ model.load_state_dict(torch.load(pretrained))
1276
+ return model
1277
+
1278
+
1279
+ @register_model
1280
+ def fastervit2_xtiny(pretrained=False, **kwargs): #,
1281
+ model = FasterViT(depths=[1, 3, 4, 5],
1282
+ num_heads=[2, 4, 8, 16],
1283
+ window_size=[8, 8, [7, 7], 7],
1284
+ dim=64,
1285
+ in_dim=64,
1286
+ mlp_ratio=4,
1287
+ drop_path_rate=0.1,
1288
+ sr_ratio=[1, 1, [2, 1], 1],
1289
+ use_swiglu=False,
1290
+ downsample_shuffle=False,
1291
+ yolo_arch=True,
1292
+ shuffle_down=False,
1293
+ **kwargs)
1294
+ if pretrained:
1295
+ model.load_state_dict(torch.load(pretrained))
1296
+ return model
1297
+
1298
+
1299
+ @register_model
1300
+ def fastervit2_xxtiny_5(pretrained=False, **kwargs): #,
1301
+ model = FasterViT(depths=[1, 3, 4, 5],
1302
+ num_heads=[2, 4, 8, 16],
1303
+ window_size=[8, 8, [7, 7], 7],
1304
+ dim=48,
1305
+ in_dim=64,
1306
+ mlp_ratio=4,
1307
+ drop_path_rate=0.05,
1308
+ sr_ratio=[1, 1, [2, 1], 1],
1309
+ use_swiglu=False,
1310
+ downsample_shuffle=False,
1311
+ yolo_arch=True,
1312
+ shuffle_down=False,
1313
+ **kwargs)
1314
+ if pretrained:
1315
+ model.load_state_dict(torch.load(pretrained))
1316
+ return model
1317
+
1318
+ @register_model
1319
+ def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1320
+ model = FasterViT(depths=[1, 3, 4, 5],
1321
+ num_heads=[2, 4, 8, 16],
1322
+ window_size=[8, 8, [7, 7], 7],
1323
+ dim=32,
1324
+ in_dim=32,
1325
+ mlp_ratio=4,
1326
+ drop_path_rate=0.0,
1327
+ sr_ratio=[1, 1, [2, 1], 1],
1328
+ use_swiglu=False,
1329
+ downsample_shuffle=False,
1330
+ yolo_arch=True,
1331
+ shuffle_down=False,
1332
+ **kwargs)
1333
+ if pretrained:
1334
+ model.load_state_dict(torch.load(pretrained))
1335
+ return model
1336
+
1337
+
1338
+ @register_model
1339
+ def eradio(pretrained=False, **kwargs):
1340
+ return fastervit2_large_fullres(pretrained=pretrained, **kwargs)
hf_model.py CHANGED
@@ -15,12 +15,70 @@ from collections import namedtuple
15
  from typing import Optional
16
 
17
  from einops import rearrange
 
18
  import torch
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
21
- #from radio.model import create_model_from_args
22
- from radio.input_conditioner import get_default_conditioner, InputConditioner
23
- from .model import eradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  class ERADIOConfig(PretrainedConfig):
 
15
  from typing import Optional
16
 
17
  from einops import rearrange
18
+ from timm.models import VisionTransformer
19
  import torch
20
  from transformers import PretrainedConfig, PreTrainedModel
21
 
22
+
23
+ from .eradio_model import eradio
24
+ from .radio_model import create_model_from_args
25
+ from .radio_model import RADIOModel as RADIOModelBase
26
+ from .input_conditioner import get_default_conditioner, InputConditioner
27
+
28
+
29
+ class RADIOConfig(PretrainedConfig):
30
+ """Pretrained Hugging Face configuration for RADIO models."""
31
+
32
+ def __init__(
33
+ self,
34
+ args: Optional[dict] = None,
35
+ version: Optional[str] = "v1",
36
+ return_summary: Optional[bool] = True,
37
+ return_spatial_features: Optional[bool] = True,
38
+ **kwargs,
39
+ ):
40
+ self.args = args
41
+ self.version = version
42
+ self.return_summary = return_summary
43
+ self.return_spatial_features = return_spatial_features
44
+ super().__init__(**kwargs)
45
+
46
+
47
+ class RADIOModel(PreTrainedModel):
48
+ """Pretrained Hugging Face model for RADIO.
49
+
50
+ This class inherits from PreTrainedModel, which provides
51
+ HuggingFace's functionality for loading and saving models.
52
+ """
53
+
54
+ config_class = RADIOConfig
55
+
56
+ def __init__(self, config):
57
+ super().__init__(config)
58
+
59
+ RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
60
+ args = RADIOArgs(**config.args)
61
+ self.config = config
62
+ model = create_model_from_args(args)
63
+ input_conditioner: InputConditioner = get_default_conditioner()
64
+
65
+ self.radio_model = RADIOModelBase(
66
+ model,
67
+ input_conditioner,
68
+ config.return_summary,
69
+ config.return_spatial_features,
70
+ )
71
+
72
+ @property
73
+ def model(self) -> VisionTransformer:
74
+ return self.radio_model.model
75
+
76
+ @property
77
+ def input_conditioner(self) -> InputConditioner:
78
+ return self.radio_model.input_conditioner
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ return self.radio_model.forward(x)
82
 
83
 
84
  class ERADIOConfig(PretrainedConfig):
input_conditioner.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Union, Tuple
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ norm_t = Union[Tuple[float, float, float], torch.Tensor]
16
+
17
+ class InputConditioner(nn.Module):
18
+ def __init__(self,
19
+ input_scale: float,
20
+ norm_mean: norm_t,
21
+ norm_std: norm_t,
22
+ dtype: torch.dtype = torch.float32,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.dtype = dtype
27
+
28
+ # self.input_scale = input_scale
29
+ self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
30
+ self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ # x = x * self.input_scale
34
+ y = (x - self.norm_mean) / self.norm_std
35
+ return y.to(self.dtype)
36
+
37
+
38
+ def get_default_conditioner():
39
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
40
+
41
+ return InputConditioner(
42
+ input_scale=1.0,
43
+ norm_mean=OPENAI_CLIP_MEAN,
44
+ norm_std=OPENAI_CLIP_STD,
45
+ )
46
+
47
+
48
+ def _to_tensor(v: norm_t):
49
+ return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
radio_model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from timm.models import create_model, VisionTransformer
13
+
14
+ from .enable_cpe_support import enable_cpe
15
+ from .input_conditioner import InputConditioner
16
+
17
+
18
+ class RADIOModel(nn.Module):
19
+ def __init__(
20
+ self,
21
+ model: nn.Module,
22
+ input_conditioner: InputConditioner,
23
+ return_summary: bool,
24
+ return_spatial_features: bool,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.model = model
29
+ self.input_conditioner = input_conditioner
30
+ self.return_summary = return_summary
31
+ self.return_spatial_features = return_spatial_features
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ x = self.input_conditioner(x)
35
+
36
+ y = self.model.forward_features(x)
37
+
38
+ if isinstance(y, (list, tuple)):
39
+ summary, all_feat = y
40
+ elif isinstance(self.model, VisionTransformer):
41
+ patch_gen = getattr(self.model, "patch_generator", None)
42
+ if patch_gen is not None:
43
+ summary = y[:, : patch_gen.num_cls_tokens].flatten(1)
44
+ all_feat = y[:, patch_gen.num_skip :]
45
+ elif self.model.global_pool == "avg":
46
+ summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
47
+ all_feat = y
48
+ else:
49
+ summary = y[:, 0]
50
+ all_feat = y[:, 1:]
51
+ else:
52
+ raise ValueError("Unsupported model type")
53
+
54
+ if self.return_summary and self.return_spatial_features:
55
+ return summary, all_feat
56
+ elif self.return_summary:
57
+ return summary
58
+ return all_feat
59
+
60
+
61
+ def create_model_from_args(args) -> nn.Module:
62
+ in_chans = 3
63
+ if args.in_chans is not None:
64
+ in_chans = args.in_chans
65
+ elif args.input_size is not None:
66
+ in_chans = args.input_size[0]
67
+
68
+ # Skip weight initialization unless it's explicitly requested.
69
+ weight_init = args.model_kwargs.pop("weight_init", "skip")
70
+
71
+ model = create_model(
72
+ args.model,
73
+ pretrained=args.pretrained,
74
+ in_chans=in_chans,
75
+ num_classes=args.num_classes,
76
+ drop_rate=args.drop,
77
+ drop_path_rate=args.drop_path,
78
+ drop_block_rate=args.drop_block,
79
+ global_pool=args.gp,
80
+ bn_momentum=args.bn_momentum,
81
+ bn_eps=args.bn_eps,
82
+ scriptable=args.torchscript,
83
+ checkpoint_path=args.initial_checkpoint,
84
+ weight_init=weight_init,
85
+ **args.model_kwargs,
86
+ )
87
+
88
+ assert (
89
+ not args.cls_token_per_teacher or args.cpe_max_size is not None
90
+ ), "CPE must be enabled for multiple CLS tokens!"
91
+
92
+ if args.cpe_max_size is not None:
93
+ enable_cpe(
94
+ model,
95
+ args.cpe_max_size,
96
+ num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
97
+ register_multiple=args.register_multiple,
98
+ )
99
+
100
+ return model
vit_patch_generator.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ from typing import Union, Tuple, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from einops import rearrange
16
+
17
+ from .cls_token import ClsToken
18
+
19
+ input_dim_t = Union[int, Tuple[int, int]]
20
+
21
+ try:
22
+ # raise ImportError()
23
+ from indirect_grid_sample import indirect_grid_sample
24
+ except ImportError:
25
+ indirect_grid_sample = None
26
+
27
+ class ViTPatchGenerator(nn.Module):
28
+ def __init__(self,
29
+ patch_size: int,
30
+ embed_dim: int,
31
+ input_dims: input_dim_t,
32
+ abs_pos: bool = True,
33
+ normalize_patches: bool = False,
34
+ cls_token: bool = False,
35
+ max_input_dims: Optional[input_dim_t] = None,
36
+ pos_dropout: float = 0.0,
37
+ return_pos_enc: bool = False,
38
+ num_cls_tokens: int = 1,
39
+ register_multiple: int = 0,
40
+ device=None, dtype=None,
41
+ ):
42
+ super().__init__()
43
+
44
+ if isinstance(input_dims, int):
45
+ input_dims = (input_dims, input_dims)
46
+
47
+ if max_input_dims is None:
48
+ max_input_dims = input_dims
49
+ if isinstance(max_input_dims, int):
50
+ max_input_dims = (max_input_dims, max_input_dims)
51
+
52
+ max_input_dims = tuple(
53
+ int(math.ceil(d / patch_size) * patch_size)
54
+ for d in max_input_dims
55
+ )
56
+
57
+ self.cpe_mode = max_input_dims != input_dims
58
+ self.pos_dropout = pos_dropout
59
+ self.return_pos_enc = return_pos_enc
60
+
61
+ factory = dict(device=device, dtype=dtype)
62
+
63
+ self.patch_size = patch_size
64
+ self.abs_pos = abs_pos
65
+ self.embed_dim = embed_dim
66
+
67
+ self.num_rows = max_input_dims[0] // patch_size
68
+ self.num_cols = max_input_dims[1] // patch_size
69
+ self.input_dims = tuple(d // patch_size for d in input_dims)
70
+ self.num_patches = self.num_rows * self.num_cols
71
+ self.max_input_dims = max_input_dims
72
+
73
+ self.im_to_patches = Im2Patches(patch_size)
74
+ self.embedder = ViTPatchLinear(patch_size, embed_dim, **factory)
75
+
76
+ if abs_pos:
77
+ scale = embed_dim ** -0.5
78
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
79
+
80
+ self.cls_token = ClsToken(
81
+ embed_dim,
82
+ num_tokens=num_cls_tokens,
83
+ enabled=cls_token,
84
+ register_multiple=register_multiple,
85
+ )
86
+
87
+ self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ patches = self.embed_patches(x)
91
+ patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
92
+ patches = self.cls_token(patches)
93
+ patches = self.patch_normalizer(patches)
94
+ if self.return_pos_enc:
95
+ return patches, pos_enc
96
+ return patches
97
+
98
+ @property
99
+ def apply_cls_token(self):
100
+ return self.cls_token.enabled
101
+
102
+ @property
103
+ def num_cls_tokens(self):
104
+ return self.cls_token.num_tokens
105
+
106
+ @property
107
+ def num_registers(self):
108
+ return self.cls_token.num_registers
109
+
110
+ @property
111
+ def num_skip(self):
112
+ return self.num_cls_tokens + self.num_registers
113
+
114
+ def no_weight_decay(self):
115
+ return [
116
+ 'pos_embed',
117
+ ]
118
+
119
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
120
+ if self.abs_pos:
121
+ self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
122
+
123
+ def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
124
+ if src_embed.shape != targ_embed.shape:
125
+ src_size = int(math.sqrt(src_embed.shape[1]))
126
+
127
+ assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
128
+
129
+ src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
130
+ src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
131
+ src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
132
+ targ_embed.data.copy_(src_embed)
133
+
134
+ def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
135
+ if src_proj_weight.shape != targ_proj_weight.shape:
136
+ src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
137
+
138
+ assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
139
+
140
+ src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
141
+ src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
142
+ src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
143
+ targ_proj_weight.data.copy_(src_proj_weight)
144
+
145
+ def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
146
+ patches = self.im_to_patches(x)
147
+ patches = self.embedder(patches)
148
+ return patches
149
+
150
+ def apply_pos_enc(self,
151
+ patches: torch.Tensor,
152
+ patch_idxs: Optional[torch.Tensor] = None,
153
+ input_size: Optional[Tuple[int, int]] = None,
154
+ ) -> torch.Tensor:
155
+ if not self.abs_pos:
156
+ return patches
157
+
158
+ pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
159
+
160
+ if self.training and self.pos_dropout > 0:
161
+ keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
162
+ pos_enc_drop = torch.where(keeps, pos_enc, 0)
163
+ else:
164
+ pos_enc_drop = pos_enc
165
+
166
+ return patches + pos_enc_drop, pos_enc
167
+
168
+ def get_pos_enc(self,
169
+ batch_size: int,
170
+ patch_idxs: Optional[torch.Tensor] = None,
171
+ input_size: Optional[Tuple[int, int]] = None,
172
+ ) -> torch.Tensor:
173
+ if input_size is None:
174
+ input_dims = self.input_dims
175
+ else:
176
+ input_dims = tuple(d // self.patch_size for d in input_size)
177
+
178
+ pos_embed = self._get_pos_embeddings(batch_size, input_dims)
179
+
180
+ if patch_idxs is None:
181
+ return pos_embed
182
+
183
+ exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
184
+
185
+ pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
186
+ return pos_embed
187
+
188
+
189
+ def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
190
+ if (self.num_rows, self.num_cols) == input_dims:
191
+ return self.pos_embed
192
+
193
+ pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
194
+
195
+ def window_select(pos_embed):
196
+ if input_dims[0] < pos_embed.shape[-2]:
197
+ pos_embed = pos_embed[..., :input_dims[0], :]
198
+ if input_dims[1] < pos_embed.shape[-1]:
199
+ pos_embed = pos_embed[..., :, :input_dims[1]]
200
+ return pos_embed
201
+
202
+ if self.cpe_mode:
203
+ if self.training:
204
+ min_scale = math.sqrt(0.1)
205
+ scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
206
+ aspect_min = math.log(3 / 4)
207
+ aspect_max = -aspect_min
208
+ aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
209
+
210
+ scale_x = scale * aspect
211
+ scale_y = scale * (1 / aspect)
212
+ scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
213
+
214
+ pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
215
+
216
+ lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
217
+ lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
218
+
219
+ lin_xy = torch.stack([lin_x, lin_y], dim=-1)
220
+
221
+ grid_xy = lin_xy * scale_xy + pos_xy
222
+
223
+ # Convert to [-1, 1] range
224
+ grid_xy.mul_(2).sub_(1)
225
+
226
+ pos_embed = F.grid_sample(
227
+ pos_embed.expand(batch_size, -1, -1, -1),
228
+ grid=grid_xy,
229
+ mode='bilinear',
230
+ padding_mode='zeros',
231
+ align_corners=True,
232
+ )
233
+ else:
234
+ # i_rows, i_cols = input_dims
235
+ # p_rows, p_cols = pos_embed.shape[2:]
236
+ # if i_rows <= p_rows and i_cols <= p_cols:
237
+ # left = (p_cols - i_cols) // 2
238
+ # top = (p_rows - i_rows) // 2
239
+ # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
240
+ # else:
241
+ max_dim = max(input_dims)
242
+ pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
243
+
244
+ pos_embed = window_select(pos_embed)
245
+ else:
246
+ pos_embed = window_select(pos_embed)
247
+
248
+ if pos_embed.shape[-2:] != input_dims:
249
+ pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
250
+
251
+ pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
252
+
253
+ return pos_embed
254
+
255
+
256
+ class Im2Patches(nn.Module):
257
+ def __init__(self, patch_size: int):
258
+ super().__init__()
259
+ self.patch_size = patch_size
260
+
261
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
262
+ if self.patch_size == 1:
263
+ patches = x.flatten(2)
264
+ patches = patches.permute(0, 2, 1)
265
+ return patches
266
+
267
+ py = x.shape[-2] // self.patch_size
268
+ px = x.shape[-1] // self.patch_size
269
+ patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
270
+ py=py, yy=self.patch_size,
271
+ px=px, xx=self.patch_size,
272
+ )
273
+ return patches
274
+
275
+
276
+ class ViTPatchLinear(nn.Linear):
277
+ def __init__(self, patch_size: int, embed_dim: int, **factory):
278
+ super().__init__(
279
+ 3 * (patch_size ** 2),
280
+ embed_dim,
281
+ bias=False,
282
+ **factory
283
+ )
284
+ self.patch_size = patch_size
285
+
286
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
287
+ if self.bias is not None:
288
+ self.bias.data.copy_(state_dict[f'{prefix}bias'])
289
+
290
+ chk_weight = state_dict[f'{prefix}weight']
291
+ if chk_weight.shape != self.weight.shape:
292
+ src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
293
+
294
+ assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
295
+
296
+ chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
297
+ chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
298
+ chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
299
+ self.weight.data.copy_(chk_weight)