import torch def extract_patches( tensor: torch.Tensor, required_corners: torch.Tensor, ps: int, ) -> torch.Tensor: c, h, w = tensor.shape corner = required_corners.long() corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) offset = torch.arange(0, ps) kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} x, y = torch.meshgrid(offset, offset, **kw) patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) patches = patches.to(corner) + corner[None, None] pts = patches.reshape(-1, 2) sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] sampled = sampled.reshape(ps, ps, -1, c) assert sampled.shape[:3] == patches.shape[:3] return sampled.permute(2, 3, 0, 1), corner.float() def batch_extract_patches(tensor: torch.Tensor, kpts: torch.Tensor, ps: int): b, c, h, w = tensor.shape b, n, _ = kpts.shape out = torch.zeros((b, n, c, ps, ps), dtype=tensor.dtype, device=tensor.device) corners = torch.zeros((b, n, 2), dtype=tensor.dtype, device=tensor.device) for i in range(b): out[i], corners[i] = extract_patches(tensor[i], kpts[i] - ps / 2 - 1, ps) return out, corners def draw_image_patches(img, patches, corners): b, c, h, w = img.shape b, n, c, p, p = patches.shape b, n, _ = corners.shape for i in range(b): for k in range(n): y, x = corners[i, k] img[i, :, x : x + p, y : y + p] = patches[i, k] def build_heatmap(img, patches, corners): hmap = torch.zeros_like(img) draw_image_patches(hmap, patches, corners.long()) hmap = hmap.squeeze(1) return hmap, (hmap > 0.0).float() # bxhxw