File size: 6,271 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from kornia.color import rgb_to_grayscale
from kornia.feature import LocalFeatureMatcher, LoFTR
from kornia.geometry.homography import find_homography_dlt_iterated
from kornia.geometry.ransac import RANSAC
from kornia.geometry.transform import warp_perspective
class ImageStitcher(nn.Module):
"""Stitch two images with overlapping fields of view.
Args:
matcher: image feature matching module.
estimator: method to compute homography, either "vanilla" or "ransac".
"ransac" is slower with a better accuracy.
blending_method: method to blend two images together.
Only "naive" is currently supported.
Note:
Current implementation requires strict image ordering from left to right.
.. code-block:: python
IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac').cuda()
# Compute the stitched result with less GPU memory cost.
with torch.inference_mode():
out = IS(img_left, img_right)
# Show the result
plt.imshow(K.tensor_to_image(out))
"""
def __init__(
self, matcher: nn.Module, estimator: str = 'ransac', blending_method: str = "naive",
) -> None:
super().__init__()
self.matcher = matcher
self.estimator = estimator
self.blending_method = blending_method
if estimator not in ['ransac', 'vanilla']:
raise NotImplementedError(f"Unsupported estimator {estimator}. Use ‘ransac’ or ‘vanilla’ instead.")
if estimator == "ransac":
self.ransac = RANSAC('homography')
def _estimate_homography(self, keypoints1: torch.Tensor, keypoints2: torch.Tensor) -> torch.Tensor:
"""Estimate homography by the matched keypoints.
Args:
keypoints1: matched keypoint set from an image, shaped as :math:`(N, 2)`.
keypoints2: matched keypoint set from the other image, shaped as :math:`(N, 2)`.
"""
homo: torch.Tensor
if self.estimator == "vanilla":
homo = find_homography_dlt_iterated(
keypoints2[None],
keypoints1[None],
torch.ones_like(keypoints1[None, :, 0])
)
elif self.estimator == "ransac":
homo, _ = self.ransac(keypoints2, keypoints1)
homo = homo[None]
else:
raise NotImplementedError(f"Unsupported estimator {self.estimator}. Use ‘ransac’ or ‘vanilla’ instead.")
return homo
def estimate_transform(self, **kwargs) -> torch.Tensor:
"""Compute the corresponding homography."""
homos: List[torch.Tensor] = []
kp1, kp2, idx = kwargs['keypoints0'], kwargs['keypoints1'], kwargs['batch_indexes']
for i in range(len(idx.unique())):
homos.append(self._estimate_homography(kp1[idx == i], kp2[idx == i]))
if len(homos) == 0:
raise RuntimeError("Compute homography failed. No matched keypoints found.")
return torch.cat(homos)
def blend_image(self, src_img: torch.Tensor, dst_img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Blend two images together."""
out: torch.Tensor
if self.blending_method == "naive":
out = torch.where(mask == 1, src_img, dst_img)
else:
raise NotImplementedError(f"Unsupported blending method {self.blending_method}. Use ‘naive’.")
return out
def preprocess(self, image_1: torch.Tensor, image_2: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Preprocess input to the required format."""
# TODO: probably perform histogram matching here.
if isinstance(self.matcher, LoFTR) or isinstance(self.matcher, LocalFeatureMatcher):
input_dict: Dict[str, torch.Tensor] = { # LofTR works on grayscale images only
"image0": rgb_to_grayscale(image_1),
"image1": rgb_to_grayscale(image_2)
}
else:
raise NotImplementedError(f"The preprocessor for {self.matcher} has not been implemented.")
return input_dict
def postprocess(self, image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# NOTE: assumes no batch mode. This method keeps all valid regions after stitching.
mask_: torch.Tensor = mask.sum((0, 1))
index: int = int(mask_.bool().any(0).long().argmin().item())
if index == 0: # If no redundant space
return image
return image[..., :index]
def on_matcher(self, data) -> dict:
return self.matcher(data)
def stitch_pair(
self, images_left: torch.Tensor, images_right: torch.Tensor,
mask_left: Optional[torch.Tensor] = None, mask_right: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the transformed images
input_dict: Dict[str, torch.Tensor] = self.preprocess(images_left, images_right)
out_shape: Tuple[int, int] = (images_left.shape[-2], images_left.shape[-1] + images_right.shape[-1])
correspondences: dict = self.on_matcher(input_dict)
homo: torch.Tensor = self.estimate_transform(**correspondences)
src_img = warp_perspective(images_right, homo, out_shape)
dst_img = torch.cat([images_left, torch.zeros_like(images_right)], dim=-1)
# Compute the transformed masks
if mask_left is None:
mask_left = torch.ones_like(images_left)
if mask_right is None:
mask_right = torch.ones_like(images_right)
# 'nearest' to ensure no floating points in the mask
src_mask = warp_perspective(mask_right, homo, out_shape, mode='nearest')
dst_mask = torch.cat([mask_left, torch.zeros_like(mask_right)], dim=-1)
return self.blend_image(src_img, dst_img, src_mask), (dst_mask + src_mask).bool().to(src_mask.dtype)
def forward(self, *imgs: torch.Tensor) -> torch.Tensor:
img_out = imgs[0]
mask_left = torch.ones_like(img_out)
for i in range(len(imgs) - 1):
img_out, mask_left = self.stitch_pair(img_out, imgs[i + 1], mask_left)
return self.postprocess(img_out, mask_left)
|