RSPrompter / mmpl /models /necks /sam_prompt_generator.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
35.9 kB
import copy
import math
from typing import Type, Tuple
import einops
import torch
import torch.nn as nn
from einops import rearrange
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks.transformer import build_transformer_layer
from torch import Tensor
from mmdet.models import SinePositionalEncoding
from mmpl.registry import MODELS
import torch.nn.functional as F
@MODELS.register_module()
class SAMTransformerPromptGenNeck(nn.Module):
def __init__(
self,
prompt_shape=(100, 6),
in_channels=[1280]*16,
out_channels=256,
positional_encoding=dict(num_feats=128, normalize=True),
n_classes=2,
kernel_size=3,
stride=1,
norm_cfg=None,
act_cfg=dict(type='ReLU')
):
super(SAMTransformerPromptGenNeck, self).__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.out_put_channels = out_channels
self.n_classes = n_classes
self.stride = stride
self.prompt_shape = prompt_shape
self.num_queries = prompt_shape[0]
self.per_query_point = prompt_shape[1]
if isinstance(in_channels, list):
self.pre_layers = nn.ModuleList()
inner_channel = 32
for idx, channel in enumerate(in_channels):
self.pre_layers.append(
nn.Sequential(
ConvModule(
channel,
inner_channel,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channel,
inner_channel*2,
kernel_size=kernel_size,
padding=kernel_size // 2,
stride=self.stride,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channel*2,
inner_channel,
kernel_size=kernel_size,
padding=kernel_size // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
)
)
self.pre_layers.append(
nn.Sequential(
ConvModule(
inner_channel * len(in_channels),
out_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
out_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
)
)
self.generator_pe = SinePositionalEncoding(**positional_encoding)
self.transformer = self.build_transformer()
self.query_feat = nn.Embedding(self.num_queries, out_channels)
self.query_emb = nn.Embedding(self.num_queries, out_channels)
self.output_upscaling = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels // 4),
nn.GELU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels // 8),
nn.GELU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1),
)
self.cls_head = nn.Sequential(
nn.Linear(out_channels, out_channels//2),
nn.ReLU(),
nn.Linear(out_channels//2, n_classes)
)
# self.point_emb = nn.Sequential(
# nn.Linear(out_channels, out_channels),
# nn.ReLU(),
# nn.Linear(out_channels, out_channels),
# nn.ReLU(),
# nn.Linear(out_channels, self.per_query_point * out_channels)
# )
self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3)
def build_transformer(
self, num_encoder_layers=2, num_decoder_layers=3, embed_dims=256, num_heads=8,
mlp_ratio=2, dropout_rate=0.0, act_cfg=dict(type="gelu")):
"""Build transformer decoder."""
# transformer = nn.Transformer(
# d_model=embed_dims, nhead=num_heads, num_encoder_layers=num_encoder_layers,
# num_decoder_layers=num_decoder_layers, dim_feedforward=mlp_ratio * embed_dims,
# dropout=dropout_rate, activation=act_cfg['type'], batch_first=True, norm_first=True,
# )
transformer = Transformer(depth=2)
return transformer
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, inputs, prompt_encoder, mask_decoder):
img_embs, inner_states = inputs
if hasattr(self, 'pre_layers'):
inner_states = inner_states[-len(self.in_channels):]
inner_states = [einops.rearrange(x, 'b h w c -> b c h w') for x in inner_states]
inner_states = [layer(x) for layer, x in zip(self.pre_layers[:-1], inner_states)]
img_feats = self.pre_layers[-1](torch.cat(inner_states, dim=1))
bs, c, h, w = img_feats.shape
mask_pe = torch.zeros((bs, h, w), device=img_feats.device)
img_feats_pe = self.generator_pe(mask_pe)
query_feat = self.query_feat.weight.unsqueeze(0).expand(bs, -1, -1) # Bx256x256
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs, -1, -1)
img_feats, query_feats = self.transformer(
image_embedding=img_feats,
image_pe=img_feats_pe,
point_embedding=query_feat,
point_pe=query_emb)
cls_logits = self.cls_head(query_feats)
# point_embs = self.point_emb(query_feats)
# point_embs = rearrange(point_embs, 'b n (t c) -> b n t c', t=self.per_query_point) # Bx100x6x256
src = img_feats.transpose(1, 2).view(bs, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in = self.output_hypernetworks_mlps(query_feats)
b, c, h, w = upscaled_embedding.shape
l1_masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# dense_masks = einops.rearrange(l1_masks, 'b (n t) h w -> (b n) t h w', t=1)
# sparse, dense = prompt_encoder(points=None, boxes=None, masks=dense_masks)
# dense = einops.rearrange(dense, '(b n) t h w -> b n t h w', n=self.num_queries)
# l2_masks = []
# iou_preds = []
# for curr_embedding, sparse_embeddings, dense_embeddings in zip(img_embs, point_embs, dense):
# low_res_masks, iou_predictions = mask_decoder(
# image_embeddings=curr_embedding.unsqueeze(0),
# image_pe=prompt_encoder.get_dense_pe(),
# sparse_prompt_embeddings=sparse_embeddings,
# dense_prompt_embeddings=dense_embeddings
# )
# l2_masks.append(low_res_masks[:, 0])
# iou_preds.append(iou_predictions[:, 0])
# l2_masks = torch.stack(l2_masks, dim=0)
# iou_preds = torch.stack(iou_preds, dim=0)
l2_masks = None
iou_preds = None
return cls_logits, l1_masks, l2_masks, iou_preds
@MODELS.register_module()
class SAMPromptConvNeck(nn.Module):
def __init__(
self,
prompt_shape=(100, 5),
img_feat_channels=1280,
out_put_channels=256,
num_img_feat_level=16,
n_cls=2,
):
super(SAMPromptConvNeck, self).__init__()
self.prompt_shape = prompt_shape
self.num_queries = prompt_shape[0]
self.per_query_point = prompt_shape[1]
self.point_size = int(math.sqrt(prompt_shape[0]))
self.img_feat_channels = img_feat_channels
self.out_put_channels = out_put_channels
self.num_img_feat_level = num_img_feat_level
self.n_cls = n_cls
# decoder_embed_dims = img_feat_channels // 32
decoder_embed_dims = 32
self.decoder_input_projs = nn.ModuleList()
# from low resolution to high resolution
for _ in range(num_img_feat_level):
self.decoder_input_projs.append(
nn.Sequential(
nn.Conv2d(img_feat_channels, decoder_embed_dims, kernel_size=1),
# nn.BatchNorm2d(decoder_embed_dims),
nn.ReLU(),
nn.Conv2d(decoder_embed_dims, decoder_embed_dims, kernel_size=3, padding=1),
# nn.BatchNorm2d(decoder_embed_dims),
nn.ReLU(),
))
self.level_embed = nn.Embedding(self.num_img_feat_level, decoder_embed_dims)
self.gather_img_feats = nn.Sequential(
nn.Conv2d(num_img_feat_level * decoder_embed_dims, out_put_channels, kernel_size=1),
# nn.BatchNorm2d(out_put_channels),
nn.ReLU(),
nn.Conv2d(out_put_channels, out_put_channels, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(out_put_channels, out_put_channels*2, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(out_put_channels * 2, out_put_channels * 2, 3, padding=1),
)
self.img_feats_pe = nn.Parameter(torch.zeros(1, out_put_channels*2, self.point_size, self.point_size))
self.cls_head = nn.Sequential(
nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(out_put_channels, n_cls, 1)
)
self.point_emb = nn.Sequential(
nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(out_put_channels, out_put_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(out_put_channels, self.per_query_point * out_put_channels, 1)
)
def forward(self, inputs):
inner_states = [x.permute(0, 3, 1, 2) for x in inputs] # from low2high, all 4 layers
bs = inner_states[0].shape[0]
# inputs: list([B, C, H, W])
num_layers = len(inputs)
# import ipdb; ipdb.set_trace()
# select the feature maps from the selected layers
layer_start_id = num_layers - self.num_img_feat_level
decoder_inputs = []
for i in range(self.num_img_feat_level):
decoder_input = self.decoder_input_projs[i](inner_states[i + layer_start_id]) # Bx256x64x64
level_embed = self.level_embed.weight[i].unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(bs, -1, -1, -1)
decoder_input = decoder_input + level_embed
decoder_inputs.append(decoder_input)
decoder_inputs = torch.cat(decoder_inputs, dim=1) # Bx256x64x64
decoder_inputs = self.gather_img_feats(decoder_inputs)
# import pdb;
# pdb.set_trace()
decoder_inputs = torch.nn.functional.interpolate(decoder_inputs, size=(self.point_size, self.point_size), mode='bilinear', align_corners=True)
img_pe = self.img_feats_pe.expand(bs, -1, -1, -1) # Bx256x64x64
decoder_inputs = decoder_inputs + img_pe
cls_logits = self.cls_head(decoder_inputs) # b c h w
cls_logits = rearrange(cls_logits, 'b c h w -> b (h w) c')
point_embs = self.point_emb(decoder_inputs) # b c h w
point_embs = rearrange(point_embs, 'b (t c) h w -> b (h w) t c', t=self.per_query_point) # Bx100x6x256
return point_embs, cls_logits
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class Transformer(nn.Module):
def __init__(
self,
depth: int = 2,
embedding_dim: int = 256,
num_heads: int = 8,
mlp_dim: int = 1024,
activation: Type[nn.Module] = nn.GELU,
attention_downsample_rate: int = 2,
) -> None:
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
AttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate
)
)
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
point_pe: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=image_embedding,
query_pe=image_pe,
keys=point_embedding,
key_pe=point_pe,
)
# Apply the final attention layer from the points to the image
q = queries + image_pe
k = keys + point_embedding
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class AttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
@MODELS.register_module()
class SAMTransformerEDPromptGenNeck(nn.Module):
def __init__(
self,
prompt_shape=(100, 5),
in_channels=[1280]*16,
inner_channels=128,
selected_channels: list=None,
num_encoders=2,
num_decoders=2,
out_channels=256,
positional_encoding=dict(num_feats=128, normalize=True),
kernel_size=3,
stride=1,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.out_channels = out_channels
self.stride = stride
self.selected_channels = selected_channels
self.prompt_shape = prompt_shape
self.num_queries = prompt_shape[0]
self.per_query_point = prompt_shape[1]
self.down_sample_layers = nn.ModuleList()
for idx in self.selected_channels:
self.down_sample_layers.append(
nn.Sequential(
ConvModule(
in_channels[idx],
inner_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
stride=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
)
)
self.fusion_layers = nn.ModuleList()
for idx in self.selected_channels:
self.fusion_layers.append(
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
self.up_layers = nn.ModuleList()
self.up_layers.append(
nn.Sequential(
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
)
self.up_layers.append(
ConvModule(
inner_channels,
out_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None
)
)
self.generator_pe = SinePositionalEncoding(**positional_encoding)
self.en_layers = nn.ModuleList()
self.de_layers = nn.ModuleList()
self.build_transformer(num_encoders=num_encoders, num_decoders=num_decoders)
self.embed_dims = self.en_layers[0].embed_dims
self.pre_norm = self.en_layers[0].pre_norm
self.query_feat = nn.Embedding(self.num_queries, out_channels)
self.query_embed = nn.Embedding(self.num_queries, out_channels)
# self.output_upscaling = nn.Sequential(
# nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
# nn.BatchNorm2d(out_channels),
# nn.GELU(),
# nn.UpsamplingBilinear2d(scale_factor=2),
# nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1),
# nn.BatchNorm2d(out_channels // 4),
# nn.GELU(),
# nn.UpsamplingBilinear2d(scale_factor=2),
# nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1),
# nn.BatchNorm2d(out_channels // 8),
# nn.GELU(),
# nn.UpsamplingBilinear2d(scale_factor=2),
# nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1),
# )
# self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3)
self.init_weights()
def build_transformer(self, num_encoders=2, num_decoders=2, embed_dims=256, num_heads=8, mlp_ratio=4):
transformer_encoder_layer = dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=dict(type='Dropout', drop_prob=0.1)
),
],
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=embed_dims * mlp_ratio,
num_fcs=2,
act_cfg=dict(type='GELU'),
ffn_drop=0.1,
add_identity=True),
operation_order=('norm', 'self_attn', 'norm', 'ffn'),
norm_cfg=dict(type='LN'),
batch_first=True
)
transformer_decoder_layer = dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=dict(type='Dropout', drop_prob=0.1)
),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=embed_dims * mlp_ratio,
num_fcs=2,
act_cfg=dict(type='GELU'),
ffn_drop=0.1,
add_identity=True),
operation_order=('norm', 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn'),
norm_cfg=dict(type='LN'),
batch_first=True
)
transformer_en_layers = [
copy.deepcopy(transformer_encoder_layer) for _ in range(num_encoders)
]
transformer_de_layers = [
copy.deepcopy(transformer_decoder_layer) for _ in range(num_decoders)
]
for i in range(num_encoders):
self.en_layers.append(build_transformer_layer(transformer_en_layers[i]))
for i in range(num_decoders):
self.de_layers.append(build_transformer_layer(transformer_de_layers[i]))
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, inputs):
_, inner_states = inputs
inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels]
inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)]
x = None
for inner_state, layer in zip(inner_states, self.fusion_layers):
if x is not None:
inner_state = x + inner_state
x = inner_state + layer(inner_state)
x = self.up_layers[0](x) + x
img_feats = self.up_layers[1](x)
bs, c, h, w = img_feats.shape
mask_pe = torch.zeros((bs, h, w), device=img_feats.device, dtype=torch.bool)
img_feats_pe = self.generator_pe(mask_pe)
query_feat = self.query_feat.weight.unsqueeze(0).repeat(
(bs, 1, 1))
query_embed = self.query_embed.weight.unsqueeze(0).repeat(
(bs, 1, 1))
encoder_inputs = rearrange(img_feats, 'b c h w -> b (h w) c')
img_feats_pe = img_feats_pe.flatten(2).permute(0, 2, 1)
# shape (batch_size, num_total_queries, c)
memory = encoder_inputs
for layer in self.en_layers:
memory = layer(
query=memory,
query_pos=img_feats_pe
)
# (batch_size, num_total_queries, c)
query_feat_list = []
for layer in self.de_layers:
query_feat = layer(
query=query_feat,
key=memory,
value=memory,
query_pos=query_embed,
key_pos=img_feats_pe
)
query_feat_list.append(query_feat)
img_feat = rearrange(memory, 'b (h w) c -> b c h w', h=h, w=w)
return query_feat, query_feat_list, img_feat
@MODELS.register_module()
class SAMAggregatorNeck(nn.Module):
def __init__(
self,
in_channels=[1280]*16,
inner_channels=128,
selected_channels: list=None,
out_channels=256,
kernel_size=3,
stride=1,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
up_sample_scale=4,
init_cfg=None,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.out_channels = out_channels
self.stride = stride
self.selected_channels = selected_channels
self.up_sample_scale = up_sample_scale
self.down_sample_layers = nn.ModuleList()
for idx in self.selected_channels:
self.down_sample_layers.append(
nn.Sequential(
ConvModule(
in_channels[idx],
inner_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
stride=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
)
)
self.fusion_layers = nn.ModuleList()
for idx in self.selected_channels:
self.fusion_layers.append(
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
self.up_layers = nn.ModuleList()
self.up_layers.append(
nn.Sequential(
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
inner_channels,
inner_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
)
self.up_layers.append(
ConvModule(
inner_channels,
out_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None
)
)
self.up_sample_layers = nn.ModuleList()
assert up_sample_scale == 4
self.up_sample_layers.append(
nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
)
self.up_sample_layers.append(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
)
self.up_sample_layers.append(
nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
),
ConvModule(
out_channels,
out_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg
)
)
)
self.up_sample_layers.append(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
)
def forward(self, inputs):
_, inner_states = inputs
inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels]
inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)]
x = None
for inner_state, layer in zip(inner_states, self.fusion_layers):
if x is not None:
inner_state = x + inner_state
x = inner_state + layer(inner_state)
x = self.up_layers[0](x) + x
img_feats_0 = self.up_layers[1](x)
img_feats_1 = self.up_sample_layers[0](img_feats_0) + self.up_sample_layers[1](img_feats_0)
img_feats_2 = self.up_sample_layers[2](img_feats_1) + self.up_sample_layers[3](img_feats_1)
return img_feats_2, img_feats_1, img_feats_0