import torch from .builder import ANCHOR_GENERATORS @ANCHOR_GENERATORS.register_module() class PointGenerator(object): def _meshgrid(self, x, y, row_major=True): xx = x.repeat(len(y)) yy = y.view(-1, 1).repeat(1, len(x)).view(-1) if row_major: return xx, yy else: return yy, xx def grid_points(self, featmap_size, stride=16, device='cuda'): feat_h, feat_w = featmap_size shift_x = torch.arange(0., feat_w, device=device) * stride shift_y = torch.arange(0., feat_h, device=device) * stride shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) stride = shift_x.new_full((shift_xx.shape[0], ), stride) shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) all_points = shifts.to(device) return all_points def valid_flags(self, featmap_size, valid_size, device='cuda'): feat_h, feat_w = featmap_size valid_h, valid_w = valid_size assert valid_h <= feat_h and valid_w <= feat_w valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) valid_x[:valid_w] = 1 valid_y[:valid_h] = 1 valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid = valid_xx & valid_yy return valid