fffiloni commited on
Commit
00249c6
1 Parent(s): 9df559f

Upload 7 files

Browse files
xdecoder/backbone/backbone.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import torch.nn as nn
3
+
4
+ from detectron2.modeling import ShapeSpec
5
+
6
+ __all__ = ["Backbone"]
7
+
8
+
9
+ class Backbone(nn.Module):
10
+ """
11
+ Abstract base class for network backbones.
12
+ """
13
+
14
+ def __init__(self):
15
+ """
16
+ The `__init__` method of any subclass can specify its own set of arguments.
17
+ """
18
+ super().__init__()
19
+
20
+ def forward(self):
21
+ """
22
+ Subclasses must override this method, but adhere to the same return type.
23
+
24
+ Returns:
25
+ dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
26
+ """
27
+ pass
28
+
29
+ @property
30
+ def size_divisibility(self) -> int:
31
+ """
32
+ Some backbones require the input height and width to be divisible by a
33
+ specific integer. This is typically true for encoder / decoder type networks
34
+ with lateral connection (e.g., FPN) for which feature maps need to match
35
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
36
+ input size divisibility is required.
37
+ """
38
+ return 0
39
+
40
+ def output_shape(self):
41
+ """
42
+ Returns:
43
+ dict[str->ShapeSpec]
44
+ """
45
+ # this is a backward-compatible default
46
+ return {
47
+ name: ShapeSpec(
48
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
49
+ )
50
+ for name in self._out_features
51
+ }
xdecoder/backbone/build.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+ from .backbone import *
5
+
6
+ def build_backbone(config, **kwargs):
7
+ model_name = config['MODEL']['BACKBONE']['NAME']
8
+ if not is_model(model_name):
9
+ raise ValueError(f'Unkown model: {model_name}')
10
+
11
+ return model_entrypoints(model_name)(config, **kwargs)
xdecoder/backbone/focal.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNet for Semantic Segmentation
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang
6
+ # --------------------------------------------------------
7
+ import math
8
+ import time
9
+ import numpy as np
10
+ import logging
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
+
17
+ from detectron2.utils.file_io import PathManager
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+ from .registry import register_backbone
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class Mlp(nn.Module):
25
+ """ Multilayer perceptron."""
26
+
27
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.fc2 = nn.Linear(hidden_features, out_features)
34
+ self.drop = nn.Dropout(drop)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.act(x)
39
+ x = self.drop(x)
40
+ x = self.fc2(x)
41
+ x = self.drop(x)
42
+ return x
43
+
44
+ class FocalModulation(nn.Module):
45
+ """ Focal Modulation
46
+
47
+ Args:
48
+ dim (int): Number of input channels.
49
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
50
+ focal_level (int): Number of focal levels
51
+ focal_window (int): Focal window size at focal level 1
52
+ focal_factor (int, default=2): Step to increase the focal window
53
+ use_postln (bool, default=False): Whether use post-modulation layernorm
54
+ """
55
+
56
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):
57
+
58
+ super().__init__()
59
+ self.dim = dim
60
+
61
+ # specific args for focalv3
62
+ self.focal_level = focal_level
63
+ self.focal_window = focal_window
64
+ self.focal_factor = focal_factor
65
+ self.use_postln_in_modulation = use_postln_in_modulation
66
+ self.scaling_modulator = scaling_modulator
67
+
68
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
69
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
70
+
71
+ self.act = nn.GELU()
72
+ self.proj = nn.Linear(dim, dim)
73
+ self.proj_drop = nn.Dropout(proj_drop)
74
+ self.focal_layers = nn.ModuleList()
75
+
76
+ if self.use_postln_in_modulation:
77
+ self.ln = nn.LayerNorm(dim)
78
+
79
+ for k in range(self.focal_level):
80
+ kernel_size = self.focal_factor*k + self.focal_window
81
+ self.focal_layers.append(
82
+ nn.Sequential(
83
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
84
+ padding=kernel_size//2, bias=False),
85
+ nn.GELU(),
86
+ )
87
+ )
88
+
89
+ def forward(self, x):
90
+ """ Forward function.
91
+
92
+ Args:
93
+ x: input features with shape of (B, H, W, C)
94
+ """
95
+ B, nH, nW, C = x.shape
96
+ x = self.f(x)
97
+ x = x.permute(0, 3, 1, 2).contiguous()
98
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
99
+
100
+ ctx_all = 0
101
+ for l in range(self.focal_level):
102
+ ctx = self.focal_layers[l](ctx)
103
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
104
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
105
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
106
+
107
+ if self.scaling_modulator:
108
+ ctx_all = ctx_all / (self.focal_level + 1)
109
+
110
+ x_out = q * self.h(ctx_all)
111
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
112
+ if self.use_postln_in_modulation:
113
+ x_out = self.ln(x_out)
114
+ x_out = self.proj(x_out)
115
+ x_out = self.proj_drop(x_out)
116
+ return x_out
117
+
118
+ class FocalModulationBlock(nn.Module):
119
+ """ Focal Modulation Block.
120
+
121
+ Args:
122
+ dim (int): Number of input channels.
123
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
124
+ drop (float, optional): Dropout rate. Default: 0.0
125
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
126
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
127
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
128
+ focal_level (int): number of focal levels
129
+ focal_window (int): focal kernel size at level 1
130
+ """
131
+
132
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
133
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
134
+ focal_level=2, focal_window=9,
135
+ use_postln=False, use_postln_in_modulation=False,
136
+ scaling_modulator=False,
137
+ use_layerscale=False,
138
+ layerscale_value=1e-4):
139
+ super().__init__()
140
+ self.dim = dim
141
+ self.mlp_ratio = mlp_ratio
142
+ self.focal_window = focal_window
143
+ self.focal_level = focal_level
144
+ self.use_postln = use_postln
145
+ self.use_layerscale = use_layerscale
146
+
147
+ self.norm1 = norm_layer(dim)
148
+ self.modulation = FocalModulation(
149
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator
150
+ )
151
+
152
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+ self.norm2 = norm_layer(dim)
154
+ mlp_hidden_dim = int(dim * mlp_ratio)
155
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
156
+
157
+ self.H = None
158
+ self.W = None
159
+
160
+ self.gamma_1 = 1.0
161
+ self.gamma_2 = 1.0
162
+ if self.use_layerscale:
163
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
164
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
165
+
166
+ def forward(self, x):
167
+ """ Forward function.
168
+
169
+ Args:
170
+ x: Input feature, tensor size (B, H*W, C).
171
+ H, W: Spatial resolution of the input feature.
172
+ """
173
+ B, L, C = x.shape
174
+ H, W = self.H, self.W
175
+ assert L == H * W, "input feature has wrong size"
176
+
177
+ shortcut = x
178
+ if not self.use_postln:
179
+ x = self.norm1(x)
180
+ x = x.view(B, H, W, C)
181
+
182
+ # FM
183
+ x = self.modulation(x).view(B, H * W, C)
184
+ if self.use_postln:
185
+ x = self.norm1(x)
186
+
187
+ # FFN
188
+ x = shortcut + self.drop_path(self.gamma_1 * x)
189
+
190
+ if self.use_postln:
191
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
192
+ else:
193
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
194
+
195
+ return x
196
+
197
+ class BasicLayer(nn.Module):
198
+ """ A basic focal modulation layer for one stage.
199
+
200
+ Args:
201
+ dim (int): Number of feature channels
202
+ depth (int): Depths of this stage.
203
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
204
+ drop (float, optional): Dropout rate. Default: 0.0
205
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
206
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
207
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
208
+ focal_level (int): Number of focal levels
209
+ focal_window (int): Focal window size at focal level 1
210
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
211
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
212
+ """
213
+
214
+ def __init__(self,
215
+ dim,
216
+ depth,
217
+ mlp_ratio=4.,
218
+ drop=0.,
219
+ drop_path=0.,
220
+ norm_layer=nn.LayerNorm,
221
+ downsample=None,
222
+ focal_window=9,
223
+ focal_level=2,
224
+ use_conv_embed=False,
225
+ use_postln=False,
226
+ use_postln_in_modulation=False,
227
+ scaling_modulator=False,
228
+ use_layerscale=False,
229
+ use_checkpoint=False
230
+ ):
231
+ super().__init__()
232
+ self.depth = depth
233
+ self.use_checkpoint = use_checkpoint
234
+
235
+ # build blocks
236
+ self.blocks = nn.ModuleList([
237
+ FocalModulationBlock(
238
+ dim=dim,
239
+ mlp_ratio=mlp_ratio,
240
+ drop=drop,
241
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
242
+ focal_window=focal_window,
243
+ focal_level=focal_level,
244
+ use_postln=use_postln,
245
+ use_postln_in_modulation=use_postln_in_modulation,
246
+ scaling_modulator=scaling_modulator,
247
+ use_layerscale=use_layerscale,
248
+ norm_layer=norm_layer)
249
+ for i in range(depth)])
250
+
251
+ # patch merging layer
252
+ if downsample is not None:
253
+ self.downsample = downsample(
254
+ patch_size=2,
255
+ in_chans=dim, embed_dim=2*dim,
256
+ use_conv_embed=use_conv_embed,
257
+ norm_layer=norm_layer,
258
+ is_stem=False
259
+ )
260
+
261
+ else:
262
+ self.downsample = None
263
+
264
+ def forward(self, x, H, W):
265
+ """ Forward function.
266
+
267
+ Args:
268
+ x: Input feature, tensor size (B, H*W, C).
269
+ H, W: Spatial resolution of the input feature.
270
+ """
271
+ for blk in self.blocks:
272
+ blk.H, blk.W = H, W
273
+ if self.use_checkpoint:
274
+ x = checkpoint.checkpoint(blk, x)
275
+ else:
276
+ x = blk(x)
277
+ if self.downsample is not None:
278
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
279
+ x_down = self.downsample(x_reshaped)
280
+ x_down = x_down.flatten(2).transpose(1, 2)
281
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
282
+ return x, H, W, x_down, Wh, Ww
283
+ else:
284
+ return x, H, W, x, H, W
285
+
286
+
287
+ class PatchEmbed(nn.Module):
288
+ """ Image to Patch Embedding
289
+
290
+ Args:
291
+ patch_size (int): Patch token size. Default: 4.
292
+ in_chans (int): Number of input image channels. Default: 3.
293
+ embed_dim (int): Number of linear projection output channels. Default: 96.
294
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
295
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
296
+ is_stem (bool): Is the stem block or not.
297
+ """
298
+
299
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):
300
+ super().__init__()
301
+ patch_size = to_2tuple(patch_size)
302
+ self.patch_size = patch_size
303
+
304
+ self.in_chans = in_chans
305
+ self.embed_dim = embed_dim
306
+
307
+ if use_conv_embed:
308
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
309
+ if is_stem:
310
+ kernel_size = 7; padding = 2; stride = 4
311
+ else:
312
+ kernel_size = 3; padding = 1; stride = 2
313
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
314
+ else:
315
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
316
+
317
+ if norm_layer is not None:
318
+ self.norm = norm_layer(embed_dim)
319
+ else:
320
+ self.norm = None
321
+
322
+ def forward(self, x):
323
+ """Forward function."""
324
+ _, _, H, W = x.size()
325
+ if W % self.patch_size[1] != 0:
326
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
327
+ if H % self.patch_size[0] != 0:
328
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
329
+
330
+ x = self.proj(x) # B C Wh Ww
331
+ if self.norm is not None:
332
+ Wh, Ww = x.size(2), x.size(3)
333
+ x = x.flatten(2).transpose(1, 2)
334
+ x = self.norm(x)
335
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
336
+
337
+ return x
338
+
339
+
340
+ class FocalNet(nn.Module):
341
+ """ FocalNet backbone.
342
+
343
+ Args:
344
+ pretrain_img_size (int): Input image size for training the pretrained model,
345
+ used in absolute postion embedding. Default 224.
346
+ patch_size (int | tuple(int)): Patch size. Default: 4.
347
+ in_chans (int): Number of input image channels. Default: 3.
348
+ embed_dim (int): Number of linear projection output channels. Default: 96.
349
+ depths (tuple[int]): Depths of each Swin Transformer stage.
350
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
351
+ drop_rate (float): Dropout rate.
352
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
353
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
354
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
355
+ out_indices (Sequence[int]): Output from which stages.
356
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
357
+ -1 means not freezing any parameters.
358
+ focal_levels (Sequence[int]): Number of focal levels at four stages
359
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
360
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
361
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
362
+ """
363
+
364
+ def __init__(self,
365
+ pretrain_img_size=1600,
366
+ patch_size=4,
367
+ in_chans=3,
368
+ embed_dim=96,
369
+ depths=[2, 2, 6, 2],
370
+ mlp_ratio=4.,
371
+ drop_rate=0.,
372
+ drop_path_rate=0.2,
373
+ norm_layer=nn.LayerNorm,
374
+ patch_norm=True,
375
+ out_indices=[0, 1, 2, 3],
376
+ frozen_stages=-1,
377
+ focal_levels=[2,2,2,2],
378
+ focal_windows=[9,9,9,9],
379
+ use_conv_embed=False,
380
+ use_postln=False,
381
+ use_postln_in_modulation=False,
382
+ scaling_modulator=False,
383
+ use_layerscale=False,
384
+ use_checkpoint=False,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.pretrain_img_size = pretrain_img_size
389
+ self.num_layers = len(depths)
390
+ self.embed_dim = embed_dim
391
+ self.patch_norm = patch_norm
392
+ self.out_indices = out_indices
393
+ self.frozen_stages = frozen_stages
394
+
395
+ # split image into non-overlapping patches
396
+ self.patch_embed = PatchEmbed(
397
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
398
+ norm_layer=norm_layer if self.patch_norm else None,
399
+ use_conv_embed=use_conv_embed, is_stem=True)
400
+
401
+ self.pos_drop = nn.Dropout(p=drop_rate)
402
+
403
+ # stochastic depth
404
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
405
+
406
+ # build layers
407
+ self.layers = nn.ModuleList()
408
+ for i_layer in range(self.num_layers):
409
+ layer = BasicLayer(
410
+ dim=int(embed_dim * 2 ** i_layer),
411
+ depth=depths[i_layer],
412
+ mlp_ratio=mlp_ratio,
413
+ drop=drop_rate,
414
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
415
+ norm_layer=norm_layer,
416
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
417
+ focal_window=focal_windows[i_layer],
418
+ focal_level=focal_levels[i_layer],
419
+ use_conv_embed=use_conv_embed,
420
+ use_postln=use_postln,
421
+ use_postln_in_modulation=use_postln_in_modulation,
422
+ scaling_modulator=scaling_modulator,
423
+ use_layerscale=use_layerscale,
424
+ use_checkpoint=use_checkpoint)
425
+ self.layers.append(layer)
426
+
427
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
428
+ self.num_features = num_features
429
+
430
+ # add a norm layer for each output
431
+ for i_layer in out_indices:
432
+ layer = norm_layer(num_features[i_layer])
433
+ layer_name = f'norm{i_layer}'
434
+ self.add_module(layer_name, layer)
435
+
436
+ self._freeze_stages()
437
+
438
+ def _freeze_stages(self):
439
+ if self.frozen_stages >= 0:
440
+ self.patch_embed.eval()
441
+ for param in self.patch_embed.parameters():
442
+ param.requires_grad = False
443
+
444
+ if self.frozen_stages >= 2:
445
+ self.pos_drop.eval()
446
+ for i in range(0, self.frozen_stages - 1):
447
+ m = self.layers[i]
448
+ m.eval()
449
+ for param in m.parameters():
450
+ param.requires_grad = False
451
+
452
+ def init_weights(self, pretrained=None):
453
+ """Initialize the weights in backbone.
454
+
455
+ Args:
456
+ pretrained (str, optional): Path to pre-trained weights.
457
+ Defaults to None.
458
+ """
459
+
460
+ def _init_weights(m):
461
+ if isinstance(m, nn.Linear):
462
+ trunc_normal_(m.weight, std=.02)
463
+ if isinstance(m, nn.Linear) and m.bias is not None:
464
+ nn.init.constant_(m.bias, 0)
465
+ elif isinstance(m, nn.LayerNorm):
466
+ nn.init.constant_(m.bias, 0)
467
+ nn.init.constant_(m.weight, 1.0)
468
+
469
+ if isinstance(pretrained, str):
470
+ self.apply(_init_weights)
471
+ logger = get_root_logger()
472
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
473
+ elif pretrained is None:
474
+ self.apply(_init_weights)
475
+ else:
476
+ raise TypeError('pretrained must be a str or None')
477
+
478
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
479
+ model_dict = self.state_dict()
480
+
481
+ missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]
482
+ logger.info(f'=> Missed keys {missed_dict}')
483
+ unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]
484
+ logger.info(f'=> Unexpected keys {unexpected_dict}')
485
+
486
+ pretrained_dict = {
487
+ k: v for k, v in pretrained_dict.items()
488
+ if k in model_dict.keys()
489
+ }
490
+
491
+ need_init_state_dict = {}
492
+ for k, v in pretrained_dict.items():
493
+ need_init = (
494
+ (
495
+ k.split('.')[0] in pretrained_layers
496
+ or pretrained_layers[0] == '*'
497
+ )
498
+ and 'relative_position_index' not in k
499
+ and 'attn_mask' not in k
500
+ )
501
+
502
+ if need_init:
503
+ # if verbose:
504
+ # logger.info(f'=> init {k} from {pretrained}')
505
+
506
+ if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():
507
+ table_pretrained = v
508
+ table_current = model_dict[k]
509
+ fsize1 = table_pretrained.shape[2]
510
+ fsize2 = table_current.shape[2]
511
+
512
+ # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv
513
+ if fsize1 < fsize2:
514
+ table_pretrained_resized = torch.zeros(table_current.shape)
515
+ table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained
516
+ v = table_pretrained_resized
517
+ elif fsize1 > fsize2:
518
+ table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]
519
+ v = table_pretrained_resized
520
+
521
+
522
+ if ("modulation.f" in k or "pre_conv" in k):
523
+ table_pretrained = v
524
+ table_current = model_dict[k]
525
+ if table_pretrained.shape != table_current.shape:
526
+ if len(table_pretrained.shape) == 2:
527
+ dim = table_pretrained.shape[1]
528
+ assert table_current.shape[1] == dim
529
+ L1 = table_pretrained.shape[0]
530
+ L2 = table_current.shape[0]
531
+
532
+ if L1 < L2:
533
+ table_pretrained_resized = torch.zeros(table_current.shape)
534
+ # copy for linear project
535
+ table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]
536
+ # copy for global token gating
537
+ table_pretrained_resized[-1] = table_pretrained[-1]
538
+ # copy for first multiple focal levels
539
+ table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
540
+ # reassign pretrained weights
541
+ v = table_pretrained_resized
542
+ elif L1 > L2:
543
+ raise NotImplementedError
544
+ elif len(table_pretrained.shape) == 1:
545
+ dim = table_pretrained.shape[0]
546
+ L1 = table_pretrained.shape[0]
547
+ L2 = table_current.shape[0]
548
+ if L1 < L2:
549
+ table_pretrained_resized = torch.zeros(table_current.shape)
550
+ # copy for linear project
551
+ table_pretrained_resized[:dim] = table_pretrained[:dim]
552
+ # copy for global token gating
553
+ table_pretrained_resized[-1] = table_pretrained[-1]
554
+ # copy for first multiple focal levels
555
+ # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
556
+ # reassign pretrained weights
557
+ v = table_pretrained_resized
558
+ elif L1 > L2:
559
+ raise NotImplementedError
560
+
561
+ need_init_state_dict[k] = v
562
+
563
+ self.load_state_dict(need_init_state_dict, strict=False)
564
+
565
+
566
+ def forward(self, x):
567
+ """Forward function."""
568
+ tic = time.time()
569
+ x = self.patch_embed(x)
570
+ Wh, Ww = x.size(2), x.size(3)
571
+
572
+ x = x.flatten(2).transpose(1, 2)
573
+ x = self.pos_drop(x)
574
+
575
+ outs = {}
576
+ for i in range(self.num_layers):
577
+ layer = self.layers[i]
578
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
579
+ if i in self.out_indices:
580
+ norm_layer = getattr(self, f'norm{i}')
581
+ x_out = norm_layer(x_out)
582
+
583
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
584
+ outs["res{}".format(i + 2)] = out
585
+
586
+ if len(self.out_indices) == 0:
587
+ outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
588
+
589
+ toc = time.time()
590
+ return outs
591
+
592
+ def train(self, mode=True):
593
+ """Convert the model into training mode while keep layers freezed."""
594
+ super(FocalNet, self).train(mode)
595
+ self._freeze_stages()
596
+
597
+
598
+ class D2FocalNet(FocalNet, Backbone):
599
+ def __init__(self, cfg, input_shape):
600
+
601
+ pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']
602
+ patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']
603
+ in_chans = 3
604
+ embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']
605
+ depths = cfg['BACKBONE']['FOCAL']['DEPTHS']
606
+ mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']
607
+ drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']
608
+ drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']
609
+ norm_layer = nn.LayerNorm
610
+ patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']
611
+ use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']
612
+ out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']
613
+ scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)
614
+
615
+ super().__init__(
616
+ pretrain_img_size,
617
+ patch_size,
618
+ in_chans,
619
+ embed_dim,
620
+ depths,
621
+ mlp_ratio,
622
+ drop_rate,
623
+ drop_path_rate,
624
+ norm_layer,
625
+ patch_norm,
626
+ out_indices,
627
+ focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],
628
+ focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],
629
+ use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],
630
+ use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],
631
+ use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'],
632
+ scaling_modulator=scaling_modulator,
633
+ use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'],
634
+ use_checkpoint=use_checkpoint,
635
+ )
636
+
637
+ self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']
638
+
639
+ self._out_feature_strides = {
640
+ "res2": 4,
641
+ "res3": 8,
642
+ "res4": 16,
643
+ "res5": 32,
644
+ }
645
+ self._out_feature_channels = {
646
+ "res2": self.num_features[0],
647
+ "res3": self.num_features[1],
648
+ "res4": self.num_features[2],
649
+ "res5": self.num_features[3],
650
+ }
651
+
652
+ def forward(self, x):
653
+ """
654
+ Args:
655
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
656
+ Returns:
657
+ dict[str->Tensor]: names and the corresponding features
658
+ """
659
+ assert (
660
+ x.dim() == 4
661
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
662
+ outputs = {}
663
+ y = super().forward(x)
664
+ for k in y.keys():
665
+ if k in self._out_features:
666
+ outputs[k] = y[k]
667
+ return outputs
668
+
669
+ def output_shape(self):
670
+ return {
671
+ name: ShapeSpec(
672
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
673
+ )
674
+ for name in self._out_features
675
+ }
676
+
677
+ @property
678
+ def size_divisibility(self):
679
+ return 32
680
+
681
+ @register_backbone
682
+ def get_focal_backbone(cfg):
683
+ focal = D2FocalNet(cfg['MODEL'], 224)
684
+
685
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
686
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
687
+ logger.info(f'=> init from {filename}')
688
+ with PathManager.open(filename, "rb") as f:
689
+ ckpt = torch.load(f)['model']
690
+ focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
691
+
692
+ return focal
xdecoder/backbone/focal_dw.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNet for Semantic Segmentation
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang
6
+ # --------------------------------------------------------
7
+ import math
8
+ import time
9
+ import numpy as np
10
+ import logging
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
+
17
+ from detectron2.utils.file_io import PathManager
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+ from .registry import register_backbone
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class Mlp(nn.Module):
25
+ """ Multilayer perceptron."""
26
+
27
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.fc2 = nn.Linear(hidden_features, out_features)
34
+ self.drop = nn.Dropout(drop)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.act(x)
39
+ x = self.drop(x)
40
+ x = self.fc2(x)
41
+ x = self.drop(x)
42
+ return x
43
+
44
+ class FocalModulation(nn.Module):
45
+ """ Focal Modulation
46
+
47
+ Args:
48
+ dim (int): Number of input channels.
49
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
50
+ focal_level (int): Number of focal levels
51
+ focal_window (int): Focal window size at focal level 1
52
+ focal_factor (int, default=2): Step to increase the focal window
53
+ use_postln (bool, default=False): Whether use post-modulation layernorm
54
+ """
55
+
56
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):
57
+
58
+ super().__init__()
59
+ self.dim = dim
60
+
61
+ # specific args for focalv3
62
+ self.focal_level = focal_level
63
+ self.focal_window = focal_window
64
+ self.focal_factor = focal_factor
65
+ self.use_postln_in_modulation = use_postln_in_modulation
66
+ self.scaling_modulator = scaling_modulator
67
+
68
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
69
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
70
+
71
+ self.act = nn.GELU()
72
+ self.proj = nn.Linear(dim, dim)
73
+ self.proj_drop = nn.Dropout(proj_drop)
74
+ self.focal_layers = nn.ModuleList()
75
+
76
+ if self.use_postln_in_modulation:
77
+ self.ln = nn.LayerNorm(dim)
78
+
79
+ for k in range(self.focal_level):
80
+ kernel_size = self.focal_factor*k + self.focal_window
81
+ self.focal_layers.append(
82
+ nn.Sequential(
83
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
84
+ padding=kernel_size//2, bias=False),
85
+ nn.GELU(),
86
+ )
87
+ )
88
+
89
+ def forward(self, x):
90
+ """ Forward function.
91
+
92
+ Args:
93
+ x: input features with shape of (B, H, W, C)
94
+ """
95
+ B, nH, nW, C = x.shape
96
+ x = self.f(x)
97
+ x = x.permute(0, 3, 1, 2).contiguous()
98
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
99
+
100
+ ctx_all = 0
101
+ for l in range(self.focal_level):
102
+ ctx = self.focal_layers[l](ctx)
103
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
104
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
105
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
106
+
107
+ if self.scaling_modulator:
108
+ ctx_all = ctx_all / (self.focal_level + 1)
109
+
110
+ x_out = q * self.h(ctx_all)
111
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
112
+ if self.use_postln_in_modulation:
113
+ x_out = self.ln(x_out)
114
+ x_out = self.proj(x_out)
115
+ x_out = self.proj_drop(x_out)
116
+ return x_out
117
+
118
+ class FocalModulationBlock(nn.Module):
119
+ """ Focal Modulation Block.
120
+
121
+ Args:
122
+ dim (int): Number of input channels.
123
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
124
+ drop (float, optional): Dropout rate. Default: 0.0
125
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
126
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
127
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
128
+ focal_level (int): number of focal levels
129
+ focal_window (int): focal kernel size at level 1
130
+ """
131
+
132
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
133
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
134
+ focal_level=2, focal_window=9,
135
+ use_postln=False, use_postln_in_modulation=False,
136
+ scaling_modulator=False,
137
+ use_layerscale=False,
138
+ layerscale_value=1e-4):
139
+ super().__init__()
140
+ self.dim = dim
141
+ self.mlp_ratio = mlp_ratio
142
+ self.focal_window = focal_window
143
+ self.focal_level = focal_level
144
+ self.use_postln = use_postln
145
+ self.use_layerscale = use_layerscale
146
+
147
+ self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
148
+ self.norm1 = norm_layer(dim)
149
+ self.modulation = FocalModulation(
150
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator
151
+ )
152
+
153
+ self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+ self.norm2 = norm_layer(dim)
156
+ mlp_hidden_dim = int(dim * mlp_ratio)
157
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
158
+
159
+ self.H = None
160
+ self.W = None
161
+
162
+ self.gamma_1 = 1.0
163
+ self.gamma_2 = 1.0
164
+ if self.use_layerscale:
165
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
166
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
167
+
168
+ def forward(self, x):
169
+ """ Forward function.
170
+
171
+ Args:
172
+ x: Input feature, tensor size (B, H*W, C).
173
+ H, W: Spatial resolution of the input feature.
174
+ """
175
+ B, L, C = x.shape
176
+ H, W = self.H, self.W
177
+ assert L == H * W, "input feature has wrong size"
178
+
179
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
180
+ x = x + self.dw1(x)
181
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
182
+
183
+ shortcut = x
184
+ if not self.use_postln:
185
+ x = self.norm1(x)
186
+ x = x.view(B, H, W, C)
187
+
188
+ # FM
189
+ x = self.modulation(x).view(B, H * W, C)
190
+ x = shortcut + self.drop_path(self.gamma_1 * x)
191
+ if self.use_postln:
192
+ x = self.norm1(x)
193
+
194
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
195
+ x = x + self.dw2(x)
196
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
197
+
198
+ if not self.use_postln:
199
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
200
+ else:
201
+ x = x + self.drop_path(self.gamma_2 * self.mlp(x))
202
+ x = self.norm2(x)
203
+
204
+ return x
205
+
206
+ class BasicLayer(nn.Module):
207
+ """ A basic focal modulation layer for one stage.
208
+
209
+ Args:
210
+ dim (int): Number of feature channels
211
+ depth (int): Depths of this stage.
212
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
213
+ drop (float, optional): Dropout rate. Default: 0.0
214
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
215
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
216
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
217
+ focal_level (int): Number of focal levels
218
+ focal_window (int): Focal window size at focal level 1
219
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
220
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
221
+ """
222
+
223
+ def __init__(self,
224
+ dim,
225
+ depth,
226
+ mlp_ratio=4.,
227
+ drop=0.,
228
+ drop_path=0.,
229
+ norm_layer=nn.LayerNorm,
230
+ downsample=None,
231
+ focal_window=9,
232
+ focal_level=2,
233
+ use_conv_embed=False,
234
+ use_postln=False,
235
+ use_postln_in_modulation=False,
236
+ scaling_modulator=False,
237
+ use_layerscale=False,
238
+ use_checkpoint=False,
239
+ use_pre_norm=False,
240
+ ):
241
+ super().__init__()
242
+ self.depth = depth
243
+ self.use_checkpoint = use_checkpoint
244
+
245
+ # build blocks
246
+ self.blocks = nn.ModuleList([
247
+ FocalModulationBlock(
248
+ dim=dim,
249
+ mlp_ratio=mlp_ratio,
250
+ drop=drop,
251
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
252
+ focal_window=focal_window,
253
+ focal_level=focal_level,
254
+ use_postln=use_postln,
255
+ use_postln_in_modulation=use_postln_in_modulation,
256
+ scaling_modulator=scaling_modulator,
257
+ use_layerscale=use_layerscale,
258
+ norm_layer=norm_layer)
259
+ for i in range(depth)])
260
+
261
+ # patch merging layer
262
+ if downsample is not None:
263
+ self.downsample = downsample(
264
+ patch_size=2,
265
+ in_chans=dim, embed_dim=2*dim,
266
+ use_conv_embed=use_conv_embed,
267
+ norm_layer=norm_layer,
268
+ is_stem=False,
269
+ use_pre_norm=use_pre_norm
270
+ )
271
+
272
+ else:
273
+ self.downsample = None
274
+
275
+ def forward(self, x, H, W):
276
+ """ Forward function.
277
+
278
+ Args:
279
+ x: Input feature, tensor size (B, H*W, C).
280
+ H, W: Spatial resolution of the input feature.
281
+ """
282
+ for blk in self.blocks:
283
+ blk.H, blk.W = H, W
284
+ if self.use_checkpoint:
285
+ x = checkpoint.checkpoint(blk, x)
286
+ else:
287
+ x = blk(x)
288
+ if self.downsample is not None:
289
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
290
+ x_down = self.downsample(x_reshaped)
291
+ x_down = x_down.flatten(2).transpose(1, 2)
292
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
293
+ return x, H, W, x_down, Wh, Ww
294
+ else:
295
+ return x, H, W, x, H, W
296
+
297
+
298
+ # class PatchEmbed(nn.Module):
299
+ # r""" Image to Patch Embedding
300
+
301
+ # Args:
302
+ # img_size (int): Image size. Default: 224.
303
+ # patch_size (int): Patch token size. Default: 4.
304
+ # in_chans (int): Number of input image channels. Default: 3.
305
+ # embed_dim (int): Number of linear projection output channels. Default: 96.
306
+ # norm_layer (nn.Module, optional): Normalization layer. Default: None
307
+ # """
308
+
309
+ # def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96,
310
+ # use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False):
311
+ # super().__init__()
312
+ # patch_size = to_2tuple(patch_size)
313
+ # patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
314
+ # self.img_size = img_size
315
+ # self.patch_size = patch_size
316
+ # self.patches_resolution = patches_resolution
317
+ # self.num_patches = patches_resolution[0] * patches_resolution[1]
318
+
319
+ # self.in_chans = in_chans
320
+ # self.embed_dim = embed_dim
321
+ # self.use_pre_norm = use_pre_norm
322
+
323
+ # if use_conv_embed:
324
+ # # if we choose to use conv embedding, then we treat the stem and non-stem differently
325
+ # if is_stem:
326
+ # kernel_size = 7; padding = 3; stride = 4
327
+ # else:
328
+ # kernel_size = 3; padding = 1; stride = 2
329
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
330
+ # else:
331
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
332
+
333
+ # if self.use_pre_norm:
334
+ # if norm_layer is not None:
335
+ # self.norm = norm_layer(in_chans)
336
+ # else:
337
+ # self.norm = None
338
+ # else:
339
+ # if norm_layer is not None:
340
+ # self.norm = norm_layer(embed_dim)
341
+ # else:
342
+ # self.norm = None
343
+
344
+ # def forward(self, x):
345
+ # B, C, H, W = x.shape
346
+ # # FIXME look at relaxing size constraints
347
+ # assert H == self.img_size[0] and W == self.img_size[1], \
348
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
349
+
350
+ # if self.use_pre_norm:
351
+ # if self.norm is not None:
352
+ # x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
353
+ # x = self.norm(x).transpose(1, 2).view(B, C, H, W)
354
+ # x = self.proj(x).flatten(2).transpose(1, 2)
355
+ # else:
356
+ # x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
357
+ # if self.norm is not None:
358
+ # x = self.norm(x)
359
+ # return x
360
+
361
+ # def flops(self):
362
+ # Ho, Wo = self.patches_resolution
363
+ # flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
364
+ # if self.norm is not None:
365
+ # flops += Ho * Wo * self.embed_dim
366
+ # return flops
367
+
368
+ class PatchEmbed(nn.Module):
369
+ """ Image to Patch Embedding
370
+
371
+ Args:
372
+ patch_size (int): Patch token size. Default: 4.
373
+ in_chans (int): Number of input image channels. Default: 3.
374
+ embed_dim (int): Number of linear projection output channels. Default: 96.
375
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
376
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
377
+ is_stem (bool): Is the stem block or not.
378
+ """
379
+
380
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False):
381
+ super().__init__()
382
+ patch_size = to_2tuple(patch_size)
383
+ self.patch_size = patch_size
384
+
385
+ self.in_chans = in_chans
386
+ self.embed_dim = embed_dim
387
+ self.use_pre_norm = use_pre_norm
388
+
389
+ if use_conv_embed:
390
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
391
+ if is_stem:
392
+ kernel_size = 7; padding = 3; stride = 4
393
+ else:
394
+ kernel_size = 3; padding = 1; stride = 2
395
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
396
+ else:
397
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
398
+
399
+ if self.use_pre_norm:
400
+ if norm_layer is not None:
401
+ self.norm = norm_layer(in_chans)
402
+ else:
403
+ self.norm = None
404
+ else:
405
+ if norm_layer is not None:
406
+ self.norm = norm_layer(embed_dim)
407
+ else:
408
+ self.norm = None
409
+
410
+ def forward(self, x):
411
+ """Forward function."""
412
+ B, C, H, W = x.size()
413
+ if W % self.patch_size[1] != 0:
414
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
415
+ if H % self.patch_size[0] != 0:
416
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
417
+
418
+ if self.use_pre_norm:
419
+ if self.norm is not None:
420
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
421
+ x = self.norm(x).transpose(1, 2).view(B, C, H, W)
422
+ x = self.proj(x)
423
+ else:
424
+ x = self.proj(x) # B C Wh Ww
425
+ if self.norm is not None:
426
+ Wh, Ww = x.size(2), x.size(3)
427
+ x = x.flatten(2).transpose(1, 2)
428
+ x = self.norm(x)
429
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
430
+
431
+ return x
432
+
433
+
434
+ class FocalNet(nn.Module):
435
+ """ FocalNet backbone.
436
+
437
+ Args:
438
+ pretrain_img_size (int): Input image size for training the pretrained model,
439
+ used in absolute postion embedding. Default 224.
440
+ patch_size (int | tuple(int)): Patch size. Default: 4.
441
+ in_chans (int): Number of input image channels. Default: 3.
442
+ embed_dim (int): Number of linear projection output channels. Default: 96.
443
+ depths (tuple[int]): Depths of each Swin Transformer stage.
444
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
445
+ drop_rate (float): Dropout rate.
446
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
447
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
448
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
449
+ out_indices (Sequence[int]): Output from which stages.
450
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
451
+ -1 means not freezing any parameters.
452
+ focal_levels (Sequence[int]): Number of focal levels at four stages
453
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
454
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
455
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
456
+ """
457
+
458
+ def __init__(self,
459
+ pretrain_img_size=1600,
460
+ patch_size=4,
461
+ in_chans=3,
462
+ embed_dim=96,
463
+ depths=[2, 2, 6, 2],
464
+ mlp_ratio=4.,
465
+ drop_rate=0.,
466
+ drop_path_rate=0.2,
467
+ norm_layer=nn.LayerNorm,
468
+ patch_norm=True,
469
+ out_indices=[0, 1, 2, 3],
470
+ frozen_stages=-1,
471
+ focal_levels=[2,2,2,2],
472
+ focal_windows=[9,9,9,9],
473
+ use_pre_norms=[False, False, False, False],
474
+ use_conv_embed=False,
475
+ use_postln=False,
476
+ use_postln_in_modulation=False,
477
+ scaling_modulator=False,
478
+ use_layerscale=False,
479
+ use_checkpoint=False,
480
+ ):
481
+ super().__init__()
482
+
483
+ self.pretrain_img_size = pretrain_img_size
484
+ self.num_layers = len(depths)
485
+ self.embed_dim = embed_dim
486
+ self.patch_norm = patch_norm
487
+ self.out_indices = out_indices
488
+ self.frozen_stages = frozen_stages
489
+
490
+ # split image into non-overlapping patches
491
+ self.patch_embed = PatchEmbed(
492
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
493
+ norm_layer=norm_layer if self.patch_norm else None,
494
+ use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False)
495
+
496
+ self.pos_drop = nn.Dropout(p=drop_rate)
497
+
498
+ # stochastic depth
499
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
500
+
501
+ # build layers
502
+ self.layers = nn.ModuleList()
503
+ for i_layer in range(self.num_layers):
504
+ layer = BasicLayer(
505
+ dim=int(embed_dim * 2 ** i_layer),
506
+ depth=depths[i_layer],
507
+ mlp_ratio=mlp_ratio,
508
+ drop=drop_rate,
509
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
510
+ norm_layer=norm_layer,
511
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
512
+ focal_window=focal_windows[i_layer],
513
+ focal_level=focal_levels[i_layer],
514
+ use_pre_norm=use_pre_norms[i_layer],
515
+ use_conv_embed=use_conv_embed,
516
+ use_postln=use_postln,
517
+ use_postln_in_modulation=use_postln_in_modulation,
518
+ scaling_modulator=scaling_modulator,
519
+ use_layerscale=use_layerscale,
520
+ use_checkpoint=use_checkpoint)
521
+ self.layers.append(layer)
522
+
523
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
524
+ self.num_features = num_features
525
+ # self.norm = norm_layer(num_features[-1])
526
+
527
+ # add a norm layer for each output
528
+ for i_layer in self.out_indices:
529
+ layer = norm_layer(num_features[i_layer])
530
+ layer_name = f'norm{i_layer}'
531
+ self.add_module(layer_name, layer)
532
+
533
+ self._freeze_stages()
534
+
535
+ def _freeze_stages(self):
536
+ if self.frozen_stages >= 0:
537
+ self.patch_embed.eval()
538
+ for param in self.patch_embed.parameters():
539
+ param.requires_grad = False
540
+
541
+ if self.frozen_stages >= 2:
542
+ self.pos_drop.eval()
543
+ for i in range(0, self.frozen_stages - 1):
544
+ m = self.layers[i]
545
+ m.eval()
546
+ for param in m.parameters():
547
+ param.requires_grad = False
548
+
549
+ def init_weights(self, pretrained=None):
550
+ """Initialize the weights in backbone.
551
+
552
+ Args:
553
+ pretrained (str, optional): Path to pre-trained weights.
554
+ Defaults to None.
555
+ """
556
+
557
+ def _init_weights(m):
558
+ if isinstance(m, nn.Linear):
559
+ trunc_normal_(m.weight, std=.02)
560
+ if isinstance(m, nn.Linear) and m.bias is not None:
561
+ nn.init.constant_(m.bias, 0)
562
+ elif isinstance(m, nn.LayerNorm):
563
+ nn.init.constant_(m.bias, 0)
564
+ nn.init.constant_(m.weight, 1.0)
565
+
566
+ if isinstance(pretrained, str):
567
+ self.apply(_init_weights)
568
+ logger = get_root_logger()
569
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
570
+ elif pretrained is None:
571
+ self.apply(_init_weights)
572
+ else:
573
+ raise TypeError('pretrained must be a str or None')
574
+
575
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
576
+ model_dict = self.state_dict()
577
+
578
+ missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]
579
+ logger.info(f'=> Missed keys {missed_dict}')
580
+ unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]
581
+ logger.info(f'=> Unexpected keys {unexpected_dict}')
582
+
583
+ pretrained_dict = {
584
+ k: v for k, v in pretrained_dict.items()
585
+ if k in model_dict.keys()
586
+ }
587
+
588
+ need_init_state_dict = {}
589
+ for k, v in pretrained_dict.items():
590
+ need_init = (
591
+ (
592
+ k.split('.')[0] in pretrained_layers
593
+ or pretrained_layers[0] == '*'
594
+ )
595
+ and 'relative_position_index' not in k
596
+ and 'attn_mask' not in k
597
+ )
598
+
599
+ if need_init:
600
+ # if verbose:
601
+ # logger.info(f'=> init {k} from {pretrained}')
602
+
603
+ if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():
604
+ table_pretrained = v
605
+ table_current = model_dict[k]
606
+ fsize1 = table_pretrained.shape[2]
607
+ fsize2 = table_current.shape[2]
608
+
609
+ # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv
610
+ if fsize1 < fsize2:
611
+ table_pretrained_resized = torch.zeros(table_current.shape)
612
+ table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained
613
+ v = table_pretrained_resized
614
+ elif fsize1 > fsize2:
615
+ table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]
616
+ v = table_pretrained_resized
617
+
618
+
619
+ if ("modulation.f" in k or "pre_conv" in k):
620
+ table_pretrained = v
621
+ table_current = model_dict[k]
622
+ if table_pretrained.shape != table_current.shape:
623
+ if len(table_pretrained.shape) == 2:
624
+ dim = table_pretrained.shape[1]
625
+ assert table_current.shape[1] == dim
626
+ L1 = table_pretrained.shape[0]
627
+ L2 = table_current.shape[0]
628
+
629
+ if L1 < L2:
630
+ table_pretrained_resized = torch.zeros(table_current.shape)
631
+ # copy for linear project
632
+ table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]
633
+ # copy for global token gating
634
+ table_pretrained_resized[-1] = table_pretrained[-1]
635
+ # copy for first multiple focal levels
636
+ table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
637
+ # reassign pretrained weights
638
+ v = table_pretrained_resized
639
+ elif L1 > L2:
640
+ raise NotImplementedError
641
+ elif len(table_pretrained.shape) == 1:
642
+ dim = table_pretrained.shape[0]
643
+ L1 = table_pretrained.shape[0]
644
+ L2 = table_current.shape[0]
645
+ if L1 < L2:
646
+ table_pretrained_resized = torch.zeros(table_current.shape)
647
+ # copy for linear project
648
+ table_pretrained_resized[:dim] = table_pretrained[:dim]
649
+ # copy for global token gating
650
+ table_pretrained_resized[-1] = table_pretrained[-1]
651
+ # copy for first multiple focal levels
652
+ # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
653
+ # reassign pretrained weights
654
+ v = table_pretrained_resized
655
+ elif L1 > L2:
656
+ raise NotImplementedError
657
+
658
+ need_init_state_dict[k] = v
659
+
660
+ self.load_state_dict(need_init_state_dict, strict=False)
661
+
662
+
663
+ def forward(self, x):
664
+ """Forward function."""
665
+ tic = time.time()
666
+ x = self.patch_embed(x)
667
+ Wh, Ww = x.size(2), x.size(3)
668
+
669
+ x = x.flatten(2).transpose(1, 2)
670
+ x = self.pos_drop(x)
671
+
672
+ outs = {}
673
+ for i in range(self.num_layers):
674
+ layer = self.layers[i]
675
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
676
+ if i in self.out_indices:
677
+ norm_layer = getattr(self, f'norm{i}')
678
+ x_out = norm_layer(x_out)
679
+
680
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
681
+ outs["res{}".format(i + 2)] = out
682
+
683
+ if len(self.out_indices) == 0:
684
+ outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
685
+
686
+ toc = time.time()
687
+ return outs
688
+
689
+ def train(self, mode=True):
690
+ """Convert the model into training mode while keep layers freezed."""
691
+ super(FocalNet, self).train(mode)
692
+ self._freeze_stages()
693
+
694
+
695
+ class D2FocalNet(FocalNet, Backbone):
696
+ def __init__(self, cfg, input_shape):
697
+
698
+ pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']
699
+ patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']
700
+ in_chans = 3
701
+ embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']
702
+ depths = cfg['BACKBONE']['FOCAL']['DEPTHS']
703
+ mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']
704
+ drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']
705
+ drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']
706
+ norm_layer = nn.LayerNorm
707
+ patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']
708
+ use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']
709
+ out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']
710
+ scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)
711
+
712
+ super().__init__(
713
+ pretrain_img_size,
714
+ patch_size,
715
+ in_chans,
716
+ embed_dim,
717
+ depths,
718
+ mlp_ratio,
719
+ drop_rate,
720
+ drop_path_rate,
721
+ norm_layer,
722
+ patch_norm,
723
+ out_indices,
724
+ focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],
725
+ focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],
726
+ use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],
727
+ use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],
728
+ use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'],
729
+ scaling_modulator=scaling_modulator,
730
+ use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'],
731
+ use_checkpoint=use_checkpoint,
732
+ )
733
+
734
+ self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']
735
+
736
+ self._out_feature_strides = {
737
+ "res2": 4,
738
+ "res3": 8,
739
+ "res4": 16,
740
+ "res5": 32,
741
+ }
742
+ self._out_feature_channels = {
743
+ "res2": self.num_features[0],
744
+ "res3": self.num_features[1],
745
+ "res4": self.num_features[2],
746
+ "res5": self.num_features[3],
747
+ }
748
+
749
+ def forward(self, x):
750
+ """
751
+ Args:
752
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
753
+ Returns:
754
+ dict[str->Tensor]: names and the corresponding features
755
+ """
756
+ assert (
757
+ x.dim() == 4
758
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
759
+ outputs = {}
760
+ y = super().forward(x)
761
+ for k in y.keys():
762
+ if k in self._out_features:
763
+ outputs[k] = y[k]
764
+ return outputs
765
+
766
+ def output_shape(self):
767
+ return {
768
+ name: ShapeSpec(
769
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
770
+ )
771
+ for name in self._out_features
772
+ }
773
+
774
+ @property
775
+ def size_divisibility(self):
776
+ return 32
777
+
778
+ @register_backbone
779
+ def get_focal_backbone(cfg):
780
+ focal = D2FocalNet(cfg['MODEL'], 224)
781
+
782
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
783
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
784
+ logger.info(f'=> init from {filename}')
785
+ with PathManager.open(filename, "rb") as f:
786
+ ckpt = torch.load(f)['model']
787
+ focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
788
+
789
+ return focal
xdecoder/backbone/registry.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def register_backbone(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+ _model_entrypoints[model_name] = fn
8
+ return fn
9
+
10
+ def model_entrypoints(model_name):
11
+ return _model_entrypoints[model_name]
12
+
13
+ def is_model(model_name):
14
+ return model_name in _model_entrypoints
xdecoder/backbone/resnet.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import pickle
3
+ import numpy as np
4
+ from typing import Any, Dict
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ from .backbone import Backbone
12
+ from .registry import register_backbone
13
+
14
+ from detectron2.layers import (
15
+ CNNBlockBase,
16
+ Conv2d,
17
+ DeformConv,
18
+ ModulatedDeformConv,
19
+ ShapeSpec,
20
+ get_norm,
21
+ )
22
+ from detectron2.utils.file_io import PathManager
23
+
24
+ __all__ = [
25
+ "ResNetBlockBase",
26
+ "BasicBlock",
27
+ "BottleneckBlock",
28
+ "DeformBottleneckBlock",
29
+ "BasicStem",
30
+ "ResNet",
31
+ "make_stage",
32
+ "get_resnet_backbone",
33
+ ]
34
+
35
+
36
+ class BasicBlock(CNNBlockBase):
37
+ """
38
+ The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
39
+ with two 3x3 conv layers and a projection shortcut if needed.
40
+ """
41
+
42
+ def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
43
+ """
44
+ Args:
45
+ in_channels (int): Number of input channels.
46
+ out_channels (int): Number of output channels.
47
+ stride (int): Stride for the first conv.
48
+ norm (str or callable): normalization for all conv layers.
49
+ See :func:`layers.get_norm` for supported format.
50
+ """
51
+ super().__init__(in_channels, out_channels, stride)
52
+
53
+ if in_channels != out_channels:
54
+ self.shortcut = Conv2d(
55
+ in_channels,
56
+ out_channels,
57
+ kernel_size=1,
58
+ stride=stride,
59
+ bias=False,
60
+ norm=get_norm(norm, out_channels),
61
+ )
62
+ else:
63
+ self.shortcut = None
64
+
65
+ self.conv1 = Conv2d(
66
+ in_channels,
67
+ out_channels,
68
+ kernel_size=3,
69
+ stride=stride,
70
+ padding=1,
71
+ bias=False,
72
+ norm=get_norm(norm, out_channels),
73
+ )
74
+
75
+ self.conv2 = Conv2d(
76
+ out_channels,
77
+ out_channels,
78
+ kernel_size=3,
79
+ stride=1,
80
+ padding=1,
81
+ bias=False,
82
+ norm=get_norm(norm, out_channels),
83
+ )
84
+
85
+ for layer in [self.conv1, self.conv2, self.shortcut]:
86
+ if layer is not None: # shortcut can be None
87
+ weight_init.c2_msra_fill(layer)
88
+
89
+ def forward(self, x):
90
+ out = self.conv1(x)
91
+ out = F.relu_(out)
92
+ out = self.conv2(out)
93
+
94
+ if self.shortcut is not None:
95
+ shortcut = self.shortcut(x)
96
+ else:
97
+ shortcut = x
98
+
99
+ out += shortcut
100
+ out = F.relu_(out)
101
+ return out
102
+
103
+
104
+ class BottleneckBlock(CNNBlockBase):
105
+ """
106
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
107
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
108
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ out_channels,
115
+ *,
116
+ bottleneck_channels,
117
+ stride=1,
118
+ num_groups=1,
119
+ norm="BN",
120
+ stride_in_1x1=False,
121
+ dilation=1,
122
+ ):
123
+ """
124
+ Args:
125
+ bottleneck_channels (int): number of output channels for the 3x3
126
+ "bottleneck" conv layers.
127
+ num_groups (int): number of groups for the 3x3 conv layer.
128
+ norm (str or callable): normalization for all conv layers.
129
+ See :func:`layers.get_norm` for supported format.
130
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
131
+ first 1x1 convolution or the bottleneck 3x3 convolution.
132
+ dilation (int): the dilation rate of the 3x3 conv layer.
133
+ """
134
+ super().__init__(in_channels, out_channels, stride)
135
+
136
+ if in_channels != out_channels:
137
+ self.shortcut = Conv2d(
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size=1,
141
+ stride=stride,
142
+ bias=False,
143
+ norm=get_norm(norm, out_channels),
144
+ )
145
+ else:
146
+ self.shortcut = None
147
+
148
+ # The original MSRA ResNet models have stride in the first 1x1 conv
149
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
150
+ # stride in the 3x3 conv
151
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
152
+
153
+ self.conv1 = Conv2d(
154
+ in_channels,
155
+ bottleneck_channels,
156
+ kernel_size=1,
157
+ stride=stride_1x1,
158
+ bias=False,
159
+ norm=get_norm(norm, bottleneck_channels),
160
+ )
161
+
162
+ self.conv2 = Conv2d(
163
+ bottleneck_channels,
164
+ bottleneck_channels,
165
+ kernel_size=3,
166
+ stride=stride_3x3,
167
+ padding=1 * dilation,
168
+ bias=False,
169
+ groups=num_groups,
170
+ dilation=dilation,
171
+ norm=get_norm(norm, bottleneck_channels),
172
+ )
173
+
174
+ self.conv3 = Conv2d(
175
+ bottleneck_channels,
176
+ out_channels,
177
+ kernel_size=1,
178
+ bias=False,
179
+ norm=get_norm(norm, out_channels),
180
+ )
181
+
182
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
183
+ if layer is not None: # shortcut can be None
184
+ weight_init.c2_msra_fill(layer)
185
+
186
+ # Zero-initialize the last normalization in each residual branch,
187
+ # so that at the beginning, the residual branch starts with zeros,
188
+ # and each residual block behaves like an identity.
189
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
190
+ # "For BN layers, the learnable scaling coefficient γ is initialized
191
+ # to be 1, except for each residual block's last BN
192
+ # where γ is initialized to be 0."
193
+
194
+ # nn.init.constant_(self.conv3.norm.weight, 0)
195
+ # TODO this somehow hurts performance when training GN models from scratch.
196
+ # Add it as an option when we need to use this code to train a backbone.
197
+
198
+ def forward(self, x):
199
+ out = self.conv1(x)
200
+ out = F.relu_(out)
201
+
202
+ out = self.conv2(out)
203
+ out = F.relu_(out)
204
+
205
+ out = self.conv3(out)
206
+
207
+ if self.shortcut is not None:
208
+ shortcut = self.shortcut(x)
209
+ else:
210
+ shortcut = x
211
+
212
+ out += shortcut
213
+ out = F.relu_(out)
214
+ return out
215
+
216
+
217
+ class DeformBottleneckBlock(CNNBlockBase):
218
+ """
219
+ Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
220
+ in the 3x3 convolution.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ in_channels,
226
+ out_channels,
227
+ *,
228
+ bottleneck_channels,
229
+ stride=1,
230
+ num_groups=1,
231
+ norm="BN",
232
+ stride_in_1x1=False,
233
+ dilation=1,
234
+ deform_modulated=False,
235
+ deform_num_groups=1,
236
+ ):
237
+ super().__init__(in_channels, out_channels, stride)
238
+ self.deform_modulated = deform_modulated
239
+
240
+ if in_channels != out_channels:
241
+ self.shortcut = Conv2d(
242
+ in_channels,
243
+ out_channels,
244
+ kernel_size=1,
245
+ stride=stride,
246
+ bias=False,
247
+ norm=get_norm(norm, out_channels),
248
+ )
249
+ else:
250
+ self.shortcut = None
251
+
252
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
253
+
254
+ self.conv1 = Conv2d(
255
+ in_channels,
256
+ bottleneck_channels,
257
+ kernel_size=1,
258
+ stride=stride_1x1,
259
+ bias=False,
260
+ norm=get_norm(norm, bottleneck_channels),
261
+ )
262
+
263
+ if deform_modulated:
264
+ deform_conv_op = ModulatedDeformConv
265
+ # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
266
+ offset_channels = 27
267
+ else:
268
+ deform_conv_op = DeformConv
269
+ offset_channels = 18
270
+
271
+ self.conv2_offset = Conv2d(
272
+ bottleneck_channels,
273
+ offset_channels * deform_num_groups,
274
+ kernel_size=3,
275
+ stride=stride_3x3,
276
+ padding=1 * dilation,
277
+ dilation=dilation,
278
+ )
279
+ self.conv2 = deform_conv_op(
280
+ bottleneck_channels,
281
+ bottleneck_channels,
282
+ kernel_size=3,
283
+ stride=stride_3x3,
284
+ padding=1 * dilation,
285
+ bias=False,
286
+ groups=num_groups,
287
+ dilation=dilation,
288
+ deformable_groups=deform_num_groups,
289
+ norm=get_norm(norm, bottleneck_channels),
290
+ )
291
+
292
+ self.conv3 = Conv2d(
293
+ bottleneck_channels,
294
+ out_channels,
295
+ kernel_size=1,
296
+ bias=False,
297
+ norm=get_norm(norm, out_channels),
298
+ )
299
+
300
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
301
+ if layer is not None: # shortcut can be None
302
+ weight_init.c2_msra_fill(layer)
303
+
304
+ nn.init.constant_(self.conv2_offset.weight, 0)
305
+ nn.init.constant_(self.conv2_offset.bias, 0)
306
+
307
+ def forward(self, x):
308
+ out = self.conv1(x)
309
+ out = F.relu_(out)
310
+
311
+ if self.deform_modulated:
312
+ offset_mask = self.conv2_offset(out)
313
+ offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
314
+ offset = torch.cat((offset_x, offset_y), dim=1)
315
+ mask = mask.sigmoid()
316
+ out = self.conv2(out, offset, mask)
317
+ else:
318
+ offset = self.conv2_offset(out)
319
+ out = self.conv2(out, offset)
320
+ out = F.relu_(out)
321
+
322
+ out = self.conv3(out)
323
+
324
+ if self.shortcut is not None:
325
+ shortcut = self.shortcut(x)
326
+ else:
327
+ shortcut = x
328
+
329
+ out += shortcut
330
+ out = F.relu_(out)
331
+ return out
332
+
333
+
334
+ class BasicStem(CNNBlockBase):
335
+ """
336
+ The standard ResNet stem (layers before the first residual block),
337
+ with a conv, relu and max_pool.
338
+ """
339
+
340
+ def __init__(self, in_channels=3, out_channels=64, norm="BN"):
341
+ """
342
+ Args:
343
+ norm (str or callable): norm after the first conv layer.
344
+ See :func:`layers.get_norm` for supported format.
345
+ """
346
+ super().__init__(in_channels, out_channels, 4)
347
+ self.in_channels = in_channels
348
+ self.conv1 = Conv2d(
349
+ in_channels,
350
+ out_channels,
351
+ kernel_size=7,
352
+ stride=2,
353
+ padding=3,
354
+ bias=False,
355
+ norm=get_norm(norm, out_channels),
356
+ )
357
+ weight_init.c2_msra_fill(self.conv1)
358
+
359
+ def forward(self, x):
360
+ x = self.conv1(x)
361
+ x = F.relu_(x)
362
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
363
+ return x
364
+
365
+
366
+ class ResNet(Backbone):
367
+ """
368
+ Implement :paper:`ResNet`.
369
+ """
370
+
371
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
372
+ """
373
+ Args:
374
+ stem (nn.Module): a stem module
375
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
376
+ each contains multiple :class:`CNNBlockBase`.
377
+ num_classes (None or int): if None, will not perform classification.
378
+ Otherwise, will create a linear layer.
379
+ out_features (list[str]): name of the layers whose outputs should
380
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
381
+ If None, will return the output of the last layer.
382
+ freeze_at (int): The number of stages at the beginning to freeze.
383
+ see :meth:`freeze` for detailed explanation.
384
+ """
385
+ super().__init__()
386
+ self.stem = stem
387
+ self.num_classes = num_classes
388
+
389
+ current_stride = self.stem.stride
390
+ self._out_feature_strides = {"stem": current_stride}
391
+ self._out_feature_channels = {"stem": self.stem.out_channels}
392
+
393
+ self.stage_names, self.stages = [], []
394
+
395
+ if out_features is not None:
396
+ # Avoid keeping unused layers in this module. They consume extra memory
397
+ # and may cause allreduce to fail
398
+ num_stages = max(
399
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
400
+ )
401
+ stages = stages[:num_stages]
402
+ for i, blocks in enumerate(stages):
403
+ assert len(blocks) > 0, len(blocks)
404
+ for block in blocks:
405
+ assert isinstance(block, CNNBlockBase), block
406
+
407
+ name = "res" + str(i + 2)
408
+ stage = nn.Sequential(*blocks)
409
+
410
+ self.add_module(name, stage)
411
+ self.stage_names.append(name)
412
+ self.stages.append(stage)
413
+
414
+ self._out_feature_strides[name] = current_stride = int(
415
+ current_stride * np.prod([k.stride for k in blocks])
416
+ )
417
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
418
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
419
+
420
+ if num_classes is not None:
421
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
422
+ self.linear = nn.Linear(curr_channels, num_classes)
423
+
424
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
425
+ # "The 1000-way fully-connected layer is initialized by
426
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
427
+ nn.init.normal_(self.linear.weight, std=0.01)
428
+ name = "linear"
429
+
430
+ if out_features is None:
431
+ out_features = [name]
432
+ self._out_features = out_features
433
+ assert len(self._out_features)
434
+ children = [x[0] for x in self.named_children()]
435
+ for out_feature in self._out_features:
436
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
437
+ self.freeze(freeze_at)
438
+
439
+ def forward(self, x):
440
+ """
441
+ Args:
442
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
443
+
444
+ Returns:
445
+ dict[str->Tensor]: names and the corresponding features
446
+ """
447
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
448
+ outputs = {}
449
+ x = self.stem(x)
450
+ if "stem" in self._out_features:
451
+ outputs["stem"] = x
452
+ for name, stage in zip(self.stage_names, self.stages):
453
+ x = stage(x)
454
+ if name in self._out_features:
455
+ outputs[name] = x
456
+ if self.num_classes is not None:
457
+ x = self.avgpool(x)
458
+ x = torch.flatten(x, 1)
459
+ x = self.linear(x)
460
+ if "linear" in self._out_features:
461
+ outputs["linear"] = x
462
+ return outputs
463
+
464
+ def output_shape(self):
465
+ return {
466
+ name: ShapeSpec(
467
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
468
+ )
469
+ for name in self._out_features
470
+ }
471
+
472
+ def freeze(self, freeze_at=0):
473
+ """
474
+ Freeze the first several stages of the ResNet. Commonly used in
475
+ fine-tuning.
476
+
477
+ Layers that produce the same feature map spatial size are defined as one
478
+ "stage" by :paper:`FPN`.
479
+
480
+ Args:
481
+ freeze_at (int): number of stages to freeze.
482
+ `1` means freezing the stem. `2` means freezing the stem and
483
+ one residual stage, etc.
484
+
485
+ Returns:
486
+ nn.Module: this ResNet itself
487
+ """
488
+ if freeze_at >= 1:
489
+ self.stem.freeze()
490
+ for idx, stage in enumerate(self.stages, start=2):
491
+ if freeze_at >= idx:
492
+ for block in stage.children():
493
+ block.freeze()
494
+ return self
495
+
496
+ @staticmethod
497
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
498
+ """
499
+ Create a list of blocks of the same type that forms one ResNet stage.
500
+
501
+ Args:
502
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
503
+ stage. A module of this type must not change spatial resolution of inputs unless its
504
+ stride != 1.
505
+ num_blocks (int): number of blocks in this stage
506
+ in_channels (int): input channels of the entire stage.
507
+ out_channels (int): output channels of **every block** in the stage.
508
+ kwargs: other arguments passed to the constructor of
509
+ `block_class`. If the argument name is "xx_per_block", the
510
+ argument is a list of values to be passed to each block in the
511
+ stage. Otherwise, the same argument is passed to every block
512
+ in the stage.
513
+
514
+ Returns:
515
+ list[CNNBlockBase]: a list of block module.
516
+
517
+ Examples:
518
+ ::
519
+ stage = ResNet.make_stage(
520
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
521
+ bottleneck_channels=16, num_groups=1,
522
+ stride_per_block=[2, 1, 1],
523
+ dilations_per_block=[1, 1, 2]
524
+ )
525
+
526
+ Usually, layers that produce the same feature map spatial size are defined as one
527
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
528
+ all be 1.
529
+ """
530
+ blocks = []
531
+ for i in range(num_blocks):
532
+ curr_kwargs = {}
533
+ for k, v in kwargs.items():
534
+ if k.endswith("_per_block"):
535
+ assert len(v) == num_blocks, (
536
+ f"Argument '{k}' of make_stage should have the "
537
+ f"same length as num_blocks={num_blocks}."
538
+ )
539
+ newk = k[: -len("_per_block")]
540
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
541
+ curr_kwargs[newk] = v[i]
542
+ else:
543
+ curr_kwargs[k] = v
544
+
545
+ blocks.append(
546
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
547
+ )
548
+ in_channels = out_channels
549
+ return blocks
550
+
551
+ @staticmethod
552
+ def make_default_stages(depth, block_class=None, **kwargs):
553
+ """
554
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
555
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
556
+ instead for fine-grained customization.
557
+
558
+ Args:
559
+ depth (int): depth of ResNet
560
+ block_class (type): the CNN block class. Has to accept
561
+ `bottleneck_channels` argument for depth > 50.
562
+ By default it is BasicBlock or BottleneckBlock, based on the
563
+ depth.
564
+ kwargs:
565
+ other arguments to pass to `make_stage`. Should not contain
566
+ stride and channels, as they are predefined for each depth.
567
+
568
+ Returns:
569
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
570
+ :class:`ResNet.__init__`.
571
+ """
572
+ num_blocks_per_stage = {
573
+ 18: [2, 2, 2, 2],
574
+ 34: [3, 4, 6, 3],
575
+ 50: [3, 4, 6, 3],
576
+ 101: [3, 4, 23, 3],
577
+ 152: [3, 8, 36, 3],
578
+ }[depth]
579
+ if block_class is None:
580
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
581
+ if depth < 50:
582
+ in_channels = [64, 64, 128, 256]
583
+ out_channels = [64, 128, 256, 512]
584
+ else:
585
+ in_channels = [64, 256, 512, 1024]
586
+ out_channels = [256, 512, 1024, 2048]
587
+ ret = []
588
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
589
+ if depth >= 50:
590
+ kwargs["bottleneck_channels"] = o // 4
591
+ ret.append(
592
+ ResNet.make_stage(
593
+ block_class=block_class,
594
+ num_blocks=n,
595
+ stride_per_block=[s] + [1] * (n - 1),
596
+ in_channels=i,
597
+ out_channels=o,
598
+ **kwargs,
599
+ )
600
+ )
601
+ return ret
602
+
603
+
604
+ ResNetBlockBase = CNNBlockBase
605
+ """
606
+ Alias for backward compatibiltiy.
607
+ """
608
+
609
+
610
+ def make_stage(*args, **kwargs):
611
+ """
612
+ Deprecated alias for backward compatibiltiy.
613
+ """
614
+ return ResNet.make_stage(*args, **kwargs)
615
+
616
+
617
+ def _convert_ndarray_to_tensor(state_dict: Dict[str, Any]) -> None:
618
+ """
619
+ In-place convert all numpy arrays in the state_dict to torch tensor.
620
+ Args:
621
+ state_dict (dict): a state-dict to be loaded to the model.
622
+ Will be modified.
623
+ """
624
+ # model could be an OrderedDict with _metadata attribute
625
+ # (as returned by Pytorch's state_dict()). We should preserve these
626
+ # properties.
627
+ for k in list(state_dict.keys()):
628
+ v = state_dict[k]
629
+ if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
630
+ raise ValueError(
631
+ "Unsupported type found in checkpoint! {}: {}".format(k, type(v))
632
+ )
633
+ if not isinstance(v, torch.Tensor):
634
+ state_dict[k] = torch.from_numpy(v)
635
+
636
+
637
+ @register_backbone
638
+ def get_resnet_backbone(cfg):
639
+ """
640
+ Create a ResNet instance from config.
641
+
642
+ Returns:
643
+ ResNet: a :class:`ResNet` instance.
644
+ """
645
+ res_cfg = cfg['MODEL']['BACKBONE']['RESNETS']
646
+
647
+ # need registration of new blocks/stems?
648
+ norm = res_cfg['NORM']
649
+ stem = BasicStem(
650
+ in_channels=res_cfg['STEM_IN_CHANNELS'],
651
+ out_channels=res_cfg['STEM_OUT_CHANNELS'],
652
+ norm=norm,
653
+ )
654
+
655
+ # fmt: off
656
+ freeze_at = res_cfg['FREEZE_AT']
657
+ out_features = res_cfg['OUT_FEATURES']
658
+ depth = res_cfg['DEPTH']
659
+ num_groups = res_cfg['NUM_GROUPS']
660
+ width_per_group = res_cfg['WIDTH_PER_GROUP']
661
+ bottleneck_channels = num_groups * width_per_group
662
+ in_channels = res_cfg['STEM_OUT_CHANNELS']
663
+ out_channels = res_cfg['RES2_OUT_CHANNELS']
664
+ stride_in_1x1 = res_cfg['STRIDE_IN_1X1']
665
+ res5_dilation = res_cfg['RES5_DILATION']
666
+ deform_on_per_stage = res_cfg['DEFORM_ON_PER_STAGE']
667
+ deform_modulated = res_cfg['DEFORM_MODULATED']
668
+ deform_num_groups = res_cfg['DEFORM_NUM_GROUPS']
669
+ # fmt: on
670
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
671
+
672
+ num_blocks_per_stage = {
673
+ 18: [2, 2, 2, 2],
674
+ 34: [3, 4, 6, 3],
675
+ 50: [3, 4, 6, 3],
676
+ 101: [3, 4, 23, 3],
677
+ 152: [3, 8, 36, 3],
678
+ }[depth]
679
+
680
+ if depth in [18, 34]:
681
+ assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
682
+ assert not any(
683
+ deform_on_per_stage
684
+ ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
685
+ assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
686
+ assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
687
+
688
+ stages = []
689
+
690
+ for idx, stage_idx in enumerate(range(2, 6)):
691
+ # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
692
+ dilation = res5_dilation if stage_idx == 5 else 1
693
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
694
+ stage_kargs = {
695
+ "num_blocks": num_blocks_per_stage[idx],
696
+ "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
697
+ "in_channels": in_channels,
698
+ "out_channels": out_channels,
699
+ "norm": norm,
700
+ }
701
+ # Use BasicBlock for R18 and R34.
702
+ if depth in [18, 34]:
703
+ stage_kargs["block_class"] = BasicBlock
704
+ else:
705
+ stage_kargs["bottleneck_channels"] = bottleneck_channels
706
+ stage_kargs["stride_in_1x1"] = stride_in_1x1
707
+ stage_kargs["dilation"] = dilation
708
+ stage_kargs["num_groups"] = num_groups
709
+ if deform_on_per_stage[idx]:
710
+ stage_kargs["block_class"] = DeformBottleneckBlock
711
+ stage_kargs["deform_modulated"] = deform_modulated
712
+ stage_kargs["deform_num_groups"] = deform_num_groups
713
+ else:
714
+ stage_kargs["block_class"] = BottleneckBlock
715
+ blocks = ResNet.make_stage(**stage_kargs)
716
+ in_channels = out_channels
717
+ out_channels *= 2
718
+ bottleneck_channels *= 2
719
+ stages.append(blocks)
720
+ backbone = ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
721
+
722
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
723
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
724
+ with PathManager.open(filename, "rb") as f:
725
+ ckpt = pickle.load(f, encoding="latin1")['model']
726
+ _convert_ndarray_to_tensor(ckpt)
727
+ ckpt.pop('stem.fc.weight')
728
+ ckpt.pop('stem.fc.bias')
729
+ backbone.load_state_dict(ckpt)
730
+
731
+ return backbone
xdecoder/backbone/swin.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+ import logging
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import Backbone, ShapeSpec
19
+ from detectron2.utils.file_io import PathManager
20
+
21
+ from .registry import register_backbone
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Mlp(nn.Module):
27
+ """Multilayer perceptron."""
28
+
29
+ def __init__(
30
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
31
+ ):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ def window_partition(x, window_size):
50
+ """
51
+ Args:
52
+ x: (B, H, W, C)
53
+ window_size (int): window size
54
+ Returns:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ """
57
+ B, H, W, C = x.shape
58
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
59
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
60
+ return windows
61
+
62
+
63
+ def window_reverse(windows, window_size, H, W):
64
+ """
65
+ Args:
66
+ windows: (num_windows*B, window_size, window_size, C)
67
+ window_size (int): Window size
68
+ H (int): Height of image
69
+ W (int): Width of image
70
+ Returns:
71
+ x: (B, H, W, C)
72
+ """
73
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
74
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
75
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
76
+ return x
77
+
78
+
79
+ class WindowAttention(nn.Module):
80
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
81
+ It supports both of shifted and non-shifted window.
82
+ Args:
83
+ dim (int): Number of input channels.
84
+ window_size (tuple[int]): The height and width of the window.
85
+ num_heads (int): Number of attention heads.
86
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
87
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
88
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
89
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ dim,
95
+ window_size,
96
+ num_heads,
97
+ qkv_bias=True,
98
+ qk_scale=None,
99
+ attn_drop=0.0,
100
+ proj_drop=0.0,
101
+ ):
102
+
103
+ super().__init__()
104
+ self.dim = dim
105
+ self.window_size = window_size # Wh, Ww
106
+ self.num_heads = num_heads
107
+ head_dim = dim // num_heads
108
+ self.scale = qk_scale or head_dim ** -0.5
109
+
110
+ # define a parameter table of relative position bias
111
+ self.relative_position_bias_table = nn.Parameter(
112
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
113
+ ) # 2*Wh-1 * 2*Ww-1, nH
114
+
115
+ # get pair-wise relative position index for each token inside the window
116
+ coords_h = torch.arange(self.window_size[0])
117
+ coords_w = torch.arange(self.window_size[1])
118
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
121
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
122
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
123
+ relative_coords[:, :, 1] += self.window_size[1] - 1
124
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
126
+ self.register_buffer("relative_position_index", relative_position_index)
127
+
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
134
+ self.softmax = nn.Softmax(dim=-1)
135
+
136
+ def forward(self, x, mask=None):
137
+ """Forward function.
138
+ Args:
139
+ x: input features with shape of (num_windows*B, N, C)
140
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
141
+ """
142
+ B_, N, C = x.shape
143
+ qkv = (
144
+ self.qkv(x)
145
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
146
+ .permute(2, 0, 3, 1, 4)
147
+ )
148
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
149
+
150
+ q = q * self.scale
151
+ attn = q @ k.transpose(-2, -1)
152
+
153
+ relative_position_bias = self.relative_position_bias_table[
154
+ self.relative_position_index.view(-1)
155
+ ].view(
156
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
157
+ ) # Wh*Ww,Wh*Ww,nH
158
+ relative_position_bias = relative_position_bias.permute(
159
+ 2, 0, 1
160
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
161
+ attn = attn + relative_position_bias.unsqueeze(0)
162
+
163
+ if mask is not None:
164
+ nW = mask.shape[0]
165
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
166
+ attn = attn.view(-1, self.num_heads, N, N)
167
+ attn = self.softmax(attn)
168
+ else:
169
+ attn = self.softmax(attn)
170
+
171
+ attn = self.attn_drop(attn)
172
+
173
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
174
+ x = self.proj(x)
175
+ x = self.proj_drop(x)
176
+
177
+ return x
178
+
179
+
180
+ class SwinTransformerBlock(nn.Module):
181
+ """Swin Transformer Block.
182
+ Args:
183
+ dim (int): Number of input channels.
184
+ num_heads (int): Number of attention heads.
185
+ window_size (int): Window size.
186
+ shift_size (int): Shift size for SW-MSA.
187
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
188
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
189
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
190
+ drop (float, optional): Dropout rate. Default: 0.0
191
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
192
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
193
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
194
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ dim,
200
+ num_heads,
201
+ window_size=7,
202
+ shift_size=0,
203
+ mlp_ratio=4.0,
204
+ qkv_bias=True,
205
+ qk_scale=None,
206
+ drop=0.0,
207
+ attn_drop=0.0,
208
+ drop_path=0.0,
209
+ act_layer=nn.GELU,
210
+ norm_layer=nn.LayerNorm,
211
+ ):
212
+ super().__init__()
213
+ self.dim = dim
214
+ self.num_heads = num_heads
215
+ self.window_size = window_size
216
+ self.shift_size = shift_size
217
+ self.mlp_ratio = mlp_ratio
218
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
219
+
220
+ self.norm1 = norm_layer(dim)
221
+ self.attn = WindowAttention(
222
+ dim,
223
+ window_size=to_2tuple(self.window_size),
224
+ num_heads=num_heads,
225
+ qkv_bias=qkv_bias,
226
+ qk_scale=qk_scale,
227
+ attn_drop=attn_drop,
228
+ proj_drop=drop,
229
+ )
230
+
231
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
232
+ self.norm2 = norm_layer(dim)
233
+ mlp_hidden_dim = int(dim * mlp_ratio)
234
+ self.mlp = Mlp(
235
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
236
+ )
237
+
238
+ self.H = None
239
+ self.W = None
240
+
241
+ def forward(self, x, mask_matrix):
242
+ """Forward function.
243
+ Args:
244
+ x: Input feature, tensor size (B, H*W, C).
245
+ H, W: Spatial resolution of the input feature.
246
+ mask_matrix: Attention mask for cyclic shift.
247
+ """
248
+ B, L, C = x.shape
249
+ H, W = self.H, self.W
250
+ assert L == H * W, "input feature has wrong size"
251
+
252
+ # HACK model will not upsampling
253
+ # if min([H, W]) <= self.window_size:
254
+ # if window size is larger than input resolution, we don't partition windows
255
+ # self.shift_size = 0
256
+ # self.window_size = min([H,W])
257
+
258
+ shortcut = x
259
+ x = self.norm1(x)
260
+ x = x.view(B, H, W, C)
261
+
262
+ # pad feature maps to multiples of window size
263
+ pad_l = pad_t = 0
264
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
265
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
266
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
267
+ _, Hp, Wp, _ = x.shape
268
+
269
+ # cyclic shift
270
+ if self.shift_size > 0:
271
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
272
+ attn_mask = mask_matrix
273
+ else:
274
+ shifted_x = x
275
+ attn_mask = None
276
+
277
+ # partition windows
278
+ x_windows = window_partition(
279
+ shifted_x, self.window_size
280
+ ) # nW*B, window_size, window_size, C
281
+ x_windows = x_windows.view(
282
+ -1, self.window_size * self.window_size, C
283
+ ) # nW*B, window_size*window_size, C
284
+
285
+ # W-MSA/SW-MSA
286
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
287
+
288
+ # merge windows
289
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
290
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
291
+
292
+ # reverse cyclic shift
293
+ if self.shift_size > 0:
294
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
295
+ else:
296
+ x = shifted_x
297
+
298
+ if pad_r > 0 or pad_b > 0:
299
+ x = x[:, :H, :W, :].contiguous()
300
+
301
+ x = x.view(B, H * W, C)
302
+
303
+ # FFN
304
+ x = shortcut + self.drop_path(x)
305
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
306
+ return x
307
+
308
+
309
+ class PatchMerging(nn.Module):
310
+ """Patch Merging Layer
311
+ Args:
312
+ dim (int): Number of input channels.
313
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
314
+ """
315
+
316
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
320
+ self.norm = norm_layer(4 * dim)
321
+
322
+ def forward(self, x, H, W):
323
+ """Forward function.
324
+ Args:
325
+ x: Input feature, tensor size (B, H*W, C).
326
+ H, W: Spatial resolution of the input feature.
327
+ """
328
+ B, L, C = x.shape
329
+ assert L == H * W, "input feature has wrong size"
330
+
331
+ x = x.view(B, H, W, C)
332
+
333
+ # padding
334
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
335
+ if pad_input:
336
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
337
+
338
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
339
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
340
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
341
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
342
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
343
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
344
+
345
+ x = self.norm(x)
346
+ x = self.reduction(x)
347
+
348
+ return x
349
+
350
+
351
+ class BasicLayer(nn.Module):
352
+ """A basic Swin Transformer layer for one stage.
353
+ Args:
354
+ dim (int): Number of feature channels
355
+ depth (int): Depths of this stage.
356
+ num_heads (int): Number of attention head.
357
+ window_size (int): Local window size. Default: 7.
358
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
361
+ drop (float, optional): Dropout rate. Default: 0.0
362
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
363
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
364
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
365
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
366
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ dim,
372
+ depth,
373
+ num_heads,
374
+ window_size=7,
375
+ mlp_ratio=4.0,
376
+ qkv_bias=True,
377
+ qk_scale=None,
378
+ drop=0.0,
379
+ attn_drop=0.0,
380
+ drop_path=0.0,
381
+ norm_layer=nn.LayerNorm,
382
+ downsample=None,
383
+ use_checkpoint=False,
384
+ ):
385
+ super().__init__()
386
+ self.window_size = window_size
387
+ self.shift_size = window_size // 2
388
+ self.depth = depth
389
+ self.use_checkpoint = use_checkpoint
390
+
391
+ # build blocks
392
+ self.blocks = nn.ModuleList(
393
+ [
394
+ SwinTransformerBlock(
395
+ dim=dim,
396
+ num_heads=num_heads,
397
+ window_size=window_size,
398
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
399
+ mlp_ratio=mlp_ratio,
400
+ qkv_bias=qkv_bias,
401
+ qk_scale=qk_scale,
402
+ drop=drop,
403
+ attn_drop=attn_drop,
404
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
405
+ norm_layer=norm_layer,
406
+ )
407
+ for i in range(depth)
408
+ ]
409
+ )
410
+
411
+ # patch merging layer
412
+ if downsample is not None:
413
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
414
+ else:
415
+ self.downsample = None
416
+
417
+ def forward(self, x, H, W):
418
+ """Forward function.
419
+ Args:
420
+ x: Input feature, tensor size (B, H*W, C).
421
+ H, W: Spatial resolution of the input feature.
422
+ """
423
+
424
+ # calculate attention mask for SW-MSA
425
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
426
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
427
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
428
+ h_slices = (
429
+ slice(0, -self.window_size),
430
+ slice(-self.window_size, -self.shift_size),
431
+ slice(-self.shift_size, None),
432
+ )
433
+ w_slices = (
434
+ slice(0, -self.window_size),
435
+ slice(-self.window_size, -self.shift_size),
436
+ slice(-self.shift_size, None),
437
+ )
438
+ cnt = 0
439
+ for h in h_slices:
440
+ for w in w_slices:
441
+ img_mask[:, h, w, :] = cnt
442
+ cnt += 1
443
+
444
+ mask_windows = window_partition(
445
+ img_mask, self.window_size
446
+ ) # nW, window_size, window_size, 1
447
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
448
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
449
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
450
+ attn_mask == 0, float(0.0)
451
+ ).type(x.dtype)
452
+
453
+ for blk in self.blocks:
454
+ blk.H, blk.W = H, W
455
+ if self.use_checkpoint:
456
+ x = checkpoint.checkpoint(blk, x, attn_mask)
457
+ else:
458
+ x = blk(x, attn_mask)
459
+ if self.downsample is not None:
460
+ x_down = self.downsample(x, H, W)
461
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
462
+ return x, H, W, x_down, Wh, Ww
463
+ else:
464
+ return x, H, W, x, H, W
465
+
466
+
467
+ class PatchEmbed(nn.Module):
468
+ """Image to Patch Embedding
469
+ Args:
470
+ patch_size (int): Patch token size. Default: 4.
471
+ in_chans (int): Number of input image channels. Default: 3.
472
+ embed_dim (int): Number of linear projection output channels. Default: 96.
473
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
474
+ """
475
+
476
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
477
+ super().__init__()
478
+ patch_size = to_2tuple(patch_size)
479
+ self.patch_size = patch_size
480
+
481
+ self.in_chans = in_chans
482
+ self.embed_dim = embed_dim
483
+
484
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
485
+ if norm_layer is not None:
486
+ self.norm = norm_layer(embed_dim)
487
+ else:
488
+ self.norm = None
489
+
490
+ def forward(self, x):
491
+ """Forward function."""
492
+ # padding
493
+ _, _, H, W = x.size()
494
+ if W % self.patch_size[1] != 0:
495
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
496
+ if H % self.patch_size[0] != 0:
497
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
498
+
499
+ x = self.proj(x) # B C Wh Ww
500
+ if self.norm is not None:
501
+ Wh, Ww = x.size(2), x.size(3)
502
+ x = x.flatten(2).transpose(1, 2)
503
+ x = self.norm(x)
504
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
505
+
506
+ return x
507
+
508
+
509
+ class SwinTransformer(nn.Module):
510
+ """Swin Transformer backbone.
511
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
512
+ https://arxiv.org/pdf/2103.14030
513
+ Args:
514
+ pretrain_img_size (int): Input image size for training the pretrained model,
515
+ used in absolute postion embedding. Default 224.
516
+ patch_size (int | tuple(int)): Patch size. Default: 4.
517
+ in_chans (int): Number of input image channels. Default: 3.
518
+ embed_dim (int): Number of linear projection output channels. Default: 96.
519
+ depths (tuple[int]): Depths of each Swin Transformer stage.
520
+ num_heads (tuple[int]): Number of attention head of each stage.
521
+ window_size (int): Window size. Default: 7.
522
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
523
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
524
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
525
+ drop_rate (float): Dropout rate.
526
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
527
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
528
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
529
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
530
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
531
+ out_indices (Sequence[int]): Output from which stages.
532
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
533
+ -1 means not freezing any parameters.
534
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
535
+ """
536
+
537
+ def __init__(
538
+ self,
539
+ pretrain_img_size=224,
540
+ patch_size=4,
541
+ in_chans=3,
542
+ embed_dim=96,
543
+ depths=[2, 2, 6, 2],
544
+ num_heads=[3, 6, 12, 24],
545
+ window_size=7,
546
+ mlp_ratio=4.0,
547
+ qkv_bias=True,
548
+ qk_scale=None,
549
+ drop_rate=0.0,
550
+ attn_drop_rate=0.0,
551
+ drop_path_rate=0.2,
552
+ norm_layer=nn.LayerNorm,
553
+ ape=False,
554
+ patch_norm=True,
555
+ out_indices=(0, 1, 2, 3),
556
+ frozen_stages=-1,
557
+ use_checkpoint=False,
558
+ ):
559
+ super().__init__()
560
+
561
+ self.pretrain_img_size = pretrain_img_size
562
+ self.num_layers = len(depths)
563
+ self.embed_dim = embed_dim
564
+ self.ape = ape
565
+ self.patch_norm = patch_norm
566
+ self.out_indices = out_indices
567
+ self.frozen_stages = frozen_stages
568
+
569
+ # split image into non-overlapping patches
570
+ self.patch_embed = PatchEmbed(
571
+ patch_size=patch_size,
572
+ in_chans=in_chans,
573
+ embed_dim=embed_dim,
574
+ norm_layer=norm_layer if self.patch_norm else None,
575
+ )
576
+
577
+ # absolute position embedding
578
+ if self.ape:
579
+ pretrain_img_size = to_2tuple(pretrain_img_size)
580
+ patch_size = to_2tuple(patch_size)
581
+ patches_resolution = [
582
+ pretrain_img_size[0] // patch_size[0],
583
+ pretrain_img_size[1] // patch_size[1],
584
+ ]
585
+
586
+ self.absolute_pos_embed = nn.Parameter(
587
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
588
+ )
589
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
590
+
591
+ self.pos_drop = nn.Dropout(p=drop_rate)
592
+
593
+ # stochastic depth
594
+ dpr = [
595
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
596
+ ] # stochastic depth decay rule
597
+
598
+ # build layers
599
+ self.layers = nn.ModuleList()
600
+ for i_layer in range(self.num_layers):
601
+ layer = BasicLayer(
602
+ dim=int(embed_dim * 2 ** i_layer),
603
+ depth=depths[i_layer],
604
+ num_heads=num_heads[i_layer],
605
+ window_size=window_size,
606
+ mlp_ratio=mlp_ratio,
607
+ qkv_bias=qkv_bias,
608
+ qk_scale=qk_scale,
609
+ drop=drop_rate,
610
+ attn_drop=attn_drop_rate,
611
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
612
+ norm_layer=norm_layer,
613
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
614
+ use_checkpoint=use_checkpoint,
615
+ )
616
+ self.layers.append(layer)
617
+
618
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
619
+ self.num_features = num_features
620
+
621
+ # add a norm layer for each output
622
+ for i_layer in out_indices:
623
+ layer = norm_layer(num_features[i_layer])
624
+ layer_name = f"norm{i_layer}"
625
+ self.add_module(layer_name, layer)
626
+
627
+ self._freeze_stages()
628
+
629
+ def _freeze_stages(self):
630
+ if self.frozen_stages >= 0:
631
+ self.patch_embed.eval()
632
+ for param in self.patch_embed.parameters():
633
+ param.requires_grad = False
634
+
635
+ if self.frozen_stages >= 1 and self.ape:
636
+ self.absolute_pos_embed.requires_grad = False
637
+
638
+ if self.frozen_stages >= 2:
639
+ self.pos_drop.eval()
640
+ for i in range(0, self.frozen_stages - 1):
641
+ m = self.layers[i]
642
+ m.eval()
643
+ for param in m.parameters():
644
+ param.requires_grad = False
645
+
646
+ def init_weights(self, pretrained=None):
647
+ """Initialize the weights in backbone.
648
+ Args:
649
+ pretrained (str, optional): Path to pre-trained weights.
650
+ Defaults to None.
651
+ """
652
+
653
+ def _init_weights(m):
654
+ if isinstance(m, nn.Linear):
655
+ trunc_normal_(m.weight, std=0.02)
656
+ if isinstance(m, nn.Linear) and m.bias is not None:
657
+ nn.init.constant_(m.bias, 0)
658
+ elif isinstance(m, nn.LayerNorm):
659
+ nn.init.constant_(m.bias, 0)
660
+ nn.init.constant_(m.weight, 1.0)
661
+
662
+
663
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
664
+ model_dict = self.state_dict()
665
+ pretrained_dict = {
666
+ k: v for k, v in pretrained_dict.items()
667
+ if k in model_dict.keys()
668
+ }
669
+ need_init_state_dict = {}
670
+ for k, v in pretrained_dict.items():
671
+ need_init = (
672
+ (
673
+ k.split('.')[0] in pretrained_layers
674
+ or pretrained_layers[0] == '*'
675
+ )
676
+ and 'relative_position_index' not in k
677
+ and 'attn_mask' not in k
678
+ )
679
+
680
+ if need_init:
681
+ # if verbose:
682
+ # logger.info(f'=> init {k} from {pretrained}')
683
+
684
+ if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
685
+ relative_position_bias_table_pretrained = v
686
+ relative_position_bias_table_current = model_dict[k]
687
+ L1, nH1 = relative_position_bias_table_pretrained.size()
688
+ L2, nH2 = relative_position_bias_table_current.size()
689
+ if nH1 != nH2:
690
+ logger.info(f"Error in loading {k}, passing")
691
+ else:
692
+ if L1 != L2:
693
+ logger.info(
694
+ '=> load_pretrained: resized variant: {} to {}'
695
+ .format((L1, nH1), (L2, nH2))
696
+ )
697
+ S1 = int(L1 ** 0.5)
698
+ S2 = int(L2 ** 0.5)
699
+ relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
700
+ relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
701
+ size=(S2, S2),
702
+ mode='bicubic')
703
+ v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
704
+
705
+ if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
706
+ absolute_pos_embed_pretrained = v
707
+ absolute_pos_embed_current = model_dict[k]
708
+ _, L1, C1 = absolute_pos_embed_pretrained.size()
709
+ _, L2, C2 = absolute_pos_embed_current.size()
710
+ if C1 != C1:
711
+ logger.info(f"Error in loading {k}, passing")
712
+ else:
713
+ if L1 != L2:
714
+ logger.info(
715
+ '=> load_pretrained: resized variant: {} to {}'
716
+ .format((1, L1, C1), (1, L2, C2))
717
+ )
718
+ S1 = int(L1 ** 0.5)
719
+ S2 = int(L2 ** 0.5)
720
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
721
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
722
+ absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
723
+ absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
724
+ v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
725
+
726
+ need_init_state_dict[k] = v
727
+ self.load_state_dict(need_init_state_dict, strict=False)
728
+
729
+
730
+ def forward(self, x):
731
+ """Forward function."""
732
+ x = self.patch_embed(x)
733
+
734
+ Wh, Ww = x.size(2), x.size(3)
735
+ if self.ape:
736
+ # interpolate the position embedding to the corresponding size
737
+ absolute_pos_embed = F.interpolate(
738
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
739
+ )
740
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
741
+ else:
742
+ x = x.flatten(2).transpose(1, 2)
743
+ x = self.pos_drop(x)
744
+
745
+ outs = {}
746
+ for i in range(self.num_layers):
747
+ layer = self.layers[i]
748
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
749
+
750
+ if i in self.out_indices:
751
+ norm_layer = getattr(self, f"norm{i}")
752
+ x_out = norm_layer(x_out)
753
+
754
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
755
+ outs["res{}".format(i + 2)] = out
756
+
757
+ if len(self.out_indices) == 0:
758
+ outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
759
+
760
+
761
+ return outs
762
+
763
+ def train(self, mode=True):
764
+ """Convert the model into training mode while keep layers freezed."""
765
+ super(SwinTransformer, self).train(mode)
766
+ self._freeze_stages()
767
+
768
+
769
+ class D2SwinTransformer(SwinTransformer, Backbone):
770
+ def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim,
771
+ depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale,
772
+ drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape,
773
+ patch_norm, out_indices, use_checkpoint):
774
+ super().__init__(
775
+ pretrain_img_size,
776
+ patch_size,
777
+ in_chans,
778
+ embed_dim,
779
+ depths,
780
+ num_heads,
781
+ window_size,
782
+ mlp_ratio,
783
+ qkv_bias,
784
+ qk_scale,
785
+ drop_rate,
786
+ attn_drop_rate,
787
+ drop_path_rate,
788
+ norm_layer,
789
+ ape,
790
+ patch_norm,
791
+ out_indices,
792
+ use_checkpoint=use_checkpoint,
793
+ )
794
+
795
+ self._out_features = cfg['OUT_FEATURES']
796
+
797
+ self._out_feature_strides = {
798
+ "res2": 4,
799
+ "res3": 8,
800
+ "res4": 16,
801
+ "res5": 32,
802
+ }
803
+ self._out_feature_channels = {
804
+ "res2": self.num_features[0],
805
+ "res3": self.num_features[1],
806
+ "res4": self.num_features[2],
807
+ "res5": self.num_features[3],
808
+ }
809
+
810
+ def forward(self, x):
811
+ """
812
+ Args:
813
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
814
+ Returns:
815
+ dict[str->Tensor]: names and the corresponding features
816
+ """
817
+ assert (
818
+ x.dim() == 4
819
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
820
+ outputs = {}
821
+ y = super().forward(x)
822
+ for k in y.keys():
823
+ if k in self._out_features:
824
+ outputs[k] = y[k]
825
+ return outputs
826
+
827
+ def output_shape(self):
828
+ feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features))
829
+ return {
830
+ name: ShapeSpec(
831
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
832
+ )
833
+ for name in feature_names
834
+ }
835
+
836
+ @property
837
+ def size_divisibility(self):
838
+ return 32
839
+
840
+
841
+ @register_backbone
842
+ def get_swin_backbone(cfg):
843
+ swin_cfg = cfg['MODEL']['BACKBONE']['SWIN']
844
+
845
+ pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE']
846
+ patch_size = swin_cfg['PATCH_SIZE']
847
+ in_chans = 3
848
+ embed_dim = swin_cfg['EMBED_DIM']
849
+ depths = swin_cfg['DEPTHS']
850
+ num_heads = swin_cfg['NUM_HEADS']
851
+ window_size = swin_cfg['WINDOW_SIZE']
852
+ mlp_ratio = swin_cfg['MLP_RATIO']
853
+ qkv_bias = swin_cfg['QKV_BIAS']
854
+ qk_scale = swin_cfg['QK_SCALE']
855
+ drop_rate = swin_cfg['DROP_RATE']
856
+ attn_drop_rate = swin_cfg['ATTN_DROP_RATE']
857
+ drop_path_rate = swin_cfg['DROP_PATH_RATE']
858
+ norm_layer = nn.LayerNorm
859
+ ape = swin_cfg['APE']
860
+ patch_norm = swin_cfg['PATCH_NORM']
861
+ use_checkpoint = swin_cfg['USE_CHECKPOINT']
862
+ out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3])
863
+
864
+ swin = D2SwinTransformer(
865
+ swin_cfg,
866
+ pretrain_img_size,
867
+ patch_size,
868
+ in_chans,
869
+ embed_dim,
870
+ depths,
871
+ num_heads,
872
+ window_size,
873
+ mlp_ratio,
874
+ qkv_bias,
875
+ qk_scale,
876
+ drop_rate,
877
+ attn_drop_rate,
878
+ drop_path_rate,
879
+ norm_layer,
880
+ ape,
881
+ patch_norm,
882
+ out_indices,
883
+ use_checkpoint=use_checkpoint,
884
+ )
885
+
886
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
887
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
888
+ with PathManager.open(filename, "rb") as f:
889
+ ckpt = torch.load(f, map_location=cfg['device'])['model']
890
+ swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
891
+
892
+ return swin