Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
827 Bytes
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : prroi_pool.py
# Author : Jiayuan Mao, Tete Xiao
# Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
# Date : 07/13/2018
#
# This file is part of PreciseRoIPooling.
# Distributed under terms of the MIT license.
# Copyright (c) 2017 Megvii Technology Limited.
import torch.nn as nn
from .functional import prroi_pool2d
__all__ = ['PrRoIPool2D']
class PrRoIPool2D(nn.Module):
def __init__(self, pooled_height, pooled_width, spatial_scale):
super().__init__()
self.pooled_height = int(pooled_height)
self.pooled_width = int(pooled_width)
self.spatial_scale = float(spatial_scale)
def forward(self, features, rois):
return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale)