Spaces:
Runtime error
Runtime error
import copy | |
import math | |
from typing import Type, Tuple | |
import einops | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from mmcv.cnn import ConvModule | |
from mmcv.cnn.bricks.transformer import build_transformer_layer | |
from torch import Tensor | |
from mmdet.models import SinePositionalEncoding | |
from mmpl.registry import MODELS | |
import torch.nn.functional as F | |
class SAMTransformerPromptGenNeck(nn.Module): | |
def __init__( | |
self, | |
prompt_shape=(100, 6), | |
in_channels=[1280]*16, | |
out_channels=256, | |
positional_encoding=dict(num_feats=128, normalize=True), | |
n_classes=2, | |
kernel_size=3, | |
stride=1, | |
norm_cfg=None, | |
act_cfg=dict(type='ReLU') | |
): | |
super(SAMTransformerPromptGenNeck, self).__init__() | |
self.in_channels = in_channels | |
self.kernel_size = kernel_size | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.out_put_channels = out_channels | |
self.n_classes = n_classes | |
self.stride = stride | |
self.prompt_shape = prompt_shape | |
self.num_queries = prompt_shape[0] | |
self.per_query_point = prompt_shape[1] | |
if isinstance(in_channels, list): | |
self.pre_layers = nn.ModuleList() | |
inner_channel = 32 | |
for idx, channel in enumerate(in_channels): | |
self.pre_layers.append( | |
nn.Sequential( | |
ConvModule( | |
channel, | |
inner_channel, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channel, | |
inner_channel*2, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
stride=self.stride, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channel*2, | |
inner_channel, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
) | |
) | |
self.pre_layers.append( | |
nn.Sequential( | |
ConvModule( | |
inner_channel * len(in_channels), | |
out_channels, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
) | |
) | |
self.generator_pe = SinePositionalEncoding(**positional_encoding) | |
self.transformer = self.build_transformer() | |
self.query_feat = nn.Embedding(self.num_queries, out_channels) | |
self.query_emb = nn.Embedding(self.num_queries, out_channels) | |
self.output_upscaling = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.GELU(), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels // 4), | |
nn.GELU(), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels // 8), | |
nn.GELU(), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1), | |
) | |
self.cls_head = nn.Sequential( | |
nn.Linear(out_channels, out_channels//2), | |
nn.ReLU(), | |
nn.Linear(out_channels//2, n_classes) | |
) | |
# self.point_emb = nn.Sequential( | |
# nn.Linear(out_channels, out_channels), | |
# nn.ReLU(), | |
# nn.Linear(out_channels, out_channels), | |
# nn.ReLU(), | |
# nn.Linear(out_channels, self.per_query_point * out_channels) | |
# ) | |
self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3) | |
def build_transformer( | |
self, num_encoder_layers=2, num_decoder_layers=3, embed_dims=256, num_heads=8, | |
mlp_ratio=2, dropout_rate=0.0, act_cfg=dict(type="gelu")): | |
"""Build transformer decoder.""" | |
# transformer = nn.Transformer( | |
# d_model=embed_dims, nhead=num_heads, num_encoder_layers=num_encoder_layers, | |
# num_decoder_layers=num_decoder_layers, dim_feedforward=mlp_ratio * embed_dims, | |
# dropout=dropout_rate, activation=act_cfg['type'], batch_first=True, norm_first=True, | |
# ) | |
transformer = Transformer(depth=2) | |
return transformer | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, inputs, prompt_encoder, mask_decoder): | |
img_embs, inner_states = inputs | |
if hasattr(self, 'pre_layers'): | |
inner_states = inner_states[-len(self.in_channels):] | |
inner_states = [einops.rearrange(x, 'b h w c -> b c h w') for x in inner_states] | |
inner_states = [layer(x) for layer, x in zip(self.pre_layers[:-1], inner_states)] | |
img_feats = self.pre_layers[-1](torch.cat(inner_states, dim=1)) | |
bs, c, h, w = img_feats.shape | |
mask_pe = torch.zeros((bs, h, w), device=img_feats.device) | |
img_feats_pe = self.generator_pe(mask_pe) | |
query_feat = self.query_feat.weight.unsqueeze(0).expand(bs, -1, -1) # Bx256x256 | |
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs, -1, -1) | |
img_feats, query_feats = self.transformer( | |
image_embedding=img_feats, | |
image_pe=img_feats_pe, | |
point_embedding=query_feat, | |
point_pe=query_emb) | |
cls_logits = self.cls_head(query_feats) | |
# point_embs = self.point_emb(query_feats) | |
# point_embs = rearrange(point_embs, 'b n (t c) -> b n t c', t=self.per_query_point) # Bx100x6x256 | |
src = img_feats.transpose(1, 2).view(bs, c, h, w) | |
upscaled_embedding = self.output_upscaling(src) | |
hyper_in = self.output_hypernetworks_mlps(query_feats) | |
b, c, h, w = upscaled_embedding.shape | |
l1_masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) | |
# dense_masks = einops.rearrange(l1_masks, 'b (n t) h w -> (b n) t h w', t=1) | |
# sparse, dense = prompt_encoder(points=None, boxes=None, masks=dense_masks) | |
# dense = einops.rearrange(dense, '(b n) t h w -> b n t h w', n=self.num_queries) | |
# l2_masks = [] | |
# iou_preds = [] | |
# for curr_embedding, sparse_embeddings, dense_embeddings in zip(img_embs, point_embs, dense): | |
# low_res_masks, iou_predictions = mask_decoder( | |
# image_embeddings=curr_embedding.unsqueeze(0), | |
# image_pe=prompt_encoder.get_dense_pe(), | |
# sparse_prompt_embeddings=sparse_embeddings, | |
# dense_prompt_embeddings=dense_embeddings | |
# ) | |
# l2_masks.append(low_res_masks[:, 0]) | |
# iou_preds.append(iou_predictions[:, 0]) | |
# l2_masks = torch.stack(l2_masks, dim=0) | |
# iou_preds = torch.stack(iou_preds, dim=0) | |
l2_masks = None | |
iou_preds = None | |
return cls_logits, l1_masks, l2_masks, iou_preds | |
class SAMPromptConvNeck(nn.Module): | |
def __init__( | |
self, | |
prompt_shape=(100, 5), | |
img_feat_channels=1280, | |
out_put_channels=256, | |
num_img_feat_level=16, | |
n_cls=2, | |
): | |
super(SAMPromptConvNeck, self).__init__() | |
self.prompt_shape = prompt_shape | |
self.num_queries = prompt_shape[0] | |
self.per_query_point = prompt_shape[1] | |
self.point_size = int(math.sqrt(prompt_shape[0])) | |
self.img_feat_channels = img_feat_channels | |
self.out_put_channels = out_put_channels | |
self.num_img_feat_level = num_img_feat_level | |
self.n_cls = n_cls | |
# decoder_embed_dims = img_feat_channels // 32 | |
decoder_embed_dims = 32 | |
self.decoder_input_projs = nn.ModuleList() | |
# from low resolution to high resolution | |
for _ in range(num_img_feat_level): | |
self.decoder_input_projs.append( | |
nn.Sequential( | |
nn.Conv2d(img_feat_channels, decoder_embed_dims, kernel_size=1), | |
# nn.BatchNorm2d(decoder_embed_dims), | |
nn.ReLU(), | |
nn.Conv2d(decoder_embed_dims, decoder_embed_dims, kernel_size=3, padding=1), | |
# nn.BatchNorm2d(decoder_embed_dims), | |
nn.ReLU(), | |
)) | |
self.level_embed = nn.Embedding(self.num_img_feat_level, decoder_embed_dims) | |
self.gather_img_feats = nn.Sequential( | |
nn.Conv2d(num_img_feat_level * decoder_embed_dims, out_put_channels, kernel_size=1), | |
# nn.BatchNorm2d(out_put_channels), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels, out_put_channels, 3, stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels, out_put_channels*2, 3, stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels * 2, out_put_channels * 2, 3, padding=1), | |
) | |
self.img_feats_pe = nn.Parameter(torch.zeros(1, out_put_channels*2, self.point_size, self.point_size)) | |
self.cls_head = nn.Sequential( | |
nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels, n_cls, 1) | |
) | |
self.point_emb = nn.Sequential( | |
nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels, out_put_channels, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(out_put_channels, self.per_query_point * out_put_channels, 1) | |
) | |
def forward(self, inputs): | |
inner_states = [x.permute(0, 3, 1, 2) for x in inputs] # from low2high, all 4 layers | |
bs = inner_states[0].shape[0] | |
# inputs: list([B, C, H, W]) | |
num_layers = len(inputs) | |
# import ipdb; ipdb.set_trace() | |
# select the feature maps from the selected layers | |
layer_start_id = num_layers - self.num_img_feat_level | |
decoder_inputs = [] | |
for i in range(self.num_img_feat_level): | |
decoder_input = self.decoder_input_projs[i](inner_states[i + layer_start_id]) # Bx256x64x64 | |
level_embed = self.level_embed.weight[i].unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(bs, -1, -1, -1) | |
decoder_input = decoder_input + level_embed | |
decoder_inputs.append(decoder_input) | |
decoder_inputs = torch.cat(decoder_inputs, dim=1) # Bx256x64x64 | |
decoder_inputs = self.gather_img_feats(decoder_inputs) | |
# import pdb; | |
# pdb.set_trace() | |
decoder_inputs = torch.nn.functional.interpolate(decoder_inputs, size=(self.point_size, self.point_size), mode='bilinear', align_corners=True) | |
img_pe = self.img_feats_pe.expand(bs, -1, -1, -1) # Bx256x64x64 | |
decoder_inputs = decoder_inputs + img_pe | |
cls_logits = self.cls_head(decoder_inputs) # b c h w | |
cls_logits = rearrange(cls_logits, 'b c h w -> b (h w) c') | |
point_embs = self.point_emb(decoder_inputs) # b c h w | |
point_embs = rearrange(point_embs, 'b (t c) h w -> b (h w) t c', t=self.per_query_point) # Bx100x6x256 | |
return point_embs, cls_logits | |
class MLPBlock(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
mlp_dim: int, | |
act: Type[nn.Module] = nn.GELU, | |
) -> None: | |
super().__init__() | |
self.lin1 = nn.Linear(embedding_dim, mlp_dim) | |
self.lin2 = nn.Linear(mlp_dim, embedding_dim) | |
self.act = act() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.lin2(self.act(self.lin1(x))) | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
depth: int = 2, | |
embedding_dim: int = 256, | |
num_heads: int = 8, | |
mlp_dim: int = 1024, | |
activation: Type[nn.Module] = nn.GELU, | |
attention_downsample_rate: int = 2, | |
) -> None: | |
super().__init__() | |
self.depth = depth | |
self.embedding_dim = embedding_dim | |
self.num_heads = num_heads | |
self.mlp_dim = mlp_dim | |
self.layers = nn.ModuleList() | |
for i in range(depth): | |
self.layers.append( | |
AttentionBlock( | |
embedding_dim=embedding_dim, | |
num_heads=num_heads, | |
mlp_dim=mlp_dim, | |
activation=activation, | |
attention_downsample_rate=attention_downsample_rate | |
) | |
) | |
self.final_attn_token_to_image = Attention( | |
embedding_dim, num_heads, downsample_rate=attention_downsample_rate | |
) | |
self.norm_final_attn = nn.LayerNorm(embedding_dim) | |
def forward( | |
self, | |
image_embedding: Tensor, | |
image_pe: Tensor, | |
point_embedding: Tensor, | |
point_pe: Tensor, | |
) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
image_embedding (torch.Tensor): image to attend to. Should be shape | |
B x embedding_dim x h x w for any h and w. | |
image_pe (torch.Tensor): the positional encoding to add to the image. Must | |
have the same shape as image_embedding. | |
point_embedding (torch.Tensor): the embedding to add to the query points. | |
Must have shape B x N_points x embedding_dim for any N_points. | |
Returns: | |
torch.Tensor: the processed point_embedding | |
torch.Tensor: the processed image_embedding | |
""" | |
# BxCxHxW -> BxHWxC == B x N_image_tokens x C | |
bs, c, h, w = image_embedding.shape | |
image_embedding = image_embedding.flatten(2).permute(0, 2, 1) | |
image_pe = image_pe.flatten(2).permute(0, 2, 1) | |
# Apply transformer blocks and final layernorm | |
for layer in self.layers: | |
queries, keys = layer( | |
queries=image_embedding, | |
query_pe=image_pe, | |
keys=point_embedding, | |
key_pe=point_pe, | |
) | |
# Apply the final attention layer from the points to the image | |
q = queries + image_pe | |
k = keys + point_embedding | |
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) | |
queries = queries + attn_out | |
queries = self.norm_final_attn(queries) | |
return queries, keys | |
class AttentionBlock(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
num_heads: int, | |
mlp_dim: int = 2048, | |
activation: Type[nn.Module] = nn.ReLU, | |
attention_downsample_rate: int = 2, | |
skip_first_layer_pe: bool = False, | |
) -> None: | |
""" | |
A transformer block with four layers: (1) self-attention of sparse | |
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp | |
block on sparse inputs, and (4) cross attention of dense inputs to sparse | |
inputs. | |
Arguments: | |
embedding_dim (int): the channel dimension of the embeddings | |
num_heads (int): the number of heads in the attention layers | |
mlp_dim (int): the hidden dimension of the mlp block | |
activation (nn.Module): the activation of the mlp block | |
skip_first_layer_pe (bool): skip the PE on the first layer | |
""" | |
super().__init__() | |
self.self_attn = Attention(embedding_dim, num_heads) | |
self.norm1 = nn.LayerNorm(embedding_dim) | |
self.cross_attn_token_to_image = Attention( | |
embedding_dim, num_heads, downsample_rate=attention_downsample_rate | |
) | |
self.norm2 = nn.LayerNorm(embedding_dim) | |
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) | |
self.norm3 = nn.LayerNorm(embedding_dim) | |
self.norm4 = nn.LayerNorm(embedding_dim) | |
self.cross_attn_image_to_token = Attention( | |
embedding_dim, num_heads, downsample_rate=attention_downsample_rate | |
) | |
self.skip_first_layer_pe = skip_first_layer_pe | |
def forward( | |
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor | |
) -> Tuple[Tensor, Tensor]: | |
# Self attention block | |
if self.skip_first_layer_pe: | |
queries = self.self_attn(q=queries, k=queries, v=queries) | |
else: | |
q = queries + query_pe | |
attn_out = self.self_attn(q=q, k=q, v=queries) | |
queries = queries + attn_out | |
queries = self.norm1(queries) | |
# Cross attention block, tokens attending to image embedding | |
q = queries + query_pe | |
k = keys + key_pe | |
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) | |
queries = queries + attn_out | |
queries = self.norm2(queries) | |
# MLP block | |
mlp_out = self.mlp(queries) | |
queries = queries + mlp_out | |
queries = self.norm3(queries) | |
# Cross attention block, image embedding attending to tokens | |
q = queries + query_pe | |
k = keys + key_pe | |
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) | |
keys = keys + attn_out | |
keys = self.norm4(keys) | |
return queries, keys | |
class Attention(nn.Module): | |
""" | |
An attention layer that allows for downscaling the size of the embedding | |
after projection to queries, keys, and values. | |
""" | |
def __init__( | |
self, | |
embedding_dim: int, | |
num_heads: int, | |
downsample_rate: int = 1, | |
) -> None: | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.internal_dim = embedding_dim // downsample_rate | |
self.num_heads = num_heads | |
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." | |
self.q_proj = nn.Linear(embedding_dim, self.internal_dim) | |
self.k_proj = nn.Linear(embedding_dim, self.internal_dim) | |
self.v_proj = nn.Linear(embedding_dim, self.internal_dim) | |
self.out_proj = nn.Linear(self.internal_dim, embedding_dim) | |
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: | |
b, n, c = x.shape | |
x = x.reshape(b, n, num_heads, c // num_heads) | |
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head | |
def _recombine_heads(self, x: Tensor) -> Tensor: | |
b, n_heads, n_tokens, c_per_head = x.shape | |
x = x.transpose(1, 2) | |
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C | |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
# Input projections | |
q = self.q_proj(q) | |
k = self.k_proj(k) | |
v = self.v_proj(v) | |
# Separate into heads | |
q = self._separate_heads(q, self.num_heads) | |
k = self._separate_heads(k, self.num_heads) | |
v = self._separate_heads(v, self.num_heads) | |
# Attention | |
_, _, _, c_per_head = q.shape | |
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens | |
attn = attn / math.sqrt(c_per_head) | |
attn = torch.softmax(attn, dim=-1) | |
# Get output | |
out = attn @ v | |
out = self._recombine_heads(out) | |
out = self.out_proj(out) | |
return out | |
class LayerNorm2d(nn.Module): | |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(num_channels)) | |
self.bias = nn.Parameter(torch.zeros(num_channels)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_layers: int, | |
sigmoid_output: bool = False, | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList( | |
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
) | |
self.sigmoid_output = sigmoid_output | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
if self.sigmoid_output: | |
x = F.sigmoid(x) | |
return x | |
class SAMTransformerEDPromptGenNeck(nn.Module): | |
def __init__( | |
self, | |
prompt_shape=(100, 5), | |
in_channels=[1280]*16, | |
inner_channels=128, | |
selected_channels: list=None, | |
num_encoders=2, | |
num_decoders=2, | |
out_channels=256, | |
positional_encoding=dict(num_feats=128, normalize=True), | |
kernel_size=3, | |
stride=1, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='ReLU', inplace=True), | |
init_cfg=None, | |
**kwargs | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.kernel_size = kernel_size | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.out_channels = out_channels | |
self.stride = stride | |
self.selected_channels = selected_channels | |
self.prompt_shape = prompt_shape | |
self.num_queries = prompt_shape[0] | |
self.per_query_point = prompt_shape[1] | |
self.down_sample_layers = nn.ModuleList() | |
for idx in self.selected_channels: | |
self.down_sample_layers.append( | |
nn.Sequential( | |
ConvModule( | |
in_channels[idx], | |
inner_channels, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
stride=2, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
) | |
) | |
self.fusion_layers = nn.ModuleList() | |
for idx in self.selected_channels: | |
self.fusion_layers.append( | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
self.up_layers = nn.ModuleList() | |
self.up_layers.append( | |
nn.Sequential( | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
) | |
self.up_layers.append( | |
ConvModule( | |
inner_channels, | |
out_channels, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None | |
) | |
) | |
self.generator_pe = SinePositionalEncoding(**positional_encoding) | |
self.en_layers = nn.ModuleList() | |
self.de_layers = nn.ModuleList() | |
self.build_transformer(num_encoders=num_encoders, num_decoders=num_decoders) | |
self.embed_dims = self.en_layers[0].embed_dims | |
self.pre_norm = self.en_layers[0].pre_norm | |
self.query_feat = nn.Embedding(self.num_queries, out_channels) | |
self.query_embed = nn.Embedding(self.num_queries, out_channels) | |
# self.output_upscaling = nn.Sequential( | |
# nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
# nn.BatchNorm2d(out_channels), | |
# nn.GELU(), | |
# nn.UpsamplingBilinear2d(scale_factor=2), | |
# nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1), | |
# nn.BatchNorm2d(out_channels // 4), | |
# nn.GELU(), | |
# nn.UpsamplingBilinear2d(scale_factor=2), | |
# nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1), | |
# nn.BatchNorm2d(out_channels // 8), | |
# nn.GELU(), | |
# nn.UpsamplingBilinear2d(scale_factor=2), | |
# nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1), | |
# ) | |
# self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3) | |
self.init_weights() | |
def build_transformer(self, num_encoders=2, num_decoders=2, embed_dims=256, num_heads=8, mlp_ratio=4): | |
transformer_encoder_layer = dict( | |
type='BaseTransformerLayer', | |
attn_cfgs=[ | |
dict( | |
type='MultiheadAttention', | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
attn_drop=0.1, | |
proj_drop=0.1, | |
dropout_layer=dict(type='Dropout', drop_prob=0.1) | |
), | |
], | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=embed_dims, | |
feedforward_channels=embed_dims * mlp_ratio, | |
num_fcs=2, | |
act_cfg=dict(type='GELU'), | |
ffn_drop=0.1, | |
add_identity=True), | |
operation_order=('norm', 'self_attn', 'norm', 'ffn'), | |
norm_cfg=dict(type='LN'), | |
batch_first=True | |
) | |
transformer_decoder_layer = dict( | |
type='BaseTransformerLayer', | |
attn_cfgs=dict( | |
type='MultiheadAttention', | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
attn_drop=0.1, | |
proj_drop=0.1, | |
dropout_layer=dict(type='Dropout', drop_prob=0.1) | |
), | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=embed_dims, | |
feedforward_channels=embed_dims * mlp_ratio, | |
num_fcs=2, | |
act_cfg=dict(type='GELU'), | |
ffn_drop=0.1, | |
add_identity=True), | |
operation_order=('norm', 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn'), | |
norm_cfg=dict(type='LN'), | |
batch_first=True | |
) | |
transformer_en_layers = [ | |
copy.deepcopy(transformer_encoder_layer) for _ in range(num_encoders) | |
] | |
transformer_de_layers = [ | |
copy.deepcopy(transformer_decoder_layer) for _ in range(num_decoders) | |
] | |
for i in range(num_encoders): | |
self.en_layers.append(build_transformer_layer(transformer_en_layers[i])) | |
for i in range(num_decoders): | |
self.de_layers.append(build_transformer_layer(transformer_de_layers[i])) | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, inputs): | |
_, inner_states = inputs | |
inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels] | |
inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)] | |
x = None | |
for inner_state, layer in zip(inner_states, self.fusion_layers): | |
if x is not None: | |
inner_state = x + inner_state | |
x = inner_state + layer(inner_state) | |
x = self.up_layers[0](x) + x | |
img_feats = self.up_layers[1](x) | |
bs, c, h, w = img_feats.shape | |
mask_pe = torch.zeros((bs, h, w), device=img_feats.device, dtype=torch.bool) | |
img_feats_pe = self.generator_pe(mask_pe) | |
query_feat = self.query_feat.weight.unsqueeze(0).repeat( | |
(bs, 1, 1)) | |
query_embed = self.query_embed.weight.unsqueeze(0).repeat( | |
(bs, 1, 1)) | |
encoder_inputs = rearrange(img_feats, 'b c h w -> b (h w) c') | |
img_feats_pe = img_feats_pe.flatten(2).permute(0, 2, 1) | |
# shape (batch_size, num_total_queries, c) | |
memory = encoder_inputs | |
for layer in self.en_layers: | |
memory = layer( | |
query=memory, | |
query_pos=img_feats_pe | |
) | |
# (batch_size, num_total_queries, c) | |
query_feat_list = [] | |
for layer in self.de_layers: | |
query_feat = layer( | |
query=query_feat, | |
key=memory, | |
value=memory, | |
query_pos=query_embed, | |
key_pos=img_feats_pe | |
) | |
query_feat_list.append(query_feat) | |
img_feat = rearrange(memory, 'b (h w) c -> b c h w', h=h, w=w) | |
return query_feat, query_feat_list, img_feat | |
class SAMAggregatorNeck(nn.Module): | |
def __init__( | |
self, | |
in_channels=[1280]*16, | |
inner_channels=128, | |
selected_channels: list=None, | |
out_channels=256, | |
kernel_size=3, | |
stride=1, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='ReLU', inplace=True), | |
up_sample_scale=4, | |
init_cfg=None, | |
**kwargs | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.kernel_size = kernel_size | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.out_channels = out_channels | |
self.stride = stride | |
self.selected_channels = selected_channels | |
self.up_sample_scale = up_sample_scale | |
self.down_sample_layers = nn.ModuleList() | |
for idx in self.selected_channels: | |
self.down_sample_layers.append( | |
nn.Sequential( | |
ConvModule( | |
in_channels[idx], | |
inner_channels, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
stride=2, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
) | |
) | |
self.fusion_layers = nn.ModuleList() | |
for idx in self.selected_channels: | |
self.fusion_layers.append( | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
self.up_layers = nn.ModuleList() | |
self.up_layers.append( | |
nn.Sequential( | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
inner_channels, | |
inner_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
) | |
self.up_layers.append( | |
ConvModule( | |
inner_channels, | |
out_channels, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None | |
) | |
) | |
self.up_sample_layers = nn.ModuleList() | |
assert up_sample_scale == 4 | |
self.up_sample_layers.append( | |
nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
) | |
self.up_sample_layers.append( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
) | |
self.up_sample_layers.append( | |
nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
), | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg | |
) | |
) | |
) | |
self.up_sample_layers.append( | |
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
) | |
def forward(self, inputs): | |
_, inner_states = inputs | |
inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels] | |
inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)] | |
x = None | |
for inner_state, layer in zip(inner_states, self.fusion_layers): | |
if x is not None: | |
inner_state = x + inner_state | |
x = inner_state + layer(inner_state) | |
x = self.up_layers[0](x) + x | |
img_feats_0 = self.up_layers[1](x) | |
img_feats_1 = self.up_sample_layers[0](img_feats_0) + self.up_sample_layers[1](img_feats_0) | |
img_feats_2 = self.up_sample_layers[2](img_feats_1) + self.up_sample_layers[3](img_feats_1) | |
return img_feats_2, img_feats_1, img_feats_0 |