Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,513 Bytes
61c2d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 |
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Callable, Dict, List, Union
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.layers import Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm
from detectron2.modeling import (
META_ARCH_REGISTRY,
SEM_SEG_HEADS_REGISTRY,
build_backbone,
build_sem_seg_head,
)
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.projects.deeplab import DeepLabV3PlusHead
from detectron2.projects.deeplab.loss import DeepLabCE
from detectron2.structures import BitMasks, ImageList, Instances
from detectron2.utils.registry import Registry
from .post_processing import get_panoptic_segmentation
__all__ = ["PanopticDeepLab", "INS_EMBED_BRANCHES_REGISTRY", "build_ins_embed_branch"]
INS_EMBED_BRANCHES_REGISTRY = Registry("INS_EMBED_BRANCHES")
INS_EMBED_BRANCHES_REGISTRY.__doc__ = """
Registry for instance embedding branches, which make instance embedding
predictions from feature maps.
"""
@META_ARCH_REGISTRY.register()
class PanopticDeepLab(nn.Module):
"""
Main class for panoptic segmentation architectures.
"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape())
self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape())
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False)
self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA
self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD
self.nms_kernel = cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL
self.top_k = cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE
self.predict_instances = cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES
self.use_depthwise_separable_conv = cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV
assert (
cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV
== cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV
)
self.size_divisibility = cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY
self.benchmark_network_speed = cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* "image": Tensor, image in (C, H, W) format.
* "sem_seg": semantic segmentation ground truth
* "center": center points heatmap ground truth
* "offset": pixel offsets to center points ground truth
* Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model (may be different
from input resolution), used in inference.
Returns:
list[dict]:
each dict is the results for one image. The dict contains the following keys:
* "panoptic_seg", "sem_seg": see documentation
:doc:`/tutorials/models` for the standard output format
* "instances": available if ``predict_instances is True``. see documentation
:doc:`/tutorials/models` for the standard output format
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
# To avoid error in ASPP layer when input has different size.
size_divisibility = (
self.size_divisibility
if self.size_divisibility > 0
else self.backbone.size_divisibility
)
images = ImageList.from_tensors(images, size_divisibility)
features = self.backbone(images.tensor)
losses = {}
if "sem_seg" in batched_inputs[0]:
targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
targets = ImageList.from_tensors(
targets, size_divisibility, self.sem_seg_head.ignore_value
).tensor
if "sem_seg_weights" in batched_inputs[0]:
# The default D2 DatasetMapper may not contain "sem_seg_weights"
# Avoid error in testing when default DatasetMapper is used.
weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs]
weights = ImageList.from_tensors(weights, size_divisibility).tensor
else:
weights = None
else:
targets = None
weights = None
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights)
losses.update(sem_seg_losses)
if "center" in batched_inputs[0] and "offset" in batched_inputs[0]:
center_targets = [x["center"].to(self.device) for x in batched_inputs]
center_targets = ImageList.from_tensors(
center_targets, size_divisibility
).tensor.unsqueeze(1)
center_weights = [x["center_weights"].to(self.device) for x in batched_inputs]
center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor
offset_targets = [x["offset"].to(self.device) for x in batched_inputs]
offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor
offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs]
offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor
else:
center_targets = None
center_weights = None
offset_targets = None
offset_weights = None
center_results, offset_results, center_losses, offset_losses = self.ins_embed_head(
features, center_targets, center_weights, offset_targets, offset_weights
)
losses.update(center_losses)
losses.update(offset_losses)
if self.training:
return losses
if self.benchmark_network_speed:
return []
processed_results = []
for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip(
sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height")
width = input_per_image.get("width")
r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
c = sem_seg_postprocess(center_result, image_size, height, width)
o = sem_seg_postprocess(offset_result, image_size, height, width)
# Post-processing to get panoptic segmentation.
panoptic_image, _ = get_panoptic_segmentation(
r.argmax(dim=0, keepdim=True),
c,
o,
thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(),
label_divisor=self.meta.label_divisor,
stuff_area=self.stuff_area,
void_label=-1,
threshold=self.threshold,
nms_kernel=self.nms_kernel,
top_k=self.top_k,
)
# For semantic segmentation evaluation.
processed_results.append({"sem_seg": r})
panoptic_image = panoptic_image.squeeze(0)
semantic_prob = F.softmax(r, dim=0)
# For panoptic segmentation evaluation.
processed_results[-1]["panoptic_seg"] = (panoptic_image, None)
# For instance segmentation evaluation.
if self.predict_instances:
instances = []
panoptic_image_cpu = panoptic_image.cpu().numpy()
for panoptic_label in np.unique(panoptic_image_cpu):
if panoptic_label == -1:
continue
pred_class = panoptic_label // self.meta.label_divisor
isthing = pred_class in list(
self.meta.thing_dataset_id_to_contiguous_id.values()
)
# Get instance segmentation results.
if isthing:
instance = Instances((height, width))
# Evaluation code takes continuous id starting from 0
instance.pred_classes = torch.tensor(
[pred_class], device=panoptic_image.device
)
mask = panoptic_image == panoptic_label
instance.pred_masks = mask.unsqueeze(0)
# Average semantic probability
sem_scores = semantic_prob[pred_class, ...]
sem_scores = torch.mean(sem_scores[mask])
# Center point probability
mask_indices = torch.nonzero(mask).float()
center_y, center_x = (
torch.mean(mask_indices[:, 0]),
torch.mean(mask_indices[:, 1]),
)
center_scores = c[0, int(center_y.item()), int(center_x.item())]
# Confidence score is semantic prob * center prob.
instance.scores = torch.tensor(
[sem_scores * center_scores], device=panoptic_image.device
)
# Get bounding boxes
instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes()
instances.append(instance)
if len(instances) > 0:
processed_results[-1]["instances"] = Instances.cat(instances)
return processed_results
@SEM_SEG_HEADS_REGISTRY.register()
class PanopticDeepLabSemSegHead(DeepLabV3PlusHead):
"""
A semantic segmentation head described in :paper:`Panoptic-DeepLab`.
"""
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
decoder_channels: List[int],
norm: Union[str, Callable],
head_channels: int,
loss_weight: float,
loss_type: str,
loss_top_k: float,
ignore_value: int,
num_classes: int,
**kwargs,
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature
decoder_channels (list[int]): a list of output channels of each
decoder stage. It should have the same length as "input_shape"
(each element in "input_shape" corresponds to one decoder stage).
norm (str or callable): normalization for all conv layers.
head_channels (int): the output channels of extra convolutions
between decoder and predictor.
loss_weight (float): loss weight.
loss_top_k: (float): setting the top k% hardest pixels for
"hard_pixel_mining" loss.
loss_type, ignore_value, num_classes: the same as the base class.
"""
super().__init__(
input_shape,
decoder_channels=decoder_channels,
norm=norm,
ignore_value=ignore_value,
**kwargs,
)
assert self.decoder_only
self.loss_weight = loss_weight
use_bias = norm == ""
# `head` is additional transform before predictor
if self.use_depthwise_separable_conv:
# We use a single 5x5 DepthwiseSeparableConv2d to replace
# 2 3x3 Conv2d since they have the same receptive field.
self.head = DepthwiseSeparableConv2d(
decoder_channels[0],
head_channels,
kernel_size=5,
padding=2,
norm1=norm,
activation1=F.relu,
norm2=norm,
activation2=F.relu,
)
else:
self.head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.head[0])
weight_init.c2_xavier_fill(self.head[1])
self.predictor = Conv2d(head_channels, num_classes, kernel_size=1)
nn.init.normal_(self.predictor.weight, 0, 0.001)
nn.init.constant_(self.predictor.bias, 0)
if loss_type == "cross_entropy":
self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_value)
elif loss_type == "hard_pixel_mining":
self.loss = DeepLabCE(ignore_label=ignore_value, top_k_percent_pixels=loss_top_k)
else:
raise ValueError("Unexpected loss type: %s" % loss_type)
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
ret["head_channels"] = cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS
ret["loss_top_k"] = cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K
return ret
def forward(self, features, targets=None, weights=None):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
y = self.layers(features)
if self.training:
return None, self.losses(y, targets, weights)
else:
y = F.interpolate(
y, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
return y, {}
def layers(self, features):
assert self.decoder_only
y = super().layers(features)
y = self.head(y)
y = self.predictor(y)
return y
def losses(self, predictions, targets, weights=None):
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = self.loss(predictions, targets, weights)
losses = {"loss_sem_seg": loss * self.loss_weight}
return losses
def build_ins_embed_branch(cfg, input_shape):
"""
Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
"""
name = cfg.MODEL.INS_EMBED_HEAD.NAME
return INS_EMBED_BRANCHES_REGISTRY.get(name)(cfg, input_shape)
@INS_EMBED_BRANCHES_REGISTRY.register()
class PanopticDeepLabInsEmbedHead(DeepLabV3PlusHead):
"""
A instance embedding head described in :paper:`Panoptic-DeepLab`.
"""
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
decoder_channels: List[int],
norm: Union[str, Callable],
head_channels: int,
center_loss_weight: float,
offset_loss_weight: float,
**kwargs,
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature
decoder_channels (list[int]): a list of output channels of each
decoder stage. It should have the same length as "input_shape"
(each element in "input_shape" corresponds to one decoder stage).
norm (str or callable): normalization for all conv layers.
head_channels (int): the output channels of extra convolutions
between decoder and predictor.
center_loss_weight (float): loss weight for center point prediction.
offset_loss_weight (float): loss weight for center offset prediction.
"""
super().__init__(input_shape, decoder_channels=decoder_channels, norm=norm, **kwargs)
assert self.decoder_only
self.center_loss_weight = center_loss_weight
self.offset_loss_weight = offset_loss_weight
use_bias = norm == ""
# center prediction
# `head` is additional transform before predictor
self.center_head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.center_head[0])
weight_init.c2_xavier_fill(self.center_head[1])
self.center_predictor = Conv2d(head_channels, 1, kernel_size=1)
nn.init.normal_(self.center_predictor.weight, 0, 0.001)
nn.init.constant_(self.center_predictor.bias, 0)
# offset prediction
# `head` is additional transform before predictor
if self.use_depthwise_separable_conv:
# We use a single 5x5 DepthwiseSeparableConv2d to replace
# 2 3x3 Conv2d since they have the same receptive field.
self.offset_head = DepthwiseSeparableConv2d(
decoder_channels[0],
head_channels,
kernel_size=5,
padding=2,
norm1=norm,
activation1=F.relu,
norm2=norm,
activation2=F.relu,
)
else:
self.offset_head = nn.Sequential(
Conv2d(
decoder_channels[0],
decoder_channels[0],
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, decoder_channels[0]),
activation=F.relu,
),
Conv2d(
decoder_channels[0],
head_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, head_channels),
activation=F.relu,
),
)
weight_init.c2_xavier_fill(self.offset_head[0])
weight_init.c2_xavier_fill(self.offset_head[1])
self.offset_predictor = Conv2d(head_channels, 2, kernel_size=1)
nn.init.normal_(self.offset_predictor.weight, 0, 0.001)
nn.init.constant_(self.offset_predictor.bias, 0)
self.center_loss = nn.MSELoss(reduction="none")
self.offset_loss = nn.L1Loss(reduction="none")
@classmethod
def from_config(cls, cfg, input_shape):
if cfg.INPUT.CROP.ENABLED:
assert cfg.INPUT.CROP.TYPE == "absolute"
train_size = cfg.INPUT.CROP.SIZE
else:
train_size = None
decoder_channels = [cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM] * (
len(cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES) - 1
) + [cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS]
ret = dict(
input_shape={
k: v for k, v in input_shape.items() if k in cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES
},
project_channels=cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS,
aspp_dilations=cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS,
aspp_dropout=cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT,
decoder_channels=decoder_channels,
common_stride=cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE,
norm=cfg.MODEL.INS_EMBED_HEAD.NORM,
train_size=train_size,
head_channels=cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS,
center_loss_weight=cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT,
offset_loss_weight=cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT,
use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV,
)
return ret
def forward(
self,
features,
center_targets=None,
center_weights=None,
offset_targets=None,
offset_weights=None,
):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
center, offset = self.layers(features)
if self.training:
return (
None,
None,
self.center_losses(center, center_targets, center_weights),
self.offset_losses(offset, offset_targets, offset_weights),
)
else:
center = F.interpolate(
center, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
offset = (
F.interpolate(
offset, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
* self.common_stride
)
return center, offset, {}, {}
def layers(self, features):
assert self.decoder_only
y = super().layers(features)
# center
center = self.center_head(y)
center = self.center_predictor(center)
# offset
offset = self.offset_head(y)
offset = self.offset_predictor(offset)
return center, offset
def center_losses(self, predictions, targets, weights):
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = self.center_loss(predictions, targets) * weights
if weights.sum() > 0:
loss = loss.sum() / weights.sum()
else:
loss = loss.sum() * 0
losses = {"loss_center": loss * self.center_loss_weight}
return losses
def offset_losses(self, predictions, targets, weights):
predictions = (
F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
* self.common_stride
)
loss = self.offset_loss(predictions, targets) * weights
if weights.sum() > 0:
loss = loss.sum() / weights.sum()
else:
loss = loss.sum() * 0
losses = {"loss_offset": loss * self.offset_loss_weight}
return losses
|