zdou0830's picture
desco
749745d
raw
history blame
No virus
5.08 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""
import torch
from torch import nn
from maskrcnn_benchmark.structures.image_list import to_image_list
from ..backbone import build_backbone
from ..rpn import build_rpn
from ..roi_heads import build_roi_heads
import timeit
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN. Currently supports boxes and masks.
It consists of three main parts:
- backbone
- rpn
- heads: takes the features + the proposals from the RPN and computes
detections / masks from it.
"""
def __init__(self, cfg):
super(GeneralizedRCNN, self).__init__()
self.backbone = build_backbone(cfg)
self.rpn = build_rpn(cfg)
self.roi_heads = build_roi_heads(cfg)
self.DEBUG = cfg.MODEL.DEBUG
self.ONNX = cfg.MODEL.ONNX
self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
self.freeze_fpn = cfg.MODEL.FPN.FREEZE
self.freeze_rpn = cfg.MODEL.RPN.FREEZE
if cfg.MODEL.LINEAR_PROB:
assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
if hasattr(self.backbone, "fpn"):
assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
self.linear_prob = cfg.MODEL.LINEAR_PROB
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(GeneralizedRCNN, self).train(mode)
if self.freeze_backbone:
self.backbone.body.eval()
for p in self.backbone.body.parameters():
p.requires_grad = False
if self.freeze_fpn:
self.backbone.fpn.eval()
for p in self.backbone.fpn.parameters():
p.requires_grad = False
if self.freeze_rpn:
self.rpn.eval()
for p in self.rpn.parameters():
p.requires_grad = False
if self.linear_prob:
if self.rpn is not None:
for key, value in self.rpn.named_parameters():
if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
value.requires_grad = False
if self.roi_heads is not None:
for key, value in self.roi_heads.named_parameters():
if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
value.requires_grad = False
def forward(self, images, targets=None):
"""
Arguments:
images (list[Tensor] or ImageList): images to be processed
targets (list[BoxList]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.DEBUG:
debug_info = {}
if self.DEBUG:
debug_info["input_size"] = images[0].size()
if self.DEBUG:
tic = timeit.time.perf_counter()
if self.ONNX:
features = self.backbone(images)
else:
images = to_image_list(images)
features = self.backbone(images.tensors)
if self.DEBUG:
debug_info["feat_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["feat_size"] = [feat.size() for feat in features]
if self.DEBUG:
tic = timeit.time.perf_counter()
proposals, proposal_losses = self.rpn(images, features, targets)
if self.DEBUG:
debug_info["rpn_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["#rpn"] = [prop for prop in proposals]
if self.DEBUG:
tic = timeit.time.perf_counter()
if self.roi_heads:
x, result, detector_losses = self.roi_heads(features, proposals, targets)
else:
# RPN-only models don't have roi_heads
x = features
result = proposals
detector_losses = {}
if self.DEBUG:
debug_info["rcnn_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["#rcnn"] = result
if self.DEBUG:
return result, debug_info
if self.training:
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
return result