laizeqiang
update
ee25e9d
raw
history blame contribute delete
No virus
22.6 kB
import math
import torch
from typing import Dict, List, Optional, Tuple, Union
from detectron2.config import configurable
from detectron2.structures import Boxes, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
from detectron2.modeling.poolers import ROIPooler
from detectron2.layers import batched_nms
from .grit_fast_rcnn import GRiTFastRCNNOutputLayers
from ..text.text_decoder import TransformerDecoderTextualHead, GRiTTextDecoder, AutoRegressiveBeamSearch
from ..text.load_text_token import LoadTextTokens
from transformers import BertTokenizer
from iGPT.models.grit_src.grit.data.custom_dataset_mapper import ObjDescription
from ..soft_nms import batched_soft_nms
import logging
logger = logging.getLogger(__name__)
@ROI_HEADS_REGISTRY.register()
class GRiTROIHeadsAndTextDecoder(CascadeROIHeads):
@configurable
def __init__(
self,
*,
text_decoder_transformer,
train_task: list,
test_task: str,
mult_proposal_score: bool = False,
mask_weight: float = 1.0,
object_feat_pooler=None,
soft_nms_enabled=False,
beam_size=1,
**kwargs,
):
super().__init__(**kwargs)
self.mult_proposal_score = mult_proposal_score
self.mask_weight = mask_weight
self.object_feat_pooler = object_feat_pooler
self.soft_nms_enabled = soft_nms_enabled
self.test_task = test_task
self.beam_size = beam_size
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
self.tokenizer = tokenizer
assert test_task in train_task, 'GRiT has not been trained on {} task, ' \
'please verify the task name or train a new ' \
'GRiT on {} task'.format(test_task, test_task)
task_begin_tokens = {}
for i, task in enumerate(train_task):
if i == 0:
task_begin_tokens[task] = tokenizer.cls_token_id
else:
task_begin_tokens[task] = 103 + i
self.task_begin_tokens = task_begin_tokens
beamsearch_decode = AutoRegressiveBeamSearch(
end_token_id=tokenizer.sep_token_id,
max_steps=40,
beam_size=beam_size,
objectdet=test_task == "ObjectDet",
per_node_beam_size=1,
)
self.text_decoder = GRiTTextDecoder(
text_decoder_transformer,
beamsearch_decode=beamsearch_decode,
begin_token_id=task_begin_tokens[test_task],
loss_type='smooth',
tokenizer=tokenizer,
)
self.get_target_text_tokens = LoadTextTokens(tokenizer, max_text_len=40, padding='do_not_pad')
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
text_decoder_transformer = TransformerDecoderTextualHead(
object_feature_size=cfg.MODEL.FPN.OUT_CHANNELS,
vocab_size=cfg.TEXT_DECODER.VOCAB_SIZE,
hidden_size=cfg.TEXT_DECODER.HIDDEN_SIZE,
num_layers=cfg.TEXT_DECODER.NUM_LAYERS,
attention_heads=cfg.TEXT_DECODER.ATTENTION_HEADS,
feedforward_size=cfg.TEXT_DECODER.FEEDFORWARD_SIZE,
mask_future_positions=True,
padding_idx=0,
decoder_type='bert_en',
use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
)
ret.update({
'text_decoder_transformer': text_decoder_transformer,
'train_task': cfg.MODEL.TRAIN_TASK,
'test_task': cfg.MODEL.TEST_TASK,
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
'soft_nms_enabled': cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED,
'beam_size': cfg.MODEL.BEAM_SIZE,
})
return ret
@classmethod
def _init_box_head(self, cfg, input_shape):
ret = super()._init_box_head(cfg, input_shape)
del ret['box_predictors']
cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
box_predictors = []
for box_head, bbox_reg_weights in zip(ret['box_heads'], \
cascade_bbox_reg_weights):
box_predictors.append(
GRiTFastRCNNOutputLayers(
cfg, box_head.output_shape,
box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
))
ret['box_predictors'] = box_predictors
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
object_feat_pooler = ROIPooler(
output_size=cfg.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type=pooler_type,
)
ret['object_feat_pooler'] = object_feat_pooler
return ret
def check_if_all_background(self, proposals, targets, stage):
all_background = True
for proposals_per_image in proposals:
if not (proposals_per_image.gt_classes == self.num_classes).all():
all_background = False
if all_background:
logger.info('all proposals are background at stage {}'.format(stage))
proposals[0].proposal_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].gt_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[0, :]
proposals[0].objectness_logits[0] = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10)))
proposals[0].gt_classes[0] = targets[0].gt_classes[0]
proposals[0].gt_object_descriptions.data[0] = targets[0].gt_object_descriptions.data[0]
if 'foreground' in proposals[0].get_fields().keys():
proposals[0].foreground[0] = 1
return proposals
def _forward_box(self, features, proposals, targets=None, task="ObjectDet"):
if self.training:
proposals = self.check_if_all_background(proposals, targets, 0)
if (not self.training) and self.mult_proposal_score:
if len(proposals) > 0 and proposals[0].has('scores'):
proposal_scores = [p.get('scores') for p in proposals]
else:
proposal_scores = [p.get('objectness_logits') for p in proposals]
features = [features[f] for f in self.box_in_features]
head_outputs = []
prev_pred_boxes = None
image_sizes = [x.image_size for x in proposals]
for k in range(self.num_cascade_stages):
if k > 0:
proposals = self._create_proposals_from_boxes(
prev_pred_boxes, image_sizes,
logits=[p.objectness_logits for p in proposals])
if self.training:
proposals = self._match_and_label_boxes_GRiT(
proposals, k, targets)
proposals = self.check_if_all_background(proposals, targets, k)
predictions = self._run_stage(features, proposals, k)
prev_pred_boxes = self.box_predictor[k].predict_boxes(
(predictions[0], predictions[1]), proposals)
head_outputs.append((self.box_predictor[k], predictions, proposals))
if self.training:
object_features = self.object_feat_pooler(features, [x.proposal_boxes for x in proposals])
object_features = _ScaleGradient.apply(object_features, 1.0 / self.num_cascade_stages)
foreground = torch.cat([x.foreground for x in proposals])
object_features = object_features[foreground > 0]
object_descriptions = []
for x in proposals:
object_descriptions += x.gt_object_descriptions[x.foreground > 0].data
object_descriptions = ObjDescription(object_descriptions)
object_descriptions = object_descriptions.data
if len(object_descriptions) > 0:
begin_token = self.task_begin_tokens[task]
text_decoder_inputs = self.get_target_text_tokens(object_descriptions, object_features, begin_token)
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_inputs.update({'object_features': object_features})
text_decoder_loss = self.text_decoder(text_decoder_inputs)
else:
text_decoder_loss = head_outputs[0][1][0].new_zeros([1])[0]
losses = {}
storage = get_event_storage()
# RoI Head losses (For the proposal generator loss, please find it in grit.py)
for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
with storage.name_scope("stage{}".format(stage)):
stage_losses = predictor.losses(
(predictions[0], predictions[1]), proposals)
losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()})
# Text Decoder loss
losses.update({'text_decoder_loss': text_decoder_loss})
return losses
else:
scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
logits_per_stage = [(h[1][0],) for h in head_outputs]
scores = [
sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
for scores_per_image in zip(*scores_per_stage)
]
logits = [
sum(list(logits_per_image)) * (1.0 / self.num_cascade_stages)
for logits_per_image in zip(*logits_per_stage)
]
if self.mult_proposal_score:
scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)]
predictor, predictions, proposals = head_outputs[-1]
boxes = predictor.predict_boxes(
(predictions[0], predictions[1]), proposals)
assert len(boxes) == 1
pred_instances, _ = self.fast_rcnn_inference_GRiT(
boxes,
scores,
logits,
image_sizes,
predictor.test_score_thresh,
predictor.test_nms_thresh,
predictor.test_topk_per_image,
self.soft_nms_enabled,
)
assert len(pred_instances) == 1, "Only support one image"
for i, pred_instance in enumerate(pred_instances):
if len(pred_instance.pred_boxes) > 0:
object_features = self.object_feat_pooler(features, [pred_instance.pred_boxes])
object_features = object_features.view(
object_features.shape[0], object_features.shape[1], -1).permute(0, 2, 1).contiguous()
text_decoder_output = self.text_decoder({'object_features': object_features})
if self.beam_size > 1 and self.test_task == "ObjectDet":
pred_boxes = []
pred_scores = []
pred_classes = []
pred_object_descriptions = []
for beam_id in range(self.beam_size):
pred_boxes.append(pred_instance.pred_boxes.tensor)
# object score = sqrt(objectness score x description score)
pred_scores.append((pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])[:, beam_id]) ** 0.5)
pred_classes.append(pred_instance.pred_classes)
for prediction in text_decoder_output['predictions'][:, beam_id, :]:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
merged_instances = Instances(image_sizes[0])
if torch.cat(pred_scores, dim=0).shape[0] <= predictor.test_topk_per_image:
merged_instances.scores = torch.cat(pred_scores, dim=0)
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0))
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)
merged_instances.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_scores, top_idx = torch.topk(
torch.cat(pred_scores, dim=0), predictor.test_topk_per_image)
merged_instances.scores = pred_scores
merged_instances.pred_boxes = Boxes(torch.cat(pred_boxes, dim=0)[top_idx, :])
merged_instances.pred_classes = torch.cat(pred_classes, dim=0)[top_idx]
merged_instances.pred_object_descriptions = \
ObjDescription(ObjDescription(pred_object_descriptions)[top_idx].data)
pred_instances[i] = merged_instances
else:
# object score = sqrt(objectness score x description score)
pred_instance.scores = (pred_instance.scores *
torch.exp(text_decoder_output['logprobs'])) ** 0.5
pred_object_descriptions = []
for prediction in text_decoder_output['predictions']:
# convert text tokens to words
description = self.tokenizer.decode(prediction.tolist()[1:], skip_special_tokens=True)
pred_object_descriptions.append(description)
pred_instance.pred_object_descriptions = ObjDescription(pred_object_descriptions)
else:
pred_instance.pred_object_descriptions = ObjDescription([])
return pred_instances
def forward(self, features, proposals, targets=None, targets_task="ObjectDet"):
if self.training:
proposals = self.label_and_sample_proposals(
proposals, targets)
losses = self._forward_box(features, proposals, targets, task=targets_task)
if targets[0].has('gt_masks'):
mask_losses = self._forward_mask(features, proposals)
losses.update({k: v * self.mask_weight \
for k, v in mask_losses.items()})
else:
losses.update(self._get_empty_mask_loss(device=proposals[0].objectness_logits.device))
return proposals, losses
else:
pred_instances = self._forward_box(features, proposals, task=self.test_task)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
@torch.no_grad()
def _match_and_label_boxes_GRiT(self, proposals, stage, targets):
"""
Add "gt_object_description" and "foreground" to detectron2's _match_and_label_boxes
"""
num_fg_samples, num_bg_samples = [], []
for proposals_per_image, targets_per_image in zip(proposals, targets):
match_quality_matrix = pairwise_iou(
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
)
# proposal_labels are 0 or 1
matched_idxs, proposal_labels = self.proposal_matchers[stage](match_quality_matrix)
if len(targets_per_image) > 0:
gt_classes = targets_per_image.gt_classes[matched_idxs]
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
gt_classes[proposal_labels == 0] = self.num_classes
foreground = torch.ones_like(gt_classes)
foreground[proposal_labels == 0] = 0
gt_boxes = targets_per_image.gt_boxes[matched_idxs]
gt_object_descriptions = targets_per_image.gt_object_descriptions[matched_idxs]
else:
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
foreground = torch.zeros_like(gt_classes)
gt_boxes = Boxes(
targets_per_image.gt_boxes.tensor.new_zeros((len(proposals_per_image), 4))
)
gt_object_descriptions = ObjDescription(['None' for i in range(len(proposals_per_image))])
proposals_per_image.gt_classes = gt_classes
proposals_per_image.gt_boxes = gt_boxes
proposals_per_image.gt_object_descriptions = gt_object_descriptions
proposals_per_image.foreground = foreground
num_fg_samples.append((proposal_labels == 1).sum().item())
num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1])
# Log the number of fg/bg samples in each stage
storage = get_event_storage()
storage.put_scalar(
"stage{}/roi_head/num_fg_samples".format(stage),
sum(num_fg_samples) / len(num_fg_samples),
)
storage.put_scalar(
"stage{}/roi_head/num_bg_samples".format(stage),
sum(num_bg_samples) / len(num_bg_samples),
)
return proposals
def fast_rcnn_inference_GRiT(
self,
boxes: List[torch.Tensor],
scores: List[torch.Tensor],
logits: List[torch.Tensor],
image_shapes: List[Tuple[int, int]],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled: bool,
):
result_per_image = [
self.fast_rcnn_inference_single_image_GRiT(
boxes_per_image, scores_per_image, logits_per_image, image_shape,
score_thresh, nms_thresh, topk_per_image, soft_nms_enabled
)
for scores_per_image, boxes_per_image, image_shape, logits_per_image \
in zip(scores, boxes, image_shapes, logits)
]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_inference_single_image_GRiT(
self,
boxes,
scores,
logits,
image_shape: Tuple[int, int],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
soft_nms_enabled,
):
"""
Add soft NMS to detectron2's fast_rcnn_inference_single_image
"""
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
boxes = boxes[valid_mask]
scores = scores[valid_mask]
logits = logits[valid_mask]
scores = scores[:, :-1]
logits = logits[:, :-1]
num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
filter_mask = scores > score_thresh # R x K
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask]
scores = scores[filter_mask]
logits = logits[filter_mask]
# 2. Apply NMS for each class independently.
if not soft_nms_enabled:
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
else:
keep, soft_nms_scores = batched_soft_nms(
boxes,
scores,
filter_inds[:, 1],
"linear",
0.5,
nms_thresh,
0.001,
)
scores[keep] = soft_nms_scores
if topk_per_image >= 0:
keep = keep[:topk_per_image]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0]
def _get_empty_mask_loss(self, device):
if self.mask_on:
return {'loss_mask': torch.zeros(
(1, ), device=device, dtype=torch.float32)[0]}
else:
return {}
def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
boxes = [Boxes(b.detach()) for b in boxes]
proposals = []
for boxes_per_image, image_size, logit in zip(
boxes, image_sizes, logits):
boxes_per_image.clip(image_size)
if self.training:
inds = boxes_per_image.nonempty()
boxes_per_image = boxes_per_image[inds]
logit = logit[inds]
prop = Instances(image_size)
prop.proposal_boxes = boxes_per_image
prop.objectness_logits = logit
proposals.append(prop)
return proposals
def _run_stage(self, features, proposals, stage):
pool_boxes = [x.proposal_boxes for x in proposals]
box_features = self.box_pooler(features, pool_boxes)
box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
box_features = self.box_head[stage](box_features)
return self.box_predictor[stage](box_features)