Detic / detic /modeling /meta_arch /d2_deformable_detr.py
AK391
files
159f437
raw
history blame
12.6 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from torch import nn
import math
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone
from detectron2.structures import Boxes, Instances
from ..utils import load_class_freq, get_fed_loss_inds
from models.backbone import Joiner
from models.deformable_detr import DeformableDETR, SetCriterion, MLP
from models.deformable_detr import _get_clones
from models.matcher import HungarianMatcher
from models.position_encoding import PositionEmbeddingSine
from models.deformable_transformer import DeformableTransformer
from models.segmentation import sigmoid_focal_loss
from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from util.misc import NestedTensor, accuracy
__all__ = ["DeformableDetr"]
class CustomSetCriterion(SetCriterion):
def __init__(self, num_classes, matcher, weight_dict, losses, \
focal_alpha=0.25, use_fed_loss=False):
super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha)
self.use_fed_loss = use_fed_loss
if self.use_fed_loss:
self.register_buffer(
'fed_loss_weight', load_class_freq(freq_weight=0.5))
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
dtype=src_logits.dtype, layout=src_logits.layout,
device=src_logits.device)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:,:,:-1] # B x N x C
if self.use_fed_loss:
inds = get_fed_loss_inds(
gt_classes=target_classes_o,
num_sample_cats=50,
weight=self.fed_loss_weight,
C=target_classes_onehot.shape[2])
loss_ce = sigmoid_focal_loss(
src_logits[:, :, inds],
target_classes_onehot[:, :, inds],
num_boxes,
alpha=self.focal_alpha,
gamma=2) * src_logits.shape[1]
else:
loss_ce = sigmoid_focal_loss(
src_logits, target_classes_onehot, num_boxes,
alpha=self.focal_alpha,
gamma=2) * src_logits.shape[1]
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
class MaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
backbone_shape = self.backbone.output_shape()
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()]
def forward(self, tensor_list: NestedTensor):
xs = self.backbone(tensor_list.tensors)
out = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
@META_ARCH_REGISTRY.register()
class DeformableDetr(nn.Module):
"""
Implement Deformable Detr
"""
def __init__(self, cfg):
super().__init__()
self.with_image_labels = cfg.WITH_IMAGE_LABELS
self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT
self.device = torch.device(cfg.MODEL.DEVICE)
self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
two_stage = cfg.MODEL.DETR.TWO_STAGE
with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE
# Loss parameters:
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
l1_weight = cfg.MODEL.DETR.L1_WEIGHT
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA
N_steps = hidden_dim // 2
d2_backbone = MaskedBackbone(cfg)
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True))
transformer = DeformableTransformer(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=num_feature_levels,
dec_n_points=4,
enc_n_points=4,
two_stage=two_stage,
two_stage_num_proposals=num_queries)
self.detr = DeformableDETR(
backbone, transformer, num_classes=self.num_classes,
num_queries=num_queries,
num_feature_levels=num_feature_levels,
aux_loss=deep_supervision,
with_box_refine=with_box_refine,
two_stage=two_stage,
)
if self.mask_on:
assert 0, 'Mask is not supported yet :('
matcher = HungarianMatcher(
cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight)
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
weight_dict["loss_giou"] = giou_weight
if deep_supervision:
aux_weight_dict = {}
for i in range(dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
print('weight_dict', weight_dict)
losses = ["labels", "boxes", "cardinality"]
if self.mask_on:
losses += ["masks"]
self.criterion = CustomSetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict,
focal_alpha=focal_alpha,
losses=losses,
use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS
)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
def forward(self, batched_inputs):
"""
Args:
Returns:
dict[str: Tensor]:
mapping from a named loss to a tensor storing the loss. Used during training only.
"""
images = self.preprocess_image(batched_inputs)
output = self.detr(images)
if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets = self.prepare_targets(gt_instances)
loss_dict = self.criterion(output, targets)
weight_dict = self.criterion.weight_dict
for k in loss_dict.keys():
if k in weight_dict:
loss_dict[k] *= weight_dict[k]
if self.with_image_labels:
if batched_inputs[0]['ann_type'] in ['image', 'captiontag']:
loss_dict['loss_image'] = self.weak_weight * self._weak_loss(
output, batched_inputs)
else:
loss_dict['loss_image'] = images[0].new_zeros(
[1], dtype=torch.float32)[0]
# import pdb; pdb.set_trace()
return loss_dict
else:
image_sizes = output["pred_boxes"].new_tensor(
[(t["height"], t["width"]) for t in batched_inputs])
results = self.post_process(output, image_sizes)
return results
def prepare_targets(self, targets):
new_targets = []
for targets_per_image in targets:
h, w = targets_per_image.image_size
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
gt_classes = targets_per_image.gt_classes
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
assert 0, 'Mask is not supported yet :('
gt_masks = targets_per_image.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
new_targets[-1].update({'masks': gt_masks})
return new_targets
def post_process(self, outputs, target_sizes):
"""
"""
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1), self.test_topk, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b, size in zip(scores, labels, boxes, target_sizes):
r = Instances((size[0], size[1]))
r.pred_boxes = Boxes(b)
r.scores = s
r.pred_classes = l
results.append({'instances': r})
return results
def preprocess_image(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
return images
def _weak_loss(self, outputs, batched_inputs):
loss = 0
for b, x in enumerate(batched_inputs):
labels = x['pos_category_ids']
pred_logits = [outputs['pred_logits'][b]]
pred_boxes = [outputs['pred_boxes'][b]]
for xx in outputs['aux_outputs']:
pred_logits.append(xx['pred_logits'][b])
pred_boxes.append(xx['pred_boxes'][b])
pred_logits = torch.stack(pred_logits, dim=0) # L x N x C
pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4
for label in labels:
loss += self._max_size_loss(
pred_logits, pred_boxes, label) / len(labels)
loss = loss / len(batched_inputs)
return loss
def _max_size_loss(self, logits, boxes, label):
'''
Inputs:
logits: L x N x C
boxes: L x N x 4
'''
target = logits.new_zeros((logits.shape[0], logits.shape[2]))
target[:, label] = 1.
sizes = boxes[..., 2] * boxes[..., 3] # L x N
ind = sizes.argmax(dim=1) # L
loss = F.binary_cross_entropy_with_logits(
logits[range(len(ind)), ind], target, reduction='sum')
return loss