Spaces:
Paused
Paused
Upload 3 files
Browse files
xdecoder/body/encoder/build.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .registry import model_entrypoints
|
2 |
+
from .registry import is_model
|
3 |
+
|
4 |
+
from .transformer_encoder_fpn import *
|
5 |
+
|
6 |
+
def build_encoder(config, *args, **kwargs):
|
7 |
+
model_name = config['MODEL']['ENCODER']['NAME']
|
8 |
+
|
9 |
+
if not is_model(model_name):
|
10 |
+
raise ValueError(f'Unkown model: {model_name}')
|
11 |
+
|
12 |
+
return model_entrypoints(model_name)(config, *args, **kwargs)
|
xdecoder/body/encoder/registry.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
def register_encoder(fn):
|
4 |
+
module_name_split = fn.__module__.split('.')
|
5 |
+
model_name = module_name_split[-1]
|
6 |
+
_model_entrypoints[model_name] = fn
|
7 |
+
return fn
|
8 |
+
|
9 |
+
def model_entrypoints(model_name):
|
10 |
+
return _model_entrypoints[model_name]
|
11 |
+
|
12 |
+
def is_model(model_name):
|
13 |
+
return model_name in _model_entrypoints
|
xdecoder/body/encoder/transformer_encoder_fpn.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
|
12 |
+
import fvcore.nn.weight_init as weight_init
|
13 |
+
from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
|
14 |
+
|
15 |
+
from .registry import register_encoder
|
16 |
+
from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
|
17 |
+
from ...modules import PositionEmbeddingSine
|
18 |
+
from ...utils import configurable
|
19 |
+
|
20 |
+
# from ..layers import Conv2d, DeformConv, ShapeSpec, get_norm
|
21 |
+
|
22 |
+
# This is a modified FPN decoder.
|
23 |
+
class BasePixelDecoder(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
input_shape: Dict[str, ShapeSpec],
|
27 |
+
*,
|
28 |
+
conv_dim: int,
|
29 |
+
mask_dim: int,
|
30 |
+
mask_on: bool,
|
31 |
+
norm: Optional[Union[str, Callable]] = None,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
NOTE: this interface is experimental.
|
35 |
+
Args:
|
36 |
+
input_shape: shapes (channels and stride) of the input features
|
37 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
38 |
+
mask_dim: number of output channels for the final conv layer.
|
39 |
+
norm (str or callable): normalization for all conv layers
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
44 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
45 |
+
feature_channels = [v.channels for k, v in input_shape]
|
46 |
+
|
47 |
+
lateral_convs = []
|
48 |
+
output_convs = []
|
49 |
+
|
50 |
+
use_bias = norm == ""
|
51 |
+
for idx, in_channels in enumerate(feature_channels):
|
52 |
+
if idx == len(self.in_features) - 1:
|
53 |
+
output_norm = get_norm(norm, conv_dim)
|
54 |
+
output_conv = Conv2d(
|
55 |
+
in_channels,
|
56 |
+
conv_dim,
|
57 |
+
kernel_size=3,
|
58 |
+
stride=1,
|
59 |
+
padding=1,
|
60 |
+
bias=use_bias,
|
61 |
+
norm=output_norm,
|
62 |
+
activation=F.relu,
|
63 |
+
)
|
64 |
+
weight_init.c2_xavier_fill(output_conv)
|
65 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
66 |
+
|
67 |
+
lateral_convs.append(None)
|
68 |
+
output_convs.append(output_conv)
|
69 |
+
else:
|
70 |
+
lateral_norm = get_norm(norm, conv_dim)
|
71 |
+
output_norm = get_norm(norm, conv_dim)
|
72 |
+
|
73 |
+
lateral_conv = Conv2d(
|
74 |
+
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
|
75 |
+
)
|
76 |
+
output_conv = Conv2d(
|
77 |
+
conv_dim,
|
78 |
+
conv_dim,
|
79 |
+
kernel_size=3,
|
80 |
+
stride=1,
|
81 |
+
padding=1,
|
82 |
+
bias=use_bias,
|
83 |
+
norm=output_norm,
|
84 |
+
activation=F.relu,
|
85 |
+
)
|
86 |
+
weight_init.c2_xavier_fill(lateral_conv)
|
87 |
+
weight_init.c2_xavier_fill(output_conv)
|
88 |
+
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
|
89 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
90 |
+
|
91 |
+
lateral_convs.append(lateral_conv)
|
92 |
+
output_convs.append(output_conv)
|
93 |
+
# Place convs into top-down order (from low to high resolution)
|
94 |
+
# to make the top-down computation in forward clearer.
|
95 |
+
self.lateral_convs = lateral_convs[::-1]
|
96 |
+
self.output_convs = output_convs[::-1]
|
97 |
+
|
98 |
+
self.mask_on = mask_on
|
99 |
+
if self.mask_on:
|
100 |
+
self.mask_dim = mask_dim
|
101 |
+
self.mask_features = Conv2d(
|
102 |
+
conv_dim,
|
103 |
+
mask_dim,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1,
|
107 |
+
)
|
108 |
+
weight_init.c2_xavier_fill(self.mask_features)
|
109 |
+
|
110 |
+
self.maskformer_num_feature_levels = 3 # always use 3 scales
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
114 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
115 |
+
ret = {}
|
116 |
+
ret["input_shape"] = {
|
117 |
+
k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
|
118 |
+
}
|
119 |
+
ret["conv_dim"] = enc_cfg['CONVS_DIM']
|
120 |
+
ret["mask_dim"] = enc_cfg['MASK_DIM']
|
121 |
+
ret["norm"] = enc_cfg['NORM']
|
122 |
+
return ret
|
123 |
+
|
124 |
+
def forward_features(self, features):
|
125 |
+
multi_scale_features = []
|
126 |
+
num_cur_levels = 0
|
127 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
128 |
+
for idx, f in enumerate(self.in_features[::-1]):
|
129 |
+
x = features[f]
|
130 |
+
lateral_conv = self.lateral_convs[idx]
|
131 |
+
output_conv = self.output_convs[idx]
|
132 |
+
if lateral_conv is None:
|
133 |
+
y = output_conv(x)
|
134 |
+
else:
|
135 |
+
cur_fpn = lateral_conv(x)
|
136 |
+
# Following FPN implementation, we use nearest upsampling here
|
137 |
+
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
|
138 |
+
y = output_conv(y)
|
139 |
+
if num_cur_levels < self.maskformer_num_feature_levels:
|
140 |
+
multi_scale_features.append(y)
|
141 |
+
num_cur_levels += 1
|
142 |
+
|
143 |
+
mask_features = self.mask_features(y) if self.mask_on else None
|
144 |
+
return mask_features, None, multi_scale_features
|
145 |
+
|
146 |
+
def forward(self, features, targets=None):
|
147 |
+
logger = logging.getLogger(__name__)
|
148 |
+
logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
|
149 |
+
return self.forward_features(features)
|
150 |
+
|
151 |
+
|
152 |
+
class TransformerEncoderOnly(nn.Module):
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
d_model=512,
|
156 |
+
nhead=8,
|
157 |
+
num_encoder_layers=6,
|
158 |
+
dim_feedforward=2048,
|
159 |
+
dropout=0.1,
|
160 |
+
activation="relu",
|
161 |
+
normalize_before=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
encoder_layer = TransformerEncoderLayer(
|
166 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
167 |
+
)
|
168 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
169 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
170 |
+
|
171 |
+
self._reset_parameters()
|
172 |
+
|
173 |
+
self.d_model = d_model
|
174 |
+
self.nhead = nhead
|
175 |
+
|
176 |
+
def _reset_parameters(self):
|
177 |
+
for p in self.parameters():
|
178 |
+
if p.dim() > 1:
|
179 |
+
nn.init.xavier_uniform_(p)
|
180 |
+
|
181 |
+
def forward(self, src, mask, pos_embed):
|
182 |
+
# flatten NxCxHxW to HWxNxC
|
183 |
+
bs, c, h, w = src.shape
|
184 |
+
src = src.flatten(2).permute(2, 0, 1)
|
185 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
186 |
+
if mask is not None:
|
187 |
+
mask = mask.flatten(1)
|
188 |
+
|
189 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
190 |
+
return memory.permute(1, 2, 0).view(bs, c, h, w)
|
191 |
+
|
192 |
+
|
193 |
+
# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
|
194 |
+
class TransformerEncoderPixelDecoder(BasePixelDecoder):
|
195 |
+
@configurable
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
input_shape: Dict[str, ShapeSpec],
|
199 |
+
*,
|
200 |
+
transformer_dropout: float,
|
201 |
+
transformer_nheads: int,
|
202 |
+
transformer_dim_feedforward: int,
|
203 |
+
transformer_enc_layers: int,
|
204 |
+
transformer_pre_norm: bool,
|
205 |
+
conv_dim: int,
|
206 |
+
mask_dim: int,
|
207 |
+
mask_on: int,
|
208 |
+
norm: Optional[Union[str, Callable]] = None,
|
209 |
+
):
|
210 |
+
"""
|
211 |
+
NOTE: this interface is experimental.
|
212 |
+
Args:
|
213 |
+
input_shape: shapes (channels and stride) of the input features
|
214 |
+
transformer_dropout: dropout probability in transformer
|
215 |
+
transformer_nheads: number of heads in transformer
|
216 |
+
transformer_dim_feedforward: dimension of feedforward network
|
217 |
+
transformer_enc_layers: number of transformer encoder layers
|
218 |
+
transformer_pre_norm: whether to use pre-layernorm or not
|
219 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
220 |
+
mask_dim: number of output channels for the final conv layer.
|
221 |
+
norm (str or callable): normalization for all conv layers
|
222 |
+
"""
|
223 |
+
super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on)
|
224 |
+
|
225 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
226 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
227 |
+
feature_strides = [v.stride for k, v in input_shape]
|
228 |
+
feature_channels = [v.channels for k, v in input_shape]
|
229 |
+
|
230 |
+
in_channels = feature_channels[len(self.in_features) - 1]
|
231 |
+
self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
|
232 |
+
weight_init.c2_xavier_fill(self.input_proj)
|
233 |
+
self.transformer = TransformerEncoderOnly(
|
234 |
+
d_model=conv_dim,
|
235 |
+
dropout=transformer_dropout,
|
236 |
+
nhead=transformer_nheads,
|
237 |
+
dim_feedforward=transformer_dim_feedforward,
|
238 |
+
num_encoder_layers=transformer_enc_layers,
|
239 |
+
normalize_before=transformer_pre_norm,
|
240 |
+
)
|
241 |
+
N_steps = conv_dim // 2
|
242 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
243 |
+
|
244 |
+
# update layer
|
245 |
+
use_bias = norm == ""
|
246 |
+
output_norm = get_norm(norm, conv_dim)
|
247 |
+
output_conv = Conv2d(
|
248 |
+
conv_dim,
|
249 |
+
conv_dim,
|
250 |
+
kernel_size=3,
|
251 |
+
stride=1,
|
252 |
+
padding=1,
|
253 |
+
bias=use_bias,
|
254 |
+
norm=output_norm,
|
255 |
+
activation=F.relu,
|
256 |
+
)
|
257 |
+
weight_init.c2_xavier_fill(output_conv)
|
258 |
+
delattr(self, "layer_{}".format(len(self.in_features)))
|
259 |
+
self.add_module("layer_{}".format(len(self.in_features)), output_conv)
|
260 |
+
self.output_convs[0] = output_conv
|
261 |
+
|
262 |
+
@classmethod
|
263 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
264 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
265 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
266 |
+
|
267 |
+
ret = super().from_config(cfg, input_shape)
|
268 |
+
ret["transformer_dropout"] = dec_cfg['DROPOUT']
|
269 |
+
ret["transformer_nheads"] = dec_cfg['NHEADS']
|
270 |
+
ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
|
271 |
+
ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config
|
272 |
+
ret["transformer_pre_norm"] = dec_cfg['PRE_NORM']
|
273 |
+
|
274 |
+
ret['mask_on'] = cfg['MODEL']['DECODER']['MASK']
|
275 |
+
return ret
|
276 |
+
|
277 |
+
def forward_features(self, features):
|
278 |
+
multi_scale_features = []
|
279 |
+
num_cur_levels = 0
|
280 |
+
|
281 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
282 |
+
for idx, f in enumerate(self.in_features[::-1]):
|
283 |
+
x = features[f]
|
284 |
+
lateral_conv = self.lateral_convs[idx]
|
285 |
+
output_conv = self.output_convs[idx]
|
286 |
+
if lateral_conv is None:
|
287 |
+
transformer = self.input_proj(x)
|
288 |
+
pos = self.pe_layer(x)
|
289 |
+
transformer = self.transformer(transformer, None, pos)
|
290 |
+
y = output_conv(transformer)
|
291 |
+
# save intermediate feature as input to Transformer decoder
|
292 |
+
transformer_encoder_features = transformer
|
293 |
+
else:
|
294 |
+
cur_fpn = lateral_conv(x)
|
295 |
+
# Following FPN implementation, we use nearest upsampling here
|
296 |
+
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
|
297 |
+
y = output_conv(y)
|
298 |
+
if num_cur_levels < self.maskformer_num_feature_levels:
|
299 |
+
multi_scale_features.append(y)
|
300 |
+
num_cur_levels += 1
|
301 |
+
|
302 |
+
mask_features = self.mask_features(y) if self.mask_on else None
|
303 |
+
return mask_features, transformer_encoder_features, multi_scale_features
|
304 |
+
|
305 |
+
def forward(self, features, targets=None):
|
306 |
+
logger = logging.getLogger(__name__)
|
307 |
+
logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
|
308 |
+
return self.forward_features(features)
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
@register_encoder
|
313 |
+
def get_transformer_encoder_fpn(cfg, input_shape):
|
314 |
+
"""
|
315 |
+
Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
|
316 |
+
"""
|
317 |
+
model = TransformerEncoderPixelDecoder(cfg, input_shape)
|
318 |
+
forward_features = getattr(model, "forward_features", None)
|
319 |
+
if not callable(forward_features):
|
320 |
+
raise ValueError(
|
321 |
+
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
|
322 |
+
f"Please implement forward_features for {name} to only return mask features."
|
323 |
+
)
|
324 |
+
return model
|