rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
5.71 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Dict
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import ShapeSpec, cat
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from .point_features import (
get_uncertain_point_coords_on_grid,
get_uncertain_point_coords_with_randomness,
point_sample,
)
from .point_head import build_point_head
def calculate_uncertainty(sem_seg_logits):
"""
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
difference between top first and top second predicted logits.
Args:
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
C is the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@SEM_SEG_HEADS_REGISTRY.register()
class PointRendSemSegHead(nn.Module):
"""
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME`
and a point head set in `MODEL.POINT_HEAD.NAME`.
"""
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
super().__init__()
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get(
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME
)(cfg, input_shape)
self._init_point_head(cfg, input_shape)
def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]):
# fmt: off
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
feature_channels = {k: v.channels for k, v in input_shape.items()}
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
# fmt: on
in_channels = int(np.sum([feature_channels[f] for f in self.in_features]))
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1))
def forward(self, features, targets=None):
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features)
if self.training:
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets)
with torch.no_grad():
point_coords = get_uncertain_point_coords_with_randomness(
coarse_sem_seg_logits,
calculate_uncertainty,
self.train_num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
],
dim=1,
)
point_logits = self.point_head(fine_grained_features, coarse_features)
point_targets = (
point_sample(
targets.unsqueeze(1).to(torch.float),
point_coords,
mode="nearest",
align_corners=False,
)
.squeeze(1)
.to(torch.long)
)
losses["loss_sem_seg_point"] = F.cross_entropy(
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value
)
return None, losses
else:
sem_seg_logits = coarse_sem_seg_logits.clone()
for _ in range(self.subdivision_steps):
sem_seg_logits = F.interpolate(
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False
)
uncertainty_map = calculate_uncertainty(sem_seg_logits)
point_indices, point_coords = get_uncertain_point_coords_on_grid(
uncertainty_map, self.subdivision_num_points
)
fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
]
)
coarse_features = point_sample(
coarse_sem_seg_logits, point_coords, align_corners=False
)
point_logits = self.point_head(fine_grained_features, coarse_features)
# put sem seg point predictions to the right places on the upsampled grid.
N, C, H, W = sem_seg_logits.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
sem_seg_logits = (
sem_seg_logits.reshape(N, C, H * W)
.scatter_(2, point_indices, point_logits)
.view(N, C, H, W)
)
return sem_seg_logits, {}