Spaces:
Build error
Build error
| import torch | |
| def extract_patches( | |
| tensor: torch.Tensor, | |
| required_corners: torch.Tensor, | |
| ps: int, | |
| ) -> torch.Tensor: | |
| c, h, w = tensor.shape | |
| corner = required_corners.long() | |
| corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) | |
| corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) | |
| offset = torch.arange(0, ps) | |
| kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} | |
| x, y = torch.meshgrid(offset, offset, **kw) | |
| patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) | |
| patches = patches.to(corner) + corner[None, None] | |
| pts = patches.reshape(-1, 2) | |
| sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] | |
| sampled = sampled.reshape(ps, ps, -1, c) | |
| assert sampled.shape[:3] == patches.shape[:3] | |
| return sampled.permute(2, 3, 0, 1), corner.float() | |
| def batch_extract_patches(tensor: torch.Tensor, kpts: torch.Tensor, ps: int): | |
| b, c, h, w = tensor.shape | |
| b, n, _ = kpts.shape | |
| out = torch.zeros((b, n, c, ps, ps), dtype=tensor.dtype, device=tensor.device) | |
| corners = torch.zeros((b, n, 2), dtype=tensor.dtype, device=tensor.device) | |
| for i in range(b): | |
| out[i], corners[i] = extract_patches(tensor[i], kpts[i] - ps / 2 - 1, ps) | |
| return out, corners | |
| def draw_image_patches(img, patches, corners): | |
| b, c, h, w = img.shape | |
| b, n, c, p, p = patches.shape | |
| b, n, _ = corners.shape | |
| for i in range(b): | |
| for k in range(n): | |
| y, x = corners[i, k] | |
| img[i, :, x : x + p, y : y + p] = patches[i, k] | |
| def build_heatmap(img, patches, corners): | |
| hmap = torch.zeros_like(img) | |
| draw_image_patches(hmap, patches, corners.long()) | |
| hmap = hmap.squeeze(1) | |
| return hmap, (hmap > 0.0).float() # bxhxw | |