Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
from typing import Callable, Dict, Optional, Union | |
import fvcore.nn.weight_init as weight_init | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d, ShapeSpec, get_norm | |
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | |
from torch import nn | |
from torch.nn import functional as F | |
from ..transformer.position_encoding import PositionEmbeddingSine | |
from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer | |
class MegaBigPixelDecoder(nn.Module): | |
def __init__( | |
self, | |
input_shape: Dict[str, ShapeSpec], | |
*, | |
conv_dim: int, | |
mask_dim: int, | |
norm: Optional[Union[str, Callable]] = None, | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
input_shape: shapes (channels and stride) of the input features | |
conv_dims: number of output channels for the intermediate conv layers. | |
mask_dim: number of output channels for the final conv layer. | |
norm (str or callable): normalization for all conv layers | |
""" | |
super().__init__() | |
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) | |
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" | |
feature_channels = [v.channels for k, v in input_shape] | |
lateral_convs = [] | |
output_convs = [] | |
use_bias = norm == "" | |
for idx, in_channels in enumerate(feature_channels): | |
if idx == len(self.in_features) - 1: | |
output_norm = get_norm(norm, conv_dim) | |
output_conv = Conv2d( | |
in_channels, | |
conv_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
) | |
weight_init.c2_xavier_fill(output_conv) | |
self.add_module("layer_{}".format(idx + 1), output_conv) | |
lateral_convs.append(None) | |
output_convs.append(output_conv) | |
else: | |
lateral_norm = get_norm(norm, conv_dim) | |
output_norm = get_norm(norm, conv_dim) | |
lateral_conv = Conv2d( | |
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm | |
) | |
output_conv = Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
) | |
weight_init.c2_xavier_fill(lateral_conv) | |
weight_init.c2_xavier_fill(output_conv) | |
self.add_module("adapter_{}".format(idx + 1), lateral_conv) | |
self.add_module("layer_{}".format(idx + 1), output_conv) | |
lateral_convs.append(lateral_conv) | |
output_convs.append(output_conv) | |
# Place convs into top-down order (from low to high resolution) | |
# to make the top-down computation in forward clearer. | |
self.lateral_convs = lateral_convs[::-1] | |
self.output_convs = output_convs[::-1] | |
self.mask_dim = mask_dim | |
# self.mask_features = Conv2d( | |
# conv_dim, | |
# mask_dim, | |
# kernel_size=3, | |
# stride=1, | |
# padding=1, | |
# ) | |
# weight_init.c2_xavier_fill(self.mask_features) | |
self.mask_features = nn.Sequential( | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
nn.UpsamplingNearest2d(scale_factor=2), | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=1, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
nn.UpsamplingNearest2d(scale_factor=2), | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=1, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
nn.UpsamplingNearest2d(scale_factor=2), | |
Conv2d( | |
conv_dim, | |
conv_dim, | |
kernel_size=1, | |
stride=1, | |
padding=1, | |
bias=use_bias, | |
norm=output_norm, | |
activation=F.relu, | |
), | |
Conv2d( | |
conv_dim, | |
mask_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
) | |
for name, module in self.mask_features.named_modules(): | |
if 'Conv2d' in name: | |
weight_init.c2_xavier_fill(module) | |
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): | |
ret = {} | |
ret["input_shape"] = { | |
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES | |
} | |
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM | |
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM | |
ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM | |
return ret | |
def forward_features(self, features): | |
# Reverse feature maps into top-down order (from low to high resolution) | |
for idx, f in enumerate(self.in_features[::-1]): | |
x = features[f] | |
lateral_conv = self.lateral_convs[idx] | |
output_conv = self.output_convs[idx] | |
if lateral_conv is None: | |
y = output_conv(x) | |
else: | |
cur_fpn = lateral_conv(x) | |
# Following FPN implementation, we use nearest upsampling here | |
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") | |
y = output_conv(y) | |
return self.mask_features(y), None | |
def forward(self, features, targets=None): | |
logger = logging.getLogger(__name__) | |
logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") | |
return self.forward_features(features) | |
class TransformerEncoderOnly(nn.Module): | |
def __init__( | |
self, | |
d_model=512, | |
nhead=8, | |
num_encoder_layers=6, | |
dim_feedforward=2048, | |
dropout=0.1, | |
activation="relu", | |
normalize_before=False, | |
): | |
super().__init__() | |
encoder_layer = TransformerEncoderLayer( | |
d_model, nhead, dim_feedforward, dropout, activation, normalize_before | |
) | |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | |
self._reset_parameters() | |
self.d_model = d_model | |
self.nhead = nhead | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, src, mask, pos_embed): | |
# flatten NxCxHxW to HWxNxC | |
bs, c, h, w = src.shape | |
src = src.flatten(2).permute(2, 0, 1) | |
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
if mask is not None: | |
mask = mask.flatten(1) | |
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) | |
return memory.permute(1, 2, 0).view(bs, c, h, w) | |