rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
No virus
23.5 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Callable, Dict, List, Union
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.layers import Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm
from detectron2.modeling import (
META_ARCH_REGISTRY,
SEM_SEG_HEADS_REGISTRY,
build_backbone,
build_sem_seg_head,
)
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.projects.deeplab import DeepLabV3PlusHead
from detectron2.projects.deeplab.loss import DeepLabCE
from detectron2.structures import BitMasks, ImageList, Instances
from detectron2.utils.registry import Registry
from .post_processing import get_panoptic_segmentation
__all__ = ["PanopticDeepLab", "INS_EMBED_BRANCHES_REGISTRY", "build_ins_embed_branch"]
INS_EMBED_BRANCHES_REGISTRY = Registry("INS_EMBED_BRANCHES")
INS_EMBED_BRANCHES_REGISTRY.__doc__ = """
Registry for instance embedding branches, which make instance embedding
predictions from feature maps.
"""
@META_ARCH_REGISTRY.register()
class PanopticDeepLab(nn.Module):
"""
Main class for panoptic segmentation architectures.
"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape())
self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape())
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False)
self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA
self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD
self.nms_kernel = cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL
self.top_k = cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE
self.predict_instances = cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES
self.use_depthwise_separable_conv = cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV
assert (
cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV
== cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV
)
self.size_divisibility = cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY
self.benchmark_network_speed = cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* "image": Tensor, image in (C, H, W) format.
* "sem_seg": semantic segmentation ground truth
* "center": center points heatmap ground truth
* "offset": pixel offsets to center points ground truth
* Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model (may be different
from input resolution), used in inference.
Returns:
list[dict]:
each dict is the results for one image. The dict contains the following keys:
* "panoptic_seg", "sem_seg": see documentation
:doc:`/tutorials/models` for the standard output format
* "instances": available if ``predict_instances is True``. see documentation
:doc:`/tutorials/models` for the standard output format
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
# To avoid error in ASPP layer when input has different size.
size_divisibility = (
self.size_divisibility
if self.size_divisibility > 0
else self.backbone.size_divisibility
)
images = ImageList.from_tensors(images, size_divisibility)
features = self.backbone(images.tensor)
losses = {}
if "sem_seg" in batched_inputs[0]:
targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
targets = ImageList.from_tensors(
targets, size_divisibility, self.sem_seg_head.ignore_value
).tensor
if "sem_seg_weights" in batched_inputs[0]:
# The default D2 DatasetMapper may not contain "sem_seg_weights"
# Avoid error in testing when default DatasetMapper is used.
weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs]
weights = ImageList.from_tensors(weights, size_divisibility).tensor
else:
weights = None
else:
targets = None
weights = None
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights)
losses.update(sem_seg_losses)
if "center" in batched_inputs[0] and "offset" in batched_inputs[0]:
center_targets = [x["center"].to(self.device) for x in batched_inputs]
center_targets = ImageList.from_tensors(
center_targets, size_divisibility
).tensor.unsqueeze(1)
center_weights = [x["center_weights"].to(self.device) for x in batched_inputs]
center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor
offset_targets = [x["offset"].to(self.device) for x in batched_inputs]
offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor
offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs]
offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor
else:
center_targets = None
center_weights = None
offset_targets = None
offset_weights = None
center_results, offset_results, center_losses, offset_losses = self.ins_embed_head(
features, center_targets, center_weights, offset_targets, offset_weights
)
losses.update(center_losses)
losses.update(offset_losses)
if self.training:
return losses
if self.benchmark_network_speed:
return []
processed_results = []
for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip(
sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height")
width = input_per_image.get("width")
r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
c = sem_seg_postprocess(center_result, image_size, height, width)
o = sem_seg_postprocess(offset_result, image_size, height, width)
# Post-processing to get panoptic segmentation.
panoptic_image, _ = get_panoptic_segmentation(
r.argmax(dim=0, keepdim=True),
c,
o,
thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(),
label_divisor=self.meta.label_divisor,
stuff_area=self.stuff_area,
void_label=-1,
threshold=self.threshold,
nms_kernel=self.nms_kernel,
top_k=self.top_k,
)
# For semantic segmentation evaluation.
processed_results.append({"sem_seg": r})
panoptic_image = panoptic_image.squeeze(0)
semantic_prob = F.softmax(r, dim=0)
# For panoptic segmentation evaluation.
processed_results[-1]["panoptic_seg"] = (panoptic_image, None)
# For instance segmentation evaluation.
if self.predict_instances:
instances = []
panoptic_image_cpu = panoptic_image.cpu().numpy()
for panoptic_label in np.unique(panoptic_image_cpu):
if panoptic_label == -1:
continue
pred_class = panoptic_label // self.meta.label_divisor
isthing = pred_class in list(
self.meta.thing_dataset_id_to_contiguous_id.values()
)
# Get instance segmentation results.
if isthing:
instance = Instances((height, width))
# Evaluation code takes continuous id starting from 0
instance.pred_classes = torch.tensor(
[pred_class], device=panoptic_image.device
)
mask = panoptic_image == panoptic_label
instance.pred_masks = mask.unsqueeze(0)
# Average semantic probability
sem_scores = semantic_prob[pred_class, ...]
sem_scores = torch.mean(sem_scores[mask])
# Center point probability
mask_indices = torch.nonzero(mask).float()
center_y, center_x = (
torch.mean(mask_indices[:, 0]),
torch.mean(mask_indices[:, 1]),
)
center_scores = c[0, int(center_y.item()), int(center_x.item())]
# Confidence score is semantic prob * center prob.
instance.scores = torch.tensor(
[sem_scores * center_scores], device=panoptic_image.device
)
# Get bounding boxes
instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes()
instances.append(instance)
if len(instances) > 0:
processed_results[-1]["instances"] = Instances.cat(instances)
return processed_results
@SEM_SEG_HEADS_REGISTRY.register()
class PanopticDeepLabSemSegHead(DeepLabV3PlusHead):
"""
A semantic segmentation head described in :paper:`Panoptic-DeepLab`.
"""
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
decoder_channels: List[int],
norm: Union[str, Callable],
head_channels: int,
loss_weight: float,
loss_type: str,
loss_top_k: float,
ignore_value: int,
num_classes: int,
**kwargs,
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature
decoder_channels (list[int]): a list of output channels of each
decoder stage. It should have the same length as "input_shape"
(each element in "input_shape" corresponds to one decoder stage).
norm (str or callable): normalization for all conv layers.
head_channels (int): the output channels of extra convolutions
between decoder and predictor.
loss_weight (float): loss weight.
loss_top_k: (float): setting the top k% hardest pixels for
"hard_pixel_mining" loss.
loss_type, ignore_value, num_classes: the same as the base class.
"""
super().__init__(
input_shape,
decoder_channels=decoder_channels,
norm=norm,
ignore_value=ignore_value,
**kwargs,
)
assert self.decoder_only
self.loss_weight = loss_weight
use_bias = norm == ""
# `head` is additional transform before predictor
if self.use_depthwise_separable_conv:
# We use a single 5x5 DepthwiseSeparableConv2d to replace
# 2 3x3 Conv2d since they have the same receptive field.
self.head = DepthwiseSeparableConv2d(
decoder_channels[0],
head_channels,
kernel_size=5,
padding=2,
norm1=norm,
activation1=F.relu,
norm2=norm,
activation2=F.relu,
)
else:
self.head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.head[0])
weight_init.c2_xavier_fill(self.head[1])
self.predictor = Conv2d(head_channels, num_classes, kernel_size=1)
nn.init.normal_(self.predictor.weight, 0, 0.001)
nn.init.constant_(self.predictor.bias, 0)
if loss_type == "cross_entropy":
self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_value)
elif loss_type == "hard_pixel_mining":
self.loss = DeepLabCE(ignore_label=ignore_value, top_k_percent_pixels=loss_top_k)
else:
raise ValueError("Unexpected loss type: %s" % loss_type)
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
ret["head_channels"] = cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS
ret["loss_top_k"] = cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K
return ret
def forward(self, features, targets=None, weights=None):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
y = self.layers(features)
if self.training:
return None, self.losses(y, targets, weights)
else:
y = F.interpolate(
y, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
return y, {}
def layers(self, features):
assert self.decoder_only
y = super().layers(features)
y = self.head(y)
y = self.predictor(y)
return y
def losses(self, predictions, targets, weights=None):
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = self.loss(predictions, targets, weights)
losses = {"loss_sem_seg": loss * self.loss_weight}
return losses
def build_ins_embed_branch(cfg, input_shape):
"""
Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
"""
name = cfg.MODEL.INS_EMBED_HEAD.NAME
return INS_EMBED_BRANCHES_REGISTRY.get(name)(cfg, input_shape)
@INS_EMBED_BRANCHES_REGISTRY.register()
class PanopticDeepLabInsEmbedHead(DeepLabV3PlusHead):
"""
A instance embedding head described in :paper:`Panoptic-DeepLab`.
"""
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
decoder_channels: List[int],
norm: Union[str, Callable],
head_channels: int,
center_loss_weight: float,
offset_loss_weight: float,
**kwargs,
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature
decoder_channels (list[int]): a list of output channels of each
decoder stage. It should have the same length as "input_shape"
(each element in "input_shape" corresponds to one decoder stage).
norm (str or callable): normalization for all conv layers.
head_channels (int): the output channels of extra convolutions
between decoder and predictor.
center_loss_weight (float): loss weight for center point prediction.
offset_loss_weight (float): loss weight for center offset prediction.
"""
super().__init__(input_shape, decoder_channels=decoder_channels, norm=norm, **kwargs)
assert self.decoder_only
self.center_loss_weight = center_loss_weight
self.offset_loss_weight = offset_loss_weight
use_bias = norm == ""
# center prediction
# `head` is additional transform before predictor
self.center_head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.center_head[0])
weight_init.c2_xavier_fill(self.center_head[1])
self.center_predictor = Conv2d(head_channels, 1, kernel_size=1)
nn.init.normal_(self.center_predictor.weight, 0, 0.001)
nn.init.constant_(self.center_predictor.bias, 0)
# offset prediction
# `head` is additional transform before predictor
if self.use_depthwise_separable_conv:
# We use a single 5x5 DepthwiseSeparableConv2d to replace
# 2 3x3 Conv2d since they have the same receptive field.
self.offset_head = DepthwiseSeparableConv2d(
decoder_channels[0],
head_channels,
kernel_size=5,
padding=2,
norm1=norm,
activation1=F.relu,
norm2=norm,
activation2=F.relu,
)
else:
self.offset_head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.offset_head[0])
weight_init.c2_xavier_fill(self.offset_head[1])
self.offset_predictor = Conv2d(head_channels, 2, kernel_size=1)
nn.init.normal_(self.offset_predictor.weight, 0, 0.001)
nn.init.constant_(self.offset_predictor.bias, 0)
self.center_loss = nn.MSELoss(reduction="none")
self.offset_loss = nn.L1Loss(reduction="none")
@classmethod
def from_config(cls, cfg, input_shape):
if cfg.INPUT.CROP.ENABLED:
assert cfg.INPUT.CROP.TYPE == "absolute"
train_size = cfg.INPUT.CROP.SIZE
else:
train_size = None
decoder_channels = [cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM] * (
len(cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES) - 1
) + [cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS]
ret = dict(
input_shape={
k: v for k, v in input_shape.items() if k in cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES
},
project_channels=cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS,
aspp_dilations=cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS,
aspp_dropout=cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT,
decoder_channels=decoder_channels,
common_stride=cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE,
norm=cfg.MODEL.INS_EMBED_HEAD.NORM,
train_size=train_size,
head_channels=cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS,
center_loss_weight=cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT,
offset_loss_weight=cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT,
use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV,
)
return ret
def forward(
self,
features,
center_targets=None,
center_weights=None,
offset_targets=None,
offset_weights=None,
):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
center, offset = self.layers(features)
if self.training:
return (
None,
None,
self.center_losses(center, center_targets, center_weights),
self.offset_losses(offset, offset_targets, offset_weights),
)
else:
center = F.interpolate(
center, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
offset = (
F.interpolate(
offset, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
* self.common_stride
)
return center, offset, {}, {}
def layers(self, features):
assert self.decoder_only
y = super().layers(features)
# center
center = self.center_head(y)
center = self.center_predictor(center)
# offset
offset = self.offset_head(y)
offset = self.offset_predictor(offset)
return center, offset
def center_losses(self, predictions, targets, weights):
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = self.center_loss(predictions, targets) * weights
if weights.sum() > 0:
loss = loss.sum() / weights.sum()
else:
loss = loss.sum() * 0
losses = {"loss_center": loss * self.center_loss_weight}
return losses
def offset_losses(self, predictions, targets, weights):
predictions = (
F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
* self.common_stride
)
loss = self.offset_loss(predictions, targets) * weights
if weights.sum() > 0:
loss = loss.sum() / weights.sum()
else:
loss = loss.sum() * 0
losses = {"loss_offset": loss * self.offset_loss_weight}
return losses