Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import math | |
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone | |
from detectron2.structures import Boxes, Instances | |
from ..utils import load_class_freq, get_fed_loss_inds | |
from models.backbone import Joiner | |
from models.deformable_detr import DeformableDETR, SetCriterion, MLP | |
from models.deformable_detr import _get_clones | |
from models.matcher import HungarianMatcher | |
from models.position_encoding import PositionEmbeddingSine | |
from models.deformable_transformer import DeformableTransformer | |
from models.segmentation import sigmoid_focal_loss | |
from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh | |
from util.misc import NestedTensor, accuracy | |
__all__ = ["DeformableDetr"] | |
class CustomSetCriterion(SetCriterion): | |
def __init__(self, num_classes, matcher, weight_dict, losses, \ | |
focal_alpha=0.25, use_fed_loss=False): | |
super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha) | |
self.use_fed_loss = use_fed_loss | |
if self.use_fed_loss: | |
self.register_buffer( | |
'fed_loss_weight', load_class_freq(freq_weight=0.5)) | |
def loss_labels(self, outputs, targets, indices, num_boxes, log=True): | |
"""Classification loss (NLL) | |
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
""" | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] | |
idx = self._get_src_permutation_idx(indices) | |
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full(src_logits.shape[:2], self.num_classes, | |
dtype=torch.int64, device=src_logits.device) | |
target_classes[idx] = target_classes_o | |
target_classes_onehot = torch.zeros( | |
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], | |
dtype=src_logits.dtype, layout=src_logits.layout, | |
device=src_logits.device) | |
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) | |
target_classes_onehot = target_classes_onehot[:,:,:-1] # B x N x C | |
if self.use_fed_loss: | |
inds = get_fed_loss_inds( | |
gt_classes=target_classes_o, | |
num_sample_cats=50, | |
weight=self.fed_loss_weight, | |
C=target_classes_onehot.shape[2]) | |
loss_ce = sigmoid_focal_loss( | |
src_logits[:, :, inds], | |
target_classes_onehot[:, :, inds], | |
num_boxes, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
else: | |
loss_ce = sigmoid_focal_loss( | |
src_logits, target_classes_onehot, num_boxes, | |
alpha=self.focal_alpha, | |
gamma=2) * src_logits.shape[1] | |
losses = {'loss_ce': loss_ce} | |
if log: | |
# TODO this should probably be a separate loss, not hacked in this one here | |
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] | |
return losses | |
class MaskedBackbone(nn.Module): | |
""" This is a thin wrapper around D2's backbone to provide padding masking""" | |
def __init__(self, cfg): | |
super().__init__() | |
self.backbone = build_backbone(cfg) | |
backbone_shape = self.backbone.output_shape() | |
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] | |
self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()] | |
self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()] | |
def forward(self, tensor_list: NestedTensor): | |
xs = self.backbone(tensor_list.tensors) | |
out = {} | |
for name, x in xs.items(): | |
m = tensor_list.mask | |
assert m is not None | |
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] | |
out[name] = NestedTensor(x, mask) | |
return out | |
class DeformableDetr(nn.Module): | |
""" | |
Implement Deformable Detr | |
""" | |
def __init__(self, cfg): | |
super().__init__() | |
self.with_image_labels = cfg.WITH_IMAGE_LABELS | |
self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT | |
self.device = torch.device(cfg.MODEL.DEVICE) | |
self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE | |
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES | |
self.mask_on = cfg.MODEL.MASK_ON | |
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM | |
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES | |
# Transformer parameters: | |
nheads = cfg.MODEL.DETR.NHEADS | |
dropout = cfg.MODEL.DETR.DROPOUT | |
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD | |
enc_layers = cfg.MODEL.DETR.ENC_LAYERS | |
dec_layers = cfg.MODEL.DETR.DEC_LAYERS | |
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS | |
two_stage = cfg.MODEL.DETR.TWO_STAGE | |
with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE | |
# Loss parameters: | |
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT | |
l1_weight = cfg.MODEL.DETR.L1_WEIGHT | |
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION | |
cls_weight = cfg.MODEL.DETR.CLS_WEIGHT | |
focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA | |
N_steps = hidden_dim // 2 | |
d2_backbone = MaskedBackbone(cfg) | |
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True)) | |
transformer = DeformableTransformer( | |
d_model=hidden_dim, | |
nhead=nheads, | |
num_encoder_layers=enc_layers, | |
num_decoder_layers=dec_layers, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
activation="relu", | |
return_intermediate_dec=True, | |
num_feature_levels=num_feature_levels, | |
dec_n_points=4, | |
enc_n_points=4, | |
two_stage=two_stage, | |
two_stage_num_proposals=num_queries) | |
self.detr = DeformableDETR( | |
backbone, transformer, num_classes=self.num_classes, | |
num_queries=num_queries, | |
num_feature_levels=num_feature_levels, | |
aux_loss=deep_supervision, | |
with_box_refine=with_box_refine, | |
two_stage=two_stage, | |
) | |
if self.mask_on: | |
assert 0, 'Mask is not supported yet :(' | |
matcher = HungarianMatcher( | |
cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight) | |
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight} | |
weight_dict["loss_giou"] = giou_weight | |
if deep_supervision: | |
aux_weight_dict = {} | |
for i in range(dec_layers - 1): | |
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | |
weight_dict.update(aux_weight_dict) | |
print('weight_dict', weight_dict) | |
losses = ["labels", "boxes", "cardinality"] | |
if self.mask_on: | |
losses += ["masks"] | |
self.criterion = CustomSetCriterion( | |
self.num_classes, matcher=matcher, weight_dict=weight_dict, | |
focal_alpha=focal_alpha, | |
losses=losses, | |
use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS | |
) | |
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) | |
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) | |
self.normalizer = lambda x: (x - pixel_mean) / pixel_std | |
def forward(self, batched_inputs): | |
""" | |
Args: | |
Returns: | |
dict[str: Tensor]: | |
mapping from a named loss to a tensor storing the loss. Used during training only. | |
""" | |
images = self.preprocess_image(batched_inputs) | |
output = self.detr(images) | |
if self.training: | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
targets = self.prepare_targets(gt_instances) | |
loss_dict = self.criterion(output, targets) | |
weight_dict = self.criterion.weight_dict | |
for k in loss_dict.keys(): | |
if k in weight_dict: | |
loss_dict[k] *= weight_dict[k] | |
if self.with_image_labels: | |
if batched_inputs[0]['ann_type'] in ['image', 'captiontag']: | |
loss_dict['loss_image'] = self.weak_weight * self._weak_loss( | |
output, batched_inputs) | |
else: | |
loss_dict['loss_image'] = images[0].new_zeros( | |
[1], dtype=torch.float32)[0] | |
# import pdb; pdb.set_trace() | |
return loss_dict | |
else: | |
image_sizes = output["pred_boxes"].new_tensor( | |
[(t["height"], t["width"]) for t in batched_inputs]) | |
results = self.post_process(output, image_sizes) | |
return results | |
def prepare_targets(self, targets): | |
new_targets = [] | |
for targets_per_image in targets: | |
h, w = targets_per_image.image_size | |
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) | |
gt_classes = targets_per_image.gt_classes | |
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy | |
gt_boxes = box_xyxy_to_cxcywh(gt_boxes) | |
new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) | |
if self.mask_on and hasattr(targets_per_image, 'gt_masks'): | |
assert 0, 'Mask is not supported yet :(' | |
gt_masks = targets_per_image.gt_masks | |
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) | |
new_targets[-1].update({'masks': gt_masks}) | |
return new_targets | |
def post_process(self, outputs, target_sizes): | |
""" | |
""" | |
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] | |
assert len(out_logits) == len(target_sizes) | |
assert target_sizes.shape[1] == 2 | |
prob = out_logits.sigmoid() | |
topk_values, topk_indexes = torch.topk( | |
prob.view(out_logits.shape[0], -1), self.test_topk, dim=1) | |
scores = topk_values | |
topk_boxes = topk_indexes // out_logits.shape[2] | |
labels = topk_indexes % out_logits.shape[2] | |
boxes = box_cxcywh_to_xyxy(out_bbox) | |
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) | |
# and from relative [0, 1] to absolute [0, height] coordinates | |
img_h, img_w = target_sizes.unbind(1) | |
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
boxes = boxes * scale_fct[:, None, :] | |
results = [] | |
for s, l, b, size in zip(scores, labels, boxes, target_sizes): | |
r = Instances((size[0], size[1])) | |
r.pred_boxes = Boxes(b) | |
r.scores = s | |
r.pred_classes = l | |
results.append({'instances': r}) | |
return results | |
def preprocess_image(self, batched_inputs): | |
""" | |
Normalize, pad and batch the input images. | |
""" | |
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] | |
return images | |
def _weak_loss(self, outputs, batched_inputs): | |
loss = 0 | |
for b, x in enumerate(batched_inputs): | |
labels = x['pos_category_ids'] | |
pred_logits = [outputs['pred_logits'][b]] | |
pred_boxes = [outputs['pred_boxes'][b]] | |
for xx in outputs['aux_outputs']: | |
pred_logits.append(xx['pred_logits'][b]) | |
pred_boxes.append(xx['pred_boxes'][b]) | |
pred_logits = torch.stack(pred_logits, dim=0) # L x N x C | |
pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4 | |
for label in labels: | |
loss += self._max_size_loss( | |
pred_logits, pred_boxes, label) / len(labels) | |
loss = loss / len(batched_inputs) | |
return loss | |
def _max_size_loss(self, logits, boxes, label): | |
''' | |
Inputs: | |
logits: L x N x C | |
boxes: L x N x 4 | |
''' | |
target = logits.new_zeros((logits.shape[0], logits.shape[2])) | |
target[:, label] = 1. | |
sizes = boxes[..., 2] * boxes[..., 3] # L x N | |
ind = sizes.argmax(dim=1) # L | |
loss = F.binary_cross_entropy_with_logits( | |
logits[range(len(ind)), ind], target, reduction='sum') | |
return loss |