| | |
| |
|
| | import contextlib |
| | from unittest import mock |
| | import torch |
| |
|
| | from detectron2.modeling import poolers |
| | from detectron2.modeling.proposal_generator import rpn |
| | from detectron2.modeling.roi_heads import keypoint_head, mask_head |
| | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers |
| |
|
| | from .c10 import ( |
| | Caffe2Compatible, |
| | Caffe2FastRCNNOutputsInference, |
| | Caffe2KeypointRCNNInference, |
| | Caffe2MaskRCNNInference, |
| | Caffe2ROIPooler, |
| | Caffe2RPN, |
| | caffe2_fast_rcnn_outputs_inference, |
| | caffe2_keypoint_rcnn_inference, |
| | caffe2_mask_rcnn_inference, |
| | ) |
| |
|
| |
|
| | class GenericMixin: |
| | pass |
| |
|
| |
|
| | class Caffe2CompatibleConverter: |
| | """ |
| | A GenericUpdater which implements the `create_from` interface, by modifying |
| | module object and assign it with another class replaceCls. |
| | """ |
| |
|
| | def __init__(self, replaceCls): |
| | self.replaceCls = replaceCls |
| |
|
| | def create_from(self, module): |
| | |
| | assert isinstance(module, torch.nn.Module) |
| | if issubclass(self.replaceCls, GenericMixin): |
| | |
| | new_class = type( |
| | "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), |
| | (self.replaceCls, module.__class__), |
| | {}, |
| | ) |
| | module.__class__ = new_class |
| | else: |
| | |
| | module.__class__ = self.replaceCls |
| |
|
| | |
| | if isinstance(module, Caffe2Compatible): |
| | module.tensor_mode = False |
| |
|
| | return module |
| |
|
| |
|
| | def patch(model, target, updater, *args, **kwargs): |
| | """ |
| | recursively (post-order) update all modules with the target type and its |
| | subclasses, make a initialization/composition/inheritance/... via the |
| | updater.create_from. |
| | """ |
| | for name, module in model.named_children(): |
| | model._modules[name] = patch(module, target, updater, *args, **kwargs) |
| | if isinstance(model, target): |
| | return updater.create_from(model, *args, **kwargs) |
| | return model |
| |
|
| |
|
| | def patch_generalized_rcnn(model): |
| | ccc = Caffe2CompatibleConverter |
| | model = patch(model, rpn.RPN, ccc(Caffe2RPN)) |
| | model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) |
| |
|
| | return model |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def mock_fastrcnn_outputs_inference( |
| | tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers |
| | ): |
| | with mock.patch.object( |
| | box_predictor_type, |
| | "inference", |
| | autospec=True, |
| | side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), |
| | ) as mocked_func: |
| | yield |
| | if check: |
| | assert mocked_func.call_count > 0 |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): |
| | with mock.patch( |
| | "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() |
| | ) as mocked_func: |
| | yield |
| | if check: |
| | assert mocked_func.call_count > 0 |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): |
| | with mock.patch( |
| | "{}.keypoint_rcnn_inference".format(patched_module), |
| | side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), |
| | ) as mocked_func: |
| | yield |
| | if check: |
| | assert mocked_func.call_count > 0 |
| |
|
| |
|
| | class ROIHeadsPatcher: |
| | def __init__(self, heads, use_heatmap_max_keypoint): |
| | self.heads = heads |
| | self.use_heatmap_max_keypoint = use_heatmap_max_keypoint |
| | self.previous_patched = {} |
| |
|
| | @contextlib.contextmanager |
| | def mock_roi_heads(self, tensor_mode=True): |
| | """ |
| | Patching several inference functions inside ROIHeads and its subclasses |
| | |
| | Args: |
| | tensor_mode (bool): whether the inputs/outputs are caffe2's tensor |
| | format or not. Default to True. |
| | """ |
| | |
| | |
| | kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ |
| | mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ |
| |
|
| | mock_ctx_managers = [ |
| | mock_fastrcnn_outputs_inference( |
| | tensor_mode=tensor_mode, |
| | check=True, |
| | box_predictor_type=type(self.heads.box_predictor), |
| | ) |
| | ] |
| | if getattr(self.heads, "keypoint_on", False): |
| | mock_ctx_managers += [ |
| | mock_keypoint_rcnn_inference( |
| | tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint |
| | ) |
| | ] |
| | if getattr(self.heads, "mask_on", False): |
| | mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] |
| |
|
| | with contextlib.ExitStack() as stack: |
| | for mgr in mock_ctx_managers: |
| | stack.enter_context(mgr) |
| | yield |
| |
|
| | def patch_roi_heads(self, tensor_mode=True): |
| | self.previous_patched["box_predictor"] = self.heads.box_predictor.inference |
| | self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference |
| | self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference |
| |
|
| | def patched_fastrcnn_outputs_inference(predictions, proposal): |
| | return caffe2_fast_rcnn_outputs_inference( |
| | True, self.heads.box_predictor, predictions, proposal |
| | ) |
| |
|
| | self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference |
| |
|
| | if getattr(self.heads, "keypoint_on", False): |
| |
|
| | def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): |
| | return caffe2_keypoint_rcnn_inference( |
| | self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances |
| | ) |
| |
|
| | keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference |
| |
|
| | if getattr(self.heads, "mask_on", False): |
| |
|
| | def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): |
| | return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) |
| |
|
| | mask_head.mask_rcnn_inference = patched_mask_rcnn_inference |
| |
|
| | def unpatch_roi_heads(self): |
| | self.heads.box_predictor.inference = self.previous_patched["box_predictor"] |
| | keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] |
| | mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] |
| |
|