File size: 1,150 Bytes
29a229f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) Facebook, Inc. and its affiliates.
import torch

from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY
from detectron2.modeling.proposal_generator.rpn import RPN
from detectron2.structures import ImageList


@PROPOSAL_GENERATOR_REGISTRY.register()
class TridentRPN(RPN):
    """
    Trident RPN subnetwork.
    """

    def __init__(self, cfg, input_shape):
        super(TridentRPN, self).__init__(cfg, input_shape)

        self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
        self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1

    def forward(self, images, features, gt_instances=None):
        """
        See :class:`RPN.forward`.
        """
        num_branch = self.num_branch if self.training or not self.trident_fast else 1
        # Duplicate images and gt_instances for all branches in TridentNet.
        all_images = ImageList(
            torch.cat([images.tensor] * num_branch), images.image_sizes * num_branch
        )
        all_gt_instances = gt_instances * num_branch if gt_instances is not None else None

        return super(TridentRPN, self).forward(all_images, features, all_gt_instances)