# Copyright (c) Facebook, Inc. and its affiliates. import copy import logging import numpy as np from typing import Dict, List, Optional, Tuple import torch from torch import nn import json from detectron2.utils.events import get_event_storage from detectron2.config import configurable from detectron2.structures import ImageList, Instances, Boxes import detectron2.utils.comm as comm from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN from detectron2.modeling.postprocessing import detector_postprocess from detectron2.utils.visualizer import Visualizer, _create_text_labels from detectron2.data.detection_utils import convert_image_to_rgb from torch.cuda.amp import autocast from ..text.text_encoder import build_text_encoder from ..utils import load_class_freq, get_fed_loss_inds @META_ARCH_REGISTRY.register() class CustomRCNN(GeneralizedRCNN): ''' Add image labels ''' @configurable def __init__( self, with_image_labels = False, dataset_loss_weight = [], fp16 = False, sync_caption_batch = False, roi_head_name = '', cap_batch_ratio = 4, with_caption = False, dynamic_classifier = False, **kwargs): """ """ self.with_image_labels = with_image_labels self.dataset_loss_weight = dataset_loss_weight self.fp16 = fp16 self.with_caption = with_caption self.sync_caption_batch = sync_caption_batch self.roi_head_name = roi_head_name self.cap_batch_ratio = cap_batch_ratio self.dynamic_classifier = dynamic_classifier self.return_proposal = False if self.dynamic_classifier: self.freq_weight = kwargs.pop('freq_weight') self.num_classes = kwargs.pop('num_classes') self.num_sample_cats = kwargs.pop('num_sample_cats') super().__init__(**kwargs) assert self.proposal_generator is not None if self.with_caption: assert not self.dynamic_classifier self.text_encoder = build_text_encoder(pretrain=True) for v in self.text_encoder.parameters(): v.requires_grad = False @classmethod def from_config(cls, cfg): ret = super().from_config(cfg) ret.update({ 'with_image_labels': cfg.WITH_IMAGE_LABELS, 'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT, 'fp16': cfg.FP16, 'with_caption': cfg.MODEL.WITH_CAPTION, 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, 'roi_head_name': cfg.MODEL.ROI_HEADS.NAME, 'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO, }) if ret['dynamic_classifier']: ret['freq_weight'] = load_class_freq( cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT) ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS return ret def inference( self, batched_inputs: Tuple[Dict[str, torch.Tensor]], detected_instances: Optional[List[Instances]] = None, do_postprocess: bool = True, ): assert not self.training assert detected_instances is None images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) proposals, _ = self.proposal_generator(images, features, None) results, _ = self.roi_heads(images, features, proposals) if do_postprocess: assert not torch.jit.is_scripting(), \ "Scripting is not supported for postprocess." return CustomRCNN._postprocess( results, batched_inputs, images.image_sizes) else: return results def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): """ Add ann_type Ignore proposal loss when training with image labels """ if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) ann_type = 'box' gt_instances = [x["instances"].to(self.device) for x in batched_inputs] if self.with_image_labels: for inst, x in zip(gt_instances, batched_inputs): inst._ann_type = x['ann_type'] inst._pos_category_ids = x['pos_category_ids'] ann_types = [x['ann_type'] for x in batched_inputs] assert len(set(ann_types)) == 1 ann_type = ann_types[0] if ann_type in ['prop', 'proptag']: for t in gt_instances: t.gt_classes *= 0 if self.fp16: # TODO (zhouxy): improve with autocast(): features = self.backbone(images.tensor.half()) features = {k: v.float() for k, v in features.items()} else: features = self.backbone(images.tensor) cls_features, cls_inds, caption_features = None, None, None if self.with_caption and 'caption' in ann_type: inds = [torch.randint(len(x['captions']), (1,))[0].item() \ for x in batched_inputs] caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)] caption_features = self.text_encoder(caps).float() if self.sync_caption_batch: caption_features = self._sync_caption_features( caption_features, ann_type, len(batched_inputs)) if self.dynamic_classifier and ann_type != 'caption': cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds ind_with_bg = cls_inds[0].tolist() + [-1] cls_features = self.roi_heads.box_predictor[ 0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous() classifier_info = cls_features, cls_inds, caption_features proposals, proposal_losses = self.proposal_generator( images, features, gt_instances) if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']: proposals, detector_losses = self.roi_heads( images, features, proposals, gt_instances) else: proposals, detector_losses = self.roi_heads( images, features, proposals, gt_instances, ann_type=ann_type, classifier_info=classifier_info) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) if self.with_image_labels: if ann_type in ['box', 'prop', 'proptag']: losses.update(proposal_losses) else: # ignore proposal loss for non-bbox data losses.update({k: v * 0 for k, v in proposal_losses.items()}) else: losses.update(proposal_losses) if len(self.dataset_loss_weight) > 0: dataset_sources = [x['dataset_source'] for x in batched_inputs] assert len(set(dataset_sources)) == 1 dataset_source = dataset_sources[0] for k in losses: losses[k] *= self.dataset_loss_weight[dataset_source] if self.return_proposal: return proposals, losses else: return losses def _sync_caption_features(self, caption_features, ann_type, BS): has_caption_feature = (caption_features is not None) BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS rank = torch.full( (BS, 1), comm.get_rank(), dtype=torch.float32, device=self.device) if not has_caption_feature: caption_features = rank.new_zeros((BS, 512)) caption_features = torch.cat([caption_features, rank], dim=1) global_caption_features = comm.all_gather(caption_features) caption_features = torch.cat( [x.to(self.device) for x in global_caption_features], dim=0) \ if has_caption_feature else None # (NB) x (D + 1) return caption_features def _sample_cls_inds(self, gt_instances, ann_type='box'): if ann_type == 'box': gt_classes = torch.cat( [x.gt_classes for x in gt_instances]) C = len(self.freq_weight) freq_weight = self.freq_weight else: gt_classes = torch.cat( [torch.tensor( x._pos_category_ids, dtype=torch.long, device=x.gt_classes.device) \ for x in gt_instances]) C = self.num_classes freq_weight = None assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C) inds = get_fed_loss_inds( gt_classes, self.num_sample_cats, C, weight=freq_weight) cls_id_map = gt_classes.new_full( (self.num_classes + 1,), len(inds)) cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) return inds, cls_id_map