Robert001's picture
first commit
b334e29
raw
history blame
No virus
1.36 kB
import torch
from .builder import ANCHOR_GENERATORS
@ANCHOR_GENERATORS.register_module()
class PointGenerator(object):
def _meshgrid(self, x, y, row_major=True):
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
if row_major:
return xx, yy
else:
return yy, xx
def grid_points(self, featmap_size, stride=16, device='cuda'):
feat_h, feat_w = featmap_size
shift_x = torch.arange(0., feat_w, device=device) * stride
shift_y = torch.arange(0., feat_h, device=device) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
stride = shift_x.new_full((shift_xx.shape[0], ), stride)
shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
all_points = shifts.to(device)
return all_points
def valid_flags(self, featmap_size, valid_size, device='cuda'):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid