fffiloni commited on
Commit
bb3a5a1
1 Parent(s): 09789c8

Upload 4 files

Browse files
xdecoder/body/build.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+ from .xdecoder_head import *
5
+
6
+
7
+ def build_xdecoder_head(config, *args, **kwargs):
8
+ model_name = config['MODEL']['HEAD']
9
+ if not is_model(model_name):
10
+ raise ValueError(f'Unkown model: {model_name}')
11
+
12
+ body = model_entrypoints(model_name)(config, *args, **kwargs)
13
+ return body
xdecoder/body/registry.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def register_body(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+ _model_entrypoints[model_name] = fn
8
+ return fn
9
+
10
+ def model_entrypoints(model_name):
11
+ return _model_entrypoints[model_name]
12
+
13
+ def is_model(model_name):
14
+ return model_name in _model_entrypoints
xdecoder/body/transformer_blocks.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ """
4
+ Transformer class.
5
+
6
+ Copy-paste from torch.nn.Transformer with modifications:
7
+ * positional encodings are passed in MHattention
8
+ * extra LN at the end of encoder is removed
9
+ * decoder returns a stack of activations from all decoding layers
10
+ """
11
+ import copy
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import Tensor, nn
17
+
18
+
19
+ class Transformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ d_model=512,
23
+ nhead=8,
24
+ num_encoder_layers=6,
25
+ num_decoder_layers=6,
26
+ dim_feedforward=2048,
27
+ dropout=0.1,
28
+ activation="relu",
29
+ normalize_before=False,
30
+ return_intermediate_dec=False,
31
+ ):
32
+ super().__init__()
33
+
34
+ encoder_layer = TransformerEncoderLayer(
35
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
36
+ )
37
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
38
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
39
+
40
+ decoder_layer = TransformerDecoderLayer(
41
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
42
+ )
43
+ decoder_norm = nn.LayerNorm(d_model)
44
+ self.decoder = TransformerDecoder(
45
+ decoder_layer,
46
+ num_decoder_layers,
47
+ decoder_norm,
48
+ return_intermediate=return_intermediate_dec,
49
+ )
50
+
51
+ self._reset_parameters()
52
+
53
+ self.d_model = d_model
54
+ self.nhead = nhead
55
+
56
+ def _reset_parameters(self):
57
+ for p in self.parameters():
58
+ if p.dim() > 1:
59
+ nn.init.xavier_uniform_(p)
60
+
61
+ def forward(self, src, mask, query_embed, pos_embed):
62
+ # flatten NxCxHxW to HWxNxC
63
+ bs, c, h, w = src.shape
64
+ src = src.flatten(2).permute(2, 0, 1)
65
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
66
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
67
+ if mask is not None:
68
+ mask = mask.flatten(1)
69
+
70
+ tgt = torch.zeros_like(query_embed)
71
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
72
+ hs = self.decoder(
73
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
74
+ )
75
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
76
+
77
+
78
+ class TransformerEncoder(nn.Module):
79
+ def __init__(self, encoder_layer, num_layers, norm=None):
80
+ super().__init__()
81
+ self.layers = _get_clones(encoder_layer, num_layers)
82
+ self.num_layers = num_layers
83
+ self.norm = norm
84
+
85
+ def forward(
86
+ self,
87
+ src,
88
+ mask: Optional[Tensor] = None,
89
+ src_key_padding_mask: Optional[Tensor] = None,
90
+ pos: Optional[Tensor] = None,
91
+ ):
92
+ output = src
93
+
94
+ for layer in self.layers:
95
+ output = layer(
96
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
97
+ )
98
+
99
+ if self.norm is not None:
100
+ output = self.norm(output)
101
+
102
+ return output
103
+
104
+
105
+ class TransformerDecoder(nn.Module):
106
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
107
+ super().__init__()
108
+ self.layers = _get_clones(decoder_layer, num_layers)
109
+ self.num_layers = num_layers
110
+ self.norm = norm
111
+ self.return_intermediate = return_intermediate
112
+
113
+ def forward(
114
+ self,
115
+ tgt,
116
+ memory,
117
+ tgt_mask: Optional[Tensor] = None,
118
+ memory_mask: Optional[Tensor] = None,
119
+ tgt_key_padding_mask: Optional[Tensor] = None,
120
+ memory_key_padding_mask: Optional[Tensor] = None,
121
+ pos: Optional[Tensor] = None,
122
+ query_pos: Optional[Tensor] = None,
123
+ ):
124
+ output = tgt
125
+
126
+ intermediate = []
127
+
128
+ for layer in self.layers:
129
+ output = layer(
130
+ output,
131
+ memory,
132
+ tgt_mask=tgt_mask,
133
+ memory_mask=memory_mask,
134
+ tgt_key_padding_mask=tgt_key_padding_mask,
135
+ memory_key_padding_mask=memory_key_padding_mask,
136
+ pos=pos,
137
+ query_pos=query_pos,
138
+ )
139
+ if self.return_intermediate:
140
+ intermediate.append(self.norm(output))
141
+
142
+ if self.norm is not None:
143
+ output = self.norm(output)
144
+ if self.return_intermediate:
145
+ intermediate.pop()
146
+ intermediate.append(output)
147
+
148
+ if self.return_intermediate:
149
+ return torch.stack(intermediate)
150
+
151
+ return output.unsqueeze(0)
152
+
153
+
154
+ class TransformerEncoderLayer(nn.Module):
155
+ def __init__(
156
+ self,
157
+ d_model,
158
+ nhead,
159
+ dim_feedforward=2048,
160
+ dropout=0.1,
161
+ activation="relu",
162
+ normalize_before=False,
163
+ ):
164
+ super().__init__()
165
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
166
+ # Implementation of Feedforward model
167
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
168
+ self.dropout = nn.Dropout(dropout)
169
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
170
+
171
+ self.norm1 = nn.LayerNorm(d_model)
172
+ self.norm2 = nn.LayerNorm(d_model)
173
+ self.dropout1 = nn.Dropout(dropout)
174
+ self.dropout2 = nn.Dropout(dropout)
175
+
176
+ self.activation = _get_activation_fn(activation)
177
+ self.normalize_before = normalize_before
178
+
179
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
180
+ return tensor if pos is None else tensor + pos
181
+
182
+ def forward_post(
183
+ self,
184
+ src,
185
+ src_mask: Optional[Tensor] = None,
186
+ src_key_padding_mask: Optional[Tensor] = None,
187
+ pos: Optional[Tensor] = None,
188
+ ):
189
+ q = k = self.with_pos_embed(src, pos)
190
+
191
+ src2 = self.self_attn(
192
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
193
+ )[0]
194
+ src = src + self.dropout1(src2)
195
+ src = self.norm1(src)
196
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
197
+ src = src + self.dropout2(src2)
198
+ src = self.norm2(src)
199
+ return src
200
+
201
+ def forward_pre(
202
+ self,
203
+ src,
204
+ src_mask: Optional[Tensor] = None,
205
+ src_key_padding_mask: Optional[Tensor] = None,
206
+ pos: Optional[Tensor] = None,
207
+ ):
208
+ src2 = self.norm1(src)
209
+ q = k = self.with_pos_embed(src2, pos)
210
+ src2 = self.self_attn(
211
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
212
+ )[0]
213
+ src = src + self.dropout1(src2)
214
+ src2 = self.norm2(src)
215
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
216
+ src = src + self.dropout2(src2)
217
+ return src
218
+
219
+ def forward(
220
+ self,
221
+ src,
222
+ src_mask: Optional[Tensor] = None,
223
+ src_key_padding_mask: Optional[Tensor] = None,
224
+ pos: Optional[Tensor] = None,
225
+ ):
226
+ if self.normalize_before:
227
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
228
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
229
+
230
+
231
+ class TransformerDecoderLayer(nn.Module):
232
+ def __init__(
233
+ self,
234
+ d_model,
235
+ nhead,
236
+ dim_feedforward=2048,
237
+ dropout=0.1,
238
+ activation="relu",
239
+ normalize_before=False,
240
+ ):
241
+ super().__init__()
242
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
243
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
244
+ # Implementation of Feedforward model
245
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
246
+ self.dropout = nn.Dropout(dropout)
247
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
248
+
249
+ self.norm1 = nn.LayerNorm(d_model)
250
+ self.norm2 = nn.LayerNorm(d_model)
251
+ self.norm3 = nn.LayerNorm(d_model)
252
+ self.dropout1 = nn.Dropout(dropout)
253
+ self.dropout2 = nn.Dropout(dropout)
254
+ self.dropout3 = nn.Dropout(dropout)
255
+
256
+ self.activation = _get_activation_fn(activation)
257
+ self.normalize_before = normalize_before
258
+
259
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
260
+ return tensor if pos is None else tensor + pos
261
+
262
+ def forward_post(
263
+ self,
264
+ tgt,
265
+ memory,
266
+ tgt_mask: Optional[Tensor] = None,
267
+ memory_mask: Optional[Tensor] = None,
268
+ tgt_key_padding_mask: Optional[Tensor] = None,
269
+ memory_key_padding_mask: Optional[Tensor] = None,
270
+ pos: Optional[Tensor] = None,
271
+ query_pos: Optional[Tensor] = None,
272
+ ):
273
+ q = k = self.with_pos_embed(tgt, query_pos)
274
+ tgt2 = self.self_attn(
275
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
276
+ )[0]
277
+ tgt = tgt + self.dropout1(tgt2)
278
+ tgt = self.norm1(tgt)
279
+ tgt2 = self.multihead_attn(
280
+ query=self.with_pos_embed(tgt, query_pos),
281
+ key=self.with_pos_embed(memory, pos),
282
+ value=memory,
283
+ attn_mask=memory_mask,
284
+ key_padding_mask=memory_key_padding_mask,
285
+ )[0]
286
+ tgt = tgt + self.dropout2(tgt2)
287
+ tgt = self.norm2(tgt)
288
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
289
+ tgt = tgt + self.dropout3(tgt2)
290
+ tgt = self.norm3(tgt)
291
+ return tgt
292
+
293
+ def forward_pre(
294
+ self,
295
+ tgt,
296
+ memory,
297
+ tgt_mask: Optional[Tensor] = None,
298
+ memory_mask: Optional[Tensor] = None,
299
+ tgt_key_padding_mask: Optional[Tensor] = None,
300
+ memory_key_padding_mask: Optional[Tensor] = None,
301
+ pos: Optional[Tensor] = None,
302
+ query_pos: Optional[Tensor] = None,
303
+ ):
304
+ tgt2 = self.norm1(tgt)
305
+ q = k = self.with_pos_embed(tgt2, query_pos)
306
+ tgt2 = self.self_attn(
307
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
308
+ )[0]
309
+ tgt = tgt + self.dropout1(tgt2)
310
+ tgt2 = self.norm2(tgt)
311
+ tgt2 = self.multihead_attn(
312
+ query=self.with_pos_embed(tgt2, query_pos),
313
+ key=self.with_pos_embed(memory, pos),
314
+ value=memory,
315
+ attn_mask=memory_mask,
316
+ key_padding_mask=memory_key_padding_mask,
317
+ )[0]
318
+ tgt = tgt + self.dropout2(tgt2)
319
+ tgt2 = self.norm3(tgt)
320
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
321
+ tgt = tgt + self.dropout3(tgt2)
322
+ return tgt
323
+
324
+ def forward(
325
+ self,
326
+ tgt,
327
+ memory,
328
+ tgt_mask: Optional[Tensor] = None,
329
+ memory_mask: Optional[Tensor] = None,
330
+ tgt_key_padding_mask: Optional[Tensor] = None,
331
+ memory_key_padding_mask: Optional[Tensor] = None,
332
+ pos: Optional[Tensor] = None,
333
+ query_pos: Optional[Tensor] = None,
334
+ ):
335
+ if self.normalize_before:
336
+ return self.forward_pre(
337
+ tgt,
338
+ memory,
339
+ tgt_mask,
340
+ memory_mask,
341
+ tgt_key_padding_mask,
342
+ memory_key_padding_mask,
343
+ pos,
344
+ query_pos,
345
+ )
346
+ return self.forward_post(
347
+ tgt,
348
+ memory,
349
+ tgt_mask,
350
+ memory_mask,
351
+ tgt_key_padding_mask,
352
+ memory_key_padding_mask,
353
+ pos,
354
+ query_pos,
355
+ )
356
+
357
+
358
+ def _get_clones(module, N):
359
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
360
+
361
+
362
+ def _get_activation_fn(activation):
363
+ """Return an activation function given a string"""
364
+ if activation == "relu":
365
+ return F.relu
366
+ if activation == "gelu":
367
+ return F.gelu
368
+ if activation == "glu":
369
+ return F.glu
370
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
xdecoder/body/xdecoder_head.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # --------------------------------------------------------
4
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
5
+ # Copyright (c) 2022 Microsoft
6
+ # Licensed under The MIT License [see LICENSE for details]
7
+ # Written by Jianwei Yang (jianwyan@microsoft.com), Xueyan Zou (xueyan@cs.wisc.edu)
8
+ # --------------------------------------------------------
9
+
10
+ from typing import Dict
11
+
12
+ from torch import nn
13
+
14
+ from detectron2.layers import ShapeSpec
15
+
16
+ from .registry import register_body
17
+ from .encoder import build_encoder
18
+ from .decoder import build_decoder
19
+ from ..utils import configurable
20
+
21
+
22
+ class XDecoderHead(nn.Module):
23
+
24
+ @configurable
25
+ def __init__(
26
+ self,
27
+ input_shape: Dict[str, ShapeSpec],
28
+ *,
29
+ num_classes: int,
30
+ pixel_decoder: nn.Module,
31
+ loss_weight: float = 1.0,
32
+ ignore_value: int = -1,
33
+ # extra parameters
34
+ transformer_predictor: nn.Module,
35
+ transformer_in_feature: str,
36
+ ):
37
+ """
38
+ NOTE: this interface is experimental.
39
+ Args:
40
+ input_shape: shapes (channels and stride) of the input features
41
+ num_classes: number of classes to predict
42
+ pixel_decoder: the pixel decoder module
43
+ loss_weight: loss weight
44
+ ignore_value: category id to be ignored during training.
45
+ transformer_predictor: the transformer decoder that makes prediction
46
+ transformer_in_feature: input feature name to the transformer_predictor
47
+ """
48
+ super().__init__()
49
+
50
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
51
+ self.in_features = [k for k, v in input_shape]
52
+ feature_strides = [v.stride for k, v in input_shape]
53
+ feature_channels = [v.channels for k, v in input_shape]
54
+
55
+ self.ignore_value = ignore_value
56
+ self.common_stride = 4
57
+ self.loss_weight = loss_weight
58
+
59
+ self.pixel_decoder = pixel_decoder
60
+ self.predictor = transformer_predictor
61
+ self.transformer_in_feature = transformer_in_feature
62
+
63
+ self.num_classes = num_classes
64
+
65
+ @classmethod
66
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
67
+
68
+ in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
69
+ enc_cfg = cfg['MODEL']['ENCODER']
70
+ dec_cfg = cfg['MODEL']['DECODER']
71
+
72
+ # figure out in_channels to transformer predictor
73
+ if in_features_type == "transformer_encoder":
74
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
75
+ elif in_features_type == "pixel_embedding":
76
+ transformer_predictor_in_channels = enc_cfg['MASK_DIM']
77
+ elif in_features_type == "multi_scale_pixel_decoder": # for maskformer2
78
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
79
+ else:
80
+ transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
81
+
82
+ return {
83
+ "input_shape": {
84
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
85
+ },
86
+ "ignore_value": enc_cfg['IGNORE_VALUE'],
87
+ "num_classes": enc_cfg.get('NUM_CLASSES', None),
88
+ "pixel_decoder": build_encoder(cfg, input_shape),
89
+ "loss_weight": enc_cfg['LOSS_WEIGHT'],
90
+ "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
91
+ "transformer_predictor": build_decoder(
92
+ cfg,
93
+ transformer_predictor_in_channels,
94
+ lang_encoder,
95
+ mask_classification=True,
96
+ extra=extra,
97
+ ),
98
+ }
99
+
100
+ def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
101
+ return self.layers(features, mask, target_queries, target_vlp, task, extra)
102
+
103
+ def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
104
+ mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
105
+
106
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
107
+ predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
108
+ else:
109
+ if self.transformer_in_feature == "transformer_encoder":
110
+ assert (
111
+ transformer_encoder_features is not None
112
+ ), "Please use the TransformerEncoderPixelDecoder."
113
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
114
+ elif self.transformer_in_feature == "pixel_embedding":
115
+ predictions = self.predictor(mask_features, mask_features, mask)
116
+ else:
117
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
118
+ return predictions
119
+
120
+
121
+ @register_body
122
+ def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
123
+ return XDecoderHead(cfg, input_shape, lang_encoder, extra)