elia / modeling /transformer_decoder /maskformer_transformer_decoder.py
yxchng
add files
a166479
raw
history blame
4.43 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from .position_encoding import PositionEmbeddingSine
from .transformer import Transformer
class StandardTransformerDecoder(nn.Module):
def __init__(
self,
in_channels,
num_classes,
mask_classification=True,
hidden_dim=256,
num_queries=100,
nheads=8,
dropout=0.0,
dim_feedforward=2048,
enc_layers=0,
dec_layers=10,
pre_norm=False,
deep_supervision=True,
mask_dim=256,
enforce_input_project=False
):
super().__init__()
self.mask_classification = mask_classification
# positional encoding
N_steps = hidden_dim // 2
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if in_channels != hidden_dim or enforce_input_project:
self.input_proj = nn.Conv3d(in_channels, hidden_dim, kernel_size=1)
weight_init.c2_xavier_fill(self.input_proj)
else:
self.input_proj = nn.Sequential()
self.aux_loss = deep_supervision
# output FFNs
if self.mask_classification:
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
def forward(self, x, mask_features, mask=None):
if mask is not None:
mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
pos = self.pe_layer(x, mask)
src = x
hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
if self.mask_classification:
outputs_class = self.class_embed(hs)
out = {"pred_logits": outputs_class[-1]}
else:
out = {}
if self.aux_loss:
# [l, bs, queries, embed]
mask_embed = self.mask_embed(hs)
outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
out["pred_masks"] = outputs_seg_masks[-1]
out["aux_outputs"] = self._set_aux_loss(
outputs_class if self.mask_classification else None, outputs_seg_masks
)
else:
# FIXME h_boxes takes the last one computed, keep this in mind
# [bs, queries, embed]
mask_embed = self.mask_embed(hs[-1])
outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
out["pred_masks"] = outputs_seg_masks
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_seg_masks):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
if self.mask_classification:
return [
{"pred_logits": a, "pred_masks": b}
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
]
else:
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x