| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import contextlib | 
					
					
						
						| 
							 | 
						from unittest import mock | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from detectron2.modeling import poolers | 
					
					
						
						| 
							 | 
						from detectron2.modeling.proposal_generator import rpn | 
					
					
						
						| 
							 | 
						from detectron2.modeling.roi_heads import keypoint_head, mask_head | 
					
					
						
						| 
							 | 
						from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .c10 import ( | 
					
					
						
						| 
							 | 
						    Caffe2Compatible, | 
					
					
						
						| 
							 | 
						    Caffe2FastRCNNOutputsInference, | 
					
					
						
						| 
							 | 
						    Caffe2KeypointRCNNInference, | 
					
					
						
						| 
							 | 
						    Caffe2MaskRCNNInference, | 
					
					
						
						| 
							 | 
						    Caffe2ROIPooler, | 
					
					
						
						| 
							 | 
						    Caffe2RPN, | 
					
					
						
						| 
							 | 
						    caffe2_fast_rcnn_outputs_inference, | 
					
					
						
						| 
							 | 
						    caffe2_keypoint_rcnn_inference, | 
					
					
						
						| 
							 | 
						    caffe2_mask_rcnn_inference, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GenericMixin: | 
					
					
						
						| 
							 | 
						    pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Caffe2CompatibleConverter: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    A GenericUpdater which implements the `create_from` interface, by modifying | 
					
					
						
						| 
							 | 
						    module object and assign it with another class replaceCls. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, replaceCls): | 
					
					
						
						| 
							 | 
						        self.replaceCls = replaceCls | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def create_from(self, module): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        assert isinstance(module, torch.nn.Module) | 
					
					
						
						| 
							 | 
						        if issubclass(self.replaceCls, GenericMixin): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            new_class = type( | 
					
					
						
						| 
							 | 
						                "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), | 
					
					
						
						| 
							 | 
						                (self.replaceCls, module.__class__), | 
					
					
						
						| 
							 | 
						                {},   | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            module.__class__ = new_class | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            module.__class__ = self.replaceCls | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if isinstance(module, Caffe2Compatible): | 
					
					
						
						| 
							 | 
						            module.tensor_mode = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return module | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def patch(model, target, updater, *args, **kwargs): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    recursively (post-order) update all modules with the target type and its | 
					
					
						
						| 
							 | 
						    subclasses, make a initialization/composition/inheritance/... via the | 
					
					
						
						| 
							 | 
						    updater.create_from. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    for name, module in model.named_children(): | 
					
					
						
						| 
							 | 
						        model._modules[name] = patch(module, target, updater, *args, **kwargs) | 
					
					
						
						| 
							 | 
						    if isinstance(model, target): | 
					
					
						
						| 
							 | 
						        return updater.create_from(model, *args, **kwargs) | 
					
					
						
						| 
							 | 
						    return model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def patch_generalized_rcnn(model): | 
					
					
						
						| 
							 | 
						    ccc = Caffe2CompatibleConverter | 
					
					
						
						| 
							 | 
						    model = patch(model, rpn.RPN, ccc(Caffe2RPN)) | 
					
					
						
						| 
							 | 
						    model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return model | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@contextlib.contextmanager | 
					
					
						
						| 
							 | 
						def mock_fastrcnn_outputs_inference( | 
					
					
						
						| 
							 | 
						    tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    with mock.patch.object( | 
					
					
						
						| 
							 | 
						        box_predictor_type, | 
					
					
						
						| 
							 | 
						        "inference", | 
					
					
						
						| 
							 | 
						        autospec=True, | 
					
					
						
						| 
							 | 
						        side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), | 
					
					
						
						| 
							 | 
						    ) as mocked_func: | 
					
					
						
						| 
							 | 
						        yield | 
					
					
						
						| 
							 | 
						    if check: | 
					
					
						
						| 
							 | 
						        assert mocked_func.call_count > 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@contextlib.contextmanager | 
					
					
						
						| 
							 | 
						def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): | 
					
					
						
						| 
							 | 
						    with mock.patch( | 
					
					
						
						| 
							 | 
						        "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() | 
					
					
						
						| 
							 | 
						    ) as mocked_func: | 
					
					
						
						| 
							 | 
						        yield | 
					
					
						
						| 
							 | 
						    if check: | 
					
					
						
						| 
							 | 
						        assert mocked_func.call_count > 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@contextlib.contextmanager | 
					
					
						
						| 
							 | 
						def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): | 
					
					
						
						| 
							 | 
						    with mock.patch( | 
					
					
						
						| 
							 | 
						        "{}.keypoint_rcnn_inference".format(patched_module), | 
					
					
						
						| 
							 | 
						        side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), | 
					
					
						
						| 
							 | 
						    ) as mocked_func: | 
					
					
						
						| 
							 | 
						        yield | 
					
					
						
						| 
							 | 
						    if check: | 
					
					
						
						| 
							 | 
						        assert mocked_func.call_count > 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class ROIHeadsPatcher: | 
					
					
						
						| 
							 | 
						    def __init__(self, heads, use_heatmap_max_keypoint): | 
					
					
						
						| 
							 | 
						        self.heads = heads | 
					
					
						
						| 
							 | 
						        self.use_heatmap_max_keypoint = use_heatmap_max_keypoint | 
					
					
						
						| 
							 | 
						        self.previous_patched = {} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @contextlib.contextmanager | 
					
					
						
						| 
							 | 
						    def mock_roi_heads(self, tensor_mode=True): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Patching several inference functions inside ROIHeads and its subclasses | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            tensor_mode (bool): whether the inputs/outputs are caffe2's tensor | 
					
					
						
						| 
							 | 
						                format or not. Default to True. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ | 
					
					
						
						| 
							 | 
						        mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        mock_ctx_managers = [ | 
					
					
						
						| 
							 | 
						            mock_fastrcnn_outputs_inference( | 
					
					
						
						| 
							 | 
						                tensor_mode=tensor_mode, | 
					
					
						
						| 
							 | 
						                check=True, | 
					
					
						
						| 
							 | 
						                box_predictor_type=type(self.heads.box_predictor), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        if getattr(self.heads, "keypoint_on", False): | 
					
					
						
						| 
							 | 
						            mock_ctx_managers += [ | 
					
					
						
						| 
							 | 
						                mock_keypoint_rcnn_inference( | 
					
					
						
						| 
							 | 
						                    tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						        if getattr(self.heads, "mask_on", False): | 
					
					
						
						| 
							 | 
						            mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with contextlib.ExitStack() as stack:   | 
					
					
						
						| 
							 | 
						            for mgr in mock_ctx_managers: | 
					
					
						
						| 
							 | 
						                stack.enter_context(mgr) | 
					
					
						
						| 
							 | 
						            yield | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def patch_roi_heads(self, tensor_mode=True): | 
					
					
						
						| 
							 | 
						        self.previous_patched["box_predictor"] = self.heads.box_predictor.inference | 
					
					
						
						| 
							 | 
						        self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference | 
					
					
						
						| 
							 | 
						        self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def patched_fastrcnn_outputs_inference(predictions, proposal): | 
					
					
						
						| 
							 | 
						            return caffe2_fast_rcnn_outputs_inference( | 
					
					
						
						| 
							 | 
						                True, self.heads.box_predictor, predictions, proposal | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if getattr(self.heads, "keypoint_on", False): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): | 
					
					
						
						| 
							 | 
						                return caffe2_keypoint_rcnn_inference( | 
					
					
						
						| 
							 | 
						                    self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if getattr(self.heads, "mask_on", False): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): | 
					
					
						
						| 
							 | 
						                return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            mask_head.mask_rcnn_inference = patched_mask_rcnn_inference | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def unpatch_roi_heads(self): | 
					
					
						
						| 
							 | 
						        self.heads.box_predictor.inference = self.previous_patched["box_predictor"] | 
					
					
						
						| 
							 | 
						        keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] | 
					
					
						
						| 
							 | 
						        mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] | 
					
					
						
						| 
							 | 
						
 |