# Copyright (c) Facebook, Inc. and its affiliates. import logging import math import json import numpy as np from typing import Dict, Union import torch from fvcore.nn import giou_loss, smooth_l1_loss from torch import nn from torch.nn import functional as F import fvcore.nn.weight_init as weight_init import detectron2.utils.comm as comm from detectron2.config import configurable from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple from detectron2.structures import Boxes, Instances from detectron2.utils.events import get_event_storage from detectron2.modeling.box_regression import Box2BoxTransform from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats from torch.cuda.amp import autocast from ..utils import load_class_freq, get_fed_loss_inds from .zero_shot_classifier import ZeroShotClassifier __all__ = ["DeticFastRCNNOutputLayers"] class DeticFastRCNNOutputLayers(FastRCNNOutputLayers): @configurable def __init__( self, input_shape: ShapeSpec, *, mult_proposal_score=False, cls_score=None, sync_caption_batch = False, use_sigmoid_ce = False, use_fed_loss = False, ignore_zero_cats = False, fed_loss_num_cat = 50, dynamic_classifier = False, image_label_loss = '', use_zeroshot_cls = False, image_loss_weight = 0.1, with_softmax_prop = False, caption_weight = 1.0, neg_cap_weight = 1.0, add_image_box = False, debug = False, prior_prob = 0.01, cat_freq_path = '', fed_loss_freq_weight = 0.5, softmax_weak_loss = False, **kwargs, ): super().__init__( input_shape=input_shape, **kwargs, ) self.mult_proposal_score = mult_proposal_score self.sync_caption_batch = sync_caption_batch self.use_sigmoid_ce = use_sigmoid_ce self.use_fed_loss = use_fed_loss self.ignore_zero_cats = ignore_zero_cats self.fed_loss_num_cat = fed_loss_num_cat self.dynamic_classifier = dynamic_classifier self.image_label_loss = image_label_loss self.use_zeroshot_cls = use_zeroshot_cls self.image_loss_weight = image_loss_weight self.with_softmax_prop = with_softmax_prop self.caption_weight = caption_weight self.neg_cap_weight = neg_cap_weight self.add_image_box = add_image_box self.softmax_weak_loss = softmax_weak_loss self.debug = debug if softmax_weak_loss: assert image_label_loss in ['max_size'] if self.use_sigmoid_ce: bias_value = -math.log((1 - prior_prob) / prior_prob) nn.init.constant_(self.cls_score.bias, bias_value) if self.use_fed_loss or self.ignore_zero_cats: freq_weight = load_class_freq(cat_freq_path, fed_loss_freq_weight) self.register_buffer('freq_weight', freq_weight) else: self.freq_weight = None if self.use_fed_loss and len(self.freq_weight) < self.num_classes: # assert self.num_classes == 11493 print('Extending federated loss weight') self.freq_weight = torch.cat( [self.freq_weight, self.freq_weight.new_zeros( self.num_classes - len(self.freq_weight))] ) assert (not self.dynamic_classifier) or (not self.use_fed_loss) input_size = input_shape.channels * \ (input_shape.width or 1) * (input_shape.height or 1) if self.use_zeroshot_cls: del self.cls_score del self.bbox_pred assert cls_score is not None self.cls_score = cls_score self.bbox_pred = nn.Sequential( nn.Linear(input_size, input_size), nn.ReLU(inplace=True), nn.Linear(input_size, 4) ) weight_init.c2_xavier_fill(self.bbox_pred[0]) nn.init.normal_(self.bbox_pred[-1].weight, std=0.001) nn.init.constant_(self.bbox_pred[-1].bias, 0) if self.with_softmax_prop: self.prop_score = nn.Sequential( nn.Linear(input_size, input_size), nn.ReLU(inplace=True), nn.Linear(input_size, self.num_classes + 1), ) weight_init.c2_xavier_fill(self.prop_score[0]) nn.init.normal_(self.prop_score[-1].weight, mean=0, std=0.001) nn.init.constant_(self.prop_score[-1].bias, 0) @classmethod def from_config(cls, cfg, input_shape): ret = super().from_config(cfg, input_shape) ret.update({ 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, 'use_sigmoid_ce': cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, 'use_fed_loss': cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, 'ignore_zero_cats': cfg.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS, 'fed_loss_num_cat': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT, 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, 'image_label_loss': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS, 'use_zeroshot_cls': cfg.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS, 'image_loss_weight': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT, 'with_softmax_prop': cfg.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP, 'caption_weight': cfg.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT, 'neg_cap_weight': cfg.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT, 'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX, 'debug': cfg.DEBUG or cfg.SAVE_DEBUG or cfg.IS_DEBUG, 'prior_prob': cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB, 'cat_freq_path': cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, 'fed_loss_freq_weight': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT, 'softmax_weak_loss': cfg.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS, }) if ret['use_zeroshot_cls']: ret['cls_score'] = ZeroShotClassifier(cfg, input_shape) return ret def losses(self, predictions, proposals, \ use_advanced_loss=True, classifier_info=(None,None,None)): """ enable advanced loss """ scores, proposal_deltas = predictions gt_classes = ( cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) ) num_classes = self.num_classes if self.dynamic_classifier: _, cls_id_map = classifier_info[1] gt_classes = cls_id_map[gt_classes] num_classes = scores.shape[1] - 1 assert cls_id_map[self.num_classes] == num_classes _log_classification_stats(scores, gt_classes) if len(proposals): proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" gt_boxes = cat( [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], dim=0, ) else: proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) if self.use_sigmoid_ce: loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) else: loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) return { "loss_cls": loss_cls, "loss_box_reg": self.box_reg_loss( proposal_boxes, gt_boxes, proposal_deltas, gt_classes, num_classes=num_classes) } def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): if pred_class_logits.numel() == 0: return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0. B = pred_class_logits.shape[0] C = pred_class_logits.shape[1] - 1 target = pred_class_logits.new_zeros(B, C + 1) target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1) target = target[:, :C] # B x C weight = 1 if self.use_fed_loss and (self.freq_weight is not None): # fedloss appeared = get_fed_loss_inds( gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight) appeared_mask = appeared.new_zeros(C + 1) appeared_mask[appeared] = 1 # C + 1 appeared_mask = appeared_mask[:C] fed_w = appeared_mask.view(1, C).expand(B, C) weight = weight * fed_w.float() if self.ignore_zero_cats and (self.freq_weight is not None): w = (self.freq_weight.view(-1) > 1e-4).float() weight = weight * w.view(1, C).expand(B, C) # import pdb; pdb.set_trace() cls_loss = F.binary_cross_entropy_with_logits( pred_class_logits[:, :-1], target, reduction='none') # B x C loss = torch.sum(cls_loss * weight) / B return loss def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): """ change _no_instance handling """ if pred_class_logits.numel() == 0: return pred_class_logits.new_zeros([1])[0] if self.ignore_zero_cats and (self.freq_weight is not None): zero_weight = torch.cat([ (self.freq_weight.view(-1) > 1e-4).float(), self.freq_weight.new_ones(1)]) # C + 1 loss = F.cross_entropy( pred_class_logits, gt_classes, weight=zero_weight, reduction="mean") elif self.use_fed_loss and (self.freq_weight is not None): # fedloss C = pred_class_logits.shape[1] - 1 appeared = get_fed_loss_inds( gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight) appeared_mask = appeared.new_zeros(C + 1).float() appeared_mask[appeared] = 1. # C + 1 appeared_mask[C] = 1. loss = F.cross_entropy( pred_class_logits, gt_classes, weight=appeared_mask, reduction="mean") else: loss = F.cross_entropy( pred_class_logits, gt_classes, reduction="mean") return loss def box_reg_loss( self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, num_classes=-1): """ Allow custom background index """ num_classes = num_classes if num_classes > 0 else self.num_classes box_dim = proposal_boxes.shape[1] # 4 or 5 fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0] if pred_deltas.shape[1] == box_dim: # cls-agnostic regression fg_pred_deltas = pred_deltas[fg_inds] else: fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ fg_inds, gt_classes[fg_inds] ] if self.box_reg_loss_type == "smooth_l1": gt_pred_deltas = self.box2box_transform.get_deltas( proposal_boxes[fg_inds], gt_boxes[fg_inds], ) loss_box_reg = smooth_l1_loss( fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum" ) elif self.box_reg_loss_type == "giou": fg_pred_boxes = self.box2box_transform.apply_deltas( fg_pred_deltas, proposal_boxes[fg_inds] ) loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum") else: raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") return loss_box_reg / max(gt_classes.numel(), 1.0) def inference(self, predictions, proposals): """ enable use proposal boxes """ predictions = (predictions[0], predictions[1]) boxes = self.predict_boxes(predictions, proposals) scores = self.predict_probs(predictions, proposals) if self.mult_proposal_score: proposal_scores = [p.get('objectness_logits') for p in proposals] scores = [(s * ps[:, None]) ** 0.5 \ for s, ps in zip(scores, proposal_scores)] image_shapes = [x.image_size for x in proposals] return fast_rcnn_inference( boxes, scores, image_shapes, self.test_score_thresh, self.test_nms_thresh, self.test_topk_per_image, ) def predict_probs(self, predictions, proposals): """ support sigmoid """ # scores, _ = predictions scores = predictions[0] num_inst_per_image = [len(p) for p in proposals] if self.use_sigmoid_ce: probs = scores.sigmoid() else: probs = F.softmax(scores, dim=-1) return probs.split(num_inst_per_image, dim=0) def image_label_losses(self, predictions, proposals, image_labels, \ classifier_info=(None,None,None), ann_type='image'): ''' Inputs: scores: N x (C + 1) image_labels B x 1 ''' num_inst_per_image = [len(p) for p in proposals] scores = predictions[0] scores = scores.split(num_inst_per_image, dim=0) # B x n x (C + 1) if self.with_softmax_prop: prop_scores = predictions[2].split(num_inst_per_image, dim=0) else: prop_scores = [None for _ in num_inst_per_image] B = len(scores) img_box_count = 0 select_size_count = 0 select_x_count = 0 select_y_count = 0 max_score_count = 0 storage = get_event_storage() loss = scores[0].new_zeros([1])[0] caption_loss = scores[0].new_zeros([1])[0] for idx, (score, labels, prop_score, p) in enumerate(zip( scores, image_labels, prop_scores, proposals)): if score.shape[0] == 0: loss += score.new_zeros([1])[0] continue if 'caption' in ann_type: score, caption_loss_img = self._caption_loss( score, classifier_info, idx, B) caption_loss += self.caption_weight * caption_loss_img if ann_type == 'caption': continue if self.debug: p.selected = score.new_zeros( (len(p),), dtype=torch.long) - 1 for i_l, label in enumerate(labels): if self.dynamic_classifier: if idx == 0 and i_l == 0 and comm.is_main_process(): storage.put_scalar('stats_label', label) label = classifier_info[1][1][label] assert label < score.shape[1] if self.image_label_loss in ['wsod', 'wsddn']: loss_i, ind = self._wsddn_loss(score, prop_score, label) elif self.image_label_loss == 'max_score': loss_i, ind = self._max_score_loss(score, label) elif self.image_label_loss == 'max_size': loss_i, ind = self._max_size_loss(score, label, p) elif self.image_label_loss == 'first': loss_i, ind = self._first_loss(score, label) elif self.image_label_loss == 'image': loss_i, ind = self._image_loss(score, label) elif self.image_label_loss == 'min_loss': loss_i, ind = self._min_loss_loss(score, label) else: assert 0 loss += loss_i / len(labels) if type(ind) == type([]): img_box_count = sum(ind) / len(ind) if self.debug: for ind_i in ind: p.selected[ind_i] = label else: img_box_count = ind select_size_count = p[ind].proposal_boxes.area() / \ (p.image_size[0] * p.image_size[1]) max_score_count = score[ind, label].sigmoid() select_x_count = (p.proposal_boxes.tensor[ind, 0] + \ p.proposal_boxes.tensor[ind, 2]) / 2 / p.image_size[1] select_y_count = (p.proposal_boxes.tensor[ind, 1] + \ p.proposal_boxes.tensor[ind, 3]) / 2 / p.image_size[0] if self.debug: p.selected[ind] = label loss = loss / B storage.put_scalar('stats_l_image', loss.item()) if 'caption' in ann_type: caption_loss = caption_loss / B loss = loss + caption_loss storage.put_scalar('stats_l_caption', caption_loss.item()) if comm.is_main_process(): storage.put_scalar('pool_stats', img_box_count) storage.put_scalar('stats_select_size', select_size_count) storage.put_scalar('stats_select_x', select_x_count) storage.put_scalar('stats_select_y', select_y_count) storage.put_scalar('stats_max_label_score', max_score_count) return { 'image_loss': loss * self.image_loss_weight, 'loss_cls': score.new_zeros([1])[0], 'loss_box_reg': score.new_zeros([1])[0]} def forward(self, x, classifier_info=(None,None,None)): """ enable classifier_info """ if x.dim() > 2: x = torch.flatten(x, start_dim=1) scores = [] if classifier_info[0] is not None: cls_scores = self.cls_score(x, classifier=classifier_info[0]) scores.append(cls_scores) else: cls_scores = self.cls_score(x) scores.append(cls_scores) if classifier_info[2] is not None: cap_cls = classifier_info[2] if self.sync_caption_batch: caption_scores = self.cls_score(x, classifier=cap_cls[:, :-1]) else: caption_scores = self.cls_score(x, classifier=cap_cls) scores.append(caption_scores) scores = torch.cat(scores, dim=1) # B x C' or B x N or B x (C'+N) proposal_deltas = self.bbox_pred(x) if self.with_softmax_prop: prop_score = self.prop_score(x) return scores, proposal_deltas, prop_score else: return scores, proposal_deltas def _caption_loss(self, score, classifier_info, idx, B): assert (classifier_info[2] is not None) assert self.add_image_box cls_and_cap_num = score.shape[1] cap_num = classifier_info[2].shape[0] score, caption_score = score.split( [cls_and_cap_num - cap_num, cap_num], dim=1) # n x (C + 1), n x B caption_score = caption_score[-1:] # 1 x B # -1: image level box caption_target = caption_score.new_zeros( caption_score.shape) # 1 x B or 1 x MB, M: num machines if self.sync_caption_batch: # caption_target: 1 x MB rank = comm.get_rank() global_idx = B * rank + idx assert (classifier_info[2][ global_idx, -1] - rank) ** 2 < 1e-8, \ '{} {} {} {} {}'.format( rank, global_idx, classifier_info[2][global_idx, -1], classifier_info[2].shape, classifier_info[2][:, -1]) caption_target[:, global_idx] = 1. else: assert caption_score.shape[1] == B caption_target[:, idx] = 1. caption_loss_img = F.binary_cross_entropy_with_logits( caption_score, caption_target, reduction='none') if self.sync_caption_batch: fg_mask = (caption_target > 0.5).float() assert (fg_mask.sum().item() - 1.) ** 2 < 1e-8, '{} {}'.format( fg_mask.shape, fg_mask) pos_loss = (caption_loss_img * fg_mask).sum() neg_loss = (caption_loss_img * (1. - fg_mask)).sum() caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss else: caption_loss_img = caption_loss_img.sum() return score, caption_loss_img def _wsddn_loss(self, score, prop_score, label): assert prop_score is not None loss = 0 final_score = score.sigmoid() * \ F.softmax(prop_score, dim=0) # B x (C + 1) img_score = torch.clamp( torch.sum(final_score, dim=0), min=1e-10, max=1-1e-10) # (C + 1) target = img_score.new_zeros(img_score.shape) # (C + 1) target[label] = 1. loss += F.binary_cross_entropy(img_score, target) ind = final_score[:, label].argmax() return loss, ind def _max_score_loss(self, score, label): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1. ind = score[:, label].argmax().item() loss += F.binary_cross_entropy_with_logits( score[ind], target, reduction='sum') return loss, ind def _min_loss_loss(self, score, label): loss = 0 target = score.new_zeros(score.shape) target[:, label] = 1. with torch.no_grad(): x = F.binary_cross_entropy_with_logits( score, target, reduction='none').sum(dim=1) # n ind = x.argmin().item() loss += F.binary_cross_entropy_with_logits( score[ind], target[0], reduction='sum') return loss, ind def _first_loss(self, score, label): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1. ind = 0 loss += F.binary_cross_entropy_with_logits( score[ind], target, reduction='sum') return loss, ind def _image_loss(self, score, label): assert self.add_image_box target = score.new_zeros(score.shape[1]) target[label] = 1. ind = score.shape[0] - 1 loss = F.binary_cross_entropy_with_logits( score[ind], target, reduction='sum') return loss, ind def _max_size_loss(self, score, label, p): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1. sizes = p.proposal_boxes.area() ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0 if self.softmax_weak_loss: loss += F.cross_entropy( score[ind:ind+1], score.new_tensor(label, dtype=torch.long).view(1), reduction='sum') else: loss += F.binary_cross_entropy_with_logits( score[ind], target, reduction='sum') return loss, ind def put_label_distribution(storage, hist_name, hist_counts, num_classes): """ """ ht_min, ht_max = 0, num_classes hist_edges = torch.linspace( start=ht_min, end=ht_max, steps=num_classes + 1, dtype=torch.float32) hist_params = dict( tag=hist_name, min=ht_min, max=ht_max, num=float(hist_counts.sum()), sum=float((hist_counts * torch.arange(len(hist_counts))).sum()), sum_squares=float(((hist_counts * torch.arange(len(hist_counts))) ** 2).sum()), bucket_limits=hist_edges[1:].tolist(), bucket_counts=hist_counts.tolist(), global_step=storage._iter, ) storage._histograms.append(hist_params)