|
|
|
|
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .position_encoding import PositionEmbeddingSine |
|
from .transformer import Transformer |
|
|
|
|
|
class StandardTransformerDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
num_classes, |
|
mask_classification=True, |
|
hidden_dim=256, |
|
num_queries=100, |
|
nheads=8, |
|
dropout=0.0, |
|
dim_feedforward=2048, |
|
enc_layers=0, |
|
dec_layers=10, |
|
pre_norm=False, |
|
deep_supervision=True, |
|
mask_dim=256, |
|
enforce_input_project=False |
|
): |
|
super().__init__() |
|
self.mask_classification = mask_classification |
|
|
|
N_steps = hidden_dim // 2 |
|
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) |
|
|
|
transformer = Transformer( |
|
d_model=hidden_dim, |
|
dropout=dropout, |
|
nhead=nheads, |
|
dim_feedforward=dim_feedforward, |
|
num_encoder_layers=enc_layers, |
|
num_decoder_layers=dec_layers, |
|
normalize_before=pre_norm, |
|
return_intermediate_dec=deep_supervision, |
|
) |
|
|
|
self.num_queries = num_queries |
|
self.transformer = transformer |
|
hidden_dim = transformer.d_model |
|
|
|
self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
|
if in_channels != hidden_dim or enforce_input_project: |
|
self.input_proj = nn.Conv3d(in_channels, hidden_dim, kernel_size=1) |
|
weight_init.c2_xavier_fill(self.input_proj) |
|
else: |
|
self.input_proj = nn.Sequential() |
|
self.aux_loss = deep_supervision |
|
|
|
|
|
if self.mask_classification: |
|
self.class_embed = nn.Linear(hidden_dim, num_classes + 1) |
|
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
|
def forward(self, x, mask_features, mask=None): |
|
if mask is not None: |
|
mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0] |
|
pos = self.pe_layer(x, mask) |
|
|
|
src = x |
|
hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) |
|
|
|
if self.mask_classification: |
|
outputs_class = self.class_embed(hs) |
|
out = {"pred_logits": outputs_class[-1]} |
|
else: |
|
out = {} |
|
|
|
if self.aux_loss: |
|
|
|
mask_embed = self.mask_embed(hs) |
|
outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) |
|
out["pred_masks"] = outputs_seg_masks[-1] |
|
out["aux_outputs"] = self._set_aux_loss( |
|
outputs_class if self.mask_classification else None, outputs_seg_masks |
|
) |
|
else: |
|
|
|
|
|
mask_embed = self.mask_embed(hs[-1]) |
|
outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) |
|
out["pred_masks"] = outputs_seg_masks |
|
return out |
|
|
|
@torch.jit.unused |
|
def _set_aux_loss(self, outputs_class, outputs_seg_masks): |
|
|
|
|
|
|
|
if self.mask_classification: |
|
return [ |
|
{"pred_logits": a, "pred_masks": b} |
|
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) |
|
] |
|
else: |
|
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] |
|
|
|
|
|
class MLP(nn.Module): |
|
"""Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList( |
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
|
) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|