Detic / detic /modeling /roi_heads /res5_roi_heads.py
AK391
files
159f437
raw
history blame
7.33 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import inspect
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, nonzero_tuple
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
from detectron2.modeling.roi_heads.box_head import build_box_head
from .detic_fast_rcnn import DeticFastRCNNOutputLayers
from ..debug import debug_second_stage
from torch.cuda.amp import autocast
@ROI_HEADS_REGISTRY.register()
class CustomRes5ROIHeads(Res5ROIHeads):
@configurable
def __init__(self, **kwargs):
cfg = kwargs.pop('cfg')
super().__init__(**kwargs)
stage_channel_factor = 2 ** 3
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
self.with_image_labels = cfg.WITH_IMAGE_LABELS
self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS
self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX
self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP
self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE
self.box_predictor = DeticFastRCNNOutputLayers(
cfg, ShapeSpec(channels=out_channels, height=1, width=1)
)
self.save_debug = cfg.SAVE_DEBUG
self.save_debug_path = cfg.SAVE_DEBUG_PATH
if self.save_debug:
self.debug_show_name = cfg.DEBUG_SHOW_NAME
self.vis_thresh = cfg.VIS_THRESH
self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
self.bgr = (cfg.INPUT.FORMAT == 'BGR')
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
ret['cfg'] = cfg
return ret
def forward(self, images, features, proposals, targets=None,
ann_type='box', classifier_info=(None,None,None)):
'''
enable debug and image labels
classifier_info is shared across the batch
'''
if not self.save_debug:
del images
if self.training:
if ann_type in ['box']:
proposals = self.label_and_sample_proposals(
proposals, targets)
else:
proposals = self.get_top_proposals(proposals)
proposal_boxes = [x.proposal_boxes for x in proposals]
box_features = self._shared_roi_transform(
[features[f] for f in self.in_features], proposal_boxes
)
predictions = self.box_predictor(
box_features.mean(dim=[2, 3]),
classifier_info=classifier_info)
if self.add_feature_to_prop:
feats_per_image = box_features.mean(dim=[2, 3]).split(
[len(p) for p in proposals], dim=0)
for feat, p in zip(feats_per_image, proposals):
p.feat = feat
if self.training:
del features
if (ann_type != 'box'):
image_labels = [x._pos_category_ids for x in targets]
losses = self.box_predictor.image_label_losses(
predictions, proposals, image_labels,
classifier_info=classifier_info,
ann_type=ann_type)
else:
losses = self.box_predictor.losses(
(predictions[0], predictions[1]), proposals)
if self.with_image_labels:
assert 'image_loss' not in losses
losses['image_loss'] = predictions[0].new_zeros([1])[0]
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
if ann_type != 'box':
image_labels = [x._pos_category_ids for x in targets]
else:
image_labels = [[] for x in targets]
debug_second_stage(
[denormalizer(x.clone()) for x in images],
targets, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
image_labels=image_labels,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return proposals, losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
debug_second_stage(
[denormalizer(x.clone()) for x in images],
pred_instances, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return pred_instances, {}
def get_top_proposals(self, proposals):
for i in range(len(proposals)):
proposals[i].proposal_boxes.clip(proposals[i].image_size)
proposals = [p[:self.ws_num_props] for p in proposals]
for i, p in enumerate(proposals):
p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
if self.add_image_box:
proposals[i] = self._add_image_box(p)
return proposals
def _add_image_box(self, p, use_score=False):
image_box = Instances(p.image_size)
n = 1
h, w = p.image_size
if self.image_box_size < 1.0:
f = self.image_box_size
image_box.proposal_boxes = Boxes(
p.proposal_boxes.tensor.new_tensor(
[w * (1. - f) / 2.,
h * (1. - f) / 2.,
w * (1. - (1. - f) / 2.),
h * (1. - (1. - f) / 2.)]
).view(n, 4))
else:
image_box.proposal_boxes = Boxes(
p.proposal_boxes.tensor.new_tensor(
[0, 0, w, h]).view(n, 4))
if use_score:
image_box.scores = \
p.objectness_logits.new_ones(n)
image_box.pred_classes = \
p.objectness_logits.new_zeros(n, dtype=torch.long)
image_box.objectness_logits = \
p.objectness_logits.new_ones(n)
else:
image_box.objectness_logits = \
p.objectness_logits.new_ones(n)
return Instances.cat([p, image_box])