fffiloni commited on
Commit
b7930b8
1 Parent(s): 440df0e

Upload 3 files

Browse files
xdecoder/body/encoder/build.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+ from .transformer_encoder_fpn import *
5
+
6
+ def build_encoder(config, *args, **kwargs):
7
+ model_name = config['MODEL']['ENCODER']['NAME']
8
+
9
+ if not is_model(model_name):
10
+ raise ValueError(f'Unkown model: {model_name}')
11
+
12
+ return model_entrypoints(model_name)(config, *args, **kwargs)
xdecoder/body/encoder/registry.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+ def register_encoder(fn):
4
+ module_name_split = fn.__module__.split('.')
5
+ model_name = module_name_split[-1]
6
+ _model_entrypoints[model_name] = fn
7
+ return fn
8
+
9
+ def model_entrypoints(model_name):
10
+ return _model_entrypoints[model_name]
11
+
12
+ def is_model(model_name):
13
+ return model_name in _model_entrypoints
xdecoder/body/encoder/transformer_encoder_fpn.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
10
+ from torch.cuda.amp import autocast
11
+
12
+ import fvcore.nn.weight_init as weight_init
13
+ from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
14
+
15
+ from .registry import register_encoder
16
+ from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
17
+ from ...modules import PositionEmbeddingSine
18
+ from ...utils import configurable
19
+
20
+ # from ..layers import Conv2d, DeformConv, ShapeSpec, get_norm
21
+
22
+ # This is a modified FPN decoder.
23
+ class BasePixelDecoder(nn.Module):
24
+ def __init__(
25
+ self,
26
+ input_shape: Dict[str, ShapeSpec],
27
+ *,
28
+ conv_dim: int,
29
+ mask_dim: int,
30
+ mask_on: bool,
31
+ norm: Optional[Union[str, Callable]] = None,
32
+ ):
33
+ """
34
+ NOTE: this interface is experimental.
35
+ Args:
36
+ input_shape: shapes (channels and stride) of the input features
37
+ conv_dims: number of output channels for the intermediate conv layers.
38
+ mask_dim: number of output channels for the final conv layer.
39
+ norm (str or callable): normalization for all conv layers
40
+ """
41
+ super().__init__()
42
+
43
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
44
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
45
+ feature_channels = [v.channels for k, v in input_shape]
46
+
47
+ lateral_convs = []
48
+ output_convs = []
49
+
50
+ use_bias = norm == ""
51
+ for idx, in_channels in enumerate(feature_channels):
52
+ if idx == len(self.in_features) - 1:
53
+ output_norm = get_norm(norm, conv_dim)
54
+ output_conv = Conv2d(
55
+ in_channels,
56
+ conv_dim,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1,
60
+ bias=use_bias,
61
+ norm=output_norm,
62
+ activation=F.relu,
63
+ )
64
+ weight_init.c2_xavier_fill(output_conv)
65
+ self.add_module("layer_{}".format(idx + 1), output_conv)
66
+
67
+ lateral_convs.append(None)
68
+ output_convs.append(output_conv)
69
+ else:
70
+ lateral_norm = get_norm(norm, conv_dim)
71
+ output_norm = get_norm(norm, conv_dim)
72
+
73
+ lateral_conv = Conv2d(
74
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
75
+ )
76
+ output_conv = Conv2d(
77
+ conv_dim,
78
+ conv_dim,
79
+ kernel_size=3,
80
+ stride=1,
81
+ padding=1,
82
+ bias=use_bias,
83
+ norm=output_norm,
84
+ activation=F.relu,
85
+ )
86
+ weight_init.c2_xavier_fill(lateral_conv)
87
+ weight_init.c2_xavier_fill(output_conv)
88
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
89
+ self.add_module("layer_{}".format(idx + 1), output_conv)
90
+
91
+ lateral_convs.append(lateral_conv)
92
+ output_convs.append(output_conv)
93
+ # Place convs into top-down order (from low to high resolution)
94
+ # to make the top-down computation in forward clearer.
95
+ self.lateral_convs = lateral_convs[::-1]
96
+ self.output_convs = output_convs[::-1]
97
+
98
+ self.mask_on = mask_on
99
+ if self.mask_on:
100
+ self.mask_dim = mask_dim
101
+ self.mask_features = Conv2d(
102
+ conv_dim,
103
+ mask_dim,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1,
107
+ )
108
+ weight_init.c2_xavier_fill(self.mask_features)
109
+
110
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
111
+
112
+ @classmethod
113
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
114
+ enc_cfg = cfg['MODEL']['ENCODER']
115
+ ret = {}
116
+ ret["input_shape"] = {
117
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
118
+ }
119
+ ret["conv_dim"] = enc_cfg['CONVS_DIM']
120
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
121
+ ret["norm"] = enc_cfg['NORM']
122
+ return ret
123
+
124
+ def forward_features(self, features):
125
+ multi_scale_features = []
126
+ num_cur_levels = 0
127
+ # Reverse feature maps into top-down order (from low to high resolution)
128
+ for idx, f in enumerate(self.in_features[::-1]):
129
+ x = features[f]
130
+ lateral_conv = self.lateral_convs[idx]
131
+ output_conv = self.output_convs[idx]
132
+ if lateral_conv is None:
133
+ y = output_conv(x)
134
+ else:
135
+ cur_fpn = lateral_conv(x)
136
+ # Following FPN implementation, we use nearest upsampling here
137
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
138
+ y = output_conv(y)
139
+ if num_cur_levels < self.maskformer_num_feature_levels:
140
+ multi_scale_features.append(y)
141
+ num_cur_levels += 1
142
+
143
+ mask_features = self.mask_features(y) if self.mask_on else None
144
+ return mask_features, None, multi_scale_features
145
+
146
+ def forward(self, features, targets=None):
147
+ logger = logging.getLogger(__name__)
148
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
149
+ return self.forward_features(features)
150
+
151
+
152
+ class TransformerEncoderOnly(nn.Module):
153
+ def __init__(
154
+ self,
155
+ d_model=512,
156
+ nhead=8,
157
+ num_encoder_layers=6,
158
+ dim_feedforward=2048,
159
+ dropout=0.1,
160
+ activation="relu",
161
+ normalize_before=False,
162
+ ):
163
+ super().__init__()
164
+
165
+ encoder_layer = TransformerEncoderLayer(
166
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
167
+ )
168
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
169
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
170
+
171
+ self._reset_parameters()
172
+
173
+ self.d_model = d_model
174
+ self.nhead = nhead
175
+
176
+ def _reset_parameters(self):
177
+ for p in self.parameters():
178
+ if p.dim() > 1:
179
+ nn.init.xavier_uniform_(p)
180
+
181
+ def forward(self, src, mask, pos_embed):
182
+ # flatten NxCxHxW to HWxNxC
183
+ bs, c, h, w = src.shape
184
+ src = src.flatten(2).permute(2, 0, 1)
185
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
186
+ if mask is not None:
187
+ mask = mask.flatten(1)
188
+
189
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
190
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
191
+
192
+
193
+ # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
194
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
195
+ @configurable
196
+ def __init__(
197
+ self,
198
+ input_shape: Dict[str, ShapeSpec],
199
+ *,
200
+ transformer_dropout: float,
201
+ transformer_nheads: int,
202
+ transformer_dim_feedforward: int,
203
+ transformer_enc_layers: int,
204
+ transformer_pre_norm: bool,
205
+ conv_dim: int,
206
+ mask_dim: int,
207
+ mask_on: int,
208
+ norm: Optional[Union[str, Callable]] = None,
209
+ ):
210
+ """
211
+ NOTE: this interface is experimental.
212
+ Args:
213
+ input_shape: shapes (channels and stride) of the input features
214
+ transformer_dropout: dropout probability in transformer
215
+ transformer_nheads: number of heads in transformer
216
+ transformer_dim_feedforward: dimension of feedforward network
217
+ transformer_enc_layers: number of transformer encoder layers
218
+ transformer_pre_norm: whether to use pre-layernorm or not
219
+ conv_dims: number of output channels for the intermediate conv layers.
220
+ mask_dim: number of output channels for the final conv layer.
221
+ norm (str or callable): normalization for all conv layers
222
+ """
223
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on)
224
+
225
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
226
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
227
+ feature_strides = [v.stride for k, v in input_shape]
228
+ feature_channels = [v.channels for k, v in input_shape]
229
+
230
+ in_channels = feature_channels[len(self.in_features) - 1]
231
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
232
+ weight_init.c2_xavier_fill(self.input_proj)
233
+ self.transformer = TransformerEncoderOnly(
234
+ d_model=conv_dim,
235
+ dropout=transformer_dropout,
236
+ nhead=transformer_nheads,
237
+ dim_feedforward=transformer_dim_feedforward,
238
+ num_encoder_layers=transformer_enc_layers,
239
+ normalize_before=transformer_pre_norm,
240
+ )
241
+ N_steps = conv_dim // 2
242
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
243
+
244
+ # update layer
245
+ use_bias = norm == ""
246
+ output_norm = get_norm(norm, conv_dim)
247
+ output_conv = Conv2d(
248
+ conv_dim,
249
+ conv_dim,
250
+ kernel_size=3,
251
+ stride=1,
252
+ padding=1,
253
+ bias=use_bias,
254
+ norm=output_norm,
255
+ activation=F.relu,
256
+ )
257
+ weight_init.c2_xavier_fill(output_conv)
258
+ delattr(self, "layer_{}".format(len(self.in_features)))
259
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
260
+ self.output_convs[0] = output_conv
261
+
262
+ @classmethod
263
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
264
+ enc_cfg = cfg['MODEL']['ENCODER']
265
+ dec_cfg = cfg['MODEL']['DECODER']
266
+
267
+ ret = super().from_config(cfg, input_shape)
268
+ ret["transformer_dropout"] = dec_cfg['DROPOUT']
269
+ ret["transformer_nheads"] = dec_cfg['NHEADS']
270
+ ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
271
+ ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config
272
+ ret["transformer_pre_norm"] = dec_cfg['PRE_NORM']
273
+
274
+ ret['mask_on'] = cfg['MODEL']['DECODER']['MASK']
275
+ return ret
276
+
277
+ def forward_features(self, features):
278
+ multi_scale_features = []
279
+ num_cur_levels = 0
280
+
281
+ # Reverse feature maps into top-down order (from low to high resolution)
282
+ for idx, f in enumerate(self.in_features[::-1]):
283
+ x = features[f]
284
+ lateral_conv = self.lateral_convs[idx]
285
+ output_conv = self.output_convs[idx]
286
+ if lateral_conv is None:
287
+ transformer = self.input_proj(x)
288
+ pos = self.pe_layer(x)
289
+ transformer = self.transformer(transformer, None, pos)
290
+ y = output_conv(transformer)
291
+ # save intermediate feature as input to Transformer decoder
292
+ transformer_encoder_features = transformer
293
+ else:
294
+ cur_fpn = lateral_conv(x)
295
+ # Following FPN implementation, we use nearest upsampling here
296
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
297
+ y = output_conv(y)
298
+ if num_cur_levels < self.maskformer_num_feature_levels:
299
+ multi_scale_features.append(y)
300
+ num_cur_levels += 1
301
+
302
+ mask_features = self.mask_features(y) if self.mask_on else None
303
+ return mask_features, transformer_encoder_features, multi_scale_features
304
+
305
+ def forward(self, features, targets=None):
306
+ logger = logging.getLogger(__name__)
307
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
308
+ return self.forward_features(features)
309
+
310
+
311
+
312
+ @register_encoder
313
+ def get_transformer_encoder_fpn(cfg, input_shape):
314
+ """
315
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
316
+ """
317
+ model = TransformerEncoderPixelDecoder(cfg, input_shape)
318
+ forward_features = getattr(model, "forward_features", None)
319
+ if not callable(forward_features):
320
+ raise ValueError(
321
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
322
+ f"Please implement forward_features for {name} to only return mask features."
323
+ )
324
+ return model