File size: 5,075 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""

Implements the Generalized R-CNN framework

"""

import torch
from torch import nn

from maskrcnn_benchmark.structures.image_list import to_image_list

from ..backbone import build_backbone
from ..rpn import build_rpn
from ..roi_heads import build_roi_heads

import timeit


class GeneralizedRCNN(nn.Module):
    """

    Main class for Generalized R-CNN. Currently supports boxes and masks.

    It consists of three main parts:

    - backbone

    - rpn

    - heads: takes the features + the proposals from the RPN and computes

        detections / masks from it.

    """

    def __init__(self, cfg):
        super(GeneralizedRCNN, self).__init__()

        self.backbone = build_backbone(cfg)
        self.rpn = build_rpn(cfg)
        self.roi_heads = build_roi_heads(cfg)
        self.DEBUG = cfg.MODEL.DEBUG
        self.ONNX = cfg.MODEL.ONNX
        self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
        self.freeze_fpn = cfg.MODEL.FPN.FREEZE
        self.freeze_rpn = cfg.MODEL.RPN.FREEZE

        if cfg.MODEL.LINEAR_PROB:
            assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
            if hasattr(self.backbone, "fpn"):
                assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
        self.linear_prob = cfg.MODEL.LINEAR_PROB

    def train(self, mode=True):
        """Convert the model into training mode while keep layers freezed."""
        super(GeneralizedRCNN, self).train(mode)
        if self.freeze_backbone:
            self.backbone.body.eval()
            for p in self.backbone.body.parameters():
                p.requires_grad = False
        if self.freeze_fpn:
            self.backbone.fpn.eval()
            for p in self.backbone.fpn.parameters():
                p.requires_grad = False
        if self.freeze_rpn:
            self.rpn.eval()
            for p in self.rpn.parameters():
                p.requires_grad = False
        if self.linear_prob:
            if self.rpn is not None:
                for key, value in self.rpn.named_parameters():
                    if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
                        value.requires_grad = False
            if self.roi_heads is not None:
                for key, value in self.roi_heads.named_parameters():
                    if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key):
                        value.requires_grad = False

    def forward(self, images, targets=None):
        """

        Arguments:

            images (list[Tensor] or ImageList): images to be processed

            targets (list[BoxList]): ground-truth boxes present in the image (optional)



        Returns:

            result (list[BoxList] or dict[Tensor]): the output from the model.

                During training, it returns a dict[Tensor] which contains the losses.

                During testing, it returns list[BoxList] contains additional fields

                like `scores`, `labels` and `mask` (for Mask R-CNN models).



        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.DEBUG:
            debug_info = {}
        if self.DEBUG:
            debug_info["input_size"] = images[0].size()
        if self.DEBUG:
            tic = timeit.time.perf_counter()

        if self.ONNX:
            features = self.backbone(images)
        else:
            images = to_image_list(images)
            features = self.backbone(images.tensors)

        if self.DEBUG:
            debug_info["feat_time"] = timeit.time.perf_counter() - tic
        if self.DEBUG:
            debug_info["feat_size"] = [feat.size() for feat in features]
        if self.DEBUG:
            tic = timeit.time.perf_counter()

        proposals, proposal_losses = self.rpn(images, features, targets)

        if self.DEBUG:
            debug_info["rpn_time"] = timeit.time.perf_counter() - tic
        if self.DEBUG:
            debug_info["#rpn"] = [prop for prop in proposals]
        if self.DEBUG:
            tic = timeit.time.perf_counter()

        if self.roi_heads:
            x, result, detector_losses = self.roi_heads(features, proposals, targets)
        else:
            # RPN-only models don't have roi_heads
            x = features
            result = proposals
            detector_losses = {}

        if self.DEBUG:
            debug_info["rcnn_time"] = timeit.time.perf_counter() - tic
        if self.DEBUG:
            debug_info["#rcnn"] = result
        if self.DEBUG:
            return result, debug_info

        if self.training:
            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses

        return result