Spaces:
Sleeping
Sleeping
File size: 5,075 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""
import torch
from torch import nn
from maskrcnn_benchmark.structures.image_list import to_image_list
from ..backbone import build_backbone
from ..rpn import build_rpn
from ..roi_heads import build_roi_heads
import timeit
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN. Currently supports boxes and masks.
It consists of three main parts:
- backbone
- rpn
- heads: takes the features + the proposals from the RPN and computes
detections / masks from it.
"""
def __init__(self, cfg):
super(GeneralizedRCNN, self).__init__()
self.backbone = build_backbone(cfg)
self.rpn = build_rpn(cfg)
self.roi_heads = build_roi_heads(cfg)
self.DEBUG = cfg.MODEL.DEBUG
self.ONNX = cfg.MODEL.ONNX
self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
self.freeze_fpn = cfg.MODEL.FPN.FREEZE
self.freeze_rpn = cfg.MODEL.RPN.FREEZE
if cfg.MODEL.LINEAR_PROB:
assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
if hasattr(self.backbone, "fpn"):
assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
self.linear_prob = cfg.MODEL.LINEAR_PROB
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(GeneralizedRCNN, self).train(mode)
if self.freeze_backbone:
self.backbone.body.eval()
for p in self.backbone.body.parameters():
p.requires_grad = False
if self.freeze_fpn:
self.backbone.fpn.eval()
for p in self.backbone.fpn.parameters():
p.requires_grad = False
if self.freeze_rpn:
self.rpn.eval()
for p in self.rpn.parameters():
p.requires_grad = False
if self.linear_prob:
if self.rpn is not None:
for key, value in self.rpn.named_parameters():
if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
value.requires_grad = False
if self.roi_heads is not None:
for key, value in self.roi_heads.named_parameters():
if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
value.requires_grad = False
def forward(self, images, targets=None):
"""
Arguments:
images (list[Tensor] or ImageList): images to be processed
targets (list[BoxList]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.DEBUG:
debug_info = {}
if self.DEBUG:
debug_info["input_size"] = images[0].size()
if self.DEBUG:
tic = timeit.time.perf_counter()
if self.ONNX:
features = self.backbone(images)
else:
images = to_image_list(images)
features = self.backbone(images.tensors)
if self.DEBUG:
debug_info["feat_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["feat_size"] = [feat.size() for feat in features]
if self.DEBUG:
tic = timeit.time.perf_counter()
proposals, proposal_losses = self.rpn(images, features, targets)
if self.DEBUG:
debug_info["rpn_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["#rpn"] = [prop for prop in proposals]
if self.DEBUG:
tic = timeit.time.perf_counter()
if self.roi_heads:
x, result, detector_losses = self.roi_heads(features, proposals, targets)
else:
# RPN-only models don't have roi_heads
x = features
result = proposals
detector_losses = {}
if self.DEBUG:
debug_info["rcnn_time"] = timeit.time.perf_counter() - tic
if self.DEBUG:
debug_info["#rcnn"] = result
if self.DEBUG:
return result, debug_info
if self.training:
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
return result
|