xichen98cn's picture
Upload 135 files
3dac99f verified
raw
history blame
18.4 kB
import numpy as np
from typing import Callable, Dict, List, Optional, Union
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import normal_
from torch.cuda.amp import autocast
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer_decoder.position_encoding import PositionEmbeddingSine
from .ops.modules import MSDeformAttn
import copy
from timm.models.layers import trunc_normal_
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def build_pixel_decoder(cfg, input_shape):
"""
Build a pixel decoder from `cfg.MODEL.ONE_FORMER.PIXEL_DECODER_NAME`.
"""
name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
forward_features = getattr(model, "forward_features", None)
if not callable(forward_features):
raise ValueError(
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
f"Please implement forward_features for {name} to only return mask features."
)
return model
# MSDeformAttn Transformer encoder in deformable detr
class MSDeformAttnTransformerEncoderOnly(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
activation="relu",
num_feature_levels=4, enc_n_points=4,
):
super().__init__()
self.d_model = d_model
self.nhead = nhead
encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, enc_n_points)
self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, pos_embeds):
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
return memory, spatial_shapes, level_start_index
class MSDeformAttnTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class Norm2d(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class MSDeformAttnTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
output = src
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
return output
@SEM_SEG_HEADS_REGISTRY.register()
class MSDeformAttnPixelDecoder(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
transformer_dropout: float,
transformer_nheads: int,
transformer_dim_feedforward: int,
transformer_enc_layers: int,
conv_dim: int,
mask_dim: int,
norm: Optional[Union[str, Callable]] = None,
# deformable transformer encoder args
transformer_in_features: List[str],
common_stride: int,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
transformer_dropout: dropout probability in transformer
transformer_nheads: number of heads in transformer
transformer_dim_feedforward: dimension of feedforward network
transformer_enc_layers: number of transformer encoder layers
conv_dims: number of output channels for the intermediate conv layers.
mask_dim: number of output channels for the final conv layer.
norm (str or callable): normalization for all conv layers
"""
super().__init__()
transformer_input_shape = {
k: v for k, v in input_shape.items() if k in transformer_in_features
}
# this is the input shape of pixel decoder
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
self.feature_strides = [v.stride for k, v in input_shape]
self.feature_channels = [v.channels for k, v in input_shape]
# this is the input shape of transformer encoder (could use less features than pixel decoder
transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
transformer_in_channels = [v.channels for k, v in transformer_input_shape]
self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
self.transformer_num_feature_levels = len(self.transformer_in_features)
if self.transformer_num_feature_levels > 1:
input_proj_list = []
# from low resolution to high resolution (res5 -> res2)
for in_channels in transformer_in_channels[::-1]:
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, conv_dim, kernel_size=1),
nn.GroupNorm(32, conv_dim),
))
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
nn.GroupNorm(32, conv_dim),
)])
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
self.transformer = MSDeformAttnTransformerEncoderOnly(
d_model=conv_dim,
dropout=transformer_dropout,
nhead=transformer_nheads,
dim_feedforward=transformer_dim_feedforward,
num_encoder_layers=transformer_enc_layers,
num_feature_levels=self.transformer_num_feature_levels,
)
N_steps = conv_dim // 2
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
self.mask_dim = mask_dim
# use 1x1 conv instead
self.mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=1,
stride=1,
padding=0,
)
weight_init.c2_xavier_fill(self.mask_features)
self.maskformer_num_feature_levels = 3 # always use 3 scales
self.common_stride = common_stride
# extra fpn levels
stride = min(self.transformer_feature_strides)
self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
lateral_convs = []
output_convs = []
use_bias = norm == ""
for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
lateral_norm = get_norm(norm, conv_dim)
output_norm = get_norm(norm, conv_dim)
lateral_conv = Conv2d(
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
weight_init.c2_xavier_fill(lateral_conv)
weight_init.c2_xavier_fill(output_conv)
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
self.add_module("layer_{}".format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
############ SAM pixel decoder ##############
self.sam_transformer = MSDeformAttnTransformerEncoderOnly(
d_model=conv_dim,
dropout=transformer_dropout,
nhead=transformer_nheads,
dim_feedforward=transformer_dim_feedforward,
num_encoder_layers=transformer_enc_layers,
num_feature_levels=self.transformer_num_feature_levels,
)
self.sam_pe_layer = None
############ FPN ##############
self.fpn3 = nn.Sequential(
nn.Conv2d(conv_dim, conv_dim, kernel_size=3, stride=2, padding=1,),
get_norm('GN', conv_dim),
nn.ReLU(inplace=True))
self.ops =[self.fpn3]
for op in self.ops:
if isinstance (op, nn.ConvTranspose2d) or isinstance(op, nn.Conv2d):
weight_init.c2_xavier_fill(op[0])
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
ret = {}
ret["input_shape"] = {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
}
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder
ret[
"transformer_enc_layers"
] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
return ret
def resize_feat(self, x, resize_shape):
x = F.interpolate(
x,
size=(resize_shape[0], resize_shape[1]),
mode="bilinear",
align_corners=False,
)
return x
@autocast(enabled=False)
def forward_features(self, features):
srcs = []
pos = []
# Reverse feature maps into top-down order (from low to high resolution)
for idx, f in enumerate(self.transformer_in_features[::-1]):
x = features[f].float() # deformable detr does not support half precision
srcs.append(self.input_proj[idx](x))
pos.append(self.pe_layer(x))
y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
bs = y.shape[0]
split_size_or_sections = [None] * self.transformer_num_feature_levels
for i in range(self.transformer_num_feature_levels):
if i < self.transformer_num_feature_levels - 1:
split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
else:
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
y = torch.split(y, split_size_or_sections, dim=1)
out = []
multi_scale_features = []
num_cur_levels = 0
for i, z in enumerate(y):
out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
x = features[f].float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
y = output_conv(y)
out.append(y)
for o in out:
if num_cur_levels < self.maskformer_num_feature_levels:
multi_scale_features.append(o)
num_cur_levels += 1
############### SAM FPN ###############
sam_intern_feats = features['sam_embedding'][0]
sam_srcs = []
sam_pos = []
for idx, ops in enumerate(self.ops[::-1]):
feat = ops(sam_intern_feats)
sam_srcs.append(feat)
sam_pos.append(self.pe_layer(feat))
y, spatial_shapes, level_start_index = self.sam_transformer(sam_srcs, sam_pos)
sam_outs = []
sam_outs.append(y.transpose(1, 2).view(bs, -1, spatial_shapes[0][0], spatial_shapes[0][1]))
return self.mask_features(out[-1]), multi_scale_features, sam_outs