|
|
|
import torch |
|
from torch import nn |
|
from torch.autograd import Function |
|
|
|
from ..utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext('_ext', [ |
|
'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward', |
|
'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward', |
|
'right_pool_forward', 'right_pool_backward' |
|
]) |
|
|
|
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} |
|
|
|
|
|
class TopPoolFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input): |
|
output = g.op( |
|
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) |
|
return output |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
output = ext_module.top_pool_forward(input) |
|
ctx.save_for_backward(input) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, = ctx.saved_tensors |
|
output = ext_module.top_pool_backward(input, grad_output) |
|
return output |
|
|
|
|
|
class BottomPoolFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input): |
|
output = g.op( |
|
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) |
|
return output |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
output = ext_module.bottom_pool_forward(input) |
|
ctx.save_for_backward(input) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, = ctx.saved_tensors |
|
output = ext_module.bottom_pool_backward(input, grad_output) |
|
return output |
|
|
|
|
|
class LeftPoolFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input): |
|
output = g.op( |
|
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) |
|
return output |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
output = ext_module.left_pool_forward(input) |
|
ctx.save_for_backward(input) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, = ctx.saved_tensors |
|
output = ext_module.left_pool_backward(input, grad_output) |
|
return output |
|
|
|
|
|
class RightPoolFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input): |
|
output = g.op( |
|
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) |
|
return output |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
output = ext_module.right_pool_forward(input) |
|
ctx.save_for_backward(input) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, = ctx.saved_tensors |
|
output = ext_module.right_pool_backward(input, grad_output) |
|
return output |
|
|
|
|
|
class CornerPool(nn.Module): |
|
"""Corner Pooling. |
|
|
|
Corner Pooling is a new type of pooling layer that helps a |
|
convolutional network better localize corners of bounding boxes. |
|
|
|
Please refer to https://arxiv.org/abs/1808.01244 for more details. |
|
Code is modified from https://github.com/princeton-vl/CornerNet-Lite. |
|
|
|
Args: |
|
mode(str): Pooling orientation for the pooling layer |
|
|
|
- 'bottom': Bottom Pooling |
|
- 'left': Left Pooling |
|
- 'right': Right Pooling |
|
- 'top': Top Pooling |
|
|
|
Returns: |
|
Feature map after pooling. |
|
""" |
|
|
|
pool_functions = { |
|
'bottom': BottomPoolFunction, |
|
'left': LeftPoolFunction, |
|
'right': RightPoolFunction, |
|
'top': TopPoolFunction, |
|
} |
|
|
|
cummax_dim_flip = { |
|
'bottom': (2, False), |
|
'left': (3, True), |
|
'right': (3, False), |
|
'top': (2, True), |
|
} |
|
|
|
def __init__(self, mode): |
|
super(CornerPool, self).__init__() |
|
assert mode in self.pool_functions |
|
self.mode = mode |
|
self.corner_pool = self.pool_functions[mode] |
|
|
|
def forward(self, x): |
|
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': |
|
if torch.onnx.is_in_onnx_export(): |
|
assert torch.__version__ >= '1.7.0', \ |
|
'When `cummax` serves as an intermediate component whose '\ |
|
'outputs is used as inputs for another modules, it\'s '\ |
|
'expected that pytorch version must be >= 1.7.0, '\ |
|
'otherwise Error appears like: `RuntimeError: tuple '\ |
|
'appears in op that does not forward tuples, unsupported '\ |
|
'kind: prim::PythonOp`.' |
|
|
|
dim, flip = self.cummax_dim_flip[self.mode] |
|
if flip: |
|
x = x.flip(dim) |
|
pool_tensor, _ = torch.cummax(x, dim=dim) |
|
if flip: |
|
pool_tensor = pool_tensor.flip(dim) |
|
return pool_tensor |
|
else: |
|
return self.corner_pool.apply(x) |
|
|