Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import inspect | |
import logging | |
import numpy as np | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
from detectron2.config import configurable | |
from detectron2.layers import ShapeSpec, nonzero_tuple | |
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou | |
from detectron2.utils.events import get_event_storage | |
from detectron2.utils.registry import Registry | |
from detectron2.modeling.box_regression import Box2BoxTransform | |
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference | |
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads | |
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient | |
from detectron2.modeling.roi_heads.box_head import build_box_head | |
from .detic_fast_rcnn import DeticFastRCNNOutputLayers | |
from ..debug import debug_second_stage | |
from torch.cuda.amp import autocast | |
class CustomRes5ROIHeads(Res5ROIHeads): | |
def __init__(self, **kwargs): | |
cfg = kwargs.pop('cfg') | |
super().__init__(**kwargs) | |
stage_channel_factor = 2 ** 3 | |
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor | |
self.with_image_labels = cfg.WITH_IMAGE_LABELS | |
self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS | |
self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX | |
self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP | |
self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE | |
self.box_predictor = DeticFastRCNNOutputLayers( | |
cfg, ShapeSpec(channels=out_channels, height=1, width=1) | |
) | |
self.save_debug = cfg.SAVE_DEBUG | |
self.save_debug_path = cfg.SAVE_DEBUG_PATH | |
if self.save_debug: | |
self.debug_show_name = cfg.DEBUG_SHOW_NAME | |
self.vis_thresh = cfg.VIS_THRESH | |
self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to( | |
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) | |
self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to( | |
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) | |
self.bgr = (cfg.INPUT.FORMAT == 'BGR') | |
def from_config(cls, cfg, input_shape): | |
ret = super().from_config(cfg, input_shape) | |
ret['cfg'] = cfg | |
return ret | |
def forward(self, images, features, proposals, targets=None, | |
ann_type='box', classifier_info=(None,None,None)): | |
''' | |
enable debug and image labels | |
classifier_info is shared across the batch | |
''' | |
if not self.save_debug: | |
del images | |
if self.training: | |
if ann_type in ['box']: | |
proposals = self.label_and_sample_proposals( | |
proposals, targets) | |
else: | |
proposals = self.get_top_proposals(proposals) | |
proposal_boxes = [x.proposal_boxes for x in proposals] | |
box_features = self._shared_roi_transform( | |
[features[f] for f in self.in_features], proposal_boxes | |
) | |
predictions = self.box_predictor( | |
box_features.mean(dim=[2, 3]), | |
classifier_info=classifier_info) | |
if self.add_feature_to_prop: | |
feats_per_image = box_features.mean(dim=[2, 3]).split( | |
[len(p) for p in proposals], dim=0) | |
for feat, p in zip(feats_per_image, proposals): | |
p.feat = feat | |
if self.training: | |
del features | |
if (ann_type != 'box'): | |
image_labels = [x._pos_category_ids for x in targets] | |
losses = self.box_predictor.image_label_losses( | |
predictions, proposals, image_labels, | |
classifier_info=classifier_info, | |
ann_type=ann_type) | |
else: | |
losses = self.box_predictor.losses( | |
(predictions[0], predictions[1]), proposals) | |
if self.with_image_labels: | |
assert 'image_loss' not in losses | |
losses['image_loss'] = predictions[0].new_zeros([1])[0] | |
if self.save_debug: | |
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean | |
if ann_type != 'box': | |
image_labels = [x._pos_category_ids for x in targets] | |
else: | |
image_labels = [[] for x in targets] | |
debug_second_stage( | |
[denormalizer(x.clone()) for x in images], | |
targets, proposals=proposals, | |
save_debug=self.save_debug, | |
debug_show_name=self.debug_show_name, | |
vis_thresh=self.vis_thresh, | |
image_labels=image_labels, | |
save_debug_path=self.save_debug_path, | |
bgr=self.bgr) | |
return proposals, losses | |
else: | |
pred_instances, _ = self.box_predictor.inference(predictions, proposals) | |
pred_instances = self.forward_with_given_boxes(features, pred_instances) | |
if self.save_debug: | |
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean | |
debug_second_stage( | |
[denormalizer(x.clone()) for x in images], | |
pred_instances, proposals=proposals, | |
save_debug=self.save_debug, | |
debug_show_name=self.debug_show_name, | |
vis_thresh=self.vis_thresh, | |
save_debug_path=self.save_debug_path, | |
bgr=self.bgr) | |
return pred_instances, {} | |
def get_top_proposals(self, proposals): | |
for i in range(len(proposals)): | |
proposals[i].proposal_boxes.clip(proposals[i].image_size) | |
proposals = [p[:self.ws_num_props] for p in proposals] | |
for i, p in enumerate(proposals): | |
p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() | |
if self.add_image_box: | |
proposals[i] = self._add_image_box(p) | |
return proposals | |
def _add_image_box(self, p, use_score=False): | |
image_box = Instances(p.image_size) | |
n = 1 | |
h, w = p.image_size | |
if self.image_box_size < 1.0: | |
f = self.image_box_size | |
image_box.proposal_boxes = Boxes( | |
p.proposal_boxes.tensor.new_tensor( | |
[w * (1. - f) / 2., | |
h * (1. - f) / 2., | |
w * (1. - (1. - f) / 2.), | |
h * (1. - (1. - f) / 2.)] | |
).view(n, 4)) | |
else: | |
image_box.proposal_boxes = Boxes( | |
p.proposal_boxes.tensor.new_tensor( | |
[0, 0, w, h]).view(n, 4)) | |
if use_score: | |
image_box.scores = \ | |
p.objectness_logits.new_ones(n) | |
image_box.pred_classes = \ | |
p.objectness_logits.new_zeros(n, dtype=torch.long) | |
image_box.objectness_logits = \ | |
p.objectness_logits.new_ones(n) | |
else: | |
image_box.objectness_logits = \ | |
p.objectness_logits.new_ones(n) | |
return Instances.cat([p, image_box]) |