wondervictor's picture
Update mask_adapter/modeling/meta_arch/mask_adapter_head.py
e8f0644 verified
raw
history blame
4.95 kB
import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
import torch
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
import torch.utils.checkpoint as cp
from .convnext import ConvNextBlock
from einops import rearrange,repeat
@SEM_SEG_HEADS_REGISTRY.register()
class MASKAdapterHead(nn.Module):
@configurable
def __init__(
self,
clip_model_name,
mask_in_chans: int,
num_channels: int,
use_checkpoint: bool,
num_output_maps: int,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
self.use_checkpoint = use_checkpoint
if '_base' in clip_model_name:
clip_dim = 640
elif '_large' in clip_model_name:
clip_dim = 768
self.fuse = nn.Conv2d(clip_dim, num_channels, 1)
self.cnext1 = ConvNextBlock(num_channels)
self.cnext2 = ConvNextBlock(num_channels)
self.cnext3 = ConvNextBlock(num_channels)
self.norm = nn.LayerNorm(num_channels)
self.final = nn.Conv2d(num_channels, num_output_maps, 1)
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=3, stride=2, padding=1),
LayerNorm2d(mask_in_chans // 4),
nn.GELU(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=3, stride=2, padding=1),
LayerNorm2d(mask_in_chans),
nn.GELU(),
nn.Conv2d(mask_in_chans, clip_dim, kernel_size=1),
)
@classmethod
def from_config(cls, cfg):
return {
"clip_model_name": cfg.MODEL.FC_CLIP.CLIP_MODEL_NAME,
"mask_in_chans": cfg.MODEL.MASK_ADAPTER.MASK_IN_CHANNELS,
"num_channels": cfg.MODEL.MASK_ADAPTER.NUM_CHANNELS,
"use_checkpoint": cfg.MODEL.MASK_ADAPTER.USE_CHECKPOINT,
"num_output_maps": cfg.MODEL.MASK_ADAPTER.NUM_OUTPUT_MAPS,
}
def forward(self, clip_feature, masks):
N = masks.size(1)
masks = rearrange(masks, 'B N H W -> (B N) H W').unsqueeze(dim=1)
clip_feature = repeat(clip_feature, "B C H W -> (B N) C H W", N=N)
H,W = clip_feature.shape[-2:]
masks = F.interpolate(masks, size=(H*4,W*4),
mode='bilinear', align_corners=False)
masks = self.mask_downscaling(masks)
outputs = clip_feature + masks
def _inner_forward(outputs):
outputs = self.fuse(outputs)
outputs = self.cnext1(outputs)
outputs = self.cnext2(outputs)
outputs = self.cnext3(outputs)
outputs = outputs.permute(0, 2, 3, 1)
outputs = self.norm(outputs.contiguous())
outputs = outputs.permute(0, 3, 1, 2)
outputs = self.final(outputs.contiguous())
outputs = rearrange(outputs, '(B N) C H W -> B (N C) H W',N=N)
return outputs
if self.use_checkpoint and self.training:
outputs = cp.checkpoint(_inner_forward, outputs,use_reentrant=False)
else:
outputs = _inner_forward(outputs)
return outputs
def build_mask_adapter(cfg,name):
return SEM_SEG_HEADS_REGISTRY.get(name)(cfg)
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x