Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import copy | |
import logging | |
import numpy as np | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
import json | |
from detectron2.utils.events import get_event_storage | |
from detectron2.config import configurable | |
from detectron2.structures import ImageList, Instances, Boxes | |
import detectron2.utils.comm as comm | |
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
from detectron2.modeling.postprocessing import detector_postprocess | |
from detectron2.utils.visualizer import Visualizer, _create_text_labels | |
from detectron2.data.detection_utils import convert_image_to_rgb | |
from torch.cuda.amp import autocast | |
from ..text.text_encoder import build_text_encoder | |
from ..utils import load_class_freq, get_fed_loss_inds | |
class CustomRCNN(GeneralizedRCNN): | |
''' | |
Add image labels | |
''' | |
def __init__( | |
self, | |
with_image_labels = False, | |
dataset_loss_weight = [], | |
fp16 = False, | |
sync_caption_batch = False, | |
roi_head_name = '', | |
cap_batch_ratio = 4, | |
with_caption = False, | |
dynamic_classifier = False, | |
**kwargs): | |
""" | |
""" | |
self.with_image_labels = with_image_labels | |
self.dataset_loss_weight = dataset_loss_weight | |
self.fp16 = fp16 | |
self.with_caption = with_caption | |
self.sync_caption_batch = sync_caption_batch | |
self.roi_head_name = roi_head_name | |
self.cap_batch_ratio = cap_batch_ratio | |
self.dynamic_classifier = dynamic_classifier | |
self.return_proposal = False | |
if self.dynamic_classifier: | |
self.freq_weight = kwargs.pop('freq_weight') | |
self.num_classes = kwargs.pop('num_classes') | |
self.num_sample_cats = kwargs.pop('num_sample_cats') | |
super().__init__(**kwargs) | |
assert self.proposal_generator is not None | |
if self.with_caption: | |
assert not self.dynamic_classifier | |
self.text_encoder = build_text_encoder(pretrain=True) | |
for v in self.text_encoder.parameters(): | |
v.requires_grad = False | |
def from_config(cls, cfg): | |
ret = super().from_config(cfg) | |
ret.update({ | |
'with_image_labels': cfg.WITH_IMAGE_LABELS, | |
'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT, | |
'fp16': cfg.FP16, | |
'with_caption': cfg.MODEL.WITH_CAPTION, | |
'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, | |
'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, | |
'roi_head_name': cfg.MODEL.ROI_HEADS.NAME, | |
'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO, | |
}) | |
if ret['dynamic_classifier']: | |
ret['freq_weight'] = load_class_freq( | |
cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, | |
cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT) | |
ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS | |
return ret | |
def inference( | |
self, | |
batched_inputs: Tuple[Dict[str, torch.Tensor]], | |
detected_instances: Optional[List[Instances]] = None, | |
do_postprocess: bool = True, | |
): | |
assert not self.training | |
assert detected_instances is None | |
images = self.preprocess_image(batched_inputs) | |
features = self.backbone(images.tensor) | |
proposals, _ = self.proposal_generator(images, features, None) | |
results, _ = self.roi_heads(images, features, proposals) | |
if do_postprocess: | |
assert not torch.jit.is_scripting(), \ | |
"Scripting is not supported for postprocess." | |
return CustomRCNN._postprocess( | |
results, batched_inputs, images.image_sizes) | |
else: | |
return results | |
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
""" | |
Add ann_type | |
Ignore proposal loss when training with image labels | |
""" | |
if not self.training: | |
return self.inference(batched_inputs) | |
images = self.preprocess_image(batched_inputs) | |
ann_type = 'box' | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
if self.with_image_labels: | |
for inst, x in zip(gt_instances, batched_inputs): | |
inst._ann_type = x['ann_type'] | |
inst._pos_category_ids = x['pos_category_ids'] | |
ann_types = [x['ann_type'] for x in batched_inputs] | |
assert len(set(ann_types)) == 1 | |
ann_type = ann_types[0] | |
if ann_type in ['prop', 'proptag']: | |
for t in gt_instances: | |
t.gt_classes *= 0 | |
if self.fp16: # TODO (zhouxy): improve | |
with autocast(): | |
features = self.backbone(images.tensor.half()) | |
features = {k: v.float() for k, v in features.items()} | |
else: | |
features = self.backbone(images.tensor) | |
cls_features, cls_inds, caption_features = None, None, None | |
if self.with_caption and 'caption' in ann_type: | |
inds = [torch.randint(len(x['captions']), (1,))[0].item() \ | |
for x in batched_inputs] | |
caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)] | |
caption_features = self.text_encoder(caps).float() | |
if self.sync_caption_batch: | |
caption_features = self._sync_caption_features( | |
caption_features, ann_type, len(batched_inputs)) | |
if self.dynamic_classifier and ann_type != 'caption': | |
cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds | |
ind_with_bg = cls_inds[0].tolist() + [-1] | |
cls_features = self.roi_heads.box_predictor[ | |
0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous() | |
classifier_info = cls_features, cls_inds, caption_features | |
proposals, proposal_losses = self.proposal_generator( | |
images, features, gt_instances) | |
if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']: | |
proposals, detector_losses = self.roi_heads( | |
images, features, proposals, gt_instances) | |
else: | |
proposals, detector_losses = self.roi_heads( | |
images, features, proposals, gt_instances, | |
ann_type=ann_type, classifier_info=classifier_info) | |
if self.vis_period > 0: | |
storage = get_event_storage() | |
if storage.iter % self.vis_period == 0: | |
self.visualize_training(batched_inputs, proposals) | |
losses = {} | |
losses.update(detector_losses) | |
if self.with_image_labels: | |
if ann_type in ['box', 'prop', 'proptag']: | |
losses.update(proposal_losses) | |
else: # ignore proposal loss for non-bbox data | |
losses.update({k: v * 0 for k, v in proposal_losses.items()}) | |
else: | |
losses.update(proposal_losses) | |
if len(self.dataset_loss_weight) > 0: | |
dataset_sources = [x['dataset_source'] for x in batched_inputs] | |
assert len(set(dataset_sources)) == 1 | |
dataset_source = dataset_sources[0] | |
for k in losses: | |
losses[k] *= self.dataset_loss_weight[dataset_source] | |
if self.return_proposal: | |
return proposals, losses | |
else: | |
return losses | |
def _sync_caption_features(self, caption_features, ann_type, BS): | |
has_caption_feature = (caption_features is not None) | |
BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS | |
rank = torch.full( | |
(BS, 1), comm.get_rank(), dtype=torch.float32, | |
device=self.device) | |
if not has_caption_feature: | |
caption_features = rank.new_zeros((BS, 512)) | |
caption_features = torch.cat([caption_features, rank], dim=1) | |
global_caption_features = comm.all_gather(caption_features) | |
caption_features = torch.cat( | |
[x.to(self.device) for x in global_caption_features], dim=0) \ | |
if has_caption_feature else None # (NB) x (D + 1) | |
return caption_features | |
def _sample_cls_inds(self, gt_instances, ann_type='box'): | |
if ann_type == 'box': | |
gt_classes = torch.cat( | |
[x.gt_classes for x in gt_instances]) | |
C = len(self.freq_weight) | |
freq_weight = self.freq_weight | |
else: | |
gt_classes = torch.cat( | |
[torch.tensor( | |
x._pos_category_ids, | |
dtype=torch.long, device=x.gt_classes.device) \ | |
for x in gt_instances]) | |
C = self.num_classes | |
freq_weight = None | |
assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C) | |
inds = get_fed_loss_inds( | |
gt_classes, self.num_sample_cats, C, | |
weight=freq_weight) | |
cls_id_map = gt_classes.new_full( | |
(self.num_classes + 1,), len(inds)) | |
cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) | |
return inds, cls_id_map |