Spaces:
Runtime error
Runtime error
File size: 3,776 Bytes
3dac99f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Dict
from torch import nn
from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer_decoder.frozenseg_transformer_decoder import build_transformer_decoder
from ..pixel_decoder.msdeformattn import build_pixel_decoder
@SEM_SEG_HEADS_REGISTRY.register()
class FrozenSegHead(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
num_classes: int,
pixel_decoder: nn.Module,
loss_weight: float = 1.0,
ignore_value: int = -1,
# extra parameters
transformer_predictor: nn.Module,
transformer_in_feature: str,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape]
feature_strides = [v.stride for k, v in input_shape]
feature_channels = [v.channels for k, v in input_shape]
self.ignore_value = ignore_value
self.common_stride = 4
self.loss_weight = loss_weight
self.pixel_decoder = pixel_decoder
self.predictor = transformer_predictor
self.transformer_in_feature = transformer_in_feature
self.num_classes = num_classes
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
# figure out in_channels to transformer predictor
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder":
transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
else:
raise NotImplementedError
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
"transformer_predictor": build_transformer_decoder(
cfg,
transformer_predictor_in_channels,
mask_classification=True,
),
}
def forward(self, features, mask=None):
return self.layers(features, mask)
def layers(self, features, mask=None):
mask_features, multi_scale_features, sam_fpn = self.pixel_decoder.forward_features(features)
if self.transformer_in_feature == "multi_scale_pixel_decoder":
predictions = self.predictor(multi_scale_features, mask_features, mask,
text_classifier=features["text_classifier"],
num_templates=features["num_templates"],
sam_embedding=features['sam_embedding'],
sam=features["sam"],
sam_fpn=sam_fpn)
else:
raise NotImplementedError
return predictions
|