Spaces:
Runtime error
Runtime error
import fvcore.nn.weight_init as weight_init | |
from typing import Optional | |
import numpy as np | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import functional as F | |
from einops import repeat | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d | |
from detectron2.utils.registry import Registry | |
from .position_encoding import PositionEmbeddingSine | |
TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE") | |
TRANSFORMER_DECODER_REGISTRY.__doc__ = """ | |
Registry for transformer module in MaskFormer. | |
""" | |
def build_transformer_decoder(cfg, in_channels, mask_classification=True): | |
""" | |
Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. | |
""" | |
name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME | |
return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification) | |
def get_classification_logits(x, text_classifier, logit_scale, num_templates=None): | |
# x in shape of [B, *, C] | |
# text_classifier in shape of [num_classes, C] | |
# logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201 | |
# return: [B, *, num_classes] | |
x = F.normalize(x, dim=-1) | |
logit_scale = torch.clamp(logit_scale.exp(), max=100) | |
pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1 | |
# max ensembel as in OpenSeg/ODISE | |
final_pred_logits = [] | |
cur_idx = 0 | |
for num_t in num_templates: | |
final_pred_logits.append(pred_logits[:, :, cur_idx: cur_idx + num_t].max(-1).values) | |
cur_idx += num_t | |
final_pred_logits.append(pred_logits[:, :, -1]) # the last classifier is for void | |
final_pred_logits = torch.stack(final_pred_logits, dim=-1) | |
return final_pred_logits | |
class MaskPooling(nn.Module): | |
def __init__( | |
self, | |
): | |
super().__init__() | |
def forward(self, x, mask): | |
""" | |
Args: | |
x: [B, C, H, W] | |
mask: [B, Q, H, W] | |
""" | |
if not x.shape[-2:] == mask.shape[-2:]: | |
# reshape mask to x | |
mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) | |
with torch.no_grad(): | |
mask = mask.detach() | |
mask = (mask > 0).to(mask.dtype) | |
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 | |
mask_pooled_x = torch.einsum( | |
"bchw,bqhw->bqc", | |
x, | |
mask / denorm, | |
) | |
return mask_pooled_x | |
class SelfAttentionLayer(nn.Module): | |
def __init__(self, d_model, nhead, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
q = k = self.with_pos_embed(tgt, query_pos) | |
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.norm(tgt) | |
q = k = self.with_pos_embed(tgt2, query_pos) | |
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
if self.normalize_before: | |
return self.forward_pre(tgt, tgt_mask, | |
tgt_key_padding_mask, query_pos) | |
return self.forward_post(tgt, tgt_mask, | |
tgt_key_padding_mask, query_pos) | |
class CrossAttentionLayer(nn.Module): | |
def __init__(self, d_model, nhead, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.norm(tgt) | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
if self.normalize_before: | |
return self.forward_pre(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
return self.forward_post(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
def get_attention(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.norm(tgt) | |
tgt2, atten_weight = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
average_attn_weights=False) | |
return atten_weight | |
class CrossAttentionLayer_MINI(nn.Module): | |
def __init__(self, d_model, nhead, dropout=0.0, | |
activation="relu", normalize_before=False, downsample_ratio=1, kernel_size=3): | |
super().__init__() | |
self.downsample_ratio = downsample_ratio | |
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
# positional encoding | |
N_steps = d_model/2 | |
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None): | |
tgt2 = self.norm(tgt) | |
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
key=self.with_pos_embed(memory, pos), | |
value=memory, attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask)[0] | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def reshape_and_downsample(self, x, down_sam_size, curr_size): | |
n, b, c = x.shape | |
h, w = curr_size[0], curr_size[1] # Assuming square shape for simplicity | |
x = x.view(h, w, b, c) # Reshape to (height, width, batch, channels) | |
x = x.permute(2, 3, 0, 1) # Reorder to (batch, channels, height, width) | |
x = F.interpolate( | |
x, | |
size=down_sam_size, | |
mode="bilinear", | |
align_corners=False, | |
) | |
x = x.permute(2, 3, 0, 1).view(down_sam_size[0] * down_sam_size[1], b, c) # Reshape to (n', b, c) | |
return x | |
def upsample_and_reshape(self, x, upsample_size, curr_size): | |
n, b, c = x.shape | |
h, w = curr_size[0], curr_size[1] # Assuming square shape for simplicity | |
x = x.view(h, w, b, c) # Reshape to (height, width, batch, channels) | |
x = x.permute(2, 3, 0, 1) # Reorder to (batch, channels, height, width) | |
x = F.interpolate( | |
x, | |
size=upsample_size, | |
mode="bilinear", | |
align_corners=False, | |
) | |
x = x.permute(2, 3, 0, 1).view(upsample_size[0] * upsample_size[1], b, c) # Reshape to (n', b, c) | |
return x, None | |
def forward(self, tgt, memory, | |
memory_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
query_pos: Optional[Tensor] = None, | |
down_sample_ration = None, | |
min_size = None, | |
curr_size = None, | |
): | |
min_size_clip = min_size[0] | |
min_size_sam = min_size[1] | |
curr_size_clip = curr_size[0] | |
curr_size_sam = curr_size[1] | |
tgt = self.reshape_and_downsample(tgt, min_size_clip, curr_size_clip) | |
memory = self.reshape_and_downsample(memory, min_size_sam, curr_size_sam) | |
query_pos = self.reshape_and_downsample(query_pos, min_size_clip, curr_size_clip) | |
pos = self.reshape_and_downsample(pos, min_size_sam, curr_size_sam) | |
if self.normalize_before: | |
tgt = self.forward_pre(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
tgt = self.upsample_and_reshape(tgt) | |
return tgt | |
tgt = self.forward_post(tgt, memory, memory_mask, | |
memory_key_padding_mask, pos, query_pos) | |
tgt, new_pos = self.upsample_and_reshape(tgt, curr_size_clip, min_size_clip) | |
return tgt, new_pos | |
class FFNLayer(nn.Module): | |
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, | |
activation="relu", normalize_before=False): | |
super().__init__() | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm = nn.LayerNorm(d_model) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post(self, tgt): | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout(tgt2) | |
tgt = self.norm(tgt) | |
return tgt | |
def forward_pre(self, tgt): | |
tgt2 = self.norm(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
tgt = tgt + self.dropout(tgt2) | |
return tgt | |
def forward(self, tgt): | |
if self.normalize_before: | |
return self.forward_pre(tgt) | |
return self.forward_post(tgt) | |
def _get_activation_fn(activation): | |
"""Return an activation function given a string""" | |
if activation == "relu": | |
return F.relu | |
if activation == "gelu": | |
return F.gelu | |
if activation == "glu": | |
return F.glu | |
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |
class MLP(nn.Module): | |
""" Very simple multi-layer perceptron (also called FFN)""" | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
return x | |
class MultiScaleMaskedTransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
mask_classification=True, | |
*, | |
hidden_dim: int, | |
num_queries: int, | |
nheads: int, | |
dim_feedforward: int, | |
dec_layers: int, | |
pre_norm: bool, | |
mask_dim: int, | |
enforce_input_project: bool, | |
clip_embedding_dim: int, | |
sam_query_fuse_layer: int = 0, | |
sam_feature_fuse_layer: int = 0, | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
in_channels: channels of the input features | |
mask_classification: whether to add mask classifier or not | |
num_classes: number of classes | |
hidden_dim: Transformer feature dimension | |
num_queries: number of queries | |
nheads: number of heads | |
dim_feedforward: feature dimension in feedforward network | |
enc_layers: number of Transformer encoder layers | |
dec_layers: number of Transformer decoder layers | |
pre_norm: whether to use pre-LayerNorm or not | |
mask_dim: mask feature dimension | |
enforce_input_project: add input project 1x1 conv even if input | |
channels and hidden dim is identical | |
""" | |
super().__init__() | |
assert mask_classification, "Only support mask classification model" | |
self.mask_classification = mask_classification | |
# positional encoding | |
N_steps = hidden_dim // 2 | |
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) | |
# define Transformer decoder here | |
self.num_heads = nheads | |
self.num_layers = dec_layers | |
self.transformer_self_attention_layers = nn.ModuleList() | |
self.transformer_cross_attention_layers = nn.ModuleList() | |
self.transformer_cross_attention_layers_sam = nn.ModuleList() | |
self.transformer_ffn_layers = nn.ModuleList() | |
self.atten_sam_layers = 50 | |
self.num_feature_levels = 3 | |
for i in range(self.num_layers): | |
self.transformer_self_attention_layers.append( | |
SelfAttentionLayer( | |
d_model=hidden_dim, | |
nhead=nheads, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
level_index = i % self.num_feature_levels | |
if level_index == 0: | |
self.transformer_cross_attention_layers_sam.append( | |
CrossAttentionLayer_MINI( | |
d_model=hidden_dim, | |
nhead=nheads, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
downsample_ratio=int(i%3) | |
) | |
) | |
else: | |
self.transformer_cross_attention_layers_sam.append( | |
nn.Identity() | |
) | |
self.transformer_cross_attention_layers.append( | |
CrossAttentionLayer( | |
d_model=hidden_dim, | |
nhead=nheads, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
self.transformer_ffn_layers.append( | |
FFNLayer( | |
d_model=hidden_dim, | |
dim_feedforward=dim_feedforward, | |
dropout=0.0, | |
normalize_before=pre_norm, | |
) | |
) | |
self.decoder_norm = nn.LayerNorm(hidden_dim) | |
self.num_queries = num_queries | |
# learnable query features | |
self.query_feat = nn.Embedding(num_queries, hidden_dim) | |
# learnable query p.e. | |
self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
# level embedding (we always use 3 scales) | |
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) | |
self.input_proj = nn.ModuleList() | |
for _ in range(self.num_feature_levels): | |
if in_channels != hidden_dim or enforce_input_project: | |
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
weight_init.c2_xavier_fill(self.input_proj[-1]) | |
else: | |
self.input_proj.append(nn.Sequential()) | |
self.input_proj_sam = nn.ModuleList() | |
for _ in range(self.num_feature_levels): | |
if in_channels != hidden_dim or enforce_input_project: | |
self.input_proj_sam.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
weight_init.c2_xavier_fill(self.input_proj[-1]) | |
else: | |
self.input_proj_sam.append(nn.Sequential()) | |
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) | |
self.mask_pooling = MaskPooling() | |
self._mask_pooling_proj = nn.Sequential( | |
nn.LayerNorm(hidden_dim), | |
nn.Linear(hidden_dim, hidden_dim)) | |
self.class_embed = MLP(hidden_dim, hidden_dim, clip_embedding_dim, 3) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.sam_query_fuse_layer = sam_query_fuse_layer | |
self.sam_feature_fuse_layer = sam_feature_fuse_layer | |
def from_config(cls, cfg, in_channels, mask_classification): | |
ret = {} | |
ret["in_channels"] = in_channels | |
ret["mask_classification"] = mask_classification | |
ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM | |
ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES | |
# Transformer parameters: | |
ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS | |
ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD | |
# NOTE: because we add learnable query features which requires supervision, | |
# we add minus 1 to decoder layers to be consistent with our loss | |
# implementation: that is, number of auxiliary losses is always | |
# equal to number of decoder layers. With learnable query features, the number of | |
# auxiliary losses equals number of decoders plus 1. | |
assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1 | |
ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 | |
ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM | |
ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ | |
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM | |
ret["clip_embedding_dim"] = cfg.MODEL.FROZEN_SEG.EMBED_DIM | |
ret["sam_query_fuse_layer"] = cfg.MODEL.MASK_FORMER.SAM_QUERY_FUSE_LAYER | |
ret["sam_feature_fuse_layer"] = cfg.MODEL.MASK_FORMER.SAM_FEATURE_FUSE_LAYER | |
return ret | |
def resize_feat(self, x, resize_shape): | |
x = F.interpolate( | |
x, | |
size=(resize_shape[0], resize_shape[1]), | |
mode="bilinear", | |
align_corners=False, | |
) | |
return x | |
def forward(self, x, mask_features, mask = None, text_classifier=None, num_templates=None, sam_embedding=None, sam=None, sam_fpn=None): | |
# x is a list of multi-scale feature | |
visualize_attention = False | |
assert len(x) == self.num_feature_levels | |
src = [] | |
pos = [] | |
size_list = [] | |
# disable mask, it does not affect performance | |
del mask | |
src_sam = [] | |
pos_sam = [] | |
size_list_sam = [] | |
for i in range(self.num_feature_levels): | |
size_list.append(x[i].shape[-2:]) | |
pos.append(self.pe_layer(x[i], None).flatten(2)) | |
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) | |
# flatten NxCxHxW to HWxNxC | |
pos[-1] = pos[-1].permute(2, 0, 1) | |
src[-1] = src[-1].permute(2, 0, 1) | |
for i in range(len(sam_fpn)): | |
sam_src_curr = sam_fpn[i] | |
if sam_src_curr.shape[-2:] != x[i].shape[-2:] and not self.training: | |
sam_src_curr = self.resize_feat(sam_src_curr, x[i].shape[-2:]) | |
size_list_sam.append(sam_src_curr.shape[-2:]) | |
else: | |
size_list_sam.append(sam_src_curr.shape[-2:]) | |
pos_sam.append(self.pe_layer(sam_src_curr, None).flatten(2)) | |
src_sam.append(self.input_proj_sam[i](sam_src_curr).flatten(2) + self.level_embed.weight[i][None, :, None]) | |
# flatten NxCxHxW to HWxNxC | |
pos_sam[-1] = pos_sam[-1].permute(2, 0, 1) | |
src_sam[-1] = src_sam[-1].permute(2, 0, 1) | |
_, bs, _ = src[0].shape | |
# QxNxC | |
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) | |
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) | |
predictions_class = [] | |
predictions_mask = [] | |
outputs_class, outputs_mask, attn_mask, sam_pool_emb = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], | |
text_classifier=text_classifier, num_templates=num_templates, sam_embedding=sam_embedding) | |
predictions_class.append(outputs_class) | |
predictions_mask.append(outputs_mask) | |
min_size_clip = size_list[0] | |
min_size_sam = size_list_sam[0] | |
assert len(size_list_sam) == 1, "Only support one scale for sam" | |
size_list_sam = size_list_sam * self.num_feature_levels | |
for i in range(self.num_layers): | |
level_index = i % self.num_feature_levels | |
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
############# Feature Injector ############## | |
if level_index==self.sam_feature_fuse_layer: | |
clip_size_curr = size_list[level_index] | |
sam_size_curr = size_list_sam[level_index] | |
clip_sam, new_pos = self.transformer_cross_attention_layers_sam[i]( | |
src[level_index], src_sam[0], | |
memory_key_padding_mask=None, | |
pos=pos_sam[0], query_pos=pos[level_index], | |
down_sample_ration = level_index, | |
min_size=(min_size_clip, min_size_sam), | |
curr_size=(clip_size_curr, sam_size_curr), | |
) | |
cross_pos = new_pos | |
else: | |
clip_sam = src[level_index] | |
cross_pos = pos[level_index] | |
######## Feature Injector ############ | |
output = self.transformer_cross_attention_layers[i]( | |
output, clip_sam, | |
memory_mask=attn_mask, | |
memory_key_padding_mask=None, # here we do not apply masking on padded region | |
pos=cross_pos, query_pos=query_embed | |
) | |
output = self.transformer_self_attention_layers[i]( | |
output, tgt_mask=None, | |
tgt_key_padding_mask=None, | |
query_pos=query_embed | |
) | |
output = self.transformer_ffn_layers[i]( | |
output | |
) | |
######## Query Injector ############ | |
outputs_class, outputs_mask, attn_mask, sam_pool_emb = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], | |
text_classifier=text_classifier, num_templates=num_templates, sam_embedding=sam_embedding, sam=sam) | |
if level_index == self.sam_query_fuse_layer: | |
output = output + sam_pool_emb | |
predictions_class.append(outputs_class) | |
predictions_mask.append(outputs_mask) | |
assert len(predictions_class) == self.num_layers + 1 | |
out = { | |
'pred_logits': predictions_class[-1], | |
'pred_masks': predictions_mask[-1], | |
'aux_outputs': self._set_aux_loss( | |
predictions_class if self.mask_classification else None, predictions_mask | |
), | |
} | |
return out | |
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, text_classifier, num_templates, sam_embedding = None, sam=None): | |
decoder_output = self.decoder_norm(output) | |
decoder_output = decoder_output.transpose(0, 1) | |
mask_embed = self.mask_embed(decoder_output) | |
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) #b q 256 256 | |
maskpool_embeddings = self.mask_pooling(x=mask_features, mask=outputs_mask) # [B, Q, C] | |
maskpool_embeddings = self._mask_pooling_proj(maskpool_embeddings) | |
sam_maskpool_embeddings = self.mask_pooling(x=sam_embedding[0], mask=outputs_mask) # [B, Q, C] | |
sam_maskpool_embeddings = self._mask_pooling_proj(sam_maskpool_embeddings) | |
sam_maskpool_embeddings = sam_maskpool_embeddings.transpose(0, 1) | |
class_embed = self.class_embed(maskpool_embeddings + decoder_output) | |
outputs_class = get_classification_logits(class_embed, text_classifier, self.logit_scale, num_templates) | |
# NOTE: prediction is of higher-resolution | |
# [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] | |
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) | |
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. | |
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() | |
attn_mask = attn_mask.detach() | |
return outputs_class, outputs_mask, attn_mask, sam_maskpool_embeddings | |
def _set_aux_loss(self, outputs_class, outputs_seg_masks): | |
# this is a workaround to make torchscript happy, as torchscript | |
# doesn't support dictionary with non-homogeneous values, such | |
# as a dict having both a Tensor and a list. | |
if self.mask_classification: | |
return [ | |
{"pred_logits": a, "pred_masks": b} | |
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) | |
] | |
else: | |
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] | |