seokju cho
initial commit
f8f62f3
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
from einops import rearrange
import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer.cat_seg_predictor import CATSegPredictor
@SEM_SEG_HEADS_REGISTRY.register()
class CATSegHead(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
num_classes: int,
ignore_value: int = -1,
# extra parameters
feature_resolution: list,
transformer_predictor: nn.Module,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape]
self.ignore_value = ignore_value
self.predictor = transformer_predictor
self.num_classes = num_classes
self.feature_resolution = feature_resolution
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"feature_resolution": cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION,
"transformer_predictor": CATSegPredictor(
cfg,
),
}
def forward(self, features, guidance_features):
"""
Arguments:
img_feats: (B, C, HW)
affinity_features: (B, C, )
"""
img_feat = rearrange(features[:, 1:, :], "b (h w) c->b c h w", h=self.feature_resolution[0], w=self.feature_resolution[1])
return self.predictor(img_feat, guidance_features)