Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from copy import deepcopy | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import fvcore.nn.weight_init as weight_init | |
from torch import nn | |
from torch.nn import functional as F | |
import torch | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d, ShapeSpec, get_norm | |
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | |
import torch.utils.checkpoint as cp | |
from .convnext import ConvNextBlock | |
from einops import rearrange,repeat | |
class MASKAdapterHead(nn.Module): | |
def __init__( | |
self, | |
clip_model_name, | |
mask_in_chans: int, | |
num_channels: int, | |
use_checkpoint: bool, | |
num_output_maps: int, | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
input_shape: shapes (channels and stride) of the input features | |
num_classes: number of classes to predict | |
pixel_decoder: the pixel decoder module | |
loss_weight: loss weight | |
ignore_value: category id to be ignored during training. | |
transformer_predictor: the transformer decoder that makes prediction | |
transformer_in_feature: input feature name to the transformer_predictor | |
""" | |
super().__init__() | |
self.use_checkpoint = use_checkpoint | |
if '_base' in clip_model_name: | |
clip_dim = 640 | |
elif '_large' in clip_model_name: | |
clip_dim = 768 | |
self.fuse = nn.Conv2d(clip_dim, num_channels, 1) | |
self.cnext1 = ConvNextBlock(num_channels) | |
self.cnext2 = ConvNextBlock(num_channels) | |
self.cnext3 = ConvNextBlock(num_channels) | |
self.norm = nn.LayerNorm(num_channels) | |
self.final = nn.Conv2d(num_channels, num_output_maps, 1) | |
self.mask_downscaling = nn.Sequential( | |
nn.Conv2d(1, mask_in_chans // 4, kernel_size=3, stride=2, padding=1), | |
LayerNorm2d(mask_in_chans // 4), | |
nn.GELU(), | |
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=3, stride=2, padding=1), | |
LayerNorm2d(mask_in_chans), | |
nn.GELU(), | |
nn.Conv2d(mask_in_chans, clip_dim, kernel_size=1), | |
) | |
def from_config(cls, cfg): | |
return { | |
"clip_model_name": cfg.MODEL.FC_CLIP.CLIP_MODEL_NAME, | |
"mask_in_chans": cfg.MODEL.MASK_ADAPTER.MASK_IN_CHANNELS, | |
"num_channels": cfg.MODEL.MASK_ADAPTER.NUM_CHANNELS, | |
"use_checkpoint": cfg.MODEL.MASK_ADAPTER.USE_CHECKPOINT, | |
"num_output_maps": cfg.MODEL.MASK_ADAPTER.NUM_OUTPUT_MAPS, | |
} | |
def forward(self, clip_feature, masks): | |
N = masks.size(1) | |
masks = rearrange(masks, 'B N H W -> (B N) H W').unsqueeze(dim=1) | |
clip_feature = repeat(clip_feature, "B C H W -> (B N) C H W", N=N) | |
H,W = clip_feature.shape[-2:] | |
masks = F.interpolate(masks, size=(H*4,W*4), | |
mode='bilinear', align_corners=False) | |
masks = self.mask_downscaling(masks) | |
outputs = clip_feature + masks | |
def _inner_forward(outputs): | |
outputs = self.fuse(outputs) | |
outputs = self.cnext1(outputs) | |
outputs = self.cnext2(outputs) | |
outputs = self.cnext3(outputs) | |
outputs = outputs.permute(0, 2, 3, 1) | |
outputs = self.norm(outputs.contiguous()) | |
outputs = outputs.permute(0, 3, 1, 2) | |
outputs = self.final(outputs.contiguous()) | |
outputs = rearrange(outputs, '(B N) C H W -> B (N C) H W',N=N) | |
return outputs | |
if self.use_checkpoint and self.training: | |
outputs = cp.checkpoint(_inner_forward, outputs,use_reentrant=False) | |
else: | |
outputs = _inner_forward(outputs) | |
return outputs | |
def build_mask_adapter(cfg,name): | |
return SEM_SEG_HEADS_REGISTRY.get(name)(cfg) | |
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa | |
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa | |
class LayerNorm2d(nn.Module): | |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(num_channels)) | |
self.bias = nn.Parameter(torch.zeros(num_channels)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x |