zyliu's picture
release iChatApp
0f90f73
raw
history blame
No virus
4.79 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/detic_fast_rcnn.py
import torch
from fvcore.nn import giou_loss, smooth_l1_loss
from torch import nn
from torch.nn import functional as F
import fvcore.nn.weight_init as weight_init
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats
__all__ = ["GRiTFastRCNNOutputLayers"]
class GRiTFastRCNNOutputLayers(FastRCNNOutputLayers):
@configurable
def __init__(
self,
input_shape: ShapeSpec,
**kwargs,
):
super().__init__(
input_shape=input_shape,
**kwargs,
)
input_size = input_shape.channels * \
(input_shape.width or 1) * (input_shape.height or 1)
self.bbox_pred = nn.Sequential(
nn.Linear(input_size, input_size),
nn.ReLU(inplace=True),
nn.Linear(input_size, 4)
)
weight_init.c2_xavier_fill(self.bbox_pred[0])
nn.init.normal_(self.bbox_pred[-1].weight, std=0.001)
nn.init.constant_(self.bbox_pred[-1].bias, 0)
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
return ret
def losses(self, predictions, proposals):
scores, proposal_deltas = predictions
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
)
num_classes = self.num_classes
_log_classification_stats(scores, gt_classes)
if len(proposals):
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
gt_boxes = cat(
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
dim=0,
)
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
return {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes,
num_classes=num_classes)
}
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
if pred_class_logits.numel() == 0:
return pred_class_logits.new_zeros([1])[0]
loss = F.cross_entropy(
pred_class_logits, gt_classes, reduction="mean")
return loss
def box_reg_loss(
self, proposal_boxes, gt_boxes, pred_deltas, gt_classes,
num_classes=-1):
num_classes = num_classes if num_classes > 0 else self.num_classes
box_dim = proposal_boxes.shape[1]
fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0]
if pred_deltas.shape[1] == box_dim:
fg_pred_deltas = pred_deltas[fg_inds]
else:
fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[
fg_inds, gt_classes[fg_inds]
]
if self.box_reg_loss_type == "smooth_l1":
gt_pred_deltas = self.box2box_transform.get_deltas(
proposal_boxes[fg_inds],
gt_boxes[fg_inds],
)
loss_box_reg = smooth_l1_loss(
fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum"
)
elif self.box_reg_loss_type == "giou":
fg_pred_boxes = self.box2box_transform.apply_deltas(
fg_pred_deltas, proposal_boxes[fg_inds]
)
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
else:
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
return loss_box_reg / max(gt_classes.numel(), 1.0)
def predict_probs(self, predictions, proposals):
scores = predictions[0]
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)
def forward(self, x):
if x.dim() > 2:
x = torch.flatten(x, start_dim=1)
scores = []
cls_scores = self.cls_score(x)
scores.append(cls_scores)
scores = torch.cat(scores, dim=1)
proposal_deltas = self.bbox_pred(x)
return scores, proposal_deltas