#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Implements the Generalized R-CNN framework """ import torch from torch import nn from maskrcnn_benchmark.structures.image_list import to_image_list from ..backbone import build_backbone from ..rpn.rpn import build_rpn from ..segmentation.segmentation import build_segmentation from ..roi_heads.roi_heads import build_roi_heads import time class GeneralizedRCNN(nn.Module): """ Main class for Generalized R-CNN. Currently supports boxes and masks. It consists of three main parts: - backbone = rpn - heads: takes the features + the proposals from the RPN and computes detections / masks from it. """ def __init__(self, cfg): super(GeneralizedRCNN, self).__init__() self.cfg = cfg self.backbone = build_backbone(cfg) if cfg.MODEL.SEG_ON: self.proposal = build_segmentation(cfg) else: self.proposal = build_rpn(cfg) if cfg.MODEL.TRAIN_DETECTION_ONLY: self.roi_heads = None else: self.roi_heads = build_roi_heads(cfg) def forward(self, images, targets=None): """ Arguments: images (list[Tensor] or ImageList): images to be processed targets (list[BoxList]): ground-truth boxes present in the image (optional) Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") # torch.cuda.synchronize() # start_time = time.time() images = to_image_list(images) # torch.cuda.synchronize() # end_time = time.time() # print('image load time:', end_time - start_time) # torch.cuda.synchronize() # start_time = time.time() features = self.backbone(images.tensors) # torch.cuda.synchronize() # end_time = time.time() # print('backbone time:', end_time - start_time) if self.cfg.MODEL.SEG_ON and not self.training: # torch.cuda.synchronize() # start_time = time.time() (proposals, seg_results), fuse_feature = self.proposal(images, features, targets) # torch.cuda.synchronize() # end_time = time.time() # print('seg time:', end_time - start_time) else: if self.cfg.MODEL.SEG_ON: (proposals, proposal_losses), fuse_feature = self.proposal(images, features, targets) else: proposals, proposal_losses = self.proposal(images, features, targets) if self.roi_heads is not None: if self.cfg.MODEL.SEG_ON and self.cfg.MODEL.SEG.USE_FUSE_FEATURE: x, result, detector_losses = self.roi_heads(fuse_feature, proposals, targets) else: x, result, detector_losses = self.roi_heads(features, proposals, targets) else: # RPN-only models don't have roi_heads # x = features result = proposals detector_losses = {} if self.training: losses = {} if self.roi_heads is not None: losses.update(detector_losses) losses.update(proposal_losses) return losses else: if self.cfg.MODEL.SEG_ON: return result, proposals, seg_results else: return result # return result