ov-seg / open_vocab_seg /modeling /transformer /transformer_predictor.py
liangfeng
add ovseg
583456e
raw history blame
No virus
6.52 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d
from .position_encoding import PositionEmbeddingSine
from .transformer import Transformer
class TransformerPredictor(nn.Module):
@configurable
def __init__(
self,
in_channels,
mask_classification=True,
*,
num_classes: int,
hidden_dim: int,
num_queries: int,
nheads: int,
dropout: float,
dim_feedforward: int,
enc_layers: int,
dec_layers: int,
pre_norm: bool,
deep_supervision: bool,
mask_dim: int,
enforce_input_project: bool,
):
"""
NOTE: this interface is experimental.
Args:
in_channels: channels of the input features
mask_classification: whether to add mask classifier or not
num_classes: number of classes
hidden_dim: Transformer feature dimension
num_queries: number of queries
nheads: number of heads
dropout: dropout in Transformer
dim_feedforward: feature dimension in feedforward network
enc_layers: number of Transformer encoder layers
dec_layers: number of Transformer decoder layers
pre_norm: whether to use pre-LayerNorm or not
deep_supervision: whether to add supervision to every decoder layers
mask_dim: mask feature dimension
enforce_input_project: add input project 1x1 conv even if input
channels and hidden dim is identical
"""
super().__init__()
self.mask_classification = mask_classification
# positional encoding
N_steps = hidden_dim // 2
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if in_channels != hidden_dim or enforce_input_project:
self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
weight_init.c2_xavier_fill(self.input_proj)
else:
self.input_proj = nn.Sequential()
self.aux_loss = deep_supervision
# output FFNs
if self.mask_classification:
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
@classmethod
def from_config(cls, cfg, in_channels, mask_classification):
ret = {}
ret["in_channels"] = in_channels
ret["mask_classification"] = mask_classification
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
# Transformer parameters:
ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
return ret
def forward(self, x, mask_features):
pos = self.pe_layer(x)
src = x
mask = None
hs, memory = self.transformer(
self.input_proj(src), mask, self.query_embed.weight, pos
)
if self.mask_classification:
outputs_class = self.class_embed(hs)
out = {"pred_logits": outputs_class[-1]}
else:
out = {}
if self.aux_loss:
# [l, bs, queries, embed]
mask_embed = self.mask_embed(hs)
outputs_seg_masks = torch.einsum(
"lbqc,bchw->lbqhw", mask_embed, mask_features
)
out["pred_masks"] = outputs_seg_masks[-1]
out["aux_outputs"] = self._set_aux_loss(
outputs_class if self.mask_classification else None, outputs_seg_masks
)
else:
# FIXME h_boxes takes the last one computed, keep this in mind
# [bs, queries, embed]
mask_embed = self.mask_embed(hs[-1])
outputs_seg_masks = torch.einsum(
"bqc,bchw->bqhw", mask_embed, mask_features
)
out["pred_masks"] = outputs_seg_masks
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_seg_masks):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
if self.mask_classification:
return [
{"pred_logits": a, "pred_masks": b}
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
]
else:
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x