Spaces:
Running
Running
Commit
·
e44f283
1
Parent(s):
42d9aa3
initial commit
Browse files- .gitattributes +1 -0
- LICENSE +31 -0
- app.py +268 -0
- packages.txt +4 -0
- requirements.txt +7 -0
- ripe/__init__.py +1 -0
- ripe/benchmarks/imw_2020.py +320 -0
- ripe/data/__init__.py +0 -0
- ripe/data/data_transforms.py +204 -0
- ripe/data/datasets/__init__.py +0 -0
- ripe/data/datasets/acdc.py +154 -0
- ripe/data/datasets/dataset_combinator.py +88 -0
- ripe/data/datasets/disk_imw.py +160 -0
- ripe/data/datasets/disk_megadepth.py +157 -0
- ripe/data/datasets/tokyo247.py +134 -0
- ripe/losses/__init__.py +0 -0
- ripe/losses/contrastive_loss.py +88 -0
- ripe/matcher/__init__.py +0 -0
- ripe/matcher/concurrent_matcher.py +97 -0
- ripe/matcher/pose_estimator_poselib.py +31 -0
- ripe/model_zoo/__init__.py +1 -0
- ripe/model_zoo/vgg_hyper.py +39 -0
- ripe/models/__init__.py +0 -0
- ripe/models/backbones/__init__.py +0 -0
- ripe/models/backbones/backbone_base.py +61 -0
- ripe/models/backbones/vgg.py +99 -0
- ripe/models/backbones/vgg_utils.py +143 -0
- ripe/models/ripe.py +303 -0
- ripe/models/upsampler/hypercolumn_features.py +54 -0
- ripe/models/upsampler/interpolate_sparse2d.py +37 -0
- ripe/scheduler/__init__.py +0 -0
- ripe/scheduler/constant.py +6 -0
- ripe/scheduler/expDecay.py +26 -0
- ripe/scheduler/linearLR.py +37 -0
- ripe/scheduler/linear_with_plateaus.py +44 -0
- ripe/train.py +410 -0
- ripe/utils/__init__.py +2 -0
- ripe/utils/image_utils.py +62 -0
- ripe/utils/pose_error.py +62 -0
- ripe/utils/pylogger.py +32 -0
- ripe/utils/utils.py +192 -0
- ripe/utils/wandb_utils.py +16 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Software Copyright License for Academic Use of RIPE, Version 2.0
|
| 2 |
+
|
| 3 |
+
© Copyright (2025) Fraunhofer-Gesellschaft zur Förderung der angewandten Forschung e.V.
|
| 4 |
+
|
| 5 |
+
1. INTRODUCTION
|
| 6 |
+
|
| 7 |
+
RIPE which means any source code, object code or binary files provided by Fraunhofer excluding third party software and materials, is made available under this Software Copyright License.
|
| 8 |
+
|
| 9 |
+
2. COPYRIGHT LICENSE
|
| 10 |
+
|
| 11 |
+
Internal use of RIPE, in source and binary forms, with or without modification, is permitted without payment of copyright license fees for non-commercial purposes of evaluation, testing and academic research.
|
| 12 |
+
|
| 13 |
+
No right or license, express or implied, is granted to any part of RIPE except and solely to the extent as expressly set forth herein. Any commercial use or exploitation of RIPE and/or any modifications thereto under this license are prohibited.
|
| 14 |
+
|
| 15 |
+
For any other use of RIPE than permitted by this software copyright license You need another license from Fraunhofer. In such case please contact Fraunhofer under the CONTACT INFORMATION below.
|
| 16 |
+
|
| 17 |
+
3. LIMITED PATENT LICENSE
|
| 18 |
+
|
| 19 |
+
If Fraunhofer patents are implemented by RIPE their use is permitted for internal non-commercial purposes of evaluation, testing and academic research. No patent grant is provided for any other use, including but not limited to commercial use or exploitation.
|
| 20 |
+
|
| 21 |
+
Fraunhofer provides no warranty of patent non-infringement with respect to RIPE.
|
| 22 |
+
|
| 23 |
+
4. DISCLAIMER
|
| 24 |
+
|
| 25 |
+
RIPE is provided by Fraunhofer "AS IS" and WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES, including but not limited to the implied warranties of fitness for a particular purpose. IN NO EVENT SHALL FRAUNHOFER BE LIABLE for any direct, indirect, incidental, special, exemplary, or consequential damages, including but not limited to procurement of substitute goods or services; loss of use, data, or profits, or business interruption, however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence), arising in any way out of the use of the Fraunhofer Software, even if advised of the possibility of such damage.
|
| 26 |
+
|
| 27 |
+
5. CONTACT INFORMATION
|
| 28 |
+
|
| 29 |
+
Fraunhofer-Institut für Nachrichtentechnik, Heinrich-Hertz-Institut, HHI
|
| 30 |
+
Einsteinufer 37, 10587 Berlin, Germany
|
| 31 |
+
info@hhi.fraunhofer.de
|
app.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is a small gradio interface to access our RIPE keypoint extractor.
|
| 2 |
+
# You can either upload two images or use one of the example image pairs.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from ripe import vgg_hyper
|
| 10 |
+
|
| 11 |
+
SEED = 32000
|
| 12 |
+
os.environ["PYTHONHASHSEED"] = str(SEED)
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
torch.manual_seed(SEED)
|
| 21 |
+
np.random.seed(SEED)
|
| 22 |
+
random.seed(SEED)
|
| 23 |
+
import cv2
|
| 24 |
+
import kornia.feature as KF
|
| 25 |
+
import kornia.geometry as KG
|
| 26 |
+
|
| 27 |
+
from ripe.utils.utils import cv2_matches_from_kornia, to_cv_kpts
|
| 28 |
+
|
| 29 |
+
MIN_SIZE = 512
|
| 30 |
+
MAX_SIZE = 768
|
| 31 |
+
|
| 32 |
+
description_text = """
|
| 33 |
+
<p align='center'>
|
| 34 |
+
<h1 align='center'>🌊🌺 ICCV 2025 🌺🌊</h1>
|
| 35 |
+
<p align='center'>
|
| 36 |
+
<a href='https://scholar.google.com/citations?user=ybMR38kAAAAJ'>Johannes Künzel</a> ·
|
| 37 |
+
<a href='https://scholar.google.com/citations?user=5yTuyGIAAAAJ'>Anna Hilsmann</a> ·
|
| 38 |
+
<a href='https://scholar.google.com/citations?user=BCElyCkAAAAJ'>Peter Eisert</a>
|
| 39 |
+
</p>
|
| 40 |
+
<h2 align='center'>
|
| 41 |
+
<a href='https://arxiv.org/abs/2507.04839'>Arxiv</a> |
|
| 42 |
+
<a href='https://fraunhoferhhi.github.io/RIPE/'>Project Page</a> |
|
| 43 |
+
<a href='https://github.com/fraunhoferhhi/RIPE'>Code</a>
|
| 44 |
+
</h2>
|
| 45 |
+
</p>
|
| 46 |
+
|
| 47 |
+
<br/>
|
| 48 |
+
<div align='center'>
|
| 49 |
+
|
| 50 |
+
### This demo showcases our new keypoint extractor model, RIPE (Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction).
|
| 51 |
+
|
| 52 |
+
### RIPE is trained without requiring pose or depth supervision or artificial augmentations. By leveraging reinforcement learning, it learns to extract keypoints solely based on whether an image pair depicts the same scene or not.
|
| 53 |
+
|
| 54 |
+
### For more detailed information, please refer to our [paper](link to be added).
|
| 55 |
+
|
| 56 |
+
The demo code extracts the 2048-top keypoints from the two input images. It uses the mutual nearest neighbor (MNN) descriptor matcher from kornia to find matches between the two images.
|
| 57 |
+
If the number of matches is greater than 8, it applies RANSAC to filter out outliers based on the inlier threshold provided by the user.
|
| 58 |
+
Images are resized to fit within a maximum size of 2048x2048 pixels with maintained aspect ratio.
|
| 59 |
+
|
| 60 |
+
</div>
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
model = vgg_hyper("./weights_ripe.pth")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_new_image_size(image, min_size=1600, max_size=2048):
|
| 67 |
+
"""
|
| 68 |
+
Get a new size for the image that is scaled to fit between min_size and max_size while maintaining the aspect ratio.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
image (PIL.Image): Input image.
|
| 72 |
+
min_size (int): Minimum allowed size for width and height.
|
| 73 |
+
max_size (int): Maximum allowed size for width and height.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
tuple: New size (width, height) for the image.
|
| 77 |
+
"""
|
| 78 |
+
width, height = image.size
|
| 79 |
+
|
| 80 |
+
aspect_ratio = width / height
|
| 81 |
+
if width > height:
|
| 82 |
+
new_width = max(min_size, min(max_size, width))
|
| 83 |
+
new_height = int(new_width / aspect_ratio)
|
| 84 |
+
else:
|
| 85 |
+
new_height = max(min_size, min(max_size, height))
|
| 86 |
+
new_width = int(new_height * aspect_ratio)
|
| 87 |
+
|
| 88 |
+
new_size = (new_width, new_height)
|
| 89 |
+
|
| 90 |
+
return new_size
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def extract_keypoints(image1, image2, inl_th):
|
| 94 |
+
"""
|
| 95 |
+
Extract keypoints from two input images using the RIPE model.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
image1 (PIL.Image): First input image.
|
| 99 |
+
image2 (PIL.Image): Second input image.
|
| 100 |
+
inl_th (float): RANSAC inlier threshold.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
dict: A dictionary containing keypoints and matches.
|
| 104 |
+
"""
|
| 105 |
+
log_text = "Extracting keypoints and matches with RIPE\n"
|
| 106 |
+
|
| 107 |
+
log_text += f"Image 1 size: {image1.size}\n"
|
| 108 |
+
log_text += f"Image 2 size: {image2.size}\n"
|
| 109 |
+
|
| 110 |
+
# check not larger than 2048x2048
|
| 111 |
+
new_size = get_new_image_size(image1, min_size=MIN_SIZE, max_size=MAX_SIZE)
|
| 112 |
+
image1 = image1.resize(new_size)
|
| 113 |
+
|
| 114 |
+
new_size = get_new_image_size(image2, min_size=MIN_SIZE, max_size=MAX_SIZE)
|
| 115 |
+
image2 = image2.resize(new_size)
|
| 116 |
+
|
| 117 |
+
log_text += f"Resized Image 1 size: {image1.size}\n"
|
| 118 |
+
log_text += f"Resized Image 2 size: {image2.size}\n"
|
| 119 |
+
|
| 120 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 121 |
+
model.to(dev)
|
| 122 |
+
|
| 123 |
+
image1 = image1.convert("RGB")
|
| 124 |
+
image2 = image2.convert("RGB")
|
| 125 |
+
|
| 126 |
+
image1_original = image1.copy()
|
| 127 |
+
image2_original = image2.copy()
|
| 128 |
+
|
| 129 |
+
# convert PIL images to numpy arrays
|
| 130 |
+
image1_original = np.array(image1_original)
|
| 131 |
+
image2_original = np.array(image2_original)
|
| 132 |
+
|
| 133 |
+
# convert PIL images to tensors
|
| 134 |
+
image1 = torch.tensor(np.array(image1)).permute(2, 0, 1).float() / 255.0
|
| 135 |
+
image2 = torch.tensor(np.array(image2)).permute(2, 0, 1).float() / 255.0
|
| 136 |
+
|
| 137 |
+
image1 = image1.to(dev).unsqueeze(0) # Add batch dimension
|
| 138 |
+
image2 = image2.to(dev).unsqueeze(0) # Add batch dimension
|
| 139 |
+
|
| 140 |
+
kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
|
| 141 |
+
kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
|
| 142 |
+
|
| 143 |
+
log_text += f"Number of keypoints in image 1: {kpts_1.shape[0]}\n"
|
| 144 |
+
log_text += f"Number of keypoints in image 2: {kpts_2.shape[0]}\n"
|
| 145 |
+
|
| 146 |
+
matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
|
| 147 |
+
match_dists, match_idxs = matcher(desc_1, desc_2)
|
| 148 |
+
|
| 149 |
+
log_text += f"Number of MNN matches: {match_idxs.shape[0]}\n"
|
| 150 |
+
|
| 151 |
+
cv2_matches = cv2_matches_from_kornia(match_dists, match_idxs)
|
| 152 |
+
|
| 153 |
+
do_ransac = match_idxs.shape[0] > 8
|
| 154 |
+
|
| 155 |
+
if do_ransac:
|
| 156 |
+
matched_pts_1 = kpts_1[match_idxs[:, 0]]
|
| 157 |
+
matched_pts_2 = kpts_2[match_idxs[:, 1]]
|
| 158 |
+
|
| 159 |
+
H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=inl_th)(matched_pts_1, matched_pts_2)
|
| 160 |
+
matchesMask = mask.int().ravel().tolist()
|
| 161 |
+
|
| 162 |
+
log_text += f"RANSAC found {mask.sum().item()} inliers out of {mask.shape[0]} matches with an inlier threshold of {inl_th}.\n"
|
| 163 |
+
else:
|
| 164 |
+
log_text += "Not enough matches for RANSAC, skipping RANSAC step.\n"
|
| 165 |
+
|
| 166 |
+
kpts_1 = to_cv_kpts(kpts_1, score_1)
|
| 167 |
+
kpts_2 = to_cv_kpts(kpts_2, score_2)
|
| 168 |
+
|
| 169 |
+
keypoints_raw_1 = cv2.drawKeypoints(image1_original, kpts_1, image1_original, color=(0, 255, 0))
|
| 170 |
+
keypoints_raw_2 = cv2.drawKeypoints(image2_original, kpts_2, image2_original, color=(0, 255, 0))
|
| 171 |
+
|
| 172 |
+
# pad height smaller image to match the height of the larger image
|
| 173 |
+
if keypoints_raw_1.shape[0] < keypoints_raw_2.shape[0]:
|
| 174 |
+
pad_height = keypoints_raw_2.shape[0] - keypoints_raw_1.shape[0]
|
| 175 |
+
keypoints_raw_1 = np.pad(
|
| 176 |
+
keypoints_raw_1, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
|
| 177 |
+
)
|
| 178 |
+
elif keypoints_raw_1.shape[0] > keypoints_raw_2.shape[0]:
|
| 179 |
+
pad_height = keypoints_raw_1.shape[0] - keypoints_raw_2.shape[0]
|
| 180 |
+
keypoints_raw_2 = np.pad(
|
| 181 |
+
keypoints_raw_2, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# concatenate keypoints images horizontally
|
| 185 |
+
keypoints_raw = np.concatenate((keypoints_raw_1, keypoints_raw_2), axis=1)
|
| 186 |
+
keypoints_raw_pil = Image.fromarray(keypoints_raw)
|
| 187 |
+
|
| 188 |
+
result_raw = cv2.drawMatches(
|
| 189 |
+
image1_original,
|
| 190 |
+
kpts_1,
|
| 191 |
+
image2_original,
|
| 192 |
+
kpts_2,
|
| 193 |
+
cv2_matches,
|
| 194 |
+
None,
|
| 195 |
+
matchColor=(0, 255, 0),
|
| 196 |
+
matchesMask=None,
|
| 197 |
+
# matchesMask=None,
|
| 198 |
+
flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if not do_ransac:
|
| 202 |
+
result_ransac = None
|
| 203 |
+
else:
|
| 204 |
+
result_ransac = cv2.drawMatches(
|
| 205 |
+
image1_original,
|
| 206 |
+
kpts_1,
|
| 207 |
+
image2_original,
|
| 208 |
+
kpts_2,
|
| 209 |
+
cv2_matches,
|
| 210 |
+
None,
|
| 211 |
+
matchColor=(0, 255, 0),
|
| 212 |
+
matchesMask=matchesMask,
|
| 213 |
+
singlePointColor=(0, 0, 255),
|
| 214 |
+
flags=cv2.DrawMatchesFlags_DEFAULT,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert BGR to RGB for display
|
| 218 |
+
|
| 219 |
+
# convert to PIL Image
|
| 220 |
+
result_raw_pil = Image.fromarray(result_raw)
|
| 221 |
+
if result_ransac is not None:
|
| 222 |
+
result_ransac_pil = Image.fromarray(result_ransac)
|
| 223 |
+
else:
|
| 224 |
+
result_ransac_pil = None
|
| 225 |
+
|
| 226 |
+
return log_text, result_ransac_pil, result_raw_pil, keypoints_raw_pil
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
demo = gr.Interface(
|
| 230 |
+
fn=extract_keypoints,
|
| 231 |
+
inputs=[
|
| 232 |
+
gr.Image(type="pil", label="Image 1"),
|
| 233 |
+
gr.Image(type="pil", label="Image 2"),
|
| 234 |
+
gr.Slider(
|
| 235 |
+
minimum=0.1,
|
| 236 |
+
maximum=3.0,
|
| 237 |
+
step=0.1,
|
| 238 |
+
value=0.5,
|
| 239 |
+
label="RANSAC inlier threshold",
|
| 240 |
+
info="Threshold for RANSAC inlier detection. Lower values may yield fewer inliers but more robust matches.",
|
| 241 |
+
),
|
| 242 |
+
],
|
| 243 |
+
outputs=[
|
| 244 |
+
gr.Textbox(type="text", label="Log"),
|
| 245 |
+
gr.Image(type="pil", label="Keypoints and Matches (RANSAC)"),
|
| 246 |
+
gr.Image(type="pil", label="Keypoints and Matches"),
|
| 247 |
+
gr.Image(type="pil", label="Keypoint Detection Results"),
|
| 248 |
+
],
|
| 249 |
+
title="RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction",
|
| 250 |
+
description=description_text,
|
| 251 |
+
examples=[
|
| 252 |
+
[
|
| 253 |
+
"assets_gradio/all_souls_000013.jpg",
|
| 254 |
+
"assets_gradio/all_souls_000055.jpg",
|
| 255 |
+
],
|
| 256 |
+
[
|
| 257 |
+
"assets_gradio/167170681_0e5c42fd21_o.jpg",
|
| 258 |
+
"assets_gradio/170804731_6bf4fbecd4_o.jpg",
|
| 259 |
+
],
|
| 260 |
+
[
|
| 261 |
+
"assets_gradio/4171014767_0fe879b783_o.jpg",
|
| 262 |
+
"assets_gradio/4174108353_20422632d6_o.jpg",
|
| 263 |
+
],
|
| 264 |
+
],
|
| 265 |
+
flagging_mode="never",
|
| 266 |
+
theme="default",
|
| 267 |
+
)
|
| 268 |
+
demo.launch()
|
packages.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libeigen3-dev
|
| 2 |
+
cmake
|
| 3 |
+
build-essentials
|
| 4 |
+
python3-opencv
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
setuptools
|
| 2 |
+
poselib @ git+https://github.com/PoseLib/PoseLib.git@56d158f744d3561b0b70174e6d8ca9a7fc9bd9c1
|
| 3 |
+
opencv-python
|
| 4 |
+
kornia
|
| 5 |
+
numpy
|
| 6 |
+
torch
|
| 7 |
+
torchvision
|
ripe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model_zoo import vgg_hyper # noqa: F401
|
ripe/benchmarks/imw_2020.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import kornia.feature as KF
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import poselib
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from ripe import utils
|
| 13 |
+
from ripe.data.data_transforms import Compose, Normalize, Resize
|
| 14 |
+
from ripe.data.datasets.disk_imw import DISK_IMW
|
| 15 |
+
from ripe.utils.pose_error import AUCMetric, relative_pose_error
|
| 16 |
+
from ripe.utils.utils import (
|
| 17 |
+
cv2_matches_from_kornia,
|
| 18 |
+
cv_resize_and_pad_to_shape,
|
| 19 |
+
to_cv_kpts,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
log = utils.get_pylogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class IMW_2020_Benchmark:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
use_predefined_subset: bool = True,
|
| 29 |
+
conf_inference=None,
|
| 30 |
+
edge_input_divisible_by=None,
|
| 31 |
+
):
|
| 32 |
+
data_dir = os.getenv("DATA_DIR")
|
| 33 |
+
if data_dir is None:
|
| 34 |
+
raise ValueError("Environment variable DATA_DIR is not set.")
|
| 35 |
+
root_path = Path(data_dir) / "disk-data"
|
| 36 |
+
|
| 37 |
+
self.data = DISK_IMW(
|
| 38 |
+
str(
|
| 39 |
+
root_path
|
| 40 |
+
), # Resize only to ensure that the input size is divisible the value of edge_input_divisible_by
|
| 41 |
+
transforms=Compose(
|
| 42 |
+
[
|
| 43 |
+
Resize(None, edge_input_divisible_by),
|
| 44 |
+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 45 |
+
]
|
| 46 |
+
),
|
| 47 |
+
)
|
| 48 |
+
self.ids_subset = None
|
| 49 |
+
self.results = []
|
| 50 |
+
self.conf_inference = conf_inference
|
| 51 |
+
|
| 52 |
+
# fmt: off
|
| 53 |
+
if use_predefined_subset:
|
| 54 |
+
self.ids_subset = [4921, 3561, 3143, 6040, 802, 6828, 5338, 9275, 10764, 10085, 5124, 11355, 7, 10027, 2161, 4433, 6887, 3311, 10766,
|
| 55 |
+
11451, 11433, 8539, 2581, 10300, 10562, 1723, 8803, 6275, 10140, 11487, 6238, 638, 8092, 9979, 201, 10394, 3414,
|
| 56 |
+
9002, 7456, 2431, 632, 6589, 9265, 9889, 3139, 7890, 10619, 4899, 675, 176, 4309, 4814, 3833, 3519, 148, 4560, 10705,
|
| 57 |
+
3744, 1441, 4049, 1791, 5106, 575, 1540, 1105, 6791, 1383, 9344, 501, 2504, 4335, 8992, 10970, 10786, 10405, 9317,
|
| 58 |
+
5279, 1396, 5044, 9408, 11125, 10417, 7627, 7480, 1358, 7738, 5461, 10178, 9226, 8106, 2766, 6216, 4032, 7298, 259,
|
| 59 |
+
3021, 2645, 8756, 7513, 3163, 2510, 6701, 6684, 3159, 9689, 7425, 6066, 1904, 6382, 3052, 777, 6277, 7409, 5997, 2987,
|
| 60 |
+
11316, 2894, 4528, 1927, 10366, 8605, 2726, 1886, 2416, 2164, 3352, 2997, 6636, 6765, 5609, 3679, 76, 10956, 3612, 6699,
|
| 61 |
+
1741, 8811, 3755, 1285, 9520, 2476, 3977, 370, 9823, 1834, 7551, 6227, 7303, 6399, 4758, 10713, 5050, 380, 11056, 7620,
|
| 62 |
+
4826, 6090, 9011, 7523, 7355, 8021, 9801, 1801, 6522, 7138, 10017, 8732, 6402, 3116, 4031, 6088, 3975, 9841, 9082, 9412,
|
| 63 |
+
5406, 217, 2385, 8791, 8361, 494, 4319, 5275, 3274, 335, 6731, 207, 10095, 3068, 5996, 3951, 2808, 5877, 6134, 7772, 10042,
|
| 64 |
+
8574, 5501, 10885, 7871]
|
| 65 |
+
# self.ids_subset = self.ids_subset[:10]
|
| 66 |
+
# fmt: on
|
| 67 |
+
|
| 68 |
+
def evaluate_sample(self, model, sample, dev):
|
| 69 |
+
img_1 = sample["src_image"].unsqueeze(0).to(dev)
|
| 70 |
+
img_2 = sample["trg_image"].unsqueeze(0).to(dev)
|
| 71 |
+
|
| 72 |
+
scale_h_1, scale_w_1 = (
|
| 73 |
+
sample["orig_size_src"][0] / img_1.shape[2],
|
| 74 |
+
sample["orig_size_src"][1] / img_1.shape[3],
|
| 75 |
+
)
|
| 76 |
+
scale_h_2, scale_w_2 = (
|
| 77 |
+
sample["orig_size_trg"][0] / img_2.shape[2],
|
| 78 |
+
sample["orig_size_trg"][1] / img_2.shape[3],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
M = None
|
| 82 |
+
info = {}
|
| 83 |
+
kpts_1, desc_1, score_1 = None, None, None
|
| 84 |
+
kpts_2, desc_2, score_2 = None, None, None
|
| 85 |
+
match_dists, match_idxs = None, None
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
kpts_1, desc_1, score_1 = model.detectAndCompute(img_1, **self.conf_inference)
|
| 89 |
+
kpts_2, desc_2, score_2 = model.detectAndCompute(img_2, **self.conf_inference)
|
| 90 |
+
|
| 91 |
+
if kpts_1.dim() == 3:
|
| 92 |
+
assert kpts_1.shape[0] == 1 and kpts_2.shape[0] == 1, "Batch size must be 1"
|
| 93 |
+
|
| 94 |
+
kpts_1, desc_1, score_1 = (
|
| 95 |
+
kpts_1.squeeze(0),
|
| 96 |
+
desc_1[0].squeeze(0),
|
| 97 |
+
score_1[0].squeeze(0),
|
| 98 |
+
)
|
| 99 |
+
kpts_2, desc_2, score_2 = (
|
| 100 |
+
kpts_2.squeeze(0),
|
| 101 |
+
desc_2[0].squeeze(0),
|
| 102 |
+
score_2[0].squeeze(0),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
scale_1 = torch.tensor([scale_w_1, scale_h_1], dtype=torch.float).to(dev)
|
| 106 |
+
scale_2 = torch.tensor([scale_w_2, scale_h_2], dtype=torch.float).to(dev)
|
| 107 |
+
|
| 108 |
+
kpts_1 = kpts_1 * scale_1
|
| 109 |
+
kpts_2 = kpts_2 * scale_2
|
| 110 |
+
|
| 111 |
+
matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
|
| 112 |
+
match_dists, match_idxs = matcher(desc_1, desc_2)
|
| 113 |
+
|
| 114 |
+
matched_pts_1 = kpts_1[match_idxs[:, 0]]
|
| 115 |
+
matched_pts_2 = kpts_2[match_idxs[:, 1]]
|
| 116 |
+
|
| 117 |
+
camera_1 = sample["src_camera"]
|
| 118 |
+
camera_2 = sample["trg_camera"]
|
| 119 |
+
|
| 120 |
+
M, info = poselib.estimate_relative_pose(
|
| 121 |
+
matched_pts_1.cpu().numpy(),
|
| 122 |
+
matched_pts_2.cpu().numpy(),
|
| 123 |
+
camera_1.to_cameradict(),
|
| 124 |
+
camera_2.to_cameradict(),
|
| 125 |
+
{
|
| 126 |
+
"max_epipolar_error": 0.5,
|
| 127 |
+
},
|
| 128 |
+
{},
|
| 129 |
+
)
|
| 130 |
+
except RuntimeError as e:
|
| 131 |
+
if "No keypoints detected" in str(e):
|
| 132 |
+
pass
|
| 133 |
+
else:
|
| 134 |
+
raise e
|
| 135 |
+
|
| 136 |
+
success = M is not None
|
| 137 |
+
if success:
|
| 138 |
+
M = {
|
| 139 |
+
"R": torch.tensor(M.R, dtype=torch.float),
|
| 140 |
+
"t": torch.tensor(M.t, dtype=torch.float),
|
| 141 |
+
}
|
| 142 |
+
inl = info["inliers"]
|
| 143 |
+
else:
|
| 144 |
+
M = {
|
| 145 |
+
"R": torch.eye(3, dtype=torch.float),
|
| 146 |
+
"t": torch.zeros((3), dtype=torch.float),
|
| 147 |
+
}
|
| 148 |
+
inl = np.zeros((0,)).astype(bool)
|
| 149 |
+
|
| 150 |
+
t_err, r_err = relative_pose_error(sample["s2t_R"].cpu(), sample["s2t_T"].cpu(), M["R"], M["t"])
|
| 151 |
+
|
| 152 |
+
rel_pose_error = max(t_err.item(), r_err.item()) if success else np.inf
|
| 153 |
+
ransac_inl = np.sum(inl)
|
| 154 |
+
ransac_inl_ratio = np.mean(inl)
|
| 155 |
+
|
| 156 |
+
if success:
|
| 157 |
+
assert match_dists is not None and match_idxs is not None, "Matches must be computed"
|
| 158 |
+
cv_keypoints_src = to_cv_kpts(kpts_1, score_1)
|
| 159 |
+
cv_keypoints_trg = to_cv_kpts(kpts_2, score_2)
|
| 160 |
+
cv_matches = cv2_matches_from_kornia(match_dists, match_idxs)
|
| 161 |
+
cv_mask = [int(m) for m in inl]
|
| 162 |
+
else:
|
| 163 |
+
cv_keypoints_src, cv_keypoints_trg = [], []
|
| 164 |
+
cv_matches, cv_mask = [], []
|
| 165 |
+
|
| 166 |
+
estimation = {
|
| 167 |
+
"success": success,
|
| 168 |
+
"M_0to1": M,
|
| 169 |
+
"inliers": torch.tensor(inl).to(img_1),
|
| 170 |
+
"rel_pose_error": rel_pose_error,
|
| 171 |
+
"ransac_inl": ransac_inl,
|
| 172 |
+
"ransac_inl_ratio": ransac_inl_ratio,
|
| 173 |
+
"path_src_image": sample["src_path"],
|
| 174 |
+
"path_trg_image": sample["trg_path"],
|
| 175 |
+
"cv_keypoints_src": cv_keypoints_src,
|
| 176 |
+
"cv_keypoints_trg": cv_keypoints_trg,
|
| 177 |
+
"cv_matches": cv_matches,
|
| 178 |
+
"cv_mask": cv_mask,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
return estimation
|
| 182 |
+
|
| 183 |
+
def evaluate(self, model, dev, progress_bar=False):
|
| 184 |
+
model.eval()
|
| 185 |
+
|
| 186 |
+
# reset results
|
| 187 |
+
self.results = []
|
| 188 |
+
|
| 189 |
+
for idx in tqdm(
|
| 190 |
+
self.ids_subset if self.ids_subset is not None else range(len(self.data)),
|
| 191 |
+
disable=not progress_bar,
|
| 192 |
+
):
|
| 193 |
+
sample = self.data[idx]
|
| 194 |
+
self.results.append(self.evaluate_sample(model, sample, dev))
|
| 195 |
+
|
| 196 |
+
def get_auc(self, threshold=5, downsampled=False):
|
| 197 |
+
if len(self.results) == 0:
|
| 198 |
+
raise ValueError("No results to log. Run evaluate first.")
|
| 199 |
+
|
| 200 |
+
summary_results = self.calc_auc(downsampled=downsampled)
|
| 201 |
+
|
| 202 |
+
return summary_results[f"rel_pose_error@{threshold}°{'__original' if not downsampled else '__downsampled'}"]
|
| 203 |
+
|
| 204 |
+
def plot_results(self, num_samples=10, logger=None, step=None, downsampled=False):
|
| 205 |
+
if len(self.results) == 0:
|
| 206 |
+
raise ValueError("No results to plot. Run evaluate first.")
|
| 207 |
+
|
| 208 |
+
plot_data = []
|
| 209 |
+
|
| 210 |
+
for result in self.results[:num_samples]:
|
| 211 |
+
img1 = cv2.imread(result["path_src_image"])
|
| 212 |
+
img2 = cv2.imread(result["path_trg_image"])
|
| 213 |
+
|
| 214 |
+
# from BGR to RGB
|
| 215 |
+
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
|
| 216 |
+
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
|
| 217 |
+
|
| 218 |
+
plt_matches = cv2.drawMatches(
|
| 219 |
+
img1,
|
| 220 |
+
result["cv_keypoints_src"],
|
| 221 |
+
img2,
|
| 222 |
+
result["cv_keypoints_trg"],
|
| 223 |
+
result["cv_matches"],
|
| 224 |
+
None,
|
| 225 |
+
matchColor=None,
|
| 226 |
+
matchesMask=result["cv_mask"],
|
| 227 |
+
flags=cv2.DrawMatchesFlags_DEFAULT,
|
| 228 |
+
)
|
| 229 |
+
file_name = (
|
| 230 |
+
Path(result["path_src_image"]).parent.parent.name
|
| 231 |
+
+ "_"
|
| 232 |
+
+ Path(result["path_src_image"]).stem
|
| 233 |
+
+ Path(result["path_trg_image"]).stem
|
| 234 |
+
+ ("_downsampled" if downsampled else "")
|
| 235 |
+
+ ".png"
|
| 236 |
+
)
|
| 237 |
+
# print rel_pose_error on image
|
| 238 |
+
plt_matches = cv2.putText(
|
| 239 |
+
plt_matches,
|
| 240 |
+
f"rel_pose_error: {result['rel_pose_error']:.2f} num_inliers: {result['ransac_inl']} inl_ratio: {result['ransac_inl_ratio']:.2f} num_matches: {len(result['cv_matches'])} num_keypoints: {len(result['cv_keypoints_src'])}/{len(result['cv_keypoints_trg'])}",
|
| 241 |
+
(10, 30),
|
| 242 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 243 |
+
1,
|
| 244 |
+
(0, 0, 0),
|
| 245 |
+
2,
|
| 246 |
+
cv2.LINE_8,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
plot_data.append({"file_name": file_name, "image": plt_matches})
|
| 250 |
+
|
| 251 |
+
if logger is None:
|
| 252 |
+
log.info("No logger provided. Using plt to plot results.")
|
| 253 |
+
for image in plot_data:
|
| 254 |
+
plt.imsave(
|
| 255 |
+
image["file_name"],
|
| 256 |
+
cv_resize_and_pad_to_shape(image["image"], (1024, 2048)),
|
| 257 |
+
)
|
| 258 |
+
plt.close()
|
| 259 |
+
else:
|
| 260 |
+
import wandb
|
| 261 |
+
|
| 262 |
+
log.info(f"Logging images to wandb with step={step}")
|
| 263 |
+
if not downsampled:
|
| 264 |
+
logger.log(
|
| 265 |
+
{
|
| 266 |
+
"examples": [
|
| 267 |
+
wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
step=step,
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
logger.log(
|
| 274 |
+
{
|
| 275 |
+
"examples_downsampled": [
|
| 276 |
+
wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
|
| 277 |
+
]
|
| 278 |
+
},
|
| 279 |
+
step=step,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def log_results(self, logger=None, step=None, downsampled=False):
|
| 283 |
+
if len(self.results) == 0:
|
| 284 |
+
raise ValueError("No results to log. Run evaluate first.")
|
| 285 |
+
|
| 286 |
+
summary_results = self.calc_auc(downsampled=downsampled)
|
| 287 |
+
|
| 288 |
+
if logger is not None:
|
| 289 |
+
logger.log(summary_results, step=step)
|
| 290 |
+
else:
|
| 291 |
+
log.warning("No logger provided. Printing results instead.")
|
| 292 |
+
print(self.calc_auc())
|
| 293 |
+
|
| 294 |
+
def print_results(self):
|
| 295 |
+
if len(self.results) == 0:
|
| 296 |
+
raise ValueError("No results to print. Run evaluate first.")
|
| 297 |
+
|
| 298 |
+
print(self.calc_auc())
|
| 299 |
+
|
| 300 |
+
def calc_auc(self, auc_thresholds=None, downsampled=False):
|
| 301 |
+
if auc_thresholds is None:
|
| 302 |
+
auc_thresholds = [5, 10, 20]
|
| 303 |
+
if not isinstance(auc_thresholds, list):
|
| 304 |
+
auc_thresholds = [auc_thresholds]
|
| 305 |
+
|
| 306 |
+
if len(self.results) == 0:
|
| 307 |
+
raise ValueError("No results to calculate auc. Run evaluate first.")
|
| 308 |
+
|
| 309 |
+
rel_pose_errors = [r["rel_pose_error"] for r in self.results]
|
| 310 |
+
|
| 311 |
+
pose_aucs = AUCMetric(auc_thresholds, rel_pose_errors).compute()
|
| 312 |
+
assert isinstance(pose_aucs, list) and len(pose_aucs) == len(auc_thresholds)
|
| 313 |
+
|
| 314 |
+
ext = "_downsampled" if downsampled else "_original"
|
| 315 |
+
|
| 316 |
+
summary = {}
|
| 317 |
+
for i, ath in enumerate(auc_thresholds):
|
| 318 |
+
summary[f"rel_pose_error@{ath}°_{ext}"] = pose_aucs[i]
|
| 319 |
+
|
| 320 |
+
return summary
|
ripe/data/__init__.py
ADDED
|
File without changes
|
ripe/data/data_transforms.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import collections.abc
|
| 3 |
+
|
| 4 |
+
import kornia.geometry as KG
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms import functional as TF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Compose:
|
| 11 |
+
"""Composes several transforms together. The transforms are applied in the order they are passed in.
|
| 12 |
+
Args: transforms (list): A list of transforms to be applied.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, transforms):
|
| 16 |
+
self.transforms = transforms
|
| 17 |
+
|
| 18 |
+
def __call__(self, src, trg, src_mask, trg_mask, h):
|
| 19 |
+
for t in self.transforms:
|
| 20 |
+
src, trg, src_mask, trg_mask, h = t(src, trg, src_mask, trg_mask, h)
|
| 21 |
+
|
| 22 |
+
return src, trg, src_mask, trg_mask, h
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Transform:
|
| 26 |
+
"""Base class for all transforms. It provides a method to apply a transformation function to the input images and masks.
|
| 27 |
+
Args:
|
| 28 |
+
src (torch.Tensor): The source image tensor.
|
| 29 |
+
trg (torch.Tensor): The target image tensor.
|
| 30 |
+
src_mask (torch.Tensor): The source image mask tensor.
|
| 31 |
+
trg_mask (torch.Tensor): The target image mask tensor.
|
| 32 |
+
h (torch.Tensor): The homography matrix tensor.
|
| 33 |
+
Returns:
|
| 34 |
+
tuple: A tuple containing the transformed source image, the transformed target image, the transformed source mask,
|
| 35 |
+
the transformed target mask and the updated homography matrix.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def apply_transform(self, src, trg, src_mask, trg_mask, h, transfrom_function):
|
| 42 |
+
src, trg, src_mask, trg_mask, h = transfrom_function(src, trg, src_mask, trg_mask, h)
|
| 43 |
+
return src, trg, src_mask, trg_mask, h
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Normalize(Transform):
|
| 47 |
+
def __init__(self, mean, std):
|
| 48 |
+
self.mean = mean
|
| 49 |
+
self.std = std
|
| 50 |
+
|
| 51 |
+
def __call__(self, src, trg, src_mask, trg_mask, h):
|
| 52 |
+
return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
|
| 53 |
+
|
| 54 |
+
def transform_function(self, src, trg, src_mask, trg_mask, h):
|
| 55 |
+
src = TF.normalize(src, mean=self.mean, std=self.std)
|
| 56 |
+
trg = TF.normalize(trg, mean=self.mean, std=self.std)
|
| 57 |
+
return src, trg, src_mask, trg_mask, h
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ResizeAndPadWithHomography(Transform):
|
| 61 |
+
def __init__(self, target_size_longer_side=768):
|
| 62 |
+
self.target_size = target_size_longer_side
|
| 63 |
+
|
| 64 |
+
def __call__(self, src, trg, src_mask, trg_mask, h):
|
| 65 |
+
return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
|
| 66 |
+
|
| 67 |
+
def transform_function(self, src, trg, src_mask, trg_mask, h):
|
| 68 |
+
src_w, src_h = src.shape[-1], src.shape[-2]
|
| 69 |
+
trg_w, trg_h = trg.shape[-1], trg.shape[-2]
|
| 70 |
+
|
| 71 |
+
# Resizing logic for both images
|
| 72 |
+
scale_src, new_src_w, new_src_h = self.compute_resize(src_w, src_h)
|
| 73 |
+
scale_trg, new_trg_w, new_trg_h = self.compute_resize(trg_w, trg_h)
|
| 74 |
+
|
| 75 |
+
# Resize both images
|
| 76 |
+
src_resized = TF.resize(src, [new_src_h, new_src_w])
|
| 77 |
+
trg_resized = TF.resize(trg, [new_trg_h, new_trg_w])
|
| 78 |
+
|
| 79 |
+
src_mask_resized = TF.resize(src_mask, [new_src_h, new_src_w])
|
| 80 |
+
trg_mask_resized = TF.resize(trg_mask, [new_trg_h, new_trg_w])
|
| 81 |
+
|
| 82 |
+
# Pad the resized images to be square (768x768)
|
| 83 |
+
src_padded, src_padding = self.apply_padding(src_resized, new_src_w, new_src_h)
|
| 84 |
+
trg_padded, trg_padding = self.apply_padding(trg_resized, new_trg_w, new_trg_h)
|
| 85 |
+
|
| 86 |
+
src_mask_padded, _ = self.apply_padding(src_mask_resized, new_src_w, new_src_h)
|
| 87 |
+
trg_mask_padded, _ = self.apply_padding(trg_mask_resized, new_trg_w, new_trg_h)
|
| 88 |
+
|
| 89 |
+
# Update the homography matrix
|
| 90 |
+
h = self.update_homography(h, scale_src, src_padding, scale_trg, trg_padding)
|
| 91 |
+
|
| 92 |
+
return src_padded, trg_padded, src_mask_padded, trg_mask_padded, h
|
| 93 |
+
|
| 94 |
+
def compute_resize(self, w, h):
|
| 95 |
+
if w > h:
|
| 96 |
+
scale = self.target_size / w
|
| 97 |
+
new_w = self.target_size
|
| 98 |
+
new_h = int(h * scale)
|
| 99 |
+
else:
|
| 100 |
+
scale = self.target_size / h
|
| 101 |
+
new_h = self.target_size
|
| 102 |
+
new_w = int(w * scale)
|
| 103 |
+
return scale, new_w, new_h
|
| 104 |
+
|
| 105 |
+
def apply_padding(self, img, new_w, new_h):
|
| 106 |
+
pad_w = (self.target_size - new_w) // 2
|
| 107 |
+
pad_h = (self.target_size - new_h) // 2
|
| 108 |
+
padding = [
|
| 109 |
+
pad_w,
|
| 110 |
+
pad_h,
|
| 111 |
+
self.target_size - new_w - pad_w,
|
| 112 |
+
self.target_size - new_h - pad_h,
|
| 113 |
+
]
|
| 114 |
+
img_padded = TF.pad(img, padding, fill=0) # Zero-pad
|
| 115 |
+
return img_padded, padding
|
| 116 |
+
|
| 117 |
+
def update_homography(self, h, scale_src, padding_src, scale_trg, padding_trg):
|
| 118 |
+
# Create the scaling matrices
|
| 119 |
+
scale_matrix_src = np.array([[scale_src, 0, 0], [0, scale_src, 0], [0, 0, 1]])
|
| 120 |
+
scale_matrix_trg = np.array([[scale_trg, 0, 0], [0, scale_trg, 0], [0, 0, 1]])
|
| 121 |
+
|
| 122 |
+
# Create the padding translation matrices
|
| 123 |
+
pad_matrix_src = np.array([[1, 0, padding_src[0]], [0, 1, padding_src[1]], [0, 0, 1]])
|
| 124 |
+
pad_matrix_trg = np.array([[1, 0, -padding_trg[0]], [0, 1, -padding_trg[1]], [0, 0, 1]])
|
| 125 |
+
|
| 126 |
+
# Update the homography: apply scaling and translation
|
| 127 |
+
h_updated = (
|
| 128 |
+
pad_matrix_trg
|
| 129 |
+
@ scale_matrix_trg
|
| 130 |
+
@ h.numpy()
|
| 131 |
+
@ np.linalg.inv(scale_matrix_src)
|
| 132 |
+
@ np.linalg.inv(pad_matrix_src)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return torch.from_numpy(h_updated).float()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Resize(Transform):
|
| 139 |
+
def __init__(self, output_size, edge_divisible_by=None, side="long", antialias=True):
|
| 140 |
+
self.output_size = output_size
|
| 141 |
+
self.edge_divisible_by = edge_divisible_by
|
| 142 |
+
self.side = side
|
| 143 |
+
self.antialias = antialias
|
| 144 |
+
|
| 145 |
+
def __call__(self, src, trg, src_mask, trg_mask, h):
|
| 146 |
+
return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
|
| 147 |
+
|
| 148 |
+
def transform_function(self, src, trg, src_mask, trg_mask, h):
|
| 149 |
+
new_size_src = self.get_new_image_size(src)
|
| 150 |
+
new_size_trg = self.get_new_image_size(trg)
|
| 151 |
+
|
| 152 |
+
src, T_src = self.resize(src, new_size_src)
|
| 153 |
+
trg, T_trg = self.resize(trg, new_size_trg)
|
| 154 |
+
|
| 155 |
+
src_mask, _ = self.resize(src_mask, new_size_src)
|
| 156 |
+
trg_mask, _ = self.resize(trg_mask, new_size_trg)
|
| 157 |
+
|
| 158 |
+
h = torch.from_numpy(T_trg @ h.numpy() @ T_src).float()
|
| 159 |
+
|
| 160 |
+
return src, trg, src_mask, trg_mask, h
|
| 161 |
+
|
| 162 |
+
def resize(self, img, size):
|
| 163 |
+
h, w = img.shape[-2:]
|
| 164 |
+
|
| 165 |
+
img = KG.transform.resize(
|
| 166 |
+
img,
|
| 167 |
+
size,
|
| 168 |
+
side=self.side,
|
| 169 |
+
antialias=self.antialias,
|
| 170 |
+
align_corners=None,
|
| 171 |
+
interpolation="bilinear",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
| 175 |
+
T = np.diag([scale[0].item(), scale[1].item(), 1])
|
| 176 |
+
|
| 177 |
+
return img, T
|
| 178 |
+
|
| 179 |
+
def get_new_image_size(self, img):
|
| 180 |
+
h, w = img.shape[-2:]
|
| 181 |
+
|
| 182 |
+
if isinstance(self.output_size, collections.abc.Iterable):
|
| 183 |
+
assert len(self.output_size) == 2
|
| 184 |
+
return tuple(self.output_size)
|
| 185 |
+
if self.output_size is None: # keep the original size, but possibly make it divisible by edge_divisible_by
|
| 186 |
+
size = (h, w)
|
| 187 |
+
else:
|
| 188 |
+
side_size = self.output_size
|
| 189 |
+
aspect_ratio = w / h
|
| 190 |
+
if self.side not in ("short", "long", "vert", "horz"):
|
| 191 |
+
raise ValueError(f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{self.side}'")
|
| 192 |
+
if self.side == "vert":
|
| 193 |
+
size = side_size, int(side_size * aspect_ratio)
|
| 194 |
+
elif self.side == "horz":
|
| 195 |
+
size = int(side_size / aspect_ratio), side_size
|
| 196 |
+
elif (self.side == "short") ^ (aspect_ratio < 1.0):
|
| 197 |
+
size = side_size, int(side_size * aspect_ratio)
|
| 198 |
+
else:
|
| 199 |
+
size = int(side_size / aspect_ratio), side_size
|
| 200 |
+
|
| 201 |
+
if self.edge_divisible_by is not None:
|
| 202 |
+
df = self.edge_divisible_by
|
| 203 |
+
size = list(map(lambda x: int(x // df * df), size))
|
| 204 |
+
return size
|
ripe/data/datasets/__init__.py
ADDED
|
File without changes
|
ripe/data/datasets/acdc.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any, Callable, Dict, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torchvision.io import read_image
|
| 7 |
+
|
| 8 |
+
from ripe import utils
|
| 9 |
+
from ripe.data.data_transforms import Compose
|
| 10 |
+
from ripe.utils.utils import get_other_random_id
|
| 11 |
+
|
| 12 |
+
log = utils.get_pylogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ACDC(Dataset):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
root: Path,
|
| 19 |
+
stage: str = "train",
|
| 20 |
+
condition: str = "rain",
|
| 21 |
+
transforms: Optional[Callable] = None,
|
| 22 |
+
positive_only: bool = False,
|
| 23 |
+
) -> None:
|
| 24 |
+
self.root = root
|
| 25 |
+
self.stage = stage
|
| 26 |
+
self.condition = condition
|
| 27 |
+
self.transforms = transforms
|
| 28 |
+
self.positive_only = positive_only
|
| 29 |
+
|
| 30 |
+
if isinstance(self.root, str):
|
| 31 |
+
self.root = Path(self.root)
|
| 32 |
+
|
| 33 |
+
if not self.root.exists():
|
| 34 |
+
raise FileNotFoundError(f"Dataset not found at {self.root}")
|
| 35 |
+
|
| 36 |
+
if transforms is None:
|
| 37 |
+
self.transforms = Compose([])
|
| 38 |
+
else:
|
| 39 |
+
self.transforms = transforms
|
| 40 |
+
|
| 41 |
+
if self.stage not in ["train", "val", "test", "pred"]:
|
| 42 |
+
raise RuntimeError(
|
| 43 |
+
"Unknown option "
|
| 44 |
+
+ self.stage
|
| 45 |
+
+ " as training stage variable. Valid options: 'train', 'val', 'test' and 'pred'"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if self.stage == "pred": # prediction uses the test set
|
| 49 |
+
self.stage = "test"
|
| 50 |
+
|
| 51 |
+
if self.stage in ["val", "test", "pred"]:
|
| 52 |
+
self.positive_only = True
|
| 53 |
+
log.info(f"{self.stage} stage: Using only positive pairs!")
|
| 54 |
+
|
| 55 |
+
weather_conditions = ["fog", "night", "rain", "snow"]
|
| 56 |
+
|
| 57 |
+
if self.condition not in weather_conditions + ["all"]:
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
"Unknown option "
|
| 60 |
+
+ self.condition
|
| 61 |
+
+ " as weather condition variable. Valid options: 'fog', 'night', 'rain', 'snow' and 'all'"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.weather_condition_query = weather_conditions if self.condition == "all" else [self.condition]
|
| 65 |
+
|
| 66 |
+
self._read_sample_files()
|
| 67 |
+
|
| 68 |
+
if positive_only:
|
| 69 |
+
log.warning("Using only positive pairs!")
|
| 70 |
+
log.info(f"Found {len(self.src_images)} source images and {len(self.trg_images)} target images.")
|
| 71 |
+
|
| 72 |
+
def _read_sample_files(self):
|
| 73 |
+
file_name_pattern_ref = "_ref_anon.png"
|
| 74 |
+
file_name_pattern = "_rgb_anon.png"
|
| 75 |
+
|
| 76 |
+
self.trg_images = []
|
| 77 |
+
self.src_images = []
|
| 78 |
+
|
| 79 |
+
for weather_condition in self.weather_condition_query:
|
| 80 |
+
rgb_files = sorted(
|
| 81 |
+
list(self.root.glob("rgb_anon/" + weather_condition + "/" + self.stage + "/**/*" + file_name_pattern)),
|
| 82 |
+
key=lambda i: i.stem[:21],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
src_images = sorted(
|
| 86 |
+
list(
|
| 87 |
+
self.root.glob(
|
| 88 |
+
"rgb_anon/" + weather_condition + "/" + self.stage + "_ref" + "/**/*" + file_name_pattern_ref
|
| 89 |
+
)
|
| 90 |
+
),
|
| 91 |
+
key=lambda i: i.stem[:21],
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.trg_images += rgb_files
|
| 95 |
+
self.src_images += src_images
|
| 96 |
+
|
| 97 |
+
def __len__(self) -> int:
|
| 98 |
+
if self.positive_only:
|
| 99 |
+
return len(self.trg_images)
|
| 100 |
+
return 2 * len(self.trg_images)
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 103 |
+
sample: Any = {}
|
| 104 |
+
|
| 105 |
+
positive_sample = (idx % 2 == 0) or (self.positive_only)
|
| 106 |
+
if not self.positive_only:
|
| 107 |
+
idx = idx // 2
|
| 108 |
+
|
| 109 |
+
sample["label"] = positive_sample
|
| 110 |
+
|
| 111 |
+
if positive_sample:
|
| 112 |
+
sample["src_path"] = str(self.src_images[idx])
|
| 113 |
+
sample["trg_path"] = str(self.trg_images[idx])
|
| 114 |
+
|
| 115 |
+
assert self.src_images[idx].stem[:21] == self.trg_images[idx].stem[:21], (
|
| 116 |
+
f"Source and target image mismatch: {self.src_images[idx]} vs {self.trg_images[idx]}"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
src_img = read_image(sample["src_path"])
|
| 120 |
+
trg_img = read_image(sample["trg_path"])
|
| 121 |
+
|
| 122 |
+
homography = torch.eye(3, dtype=torch.float32)
|
| 123 |
+
else:
|
| 124 |
+
sample["src_path"] = str(self.src_images[idx])
|
| 125 |
+
idx_other = get_other_random_id(idx, len(self) // 2)
|
| 126 |
+
sample["trg_path"] = str(self.trg_images[idx_other])
|
| 127 |
+
|
| 128 |
+
assert self.src_images[idx].stem[:21] != self.trg_images[idx_other].stem[:21], (
|
| 129 |
+
f"Source and target image match for negative sample: {self.src_images[idx]} vs {self.trg_images[idx_other]}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
src_img = read_image(sample["src_path"])
|
| 133 |
+
trg_img = read_image(sample["trg_path"])
|
| 134 |
+
|
| 135 |
+
homography = torch.zeros((3, 3), dtype=torch.float32)
|
| 136 |
+
|
| 137 |
+
src_img = src_img / 255.0
|
| 138 |
+
trg_img = trg_img / 255.0
|
| 139 |
+
|
| 140 |
+
_, H, W = src_img.shape
|
| 141 |
+
|
| 142 |
+
src_mask = torch.ones((1, H, W), dtype=torch.uint8)
|
| 143 |
+
trg_mask = torch.ones((1, H, W), dtype=torch.uint8)
|
| 144 |
+
|
| 145 |
+
if self.transforms:
|
| 146 |
+
src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
|
| 147 |
+
|
| 148 |
+
sample["src_image"] = src_img
|
| 149 |
+
sample["trg_image"] = trg_img
|
| 150 |
+
sample["src_mask"] = src_mask.to(torch.bool)
|
| 151 |
+
sample["trg_mask"] = trg_mask.to(torch.bool)
|
| 152 |
+
sample["homography"] = homography
|
| 153 |
+
|
| 154 |
+
return sample
|
ripe/data/datasets/dataset_combinator.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ripe import utils
|
| 4 |
+
|
| 5 |
+
log = utils.get_pylogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DatasetCombinator:
|
| 9 |
+
"""Combines multiple datasets into one. Length of the combined dataset is the length of the
|
| 10 |
+
longest dataset. Shorter datasets are looped over.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
datasets: List of datasets to combine.
|
| 14 |
+
mode: How to sample from the datasets. Can be either "uniform" or "weighted".
|
| 15 |
+
In "uniform" mode, each dataset is sampled with equal probability.
|
| 16 |
+
In "weighted" mode, each dataset is sampled with probability proportional to its length.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, datasets, mode="uniform", weights=None):
|
| 20 |
+
self.datasets = datasets
|
| 21 |
+
|
| 22 |
+
names_datasets = [type(ds).__name__ for ds in self.datasets]
|
| 23 |
+
self.lengths = [len(ds) for ds in datasets]
|
| 24 |
+
|
| 25 |
+
if mode == "weighted":
|
| 26 |
+
self.probs_datasets = [length / sum(self.lengths) for length in self.lengths]
|
| 27 |
+
elif mode == "uniform":
|
| 28 |
+
self.probs_datasets = [1 / len(self.datasets) for _ in self.datasets]
|
| 29 |
+
elif mode == "custom":
|
| 30 |
+
assert weights is not None, "Weights must be provided in custom mode"
|
| 31 |
+
assert len(weights) == len(datasets), "Number of weights must match number of datasets"
|
| 32 |
+
assert sum(weights) == 1.0, "Weights must sum to 1"
|
| 33 |
+
self.probs_datasets = weights
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Unknown mode {mode}")
|
| 36 |
+
|
| 37 |
+
log.info("Got the following datasets: ")
|
| 38 |
+
|
| 39 |
+
for name, length, prob in zip(names_datasets, self.lengths, self.probs_datasets):
|
| 40 |
+
log.info(f"{name} with {length} samples and probability {prob}")
|
| 41 |
+
log.info(f"Total number of samples: {sum(self.lengths)}")
|
| 42 |
+
|
| 43 |
+
self.num_samples = max(self.lengths)
|
| 44 |
+
|
| 45 |
+
self.dataset_dist = torch.distributions.Categorical(probs=torch.tensor(self.probs_datasets))
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return self.num_samples
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, idx: int):
|
| 51 |
+
positive_sample = idx % 2 == 0
|
| 52 |
+
|
| 53 |
+
if positive_sample:
|
| 54 |
+
dataset_idx = self.dataset_dist.sample().item()
|
| 55 |
+
|
| 56 |
+
idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
|
| 57 |
+
while idx % 2 == 1:
|
| 58 |
+
idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
|
| 59 |
+
|
| 60 |
+
return self.datasets[dataset_idx][idx]
|
| 61 |
+
else:
|
| 62 |
+
dataset_idx_1 = self.dataset_dist.sample().item()
|
| 63 |
+
dataset_idx_2 = self.dataset_dist.sample().item()
|
| 64 |
+
|
| 65 |
+
if dataset_idx_1 == dataset_idx_2:
|
| 66 |
+
idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
|
| 67 |
+
while idx % 2 == 0:
|
| 68 |
+
idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
|
| 69 |
+
return self.datasets[dataset_idx_1][idx]
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
idx_1 = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
|
| 73 |
+
idx_2 = torch.randint(0, self.lengths[dataset_idx_2], (1,)).item()
|
| 74 |
+
|
| 75 |
+
sample_1 = self.datasets[dataset_idx_1][idx_1]
|
| 76 |
+
sample_2 = self.datasets[dataset_idx_2][idx_2]
|
| 77 |
+
|
| 78 |
+
sample = {
|
| 79 |
+
"label": False,
|
| 80 |
+
"src_path": sample_1["src_path"],
|
| 81 |
+
"trg_path": sample_2["trg_path"],
|
| 82 |
+
"src_image": sample_1["src_image"],
|
| 83 |
+
"trg_image": sample_2["trg_image"],
|
| 84 |
+
"src_mask": sample_1["src_mask"],
|
| 85 |
+
"trg_mask": sample_2["trg_mask"],
|
| 86 |
+
"homography": sample_2["homography"],
|
| 87 |
+
}
|
| 88 |
+
return sample
|
ripe/data/datasets/disk_imw.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from itertools import accumulate
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torchvision.io import read_image
|
| 10 |
+
|
| 11 |
+
from ripe import utils
|
| 12 |
+
from ripe.data.data_transforms import Compose
|
| 13 |
+
from ripe.utils.image_utils import Camera, cameras2F
|
| 14 |
+
|
| 15 |
+
log = utils.get_pylogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DISK_IMW(Dataset):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
root: str,
|
| 22 |
+
stage: str = "val",
|
| 23 |
+
# condition: str = "rain",
|
| 24 |
+
transforms: Optional[Callable] = None,
|
| 25 |
+
) -> None:
|
| 26 |
+
self.root = root
|
| 27 |
+
self.stage = stage
|
| 28 |
+
self.transforms = transforms
|
| 29 |
+
|
| 30 |
+
if isinstance(self.root, str):
|
| 31 |
+
self.root = Path(self.root)
|
| 32 |
+
|
| 33 |
+
if not self.root.exists():
|
| 34 |
+
raise FileNotFoundError(f"Dataset not found at {self.root}")
|
| 35 |
+
|
| 36 |
+
if transforms is None:
|
| 37 |
+
self.transforms = Compose([])
|
| 38 |
+
else:
|
| 39 |
+
self.transforms = transforms
|
| 40 |
+
|
| 41 |
+
if self.stage not in ["val"]:
|
| 42 |
+
raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
|
| 43 |
+
|
| 44 |
+
json_path = self.root / "imw2020-val" / "dataset.json"
|
| 45 |
+
with open(json_path) as json_file:
|
| 46 |
+
json_data = json.load(json_file)
|
| 47 |
+
|
| 48 |
+
self.scenes = []
|
| 49 |
+
|
| 50 |
+
for scene in json_data:
|
| 51 |
+
self.scenes.append(Scene(self.root / "imw2020-val", json_data[scene]))
|
| 52 |
+
|
| 53 |
+
self.tuples_per_scene = [len(scene) for scene in self.scenes]
|
| 54 |
+
|
| 55 |
+
def __len__(self) -> int:
|
| 56 |
+
return sum(self.tuples_per_scene)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 59 |
+
sample: Any = {}
|
| 60 |
+
|
| 61 |
+
i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
|
| 62 |
+
|
| 63 |
+
sample["src_path"], sample["trg_path"], path_calib_src, path_calib_trg = self.scenes[i_scene][i_image]
|
| 64 |
+
|
| 65 |
+
cam_src = Camera.from_calibration_file(path_calib_src)
|
| 66 |
+
cam_trg = Camera.from_calibration_file(path_calib_trg)
|
| 67 |
+
|
| 68 |
+
F = self.get_F(cam_src, cam_trg)
|
| 69 |
+
s2t_R, s2t_T = self.get_relative_pose(cam_src, cam_trg)
|
| 70 |
+
|
| 71 |
+
src_img = read_image(sample["src_path"]) / 255.0
|
| 72 |
+
trg_img = read_image(sample["trg_path"]) / 255.0
|
| 73 |
+
|
| 74 |
+
_, H_src, W_src = src_img.shape
|
| 75 |
+
_, H_trg, W_trg = trg_img.shape
|
| 76 |
+
|
| 77 |
+
src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
|
| 78 |
+
trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
|
| 79 |
+
|
| 80 |
+
H = torch.eye(3)
|
| 81 |
+
if self.transforms:
|
| 82 |
+
src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, H)
|
| 83 |
+
|
| 84 |
+
# check if transformations in self.transforms. Only Normalize is allowed
|
| 85 |
+
for t in self.transforms.transforms:
|
| 86 |
+
if t.__class__.__name__ not in ["Normalize", "Resize"]:
|
| 87 |
+
raise ValueError(f"Transform {t.__class__.__name__} not allowed in DISK_IMW dataset")
|
| 88 |
+
|
| 89 |
+
sample["src_image"] = src_img
|
| 90 |
+
sample["trg_image"] = trg_img
|
| 91 |
+
sample["orig_size_src"] = (H_src, W_src)
|
| 92 |
+
sample["orig_size_trg"] = (H_trg, W_trg)
|
| 93 |
+
sample["src_mask"] = src_mask.to(torch.bool)
|
| 94 |
+
sample["trg_mask"] = trg_mask.to(torch.bool)
|
| 95 |
+
sample["F"] = F
|
| 96 |
+
sample["s2t_R"] = s2t_R
|
| 97 |
+
sample["s2t_T"] = s2t_T
|
| 98 |
+
sample["src_camera"] = cam_src
|
| 99 |
+
sample["trg_camera"] = cam_trg
|
| 100 |
+
|
| 101 |
+
return sample
|
| 102 |
+
|
| 103 |
+
def get_relative_pose(self, cam_src: Camera, cam_trg: Camera) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 104 |
+
R = cam_trg.R @ cam_src.R.T
|
| 105 |
+
T = cam_trg.t - R @ cam_src.t
|
| 106 |
+
|
| 107 |
+
return R, T
|
| 108 |
+
|
| 109 |
+
def get_F(self, cam_src: Camera, cam_trg: Camera) -> torch.Tensor:
|
| 110 |
+
F = cameras2F(cam_src, cam_trg)
|
| 111 |
+
|
| 112 |
+
return F
|
| 113 |
+
|
| 114 |
+
def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
|
| 115 |
+
accumulated_tuples = accumulate(self.tuples_per_scene)
|
| 116 |
+
|
| 117 |
+
if idx >= sum(self.tuples_per_scene):
|
| 118 |
+
raise IndexError(f"Index {idx} out of bounds")
|
| 119 |
+
|
| 120 |
+
idx_scene = None
|
| 121 |
+
for i, accumulated_tuple in enumerate(accumulated_tuples):
|
| 122 |
+
idx_scene = i
|
| 123 |
+
if idx < accumulated_tuple:
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
|
| 127 |
+
|
| 128 |
+
return idx_scene, idx_image
|
| 129 |
+
|
| 130 |
+
def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
|
| 131 |
+
possible_scene_ids = list(range(len(self.scenes)))
|
| 132 |
+
possible_scene_ids.remove(scene_id_to_exclude)
|
| 133 |
+
|
| 134 |
+
idx_scene = random.choice(possible_scene_ids)
|
| 135 |
+
idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
|
| 136 |
+
|
| 137 |
+
return idx_scene, idx_image
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Scene:
|
| 141 |
+
def __init__(self, root_path, scene_data: Dict[str, Any]) -> None:
|
| 142 |
+
self.root_path = root_path
|
| 143 |
+
self.image_path = Path(scene_data["image_path"])
|
| 144 |
+
self.calib_path = Path(scene_data["calib_path"])
|
| 145 |
+
self.image_names = scene_data["images"]
|
| 146 |
+
self.tuples = scene_data["tuples"]
|
| 147 |
+
|
| 148 |
+
def __len__(self) -> int:
|
| 149 |
+
return len(self.tuples)
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 152 |
+
idx_1 = self.tuples[idx][0]
|
| 153 |
+
idx_2 = self.tuples[idx][1]
|
| 154 |
+
|
| 155 |
+
path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1]) + ".jpg"
|
| 156 |
+
path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2]) + ".jpg"
|
| 157 |
+
path_calib_1 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_1])) + ".h5"
|
| 158 |
+
path_calib_2 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_2])) + ".h5"
|
| 159 |
+
|
| 160 |
+
return path_image_1, path_image_2, path_calib_1, path_calib_2
|
ripe/data/datasets/disk_megadepth.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from itertools import accumulate
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torchvision.io import read_image
|
| 10 |
+
|
| 11 |
+
from ripe import utils
|
| 12 |
+
from ripe.data.data_transforms import Compose
|
| 13 |
+
|
| 14 |
+
log = utils.get_pylogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DISK_Megadepth(Dataset):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
root: str,
|
| 21 |
+
max_scene_size: int,
|
| 22 |
+
stage: str = "train",
|
| 23 |
+
# condition: str = "rain",
|
| 24 |
+
transforms: Optional[Callable] = None,
|
| 25 |
+
positive_only: bool = False,
|
| 26 |
+
) -> None:
|
| 27 |
+
self.root = root
|
| 28 |
+
self.stage = stage
|
| 29 |
+
self.transforms = transforms
|
| 30 |
+
self.positive_only = positive_only
|
| 31 |
+
|
| 32 |
+
if isinstance(self.root, str):
|
| 33 |
+
self.root = Path(self.root)
|
| 34 |
+
|
| 35 |
+
if not self.root.exists():
|
| 36 |
+
raise FileNotFoundError(f"Dataset not found at {self.root}")
|
| 37 |
+
|
| 38 |
+
if transforms is None:
|
| 39 |
+
self.transforms = Compose([])
|
| 40 |
+
else:
|
| 41 |
+
self.transforms = transforms
|
| 42 |
+
|
| 43 |
+
if self.stage not in ["train"]:
|
| 44 |
+
raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
|
| 45 |
+
|
| 46 |
+
json_path = self.root / "megadepth" / "dataset.json"
|
| 47 |
+
with open(json_path) as json_file:
|
| 48 |
+
json_data = json.load(json_file)
|
| 49 |
+
|
| 50 |
+
self.scenes = []
|
| 51 |
+
|
| 52 |
+
for scene in json_data:
|
| 53 |
+
self.scenes.append(Scene(self.root / "megadepth", json_data[scene], max_scene_size))
|
| 54 |
+
|
| 55 |
+
self.tuples_per_scene = [len(scene) for scene in self.scenes]
|
| 56 |
+
|
| 57 |
+
if positive_only:
|
| 58 |
+
log.warning("Using only positive pairs!")
|
| 59 |
+
|
| 60 |
+
def __len__(self) -> int:
|
| 61 |
+
if self.positive_only:
|
| 62 |
+
return sum(self.tuples_per_scene)
|
| 63 |
+
return 2 * sum(self.tuples_per_scene)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 66 |
+
sample: Any = {}
|
| 67 |
+
|
| 68 |
+
positive_sample = idx % 2 == 0 or self.positive_only
|
| 69 |
+
if not self.positive_only:
|
| 70 |
+
idx = idx // 2
|
| 71 |
+
|
| 72 |
+
sample["label"] = positive_sample
|
| 73 |
+
|
| 74 |
+
i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
|
| 75 |
+
|
| 76 |
+
if positive_sample:
|
| 77 |
+
sample["src_path"], sample["trg_path"] = self.scenes[i_scene][i_image]
|
| 78 |
+
|
| 79 |
+
homography = torch.eye(3, dtype=torch.float32)
|
| 80 |
+
else:
|
| 81 |
+
sample["src_path"], _ = self.scenes[i_scene][i_image]
|
| 82 |
+
|
| 83 |
+
i_scene_other, i_image_other = self._get_other_random_scene_and_image_id(i_scene)
|
| 84 |
+
|
| 85 |
+
sample["trg_path"], _ = self.scenes[i_scene_other][i_image_other]
|
| 86 |
+
|
| 87 |
+
homography = torch.zeros((3, 3), dtype=torch.float32)
|
| 88 |
+
|
| 89 |
+
src_img = read_image(sample["src_path"]) / 255.0
|
| 90 |
+
trg_img = read_image(sample["trg_path"]) / 255.0
|
| 91 |
+
|
| 92 |
+
_, H_src, W_src = src_img.shape
|
| 93 |
+
_, H_trg, W_trg = trg_img.shape
|
| 94 |
+
|
| 95 |
+
src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
|
| 96 |
+
trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
|
| 97 |
+
|
| 98 |
+
if self.transforms:
|
| 99 |
+
src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
|
| 100 |
+
|
| 101 |
+
sample["src_image"] = src_img
|
| 102 |
+
sample["trg_image"] = trg_img
|
| 103 |
+
sample["src_mask"] = src_mask.to(torch.bool)
|
| 104 |
+
sample["trg_mask"] = trg_mask.to(torch.bool)
|
| 105 |
+
sample["homography"] = homography
|
| 106 |
+
|
| 107 |
+
return sample
|
| 108 |
+
|
| 109 |
+
def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
|
| 110 |
+
accumulated_tuples = accumulate(self.tuples_per_scene)
|
| 111 |
+
|
| 112 |
+
if idx >= sum(self.tuples_per_scene):
|
| 113 |
+
raise IndexError(f"Index {idx} out of bounds")
|
| 114 |
+
|
| 115 |
+
idx_scene = None
|
| 116 |
+
for i, accumulated_tuple in enumerate(accumulated_tuples):
|
| 117 |
+
idx_scene = i
|
| 118 |
+
if idx < accumulated_tuple:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
|
| 122 |
+
|
| 123 |
+
return idx_scene, idx_image
|
| 124 |
+
|
| 125 |
+
def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
|
| 126 |
+
possible_scene_ids = list(range(len(self.scenes)))
|
| 127 |
+
possible_scene_ids.remove(scene_id_to_exclude)
|
| 128 |
+
|
| 129 |
+
idx_scene = random.choice(possible_scene_ids)
|
| 130 |
+
idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
|
| 131 |
+
|
| 132 |
+
return idx_scene, idx_image
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Scene:
|
| 136 |
+
def __init__(self, root_path, scene_data: Dict[str, Any], max_size_scene) -> None:
|
| 137 |
+
self.root_path = root_path
|
| 138 |
+
self.image_path = Path(scene_data["image_path"])
|
| 139 |
+
self.image_names = scene_data["images"]
|
| 140 |
+
|
| 141 |
+
# randomly sample tuples
|
| 142 |
+
if max_size_scene > 0:
|
| 143 |
+
self.tuples = random.sample(scene_data["tuples"], min(max_size_scene, len(scene_data["tuples"])))
|
| 144 |
+
|
| 145 |
+
def __len__(self) -> int:
|
| 146 |
+
return len(self.tuples)
|
| 147 |
+
|
| 148 |
+
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
| 149 |
+
idx_1, idx_2 = random.sample([0, 1, 2], 2)
|
| 150 |
+
|
| 151 |
+
idx_1 = self.tuples[idx][idx_1]
|
| 152 |
+
idx_2 = self.tuples[idx][idx_2]
|
| 153 |
+
|
| 154 |
+
path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1])
|
| 155 |
+
path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2])
|
| 156 |
+
|
| 157 |
+
return path_image_1, path_image_2
|
ripe/data/datasets/tokyo247.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from glob import glob
|
| 4 |
+
from typing import Any, Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from torchvision.io import read_image
|
| 9 |
+
|
| 10 |
+
from ripe import utils
|
| 11 |
+
from ripe.data.data_transforms import Compose
|
| 12 |
+
|
| 13 |
+
log = utils.get_pylogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Tokyo247(Dataset):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
root: str,
|
| 20 |
+
stage: str = "train",
|
| 21 |
+
transforms: Optional[Callable] = None,
|
| 22 |
+
positive_only: bool = False,
|
| 23 |
+
):
|
| 24 |
+
if stage != "train":
|
| 25 |
+
raise ValueError("Tokyo247Dataset only supports the 'train' stage.")
|
| 26 |
+
|
| 27 |
+
# check if the root directory exists
|
| 28 |
+
if not os.path.isdir(root):
|
| 29 |
+
raise FileNotFoundError(f"Directory {root} does not exist.")
|
| 30 |
+
|
| 31 |
+
self.root_dir = root
|
| 32 |
+
self.transforms = transforms if transforms is not None else Compose([])
|
| 33 |
+
self.positive_only = positive_only
|
| 34 |
+
|
| 35 |
+
self.image_paths = []
|
| 36 |
+
self.positive_pairs = []
|
| 37 |
+
|
| 38 |
+
# Collect images grouped by location folder
|
| 39 |
+
self.locations = {}
|
| 40 |
+
for location_rough in sorted(os.listdir(self.root_dir)):
|
| 41 |
+
location_rough_path = os.path.join(self.root_dir, location_rough)
|
| 42 |
+
|
| 43 |
+
# check if the location_rough_path is a directory
|
| 44 |
+
if not os.path.isdir(location_rough_path):
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
for location_fine in sorted(os.listdir(location_rough_path)):
|
| 48 |
+
location_fine_path = os.path.join(self.root_dir, location_rough, location_fine)
|
| 49 |
+
|
| 50 |
+
if os.path.isdir(location_fine_path):
|
| 51 |
+
images = sorted(
|
| 52 |
+
glob(os.path.join(location_fine_path, "*.png")),
|
| 53 |
+
key=lambda i: int(i[-7:-4]),
|
| 54 |
+
)
|
| 55 |
+
if len(images) >= 12:
|
| 56 |
+
self.locations[location_fine] = images
|
| 57 |
+
self.image_paths.extend(images)
|
| 58 |
+
|
| 59 |
+
# Generate positive pairs
|
| 60 |
+
for _, images in self.locations.items():
|
| 61 |
+
for i in range(len(images) - 1):
|
| 62 |
+
self.positive_pairs.append((images[i], images[i + 1]))
|
| 63 |
+
self.positive_pairs.append((images[-1], images[0]))
|
| 64 |
+
|
| 65 |
+
if positive_only:
|
| 66 |
+
log.warning("Using only positive pairs!")
|
| 67 |
+
|
| 68 |
+
log.info(f"Found {len(self.positive_pairs)} image pairs.")
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
if self.positive_only:
|
| 72 |
+
return len(self.positive_pairs)
|
| 73 |
+
return 2 * len(self.positive_pairs)
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, idx):
|
| 76 |
+
sample: Any = {}
|
| 77 |
+
|
| 78 |
+
positive_sample = (idx % 2 == 0) or (self.positive_only)
|
| 79 |
+
if not self.positive_only:
|
| 80 |
+
idx = idx // 2
|
| 81 |
+
|
| 82 |
+
sample["label"] = positive_sample
|
| 83 |
+
|
| 84 |
+
if positive_sample: # Positive pair
|
| 85 |
+
img1_path, img2_path = self.positive_pairs[idx]
|
| 86 |
+
|
| 87 |
+
assert os.path.dirname(img1_path) == os.path.dirname(img2_path), (
|
| 88 |
+
f"Source and target image mismatch: {img1_path} vs {img2_path}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
homography = torch.eye(3, dtype=torch.float32)
|
| 92 |
+
else: # Negative pair
|
| 93 |
+
img1_path = random.choice(self.image_paths)
|
| 94 |
+
img2_path = random.choice(self.image_paths)
|
| 95 |
+
|
| 96 |
+
# Ensure images are from different folders
|
| 97 |
+
esc = 0
|
| 98 |
+
while os.path.dirname(img1_path) == os.path.dirname(img2_path):
|
| 99 |
+
img2_path = random.choice(self.image_paths)
|
| 100 |
+
|
| 101 |
+
esc += 1
|
| 102 |
+
if esc > 100:
|
| 103 |
+
raise RuntimeError("Could not find a negative pair.")
|
| 104 |
+
|
| 105 |
+
assert os.path.dirname(img1_path) != os.path.dirname(img2_path), (
|
| 106 |
+
f"Source and target image match for negative pair: {img1_path} vs {img2_path}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
homography = torch.zeros((3, 3), dtype=torch.float32)
|
| 110 |
+
|
| 111 |
+
sample["src_path"] = img1_path
|
| 112 |
+
sample["trg_path"] = img2_path
|
| 113 |
+
|
| 114 |
+
# Load images
|
| 115 |
+
src_img = read_image(sample["src_path"]) / 255.0
|
| 116 |
+
trg_img = read_image(sample["trg_path"]) / 255.0
|
| 117 |
+
|
| 118 |
+
_, H_src, W_src = src_img.shape
|
| 119 |
+
_, H_trg, W_trg = src_img.shape
|
| 120 |
+
|
| 121 |
+
src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
|
| 122 |
+
trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
|
| 123 |
+
|
| 124 |
+
# Apply transformations
|
| 125 |
+
if self.transforms:
|
| 126 |
+
src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
|
| 127 |
+
|
| 128 |
+
sample["src_image"] = src_img
|
| 129 |
+
sample["trg_image"] = trg_img
|
| 130 |
+
sample["src_mask"] = src_mask.to(torch.bool)
|
| 131 |
+
sample["trg_mask"] = trg_mask.to(torch.bool)
|
| 132 |
+
sample["homography"] = homography
|
| 133 |
+
|
| 134 |
+
return sample
|
ripe/losses/__init__.py
ADDED
|
File without changes
|
ripe/losses/contrastive_loss.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def second_nearest_neighbor(desc1, desc2):
|
| 7 |
+
if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
|
| 8 |
+
raise ValueError("desc2 should have at least 2 descriptors")
|
| 9 |
+
|
| 10 |
+
dist = torch.cdist(desc1, desc2, p=2)
|
| 11 |
+
|
| 12 |
+
vals, idxs = torch.topk(dist, 2, dim=1, largest=False)
|
| 13 |
+
idxs_in_2 = idxs[:, 1]
|
| 14 |
+
idxs_in_1 = torch.arange(0, idxs_in_2.size(0), device=dist.device)
|
| 15 |
+
|
| 16 |
+
matches_idxs = torch.cat([idxs_in_1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
|
| 17 |
+
|
| 18 |
+
return vals[:, 1].view(-1, 1), matches_idxs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def contrastive_loss(
|
| 22 |
+
desc1,
|
| 23 |
+
desc2,
|
| 24 |
+
matches,
|
| 25 |
+
inliers,
|
| 26 |
+
label,
|
| 27 |
+
logits_1,
|
| 28 |
+
logits_2,
|
| 29 |
+
pos_margin=1.0,
|
| 30 |
+
neg_margin=1.0,
|
| 31 |
+
):
|
| 32 |
+
if inliers.sum() < 8: # if there are too few inliers, calculate loss on all matches
|
| 33 |
+
inliers = torch.ones_like(inliers)
|
| 34 |
+
|
| 35 |
+
matched_inliers_descs1 = desc1[matches[:, 0][inliers]]
|
| 36 |
+
matched_inliers_descs2 = desc2[matches[:, 1][inliers]]
|
| 37 |
+
|
| 38 |
+
if logits_1 is not None and logits_2 is not None:
|
| 39 |
+
matched_inliers_logits1 = logits_1[matches[:, 0][inliers]]
|
| 40 |
+
matched_inliers_logits2 = logits_2[matches[:, 1][inliers]]
|
| 41 |
+
logits = torch.minimum(matched_inliers_logits1, matched_inliers_logits2)
|
| 42 |
+
else:
|
| 43 |
+
logits = torch.ones_like(matches[:, 0][inliers])
|
| 44 |
+
|
| 45 |
+
if label:
|
| 46 |
+
snn_match_dists_1, idx1 = second_nearest_neighbor(matched_inliers_descs1, desc2)
|
| 47 |
+
snn_match_dists_2, idx2 = second_nearest_neighbor(matched_inliers_descs2, desc1)
|
| 48 |
+
|
| 49 |
+
dists = torch.hstack((snn_match_dists_1, snn_match_dists_2))
|
| 50 |
+
min_dists_idx = torch.min(dists, dim=1).indices.unsqueeze(1)
|
| 51 |
+
|
| 52 |
+
dists_hard = torch.gather(dists, 1, min_dists_idx).squeeze(-1)
|
| 53 |
+
dists_pos = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
|
| 54 |
+
|
| 55 |
+
contrastive_loss = torch.clamp(pos_margin + dists_pos - dists_hard, min=0.0)
|
| 56 |
+
|
| 57 |
+
contrastive_loss = contrastive_loss * logits
|
| 58 |
+
|
| 59 |
+
contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
|
| 60 |
+
else:
|
| 61 |
+
dists = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
|
| 62 |
+
contrastive_loss = torch.clamp(neg_margin - dists, min=0.0)
|
| 63 |
+
|
| 64 |
+
contrastive_loss = contrastive_loss * logits
|
| 65 |
+
|
| 66 |
+
contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
|
| 67 |
+
|
| 68 |
+
return contrastive_loss
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ContrastiveLoss(nn.Module):
|
| 72 |
+
def __init__(self, pos_margin=1.0, neg_margin=1.0):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.pos_margin = pos_margin
|
| 75 |
+
self.neg_margin = neg_margin
|
| 76 |
+
|
| 77 |
+
def forward(self, desc1, desc2, matches, inliers, label, logits_1=None, logits_2=None):
|
| 78 |
+
return contrastive_loss(
|
| 79 |
+
desc1,
|
| 80 |
+
desc2,
|
| 81 |
+
matches,
|
| 82 |
+
inliers,
|
| 83 |
+
label,
|
| 84 |
+
logits_1,
|
| 85 |
+
logits_2,
|
| 86 |
+
self.pos_margin,
|
| 87 |
+
self.neg_margin,
|
| 88 |
+
)
|
ripe/matcher/__init__.py
ADDED
|
File without changes
|
ripe/matcher/concurrent_matcher.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConcurrentMatcher:
|
| 7 |
+
"""A class that performs matching and geometric filtering in parallel using a thread pool executor.
|
| 8 |
+
It matches keypoints from two sets of descriptors and applies a robust estimator to filter the matches based on geometric constraints.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
matcher (callable): A callable that takes two sets of descriptors and returns distances and indices of matches.
|
| 12 |
+
robust_estimator (callable): A callable that estimates a geometric transformation and returns inliers.
|
| 13 |
+
min_num_matches (int, optional): Minimum number of matches required to perform geometric filtering. Defaults to 8.
|
| 14 |
+
max_workers (int, optional): Maximum number of threads in the thread pool executor. Defaults to 12.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, matcher, robust_estimator, min_num_matches=8, max_workers=12):
|
| 18 |
+
self.matcher = matcher
|
| 19 |
+
self.robust_estimator = robust_estimator
|
| 20 |
+
self.min_num_matches = min_num_matches
|
| 21 |
+
|
| 22 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def __call__(
|
| 26 |
+
self,
|
| 27 |
+
kpts1,
|
| 28 |
+
kpts2,
|
| 29 |
+
pdesc1,
|
| 30 |
+
pdesc2,
|
| 31 |
+
selected_mask1,
|
| 32 |
+
selected_mask2,
|
| 33 |
+
inl_th,
|
| 34 |
+
label=None,
|
| 35 |
+
):
|
| 36 |
+
dev = pdesc1.device
|
| 37 |
+
B = pdesc1.shape[0]
|
| 38 |
+
|
| 39 |
+
batch_rel_idx_matches = [None] * B
|
| 40 |
+
batch_idx_matches = [None] * B
|
| 41 |
+
future_results = [None] * B
|
| 42 |
+
|
| 43 |
+
for b in range(B):
|
| 44 |
+
if selected_mask1[b].sum() < 16 or selected_mask2[b].sum() < 16:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
dists, idx_matches = self.matcher(pdesc1[b][selected_mask1[b]], pdesc2[b][selected_mask2[b]])
|
| 48 |
+
|
| 49 |
+
batch_rel_idx_matches[b] = idx_matches.clone()
|
| 50 |
+
|
| 51 |
+
# calculate ABSOLUTE indexes
|
| 52 |
+
idx_matches[:, 0] = torch.nonzero(selected_mask1[b], as_tuple=False)[idx_matches[:, 0]].squeeze()
|
| 53 |
+
idx_matches[:, 1] = torch.nonzero(selected_mask2[b], as_tuple=False)[idx_matches[:, 1]].squeeze()
|
| 54 |
+
|
| 55 |
+
batch_idx_matches[b] = idx_matches
|
| 56 |
+
|
| 57 |
+
# if not enough matches
|
| 58 |
+
if idx_matches.shape[0] < self.min_num_matches:
|
| 59 |
+
ransac_inliers = torch.zeros((idx_matches.shape[0]), device=dev).bool()
|
| 60 |
+
future_results[b] = (None, ransac_inliers)
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# use label information to exclude negative pairs from geometric filtering process -> enforces more descriminative descriptors
|
| 64 |
+
if label is not None and label[b] == 0:
|
| 65 |
+
ransac_inliers = torch.ones((idx_matches.shape[0]), device=dev).bool()
|
| 66 |
+
future_results[b] = (None, ransac_inliers)
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
mkpts1 = kpts1[b][idx_matches[:, 0]]
|
| 70 |
+
mkpts2 = kpts2[b][idx_matches[:, 1]]
|
| 71 |
+
|
| 72 |
+
future_results[b] = self.executor.submit(self.robust_estimator, mkpts1, mkpts2, inl_th)
|
| 73 |
+
|
| 74 |
+
batch_ransac_inliers = [None] * B
|
| 75 |
+
batch_Fm = [None] * B
|
| 76 |
+
|
| 77 |
+
for b in range(B):
|
| 78 |
+
future_result = future_results[b]
|
| 79 |
+
if future_result is None:
|
| 80 |
+
ransac_inliers = None
|
| 81 |
+
Fm = None
|
| 82 |
+
elif isinstance(future_result, tuple):
|
| 83 |
+
Fm, ransac_inliers = future_result
|
| 84 |
+
else:
|
| 85 |
+
Fm, ransac_inliers = future_result.result()
|
| 86 |
+
|
| 87 |
+
# if no inliers
|
| 88 |
+
if ransac_inliers.sum() == 0:
|
| 89 |
+
ransac_inliers = ransac_inliers.squeeze(
|
| 90 |
+
-1
|
| 91 |
+
) # kornia.geometry.ransac.RANSAC returns (N, 1) tensor if no inliers and (N,) tensor if inliers
|
| 92 |
+
Fm = None
|
| 93 |
+
|
| 94 |
+
batch_ransac_inliers[b] = ransac_inliers
|
| 95 |
+
batch_Fm[b] = Fm
|
| 96 |
+
|
| 97 |
+
return batch_rel_idx_matches, batch_idx_matches, batch_ransac_inliers, batch_Fm
|
ripe/matcher/pose_estimator_poselib.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import poselib
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PoseLibRelativePoseEstimator:
|
| 6 |
+
"""PoseLibRelativePoseEstimator estimates the fundamental matrix using poselib library.
|
| 7 |
+
It uses the poselib's estimate_fundamental function to compute the fundamental matrix and inliers based on the provided points.
|
| 8 |
+
Args:
|
| 9 |
+
None
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def __call__(self, pts0, pts1, inl_th):
|
| 16 |
+
F, info = poselib.estimate_fundamental(
|
| 17 |
+
pts0.cpu().numpy(),
|
| 18 |
+
pts1.cpu().numpy(),
|
| 19 |
+
{
|
| 20 |
+
"max_epipolar_error": inl_th,
|
| 21 |
+
},
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
success = F is not None
|
| 25 |
+
if success:
|
| 26 |
+
inliers = info.pop("inliers")
|
| 27 |
+
inliers = torch.tensor(inliers, dtype=torch.bool, device=pts0.device)
|
| 28 |
+
else:
|
| 29 |
+
inliers = torch.zeros(pts0.shape[0], dtype=torch.bool, device=pts0.device)
|
| 30 |
+
|
| 31 |
+
return F, inliers
|
ripe/model_zoo/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .vgg_hyper import vgg_hyper # noqa: F401
|
ripe/model_zoo/vgg_hyper.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ripe.models.backbones.vgg import VGG
|
| 6 |
+
from ripe.models.ripe import RIPE
|
| 7 |
+
from ripe.models.upsampler.hypercolumn_features import HyperColumnFeatures
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def vgg_hyper(model_path: Path = None, desc_shares=None):
|
| 11 |
+
if model_path is None:
|
| 12 |
+
# check if the weights file exists in the current directory
|
| 13 |
+
model_path = Path("/tmp/ripe_weights.pth")
|
| 14 |
+
|
| 15 |
+
if model_path.exists():
|
| 16 |
+
print(f"Using existing weights from {model_path}")
|
| 17 |
+
else:
|
| 18 |
+
print("Weights file not found. Downloading ...")
|
| 19 |
+
torch.hub.download_url_to_file(
|
| 20 |
+
"https://cvg.hhi.fraunhofer.de/RIPE/ripe_weights.pth",
|
| 21 |
+
"/tmp/ripe_weights.pth",
|
| 22 |
+
)
|
| 23 |
+
else:
|
| 24 |
+
if not model_path.exists():
|
| 25 |
+
print(f"Error: {model_path} does not exist.")
|
| 26 |
+
raise FileNotFoundError(f"Error: {model_path} does not exist.")
|
| 27 |
+
|
| 28 |
+
backbone = VGG(pretrained=False)
|
| 29 |
+
upsampler = HyperColumnFeatures()
|
| 30 |
+
|
| 31 |
+
extractor = RIPE(
|
| 32 |
+
net=backbone,
|
| 33 |
+
upsampler=upsampler,
|
| 34 |
+
desc_shares=desc_shares,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
extractor.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| 38 |
+
|
| 39 |
+
return extractor
|
ripe/models/__init__.py
ADDED
|
File without changes
|
ripe/models/backbones/__init__.py
ADDED
|
File without changes
|
ripe/models/backbones/backbone_base.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BackboneBase(nn.Module):
|
| 6 |
+
"""Base class for backbone networks. Provides a standard interface for preprocessing inputs and
|
| 7 |
+
defining encoder dimensions.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
nchannels (int): Number of input channels.
|
| 11 |
+
use_instance_norm (bool): Whether to apply instance normalization.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, nchannels=3, use_instance_norm=False):
|
| 15 |
+
super().__init__()
|
| 16 |
+
assert nchannels > 0, "Number of channels must be positive."
|
| 17 |
+
self.nchannels = nchannels
|
| 18 |
+
self.use_instance_norm = use_instance_norm
|
| 19 |
+
self.norm = nn.InstanceNorm2d(nchannels) if use_instance_norm else None
|
| 20 |
+
|
| 21 |
+
def get_dim_layers_encoder(self):
|
| 22 |
+
"""Get dimensions of encoder layers."""
|
| 23 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
| 24 |
+
|
| 25 |
+
def _forward(self, x):
|
| 26 |
+
"""Define the forward pass for the backbone."""
|
| 27 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor, preprocess=True):
|
| 30 |
+
"""Forward pass with optional preprocessing.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
x (Tensor): Input tensor.
|
| 34 |
+
preprocess (bool): Whether to apply channel reduction.
|
| 35 |
+
"""
|
| 36 |
+
if preprocess:
|
| 37 |
+
if x.dim() != 4:
|
| 38 |
+
if x.dim() == 2 and x.shape[0] > 3 and x.shape[1] > 3:
|
| 39 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
| 40 |
+
elif x.dim() == 3:
|
| 41 |
+
x = x.unsqueeze(0)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unexpected input shape: {x.shape}")
|
| 44 |
+
|
| 45 |
+
if self.nchannels == 1 and x.shape[1] != 1:
|
| 46 |
+
if len(x.shape) == 4: # Assumes (batch, channel, height, width)
|
| 47 |
+
x = torch.mean(x, axis=1, keepdim=True)
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Unexpected input shape: {x.shape}")
|
| 50 |
+
|
| 51 |
+
#
|
| 52 |
+
if self.nchannels == 3 and x.shape[1] == 1:
|
| 53 |
+
if len(x.shape) == 4:
|
| 54 |
+
x = x.repeat(1, 3, 1, 1)
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unexpected input shape: {x.shape}")
|
| 57 |
+
|
| 58 |
+
if self.use_instance_norm:
|
| 59 |
+
x = self.norm(x)
|
| 60 |
+
|
| 61 |
+
return self._forward(x)
|
ripe/models/backbones/vgg.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .backbone_base import BackboneBase
|
| 7 |
+
from .vgg_utils import VGG19, ConvRefiner, Decoder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VGG(BackboneBase):
|
| 11 |
+
def __init__(self, nchannels=3, pretrained=True, use_instance_norm=True, mode="dect"):
|
| 12 |
+
super().__init__(nchannels=nchannels, use_instance_norm=use_instance_norm)
|
| 13 |
+
|
| 14 |
+
self.nchannels = nchannels
|
| 15 |
+
self.mode = mode
|
| 16 |
+
|
| 17 |
+
if self.mode not in ["dect", "desc", "dect+desc"]:
|
| 18 |
+
raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
|
| 19 |
+
|
| 20 |
+
NUM_OUTPUT_CHANNELS, hidden_blocks = self._get_mode_params(mode)
|
| 21 |
+
conv_refiner = self._create_conv_refiner(NUM_OUTPUT_CHANNELS, hidden_blocks)
|
| 22 |
+
|
| 23 |
+
self.encoder = VGG19(pretrained=pretrained, num_input_channels=nchannels)
|
| 24 |
+
self.decoder = Decoder(conv_refiner, num_prototypes=NUM_OUTPUT_CHANNELS)
|
| 25 |
+
|
| 26 |
+
def _get_mode_params(self, mode):
|
| 27 |
+
"""Get the number of output channels and the number of hidden blocks for the ConvRefiner.
|
| 28 |
+
|
| 29 |
+
Depending on the mode, the ConvRefiner will have a different number of output channels.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if mode == "dect":
|
| 33 |
+
return 1, 8
|
| 34 |
+
elif mode == "desc":
|
| 35 |
+
return 256, 5
|
| 36 |
+
elif mode == "dect+desc":
|
| 37 |
+
return 256 + 1, 8
|
| 38 |
+
|
| 39 |
+
def _create_conv_refiner(self, num_output_channels, hidden_blocks):
|
| 40 |
+
return nn.ModuleDict(
|
| 41 |
+
{
|
| 42 |
+
"8": ConvRefiner(
|
| 43 |
+
512,
|
| 44 |
+
512,
|
| 45 |
+
256 + num_output_channels,
|
| 46 |
+
hidden_blocks=hidden_blocks,
|
| 47 |
+
residual=True,
|
| 48 |
+
),
|
| 49 |
+
"4": ConvRefiner(
|
| 50 |
+
256 + 256,
|
| 51 |
+
256,
|
| 52 |
+
128 + num_output_channels,
|
| 53 |
+
hidden_blocks=hidden_blocks,
|
| 54 |
+
residual=True,
|
| 55 |
+
),
|
| 56 |
+
"2": ConvRefiner(
|
| 57 |
+
128 + 128,
|
| 58 |
+
128,
|
| 59 |
+
64 + num_output_channels,
|
| 60 |
+
hidden_blocks=hidden_blocks,
|
| 61 |
+
residual=True,
|
| 62 |
+
),
|
| 63 |
+
"1": ConvRefiner(
|
| 64 |
+
64 + 64,
|
| 65 |
+
64,
|
| 66 |
+
1 + num_output_channels,
|
| 67 |
+
hidden_blocks=hidden_blocks,
|
| 68 |
+
residual=True,
|
| 69 |
+
),
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def get_dim_layers_encoder(self):
|
| 74 |
+
return self.encoder.get_dim_layers()
|
| 75 |
+
|
| 76 |
+
def _forward(self, x):
|
| 77 |
+
features, sizes = self.encoder(x)
|
| 78 |
+
output = 0
|
| 79 |
+
context = None
|
| 80 |
+
scales = self.decoder.scales
|
| 81 |
+
for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
|
| 82 |
+
delta_descriptor, context = self.decoder(feature_map, scale=scale, context=context)
|
| 83 |
+
output = output + delta_descriptor
|
| 84 |
+
if idx < len(scales) - 1:
|
| 85 |
+
size = sizes[-(idx + 2)]
|
| 86 |
+
output = F.interpolate(output, size=size, mode="bilinear", align_corners=False)
|
| 87 |
+
context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)
|
| 88 |
+
|
| 89 |
+
if self.mode == "dect":
|
| 90 |
+
return {"heatmap": output, "coarse_descs": features}
|
| 91 |
+
elif self.mode == "desc":
|
| 92 |
+
return {"fine_descs": output, "coarse_descs": features}
|
| 93 |
+
elif self.mode == "dect+desc":
|
| 94 |
+
logits = output[:, :1].contiguous()
|
| 95 |
+
descs = output[:, 1:].contiguous()
|
| 96 |
+
|
| 97 |
+
return {"heatmap": logits, "fine_descs": descs, "coarse_descs": features}
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
|
ripe/models/backbones/vgg_utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.models as tvm
|
| 6 |
+
|
| 7 |
+
from ripe import utils
|
| 8 |
+
|
| 9 |
+
log = utils.get_pylogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Decoder(nn.Module):
|
| 13 |
+
def __init__(self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs) -> None:
|
| 14 |
+
super().__init__(*args, **kwargs)
|
| 15 |
+
self.layers = layers
|
| 16 |
+
self.scales = self.layers.keys()
|
| 17 |
+
self.super_resolution = super_resolution
|
| 18 |
+
self.num_prototypes = num_prototypes
|
| 19 |
+
|
| 20 |
+
def forward(self, features, context=None, scale=None):
|
| 21 |
+
if context is not None:
|
| 22 |
+
features = torch.cat((features, context), dim=1)
|
| 23 |
+
stuff = self.layers[scale](features)
|
| 24 |
+
logits, context = (
|
| 25 |
+
stuff[:, : self.num_prototypes],
|
| 26 |
+
stuff[:, self.num_prototypes :],
|
| 27 |
+
)
|
| 28 |
+
return logits, context
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConvRefiner(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
in_dim=6,
|
| 35 |
+
hidden_dim=16,
|
| 36 |
+
out_dim=2,
|
| 37 |
+
dw=True,
|
| 38 |
+
kernel_size=5,
|
| 39 |
+
hidden_blocks=5,
|
| 40 |
+
residual=False,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.block1 = self.create_block(
|
| 44 |
+
in_dim,
|
| 45 |
+
hidden_dim,
|
| 46 |
+
dw=False,
|
| 47 |
+
kernel_size=1,
|
| 48 |
+
)
|
| 49 |
+
self.hidden_blocks = nn.Sequential(
|
| 50 |
+
*[
|
| 51 |
+
self.create_block(
|
| 52 |
+
hidden_dim,
|
| 53 |
+
hidden_dim,
|
| 54 |
+
dw=dw,
|
| 55 |
+
kernel_size=kernel_size,
|
| 56 |
+
)
|
| 57 |
+
for hb in range(hidden_blocks)
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
self.hidden_blocks = self.hidden_blocks
|
| 61 |
+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
| 62 |
+
self.residual = residual
|
| 63 |
+
|
| 64 |
+
def create_block(
|
| 65 |
+
self,
|
| 66 |
+
in_dim,
|
| 67 |
+
out_dim,
|
| 68 |
+
dw=True,
|
| 69 |
+
kernel_size=5,
|
| 70 |
+
bias=True,
|
| 71 |
+
norm_type=nn.BatchNorm2d,
|
| 72 |
+
):
|
| 73 |
+
num_groups = 1 if not dw else in_dim
|
| 74 |
+
if dw:
|
| 75 |
+
assert out_dim % in_dim == 0, "outdim must be divisible by indim for depthwise"
|
| 76 |
+
conv1 = nn.Conv2d(
|
| 77 |
+
in_dim,
|
| 78 |
+
out_dim,
|
| 79 |
+
kernel_size=kernel_size,
|
| 80 |
+
stride=1,
|
| 81 |
+
padding=kernel_size // 2,
|
| 82 |
+
groups=num_groups,
|
| 83 |
+
bias=bias,
|
| 84 |
+
)
|
| 85 |
+
norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
|
| 86 |
+
relu = nn.ReLU(inplace=True)
|
| 87 |
+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
| 88 |
+
return nn.Sequential(conv1, norm, relu, conv2)
|
| 89 |
+
|
| 90 |
+
def forward(self, feats):
|
| 91 |
+
b, c, hs, ws = feats.shape
|
| 92 |
+
x0 = self.block1(feats)
|
| 93 |
+
x = self.hidden_blocks(x0)
|
| 94 |
+
if self.residual:
|
| 95 |
+
x = (x + x0) / 1.4
|
| 96 |
+
x = self.out_conv(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class VGG19(nn.Module):
|
| 101 |
+
def __init__(self, pretrained=False, num_input_channels=3) -> None:
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 104 |
+
# Maxpool layers: 6, 13, 26, 39
|
| 105 |
+
|
| 106 |
+
if num_input_channels != 3:
|
| 107 |
+
log.info(f"Changing input channels from 3 to {num_input_channels}")
|
| 108 |
+
self.layers[0] = nn.Conv2d(num_input_channels, 64, 3, 1, 1)
|
| 109 |
+
|
| 110 |
+
def get_dim_layers(self):
|
| 111 |
+
return [64, 128, 256, 512]
|
| 112 |
+
|
| 113 |
+
def forward(self, x, **kwargs):
|
| 114 |
+
feats = []
|
| 115 |
+
sizes = []
|
| 116 |
+
for layer in self.layers:
|
| 117 |
+
if isinstance(layer, nn.MaxPool2d):
|
| 118 |
+
feats.append(x)
|
| 119 |
+
sizes.append(x.shape[-2:])
|
| 120 |
+
x = layer(x)
|
| 121 |
+
return feats, sizes
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class VGG(nn.Module):
|
| 125 |
+
def __init__(self, size="19", pretrained=False) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
if size == "11":
|
| 128 |
+
self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
|
| 129 |
+
elif size == "13":
|
| 130 |
+
self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
|
| 131 |
+
elif size == "19":
|
| 132 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 133 |
+
# Maxpool layers: 6, 13, 26, 39
|
| 134 |
+
|
| 135 |
+
def forward(self, x, **kwargs):
|
| 136 |
+
feats = []
|
| 137 |
+
sizes = []
|
| 138 |
+
for layer in self.layers:
|
| 139 |
+
if isinstance(layer, nn.MaxPool2d):
|
| 140 |
+
feats.append(x)
|
| 141 |
+
sizes.append(x.shape[-2:])
|
| 142 |
+
x = layer(x)
|
| 143 |
+
return feats, sizes
|
ripe/models/ripe.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from ripe import utils
|
| 9 |
+
from ripe.utils.utils import gridify
|
| 10 |
+
|
| 11 |
+
log = utils.get_pylogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class KeypointSampler(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Sample keypoints according to a Heatmap
|
| 17 |
+
Adapted from: https://github.com/verlab/DALF_CVPR_2023/blob/main/modules/models/DALF.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, window_size=8):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.window_size = window_size
|
| 23 |
+
self.idx_cells = None # Cache for meshgrid indices
|
| 24 |
+
|
| 25 |
+
def sample(self, grid):
|
| 26 |
+
"""
|
| 27 |
+
Sample keypoints given a grid where each cell has logits stacked in last dimension
|
| 28 |
+
Input
|
| 29 |
+
grid: [B, C, H//w, W//w, w*w]
|
| 30 |
+
|
| 31 |
+
Returns
|
| 32 |
+
log_probs: [B, C, H//w, W//w ] - logprobs of selected samples
|
| 33 |
+
choices: [B, C, H//w, W//w] indices of choices
|
| 34 |
+
accept_mask: [B, C, H//w, W//w] mask of accepted keypoints
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
chooser = torch.distributions.Categorical(logits=grid)
|
| 38 |
+
choices = chooser.sample()
|
| 39 |
+
logits_selected = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1)
|
| 40 |
+
|
| 41 |
+
flipper = torch.distributions.Bernoulli(logits=logits_selected)
|
| 42 |
+
accepted_choices = flipper.sample()
|
| 43 |
+
|
| 44 |
+
# Sum log-probabilities is equivalent to multiplying the probabilities
|
| 45 |
+
log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices)
|
| 46 |
+
|
| 47 |
+
accept_mask = accepted_choices.gt(0)
|
| 48 |
+
|
| 49 |
+
return (
|
| 50 |
+
log_probs.squeeze(1),
|
| 51 |
+
choices,
|
| 52 |
+
accept_mask.squeeze(1),
|
| 53 |
+
logits_selected.squeeze(1),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def precompute_idx_cells(self, H, W, device):
|
| 57 |
+
idx_cells = gridify(
|
| 58 |
+
torch.dstack(
|
| 59 |
+
torch.meshgrid(
|
| 60 |
+
torch.arange(H, dtype=torch.float32, device=device),
|
| 61 |
+
torch.arange(W, dtype=torch.float32, device=device),
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
.permute(2, 0, 1)
|
| 65 |
+
.unsqueeze(0)
|
| 66 |
+
.expand(1, -1, -1, -1),
|
| 67 |
+
window_size=self.window_size,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return idx_cells
|
| 71 |
+
|
| 72 |
+
def forward(self, x, mask_padding=None):
|
| 73 |
+
"""
|
| 74 |
+
Sample keypoints from a heatmap
|
| 75 |
+
Input
|
| 76 |
+
x: [B, C, H, W] Heatmap
|
| 77 |
+
mask_padding: [B, 1, H, W] Mask for padding (optional)
|
| 78 |
+
Returns
|
| 79 |
+
keypoints: [B, H//w, W//w, 2] Keypoints in (x, y) format
|
| 80 |
+
log_probs: [B, H//w, W//w] Log probabilities of selected keypoints
|
| 81 |
+
mask: [B, H//w, W//w] Mask of accepted keypoints
|
| 82 |
+
mask_padding: [B, 1, H//w, W//w] Mask of padding (optional)
|
| 83 |
+
logits_selected: [B, H//w, W//w] Logits of selected keypoints
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
B, C, H, W = x.shape
|
| 87 |
+
|
| 88 |
+
keypoint_cells = gridify(x, self.window_size)
|
| 89 |
+
|
| 90 |
+
mask_padding = (
|
| 91 |
+
(torch.min(gridify(mask_padding, self.window_size), dim=4).values) if mask_padding is not None else None
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if self.idx_cells is None or self.idx_cells.shape[2:4] != (
|
| 95 |
+
H // self.window_size,
|
| 96 |
+
W // self.window_size,
|
| 97 |
+
):
|
| 98 |
+
self.idx_cells = self.precompute_idx_cells(H, W, x.device)
|
| 99 |
+
|
| 100 |
+
log_probs, idx, mask, logits_selected = self.sample(keypoint_cells)
|
| 101 |
+
|
| 102 |
+
keypoints = (
|
| 103 |
+
torch.gather(
|
| 104 |
+
self.idx_cells.expand(B, -1, -1, -1, -1),
|
| 105 |
+
-1,
|
| 106 |
+
idx.repeat(1, 2, 1, 1).unsqueeze(-1),
|
| 107 |
+
)
|
| 108 |
+
.squeeze(-1)
|
| 109 |
+
.permute(0, 2, 3, 1)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# flip keypoints to (x, y) format
|
| 113 |
+
return keypoints.flip(-1), log_probs, mask, mask_padding, logits_selected
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class RIPE(nn.Module):
|
| 117 |
+
"""
|
| 118 |
+
Base class for extracting keypoints and descriptors
|
| 119 |
+
Input
|
| 120 |
+
x: [B, C, H, W] Images
|
| 121 |
+
|
| 122 |
+
Returns
|
| 123 |
+
kpts:
|
| 124 |
+
list of size [B] with detected keypoints
|
| 125 |
+
descs:
|
| 126 |
+
list of size [B] with descriptors
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
net,
|
| 132 |
+
upsampler,
|
| 133 |
+
window_size: int = 8,
|
| 134 |
+
non_linearity_dect=None,
|
| 135 |
+
desc_shares: Optional[List[int]] = None,
|
| 136 |
+
descriptor_dim: int = 256,
|
| 137 |
+
device=None,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.net = net
|
| 141 |
+
|
| 142 |
+
self.detector = KeypointSampler(window_size)
|
| 143 |
+
self.upsampler = upsampler
|
| 144 |
+
self.sampler = None
|
| 145 |
+
self.window_size = window_size
|
| 146 |
+
self.non_linearity_dect = non_linearity_dect if non_linearity_dect is not None else nn.Identity()
|
| 147 |
+
|
| 148 |
+
log.info(f"Training with window size {window_size}.")
|
| 149 |
+
log.info(f"Use {non_linearity_dect} as final non-linearity before the detection heatmap.")
|
| 150 |
+
|
| 151 |
+
dim_coarse_desc = self.get_dim_raw_desc()
|
| 152 |
+
|
| 153 |
+
if desc_shares is not None:
|
| 154 |
+
assert upsampler.name == "HyperColumnFeatures", (
|
| 155 |
+
"Individual descriptor convolutions are only supported with HyperColumnFeatures"
|
| 156 |
+
)
|
| 157 |
+
assert len(desc_shares) == 4, "desc_shares should have 4 elements"
|
| 158 |
+
assert sum(desc_shares) == descriptor_dim, f"sum of desc_shares should be {descriptor_dim}"
|
| 159 |
+
|
| 160 |
+
self.conv_dim_reduction_coarse_desc = nn.ModuleList()
|
| 161 |
+
|
| 162 |
+
for dim_in, dim_out in zip(dim_coarse_desc, desc_shares):
|
| 163 |
+
log.info(f"Training dim reduction descriptor with {dim_in} -> {dim_out} 1x1 conv")
|
| 164 |
+
self.conv_dim_reduction_coarse_desc.append(
|
| 165 |
+
nn.Conv1d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
if descriptor_dim is not None:
|
| 169 |
+
log.info(f"Training dim reduction descriptor with {sum(dim_coarse_desc)} -> {descriptor_dim} 1x1 conv")
|
| 170 |
+
self.conv_dim_reduction_coarse_desc = nn.Conv1d(
|
| 171 |
+
sum(dim_coarse_desc),
|
| 172 |
+
descriptor_dim,
|
| 173 |
+
kernel_size=1,
|
| 174 |
+
stride=1,
|
| 175 |
+
padding=0,
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
log.warning(
|
| 179 |
+
f"No descriptor dimension specified, no 1x1 conv will be applied! Direct usage of {sum(dim_coarse_desc)}-dimensional raw descriptor"
|
| 180 |
+
)
|
| 181 |
+
self.conv_dim_reduction_coarse_desc = nn.Identity()
|
| 182 |
+
|
| 183 |
+
def get_dim_raw_desc(self):
|
| 184 |
+
layers_dims_encoder = self.net.get_dim_layers_encoder()
|
| 185 |
+
|
| 186 |
+
if self.upsampler.name == "InterpolateSparse2d":
|
| 187 |
+
return [layers_dims_encoder[-1]]
|
| 188 |
+
elif self.upsampler.name == "HyperColumnFeatures":
|
| 189 |
+
return layers_dims_encoder
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Unknown interpolator {self.upsampler.name}")
|
| 192 |
+
|
| 193 |
+
@torch.inference_mode()
|
| 194 |
+
def detectAndCompute(self, img, threshold=0.5, top_k=2048, output_aux=False):
|
| 195 |
+
self.train(False)
|
| 196 |
+
|
| 197 |
+
if img.dim() == 3:
|
| 198 |
+
img = img.unsqueeze(0)
|
| 199 |
+
|
| 200 |
+
out = self(img, training=False)
|
| 201 |
+
B, K, H, W = out["heatmap"].shape
|
| 202 |
+
|
| 203 |
+
assert B == 1, "Batch size should be 1"
|
| 204 |
+
|
| 205 |
+
kpts = [{"xy": self.NMS(out["heatmap"][b], threshold)} for b in range(B)]
|
| 206 |
+
|
| 207 |
+
if top_k is not None:
|
| 208 |
+
for b in range(B):
|
| 209 |
+
scores = out["heatmap"][b].squeeze(0)[kpts[b]["xy"][:, 1].long(), kpts[b]["xy"][:, 0].long()]
|
| 210 |
+
sorted_idx = torch.argsort(-scores)
|
| 211 |
+
kpts[b]["xy"] = kpts[b]["xy"][sorted_idx[:top_k]]
|
| 212 |
+
if "logprobs" in kpts[b]:
|
| 213 |
+
kpts[b]["logprobs"] = kpts[b]["xy"][sorted_idx[:top_k]]
|
| 214 |
+
|
| 215 |
+
if kpts[0]["xy"].shape[0] == 0:
|
| 216 |
+
raise RuntimeError("No keypoints detected")
|
| 217 |
+
|
| 218 |
+
# the following works for batch size 1 only
|
| 219 |
+
|
| 220 |
+
descs = self.get_descs(out["coarse_descs"], img, kpts[0]["xy"].unsqueeze(0), H, W)
|
| 221 |
+
descs = descs.squeeze(0)
|
| 222 |
+
|
| 223 |
+
score_map = out["heatmap"][0].squeeze(0)
|
| 224 |
+
|
| 225 |
+
kpts = kpts[0]["xy"]
|
| 226 |
+
|
| 227 |
+
scores = score_map[kpts[:, 1], kpts[:, 0]]
|
| 228 |
+
scores /= score_map.max()
|
| 229 |
+
|
| 230 |
+
sort_idx = torch.argsort(-scores)
|
| 231 |
+
kpts, descs, scores = kpts[sort_idx], descs[sort_idx], scores[sort_idx]
|
| 232 |
+
|
| 233 |
+
if output_aux:
|
| 234 |
+
return (
|
| 235 |
+
kpts.float(),
|
| 236 |
+
descs,
|
| 237 |
+
scores,
|
| 238 |
+
{
|
| 239 |
+
"heatmap": out["heatmap"],
|
| 240 |
+
"descs": out["coarse_descs"],
|
| 241 |
+
"conv": self.conv_dim_reduction_coarse_desc,
|
| 242 |
+
},
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return kpts.float(), descs, scores
|
| 246 |
+
|
| 247 |
+
def NMS(self, x, threshold=3.0, kernel_size=3):
|
| 248 |
+
pad = kernel_size // 2
|
| 249 |
+
local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
|
| 250 |
+
|
| 251 |
+
pos = (x == local_max) & (x > threshold)
|
| 252 |
+
return pos.nonzero()[..., 1:].flip(-1)
|
| 253 |
+
|
| 254 |
+
def get_descs(self, feature_map, guidance, kpts, H, W):
|
| 255 |
+
descs = self.upsampler(feature_map, kpts, H, W)
|
| 256 |
+
|
| 257 |
+
if isinstance(self.conv_dim_reduction_coarse_desc, nn.ModuleList):
|
| 258 |
+
# individual descriptor convolutions for each layer
|
| 259 |
+
desc_conv = []
|
| 260 |
+
for desc, conv in zip(descs, self.conv_dim_reduction_coarse_desc):
|
| 261 |
+
desc_conv.append(conv(desc.permute(0, 2, 1)).permute(0, 2, 1))
|
| 262 |
+
desc = torch.cat(desc_conv, dim=-1)
|
| 263 |
+
else:
|
| 264 |
+
desc = torch.cat(descs, dim=-1)
|
| 265 |
+
desc = self.conv_dim_reduction_coarse_desc(desc.permute(0, 2, 1)).permute(0, 2, 1)
|
| 266 |
+
|
| 267 |
+
desc = F.normalize(desc, dim=2)
|
| 268 |
+
|
| 269 |
+
return desc
|
| 270 |
+
|
| 271 |
+
def forward(self, x, mask_padding=None, training=False):
|
| 272 |
+
B, C, H, W = x.shape
|
| 273 |
+
out = self.net(x)
|
| 274 |
+
out["heatmap"] = self.non_linearity_dect(out["heatmap"])
|
| 275 |
+
# print(out['map'].shape, out['descr'].shape)
|
| 276 |
+
if training:
|
| 277 |
+
kpts, log_probs, mask, mask_padding, logits_selected = self.detector(out["heatmap"], mask_padding)
|
| 278 |
+
|
| 279 |
+
filter_A = kpts[:, :, :, 0] >= 16
|
| 280 |
+
filter_B = kpts[:, :, :, 1] >= 16
|
| 281 |
+
filter_C = kpts[:, :, :, 0] < W - 16
|
| 282 |
+
filter_D = kpts[:, :, :, 1] < H - 16
|
| 283 |
+
filter_all = filter_A * filter_B * filter_C * filter_D
|
| 284 |
+
|
| 285 |
+
mask = mask * filter_all
|
| 286 |
+
|
| 287 |
+
return (
|
| 288 |
+
kpts.view(B, -1, 2),
|
| 289 |
+
log_probs.view(B, -1),
|
| 290 |
+
mask.view(B, -1),
|
| 291 |
+
mask_padding.view(B, -1),
|
| 292 |
+
logits_selected.view(B, -1),
|
| 293 |
+
out,
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
return out
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def output_number_trainable_params(model):
|
| 300 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 301 |
+
nb_params = sum([np.prod(p.size()) for p in model_parameters])
|
| 302 |
+
|
| 303 |
+
print(f"Number of trainable parameters: {nb_params:d}")
|
ripe/models/upsampler/hypercolumn_features.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HyperColumnFeatures(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Interpolate 3D tensor given N sparse 2D positions
|
| 9 |
+
Input
|
| 10 |
+
x: list([C, H, W]) list of feature tensors at different scales (e.g. from a U-Net) -> extract hypercolumn features
|
| 11 |
+
pos: [N, 2] tensor of positions
|
| 12 |
+
H: int, height of the OUTPUT map
|
| 13 |
+
W: int, width of the OUTPUT map
|
| 14 |
+
|
| 15 |
+
Returns
|
| 16 |
+
[N, C] sampled features at 2d positions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, mode="bilinear"):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.mode = mode
|
| 22 |
+
self.name = "HyperColumnFeatures"
|
| 23 |
+
|
| 24 |
+
def normgrid(self, x, H, W):
|
| 25 |
+
return 2.0 * (x / (torch.tensor([W - 1, H - 1], device=x.device, dtype=x.dtype))) - 1.0
|
| 26 |
+
|
| 27 |
+
def extract_values_at_poses(self, x, pos, H, W):
|
| 28 |
+
"""Extract values from tensor x at the positions given by pos.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
- x (Tensor): Tensor of size (C, H, W).
|
| 32 |
+
- pos (Tensor): Tensor of size (N, 2) containing the x, y positions.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
- values (Tensor): Tensor of size (N, C) with the values from f at the positions given by p.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# check if grid is float32
|
| 39 |
+
if x.dtype != torch.float32:
|
| 40 |
+
x = x.to(torch.float32)
|
| 41 |
+
|
| 42 |
+
grid = self.normgrid(pos, H, W).unsqueeze(-2)
|
| 43 |
+
|
| 44 |
+
x = F.grid_sample(x, grid, mode=self.mode, align_corners=True)
|
| 45 |
+
return x.permute(0, 2, 3, 1).squeeze(-2)
|
| 46 |
+
|
| 47 |
+
def forward(self, x, pos, H, W):
|
| 48 |
+
descs = []
|
| 49 |
+
|
| 50 |
+
for layer in x:
|
| 51 |
+
desc = self.extract_values_at_poses(layer, pos, H, W)
|
| 52 |
+
descs.append(desc)
|
| 53 |
+
|
| 54 |
+
return descs
|
ripe/models/upsampler/interpolate_sparse2d.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class InterpolateSparse2d(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Interpolate 3D tensor given N sparse 2D positions
|
| 9 |
+
Input
|
| 10 |
+
x: list([C, H, W]) feature tensors at different scales (e.g. from a U-Net), ONLY the last one is used
|
| 11 |
+
pos: [N, 2] tensor of positions
|
| 12 |
+
H: int, height of the OUTPUT map
|
| 13 |
+
W: int, width of the OUTPUT map
|
| 14 |
+
|
| 15 |
+
Returns
|
| 16 |
+
[N, C] sampled features at 2d positions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, mode="bicubic"):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.mode = mode
|
| 22 |
+
self.name = "InterpolateSparse2d"
|
| 23 |
+
|
| 24 |
+
def normgrid(self, x, H, W):
|
| 25 |
+
return 2.0 * (x / (torch.tensor([W - 1, H - 1], device=x.device, dtype=x.dtype))) - 1.0
|
| 26 |
+
|
| 27 |
+
def forward(self, x, pos, H, W):
|
| 28 |
+
x = x[-1] # only use the last layer
|
| 29 |
+
|
| 30 |
+
# check if grid is float32
|
| 31 |
+
if x.dtype != torch.float32:
|
| 32 |
+
x = x.to(torch.float32)
|
| 33 |
+
|
| 34 |
+
grid = self.normgrid(pos, H, W).unsqueeze(-2)
|
| 35 |
+
|
| 36 |
+
x = F.grid_sample(x, grid, mode=self.mode, align_corners=True)
|
| 37 |
+
return [x.permute(0, 2, 3, 1).squeeze(-2)]
|
ripe/scheduler/__init__.py
ADDED
|
File without changes
|
ripe/scheduler/constant.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ConstantScheduler:
|
| 2 |
+
def __init__(self, value):
|
| 3 |
+
self.value = value
|
| 4 |
+
|
| 5 |
+
def __call__(self, step):
|
| 6 |
+
return self.value
|
ripe/scheduler/expDecay.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from ripe import utils
|
| 4 |
+
|
| 5 |
+
log = utils.get_pylogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ExpDecay:
|
| 9 |
+
"""Exponential decay scheduler.
|
| 10 |
+
args:
|
| 11 |
+
a: float, a + c = initial value
|
| 12 |
+
b: decay rate
|
| 13 |
+
c: float, final value
|
| 14 |
+
|
| 15 |
+
f(x) = a * e^(-b * x) + c
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, a, b, c):
|
| 19 |
+
self.a = a
|
| 20 |
+
self.b = b
|
| 21 |
+
self.c = c
|
| 22 |
+
|
| 23 |
+
log.info(f"ExpDecay: a={a}, b={b}, c={c}")
|
| 24 |
+
|
| 25 |
+
def __call__(self, step):
|
| 26 |
+
return self.a * np.exp(-self.b * step) + self.c
|
ripe/scheduler/linearLR.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class StepLinearLR:
|
| 2 |
+
"""Decay the learning rate by a linearly changing factor at each STEP (not epoch).
|
| 3 |
+
|
| 4 |
+
Args:
|
| 5 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 6 |
+
num_steps (int): Total number of steps in the training process.
|
| 7 |
+
initial_lr (float): Initial learning rate.
|
| 8 |
+
final_lr (float): Final learning rate.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, optimizer, steps_init, num_steps, initial_lr, final_lr):
|
| 12 |
+
self.optimizer = optimizer
|
| 13 |
+
self.num_steps = num_steps
|
| 14 |
+
self.initial_lr = initial_lr
|
| 15 |
+
self.final_lr = final_lr
|
| 16 |
+
self.i_step = steps_init
|
| 17 |
+
self.decay_factor = (final_lr - initial_lr) / num_steps
|
| 18 |
+
|
| 19 |
+
def step(self):
|
| 20 |
+
"""Decay the learning rate by decay_factor."""
|
| 21 |
+
self.i_step += 1
|
| 22 |
+
|
| 23 |
+
if self.i_step > self.num_steps:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
lr = self.initial_lr + self.i_step * self.decay_factor
|
| 27 |
+
for param_group in self.optimizer.param_groups:
|
| 28 |
+
param_group["lr"] = lr
|
| 29 |
+
|
| 30 |
+
def get_lr(self):
|
| 31 |
+
return self.optimizer.param_groups[0]["lr"]
|
| 32 |
+
|
| 33 |
+
def get_last_lr(self):
|
| 34 |
+
return self.optimizer.param_groups[0]["lr"]
|
| 35 |
+
|
| 36 |
+
def get_step(self):
|
| 37 |
+
return self.i_step
|
ripe/scheduler/linear_with_plateaus.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ripe import utils
|
| 2 |
+
|
| 3 |
+
log = utils.get_pylogger(__name__)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LinearWithPlateaus:
|
| 7 |
+
"""Linear scheduler with plateaus.
|
| 8 |
+
|
| 9 |
+
Linearly increases from `start_val` to `end_val`.
|
| 10 |
+
Stays at `start_val` for `plateau_start_steps` steps and at `end_val` for `plateau_end_steps` steps.
|
| 11 |
+
Linearly changes from `start_val` to `end_val` during the remaining steps.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
start_val,
|
| 17 |
+
end_val,
|
| 18 |
+
steps_total,
|
| 19 |
+
rel_length_start_plateau=0.0,
|
| 20 |
+
rel_length_end_plateu=0.0,
|
| 21 |
+
):
|
| 22 |
+
self.start_val = start_val
|
| 23 |
+
self.end_val = end_val
|
| 24 |
+
self.steps_total = steps_total
|
| 25 |
+
self.plateau_start_steps = steps_total * rel_length_start_plateau
|
| 26 |
+
self.plateau_end_steps = steps_total * rel_length_end_plateu
|
| 27 |
+
|
| 28 |
+
assert self.plateau_start_steps >= 0
|
| 29 |
+
assert self.plateau_end_steps >= 0
|
| 30 |
+
assert self.plateau_start_steps + self.plateau_end_steps <= self.steps_total
|
| 31 |
+
|
| 32 |
+
self.slope = (end_val - start_val) / (steps_total - self.plateau_start_steps - self.plateau_end_steps)
|
| 33 |
+
|
| 34 |
+
log.info(
|
| 35 |
+
f"LinearWithPlateaus: start_val={start_val}, end_val={end_val}, steps_total={steps_total}, "
|
| 36 |
+
f"plateau_start_steps={self.plateau_start_steps}, plateau_end_steps={self.plateau_end_steps}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def __call__(self, step):
|
| 40 |
+
if step < self.plateau_start_steps:
|
| 41 |
+
return self.start_val
|
| 42 |
+
if step < self.steps_total - self.plateau_end_steps:
|
| 43 |
+
return self.start_val + self.slope * (step - self.plateau_start_steps)
|
| 44 |
+
return self.end_val
|
ripe/train.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyrootutils
|
| 2 |
+
|
| 3 |
+
root = pyrootutils.setup_root(
|
| 4 |
+
search_from=__file__,
|
| 5 |
+
indicator=[".git", "pyproject.toml"],
|
| 6 |
+
pythonpath=True,
|
| 7 |
+
dotenv=True,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
SEED = 32000
|
| 11 |
+
|
| 12 |
+
import collections
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import hydra
|
| 16 |
+
from hydra.utils import instantiate
|
| 17 |
+
from lightning.fabric import Fabric
|
| 18 |
+
|
| 19 |
+
print(SEED)
|
| 20 |
+
import random
|
| 21 |
+
|
| 22 |
+
os.environ["PYTHONHASHSEED"] = str(SEED)
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import tqdm
|
| 27 |
+
import wandb
|
| 28 |
+
from torch.optim.adamw import AdamW
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
|
| 31 |
+
from ripe import utils
|
| 32 |
+
from ripe.benchmarks.imw_2020 import IMW_2020_Benchmark
|
| 33 |
+
from ripe.utils.utils import get_rewards
|
| 34 |
+
from ripe.utils.wandb_utils import get_flattened_wandb_cfg
|
| 35 |
+
|
| 36 |
+
log = utils.get_pylogger(__name__)
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
torch.manual_seed(SEED)
|
| 40 |
+
np.random.seed(SEED)
|
| 41 |
+
random.seed(SEED)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def unpack_batch(batch):
|
| 45 |
+
src_image = batch["src_image"]
|
| 46 |
+
trg_image = batch["trg_image"]
|
| 47 |
+
trg_mask = batch["trg_mask"]
|
| 48 |
+
src_mask = batch["src_mask"]
|
| 49 |
+
label = batch["label"]
|
| 50 |
+
H = batch["homography"]
|
| 51 |
+
|
| 52 |
+
return src_image, trg_image, src_mask, trg_mask, H, label
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@hydra.main(config_path="../conf/", config_name="config", version_base=None)
|
| 56 |
+
def train(cfg):
|
| 57 |
+
"""Main training function for the RIPE model."""
|
| 58 |
+
# Prepare model, data and hyperparms
|
| 59 |
+
|
| 60 |
+
strategy = "ddp" if cfg.num_gpus > 1 else "auto"
|
| 61 |
+
fabric = Fabric(
|
| 62 |
+
accelerator="cuda",
|
| 63 |
+
devices=cfg.num_gpus,
|
| 64 |
+
precision=cfg.precision,
|
| 65 |
+
strategy=strategy,
|
| 66 |
+
)
|
| 67 |
+
fabric.launch()
|
| 68 |
+
|
| 69 |
+
output_dir = Path(cfg.output_dir)
|
| 70 |
+
experiment_name = output_dir.parent.parent.parent.name
|
| 71 |
+
run_id = output_dir.parent.parent.name
|
| 72 |
+
timestamp = output_dir.parent.name + "_" + output_dir.name
|
| 73 |
+
|
| 74 |
+
experiment_name = run_id + " " + timestamp + " " + experiment_name
|
| 75 |
+
|
| 76 |
+
# setup logger
|
| 77 |
+
wandb_logger = wandb.init(
|
| 78 |
+
project=cfg.project_name,
|
| 79 |
+
name=experiment_name,
|
| 80 |
+
config=get_flattened_wandb_cfg(cfg),
|
| 81 |
+
dir=cfg.output_dir,
|
| 82 |
+
mode=cfg.wandb_mode,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
min_nums_matches = {"homography": 4, "fundamental": 8, "fundamental_7pt": 7}
|
| 86 |
+
min_num_matches = min_nums_matches[cfg.transformation_model]
|
| 87 |
+
print(f"Minimum number of matches for {cfg.transformation_model} is {min_num_matches}")
|
| 88 |
+
|
| 89 |
+
batch_size = cfg.batch_size
|
| 90 |
+
steps = cfg.num_steps
|
| 91 |
+
lr = cfg.lr
|
| 92 |
+
|
| 93 |
+
num_grad_accs = (
|
| 94 |
+
cfg.num_grad_accs
|
| 95 |
+
) # this performs grad accumulation to simulate larger batch size, set to 1 to disable;
|
| 96 |
+
|
| 97 |
+
# instantiate dataset
|
| 98 |
+
ds = instantiate(cfg.data)
|
| 99 |
+
|
| 100 |
+
# prepare dataloader
|
| 101 |
+
dl = DataLoader(
|
| 102 |
+
ds,
|
| 103 |
+
batch_size=batch_size,
|
| 104 |
+
shuffle=True,
|
| 105 |
+
drop_last=True,
|
| 106 |
+
persistent_workers=False,
|
| 107 |
+
num_workers=cfg.num_workers,
|
| 108 |
+
)
|
| 109 |
+
dl = fabric.setup_dataloaders(dl)
|
| 110 |
+
i_dl = iter(dl)
|
| 111 |
+
|
| 112 |
+
# create matcher
|
| 113 |
+
matcher = instantiate(cfg.matcher)
|
| 114 |
+
|
| 115 |
+
if cfg.desc_loss_weight != 0.0:
|
| 116 |
+
descriptor_loss = instantiate(cfg.descriptor_loss)
|
| 117 |
+
else:
|
| 118 |
+
log.warning(
|
| 119 |
+
"Descriptor loss weight is 0.0, descriptor loss will not be used. 1x1 conv for descriptors will be deactivated!"
|
| 120 |
+
)
|
| 121 |
+
descriptor_loss = None
|
| 122 |
+
|
| 123 |
+
upsampler = instantiate(cfg.upsampler) if "upsampler" in cfg else None
|
| 124 |
+
|
| 125 |
+
# create network
|
| 126 |
+
net = instantiate(cfg.network)(
|
| 127 |
+
net=instantiate(cfg.backbones),
|
| 128 |
+
upsampler=upsampler,
|
| 129 |
+
descriptor_dim=cfg.descriptor_dim if descriptor_loss is not None else None,
|
| 130 |
+
device=fabric.device,
|
| 131 |
+
).train()
|
| 132 |
+
|
| 133 |
+
# get num parameters
|
| 134 |
+
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
| 135 |
+
log.info(f"Number of parameters: {num_params}")
|
| 136 |
+
|
| 137 |
+
fp_penalty = cfg.fp_penalty # small penalty for not finding a match
|
| 138 |
+
kp_penalty = cfg.kp_penalty # small penalty for low logprob keypoints
|
| 139 |
+
|
| 140 |
+
opt_pi = AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=lr, weight_decay=1e-5)
|
| 141 |
+
net, opt_pi = fabric.setup(net, opt_pi)
|
| 142 |
+
|
| 143 |
+
if cfg.lr_scheduler:
|
| 144 |
+
scheduler = instantiate(cfg.lr_scheduler)(optimizer=opt_pi, steps_init=0)
|
| 145 |
+
else:
|
| 146 |
+
scheduler = None
|
| 147 |
+
|
| 148 |
+
val_benchmark = IMW_2020_Benchmark(
|
| 149 |
+
use_predefined_subset=True,
|
| 150 |
+
conf_inference=cfg.conf_inference,
|
| 151 |
+
edge_input_divisible_by=None,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# mean average of skipped batches
|
| 155 |
+
# this is used to monitor how many batches were skipped due to not enough keypoints
|
| 156 |
+
# this is useful to detect if the model is not learning anything -> should be zero
|
| 157 |
+
ma_skipped_batches = collections.deque(maxlen=100)
|
| 158 |
+
|
| 159 |
+
opt_pi.zero_grad()
|
| 160 |
+
|
| 161 |
+
# initialize scheduler
|
| 162 |
+
alpha_scheduler = instantiate(cfg.alpha_scheduler)
|
| 163 |
+
beta_scheduler = instantiate(cfg.beta_scheduler)
|
| 164 |
+
inl_th_scheduler = instantiate(cfg.inl_th)
|
| 165 |
+
|
| 166 |
+
# ====== Training Loop ======
|
| 167 |
+
# check if the model is in training mode
|
| 168 |
+
net.train()
|
| 169 |
+
|
| 170 |
+
with tqdm.tqdm(total=steps) as pbar:
|
| 171 |
+
for i_step in range(steps):
|
| 172 |
+
alpha = alpha_scheduler(i_step)
|
| 173 |
+
beta = beta_scheduler(i_step)
|
| 174 |
+
inl_th = inl_th_scheduler(i_step)
|
| 175 |
+
|
| 176 |
+
if scheduler:
|
| 177 |
+
scheduler.step()
|
| 178 |
+
|
| 179 |
+
# Initialize vars for current step
|
| 180 |
+
# We need to handle batching because the description can have arbitrary number of keypoints
|
| 181 |
+
sum_reward_batch = 0
|
| 182 |
+
sum_num_keypoints_1 = 0
|
| 183 |
+
sum_num_keypoints_2 = 0
|
| 184 |
+
loss = None
|
| 185 |
+
loss_policy_stack = None
|
| 186 |
+
loss_desc_stack = None
|
| 187 |
+
loss_kp_stack = None
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
batch = next(i_dl)
|
| 191 |
+
except StopIteration:
|
| 192 |
+
i_dl = iter(dl)
|
| 193 |
+
batch = next(i_dl)
|
| 194 |
+
|
| 195 |
+
p1, p2, mask_padding_1, mask_padding_2, Hs, label = unpack_batch(batch)
|
| 196 |
+
|
| 197 |
+
(
|
| 198 |
+
kpts1,
|
| 199 |
+
logprobs1,
|
| 200 |
+
selected_mask1,
|
| 201 |
+
mask_padding_grid_1,
|
| 202 |
+
logits_selected_1,
|
| 203 |
+
out1,
|
| 204 |
+
) = net(p1, mask_padding_1, training=True)
|
| 205 |
+
(
|
| 206 |
+
kpts2,
|
| 207 |
+
logprobs2,
|
| 208 |
+
selected_mask2,
|
| 209 |
+
mask_padding_grid_2,
|
| 210 |
+
logits_selected_2,
|
| 211 |
+
out2,
|
| 212 |
+
) = net(p2, mask_padding_2, training=True)
|
| 213 |
+
|
| 214 |
+
# upsample coarse descriptors for all keypoints from the intermediate feature maps from the encoder
|
| 215 |
+
desc_1 = net.get_descs(out1["coarse_descs"], p1, kpts1, p1.shape[2], p1.shape[3])
|
| 216 |
+
desc_2 = net.get_descs(out2["coarse_descs"], p2, kpts2, p2.shape[2], p2.shape[3])
|
| 217 |
+
|
| 218 |
+
if cfg.padding_filter_mode == "ignore": # remove keypoints that are in padding
|
| 219 |
+
batch_mask_selection_for_matching_1 = selected_mask1 & mask_padding_grid_1
|
| 220 |
+
batch_mask_selection_for_matching_2 = selected_mask2 & mask_padding_grid_2
|
| 221 |
+
elif cfg.padding_filter_mode == "punish":
|
| 222 |
+
batch_mask_selection_for_matching_1 = selected_mask1 # keep all keypoints
|
| 223 |
+
batch_mask_selection_for_matching_2 = selected_mask2 # punish the keypoints in the padding area
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
|
| 226 |
+
|
| 227 |
+
(
|
| 228 |
+
batch_rel_idx_matches,
|
| 229 |
+
batch_abs_idx_matches,
|
| 230 |
+
batch_ransac_inliers,
|
| 231 |
+
batch_Fm,
|
| 232 |
+
) = matcher(
|
| 233 |
+
kpts1,
|
| 234 |
+
kpts2,
|
| 235 |
+
desc_1,
|
| 236 |
+
desc_2,
|
| 237 |
+
batch_mask_selection_for_matching_1,
|
| 238 |
+
batch_mask_selection_for_matching_2,
|
| 239 |
+
inl_th,
|
| 240 |
+
label if cfg.no_filtering_negatives else None,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
for b in range(batch_size):
|
| 244 |
+
# ignore if less than 16 keypoints have been detected
|
| 245 |
+
if batch_rel_idx_matches[b] is None:
|
| 246 |
+
ma_skipped_batches.append(1)
|
| 247 |
+
continue
|
| 248 |
+
else:
|
| 249 |
+
ma_skipped_batches.append(0)
|
| 250 |
+
|
| 251 |
+
mask_selection_for_matching_1 = batch_mask_selection_for_matching_1[b]
|
| 252 |
+
mask_selection_for_matching_2 = batch_mask_selection_for_matching_2[b]
|
| 253 |
+
|
| 254 |
+
rel_idx_matches = batch_rel_idx_matches[b]
|
| 255 |
+
abs_idx_matches = batch_abs_idx_matches[b]
|
| 256 |
+
ransac_inliers = batch_ransac_inliers[b]
|
| 257 |
+
|
| 258 |
+
if cfg.selected_only:
|
| 259 |
+
# every SELECTED keypoint with every other SELECTED keypoint
|
| 260 |
+
dense_logprobs = logprobs1[b][mask_selection_for_matching_1].view(-1, 1) + logprobs2[b][
|
| 261 |
+
mask_selection_for_matching_2
|
| 262 |
+
].view(1, -1)
|
| 263 |
+
else:
|
| 264 |
+
if cfg.padding_filter_mode == "ignore":
|
| 265 |
+
# every keypoint with every other keypoint, but WITHOUT keypoint in the padding area
|
| 266 |
+
dense_logprobs = logprobs1[b][mask_padding_grid_1[b]].view(-1, 1) + logprobs2[b][
|
| 267 |
+
mask_padding_grid_2[b]
|
| 268 |
+
].view(1, -1)
|
| 269 |
+
elif cfg.padding_filter_mode == "punish":
|
| 270 |
+
# every keypoint with every other keypoint, also WITH keypoints in the padding areas -> will be punished by the reward
|
| 271 |
+
dense_logprobs = logprobs1[b].view(-1, 1) + logprobs2[b].view(1, -1)
|
| 272 |
+
else:
|
| 273 |
+
raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
|
| 274 |
+
|
| 275 |
+
reward = None
|
| 276 |
+
|
| 277 |
+
if cfg.reward_type == "inlier":
|
| 278 |
+
reward = (
|
| 279 |
+
0.5 if cfg.no_filtering_negatives and not label[b] else 1.0
|
| 280 |
+
) # reward is 1.0 if the pair is positive, 0.5 if negative and no filtering is applied
|
| 281 |
+
elif cfg.reward_type == "inlier_ratio":
|
| 282 |
+
ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
|
| 283 |
+
reward = ratio_inlier # reward is the ratio of inliers -> higher if more matches are inliers
|
| 284 |
+
elif cfg.reward_type == "inlier+inlier_ratio":
|
| 285 |
+
ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
|
| 286 |
+
reward = (
|
| 287 |
+
(1.0 - beta) * 1.0 + beta * ratio_inlier
|
| 288 |
+
) # reward is a combination of the ratio of inliers and the number of inliers -> gradually changes
|
| 289 |
+
else:
|
| 290 |
+
raise ValueError(f"Unknown reward type: {cfg.reward_type}")
|
| 291 |
+
|
| 292 |
+
dense_rewards = get_rewards(
|
| 293 |
+
reward,
|
| 294 |
+
kpts1[b],
|
| 295 |
+
kpts2[b],
|
| 296 |
+
mask_selection_for_matching_1,
|
| 297 |
+
mask_selection_for_matching_2,
|
| 298 |
+
mask_padding_grid_1[b],
|
| 299 |
+
mask_padding_grid_2[b],
|
| 300 |
+
rel_idx_matches,
|
| 301 |
+
abs_idx_matches,
|
| 302 |
+
ransac_inliers,
|
| 303 |
+
label[b],
|
| 304 |
+
fp_penalty * alpha,
|
| 305 |
+
use_whitening=cfg.use_whitening,
|
| 306 |
+
selected_only=cfg.selected_only,
|
| 307 |
+
filter_mode=cfg.padding_filter_mode,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if descriptor_loss is not None:
|
| 311 |
+
hard_loss = descriptor_loss(
|
| 312 |
+
desc1=desc_1[b],
|
| 313 |
+
desc2=desc_2[b],
|
| 314 |
+
matches=abs_idx_matches,
|
| 315 |
+
inliers=ransac_inliers,
|
| 316 |
+
label=label[b],
|
| 317 |
+
logits_1=None,
|
| 318 |
+
logits_2=None,
|
| 319 |
+
)
|
| 320 |
+
loss_desc_stack = (
|
| 321 |
+
hard_loss if loss_desc_stack is None else torch.hstack((loss_desc_stack, hard_loss))
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
sum_reward_batch += dense_rewards.sum()
|
| 325 |
+
|
| 326 |
+
current_loss_policy = (dense_rewards * dense_logprobs).view(-1)
|
| 327 |
+
|
| 328 |
+
loss_policy_stack = (
|
| 329 |
+
current_loss_policy
|
| 330 |
+
if loss_policy_stack is None
|
| 331 |
+
else torch.hstack((loss_policy_stack, current_loss_policy))
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if kp_penalty != 0.0:
|
| 335 |
+
# keypoints with low logprob are penalized
|
| 336 |
+
# as they get large negative logprob values multiplying them with the penalty will make the loss larger
|
| 337 |
+
loss_kp = (
|
| 338 |
+
logprobs1[b][mask_selection_for_matching_1]
|
| 339 |
+
* torch.full_like(
|
| 340 |
+
logprobs1[b][mask_selection_for_matching_1],
|
| 341 |
+
kp_penalty * alpha,
|
| 342 |
+
)
|
| 343 |
+
).mean() + (
|
| 344 |
+
logprobs2[b][mask_selection_for_matching_2]
|
| 345 |
+
* torch.full_like(
|
| 346 |
+
logprobs2[b][mask_selection_for_matching_2],
|
| 347 |
+
kp_penalty * alpha,
|
| 348 |
+
)
|
| 349 |
+
).mean()
|
| 350 |
+
loss_kp_stack = loss_kp if loss_kp_stack is None else torch.hstack((loss_kp_stack, loss_kp))
|
| 351 |
+
|
| 352 |
+
sum_num_keypoints_1 += mask_selection_for_matching_1.sum()
|
| 353 |
+
sum_num_keypoints_2 += mask_selection_for_matching_2.sum()
|
| 354 |
+
|
| 355 |
+
loss = loss_policy_stack.mean()
|
| 356 |
+
if loss_kp_stack is not None:
|
| 357 |
+
loss += loss_kp_stack.mean()
|
| 358 |
+
|
| 359 |
+
loss = -loss
|
| 360 |
+
|
| 361 |
+
if descriptor_loss is not None:
|
| 362 |
+
loss += cfg.desc_loss_weight * loss_desc_stack.mean()
|
| 363 |
+
|
| 364 |
+
pbar.set_description(
|
| 365 |
+
f"LP: {loss.item():.4f} - Det: ({sum_num_keypoints_1 / batch_size:.4f}, {sum_num_keypoints_2 / batch_size:.4f}), #mRwd: {sum_reward_batch / batch_size:.1f}"
|
| 366 |
+
)
|
| 367 |
+
pbar.update()
|
| 368 |
+
|
| 369 |
+
# backward pass
|
| 370 |
+
loss /= num_grad_accs
|
| 371 |
+
fabric.backward(loss)
|
| 372 |
+
|
| 373 |
+
if i_step % num_grad_accs == 0:
|
| 374 |
+
opt_pi.step()
|
| 375 |
+
opt_pi.zero_grad()
|
| 376 |
+
|
| 377 |
+
if i_step % cfg.log_interval == 0:
|
| 378 |
+
wandb_logger.log(
|
| 379 |
+
{
|
| 380 |
+
# "loss": loss.item() if not use_amp else scaled_loss.item(),
|
| 381 |
+
"loss": loss.item(),
|
| 382 |
+
"loss_policy": -loss_policy_stack.mean().item(),
|
| 383 |
+
"loss_kp": loss_kp_stack.mean().item() if loss_kp_stack is not None else 0.0,
|
| 384 |
+
"loss_hard": (loss_desc_stack.mean().item() if loss_desc_stack is not None else 0.0),
|
| 385 |
+
"mean_num_det_kpts1": sum_num_keypoints_1 / batch_size,
|
| 386 |
+
"mean_num_det_kpts2": sum_num_keypoints_2 / batch_size,
|
| 387 |
+
"mean_reward": sum_reward_batch / batch_size,
|
| 388 |
+
"lr": opt_pi.param_groups[0]["lr"],
|
| 389 |
+
"ma_skipped_batches": sum(ma_skipped_batches) / len(ma_skipped_batches),
|
| 390 |
+
"inl_th": inl_th,
|
| 391 |
+
},
|
| 392 |
+
step=i_step,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if i_step % cfg.val_interval == 0:
|
| 396 |
+
val_benchmark.evaluate(net, fabric.device, progress_bar=False)
|
| 397 |
+
val_benchmark.log_results(logger=wandb_logger, step=i_step)
|
| 398 |
+
|
| 399 |
+
# ensure that the model is in training mode again
|
| 400 |
+
net.train()
|
| 401 |
+
|
| 402 |
+
# save the model
|
| 403 |
+
torch.save(
|
| 404 |
+
net.state_dict(),
|
| 405 |
+
output_dir / ("model" + "_" + str(i_step + 1) + "_final" + ".pth"),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if __name__ == "__main__":
|
| 410 |
+
train()
|
ripe/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from x_dd.utils import loggers
|
| 2 |
+
from ripe.utils.pylogger import get_pylogger # noqa: F401
|
ripe/utils/image_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import h5py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Camera:
|
| 7 |
+
def __init__(self, K, R, t):
|
| 8 |
+
self.K = K
|
| 9 |
+
self.R = R
|
| 10 |
+
self.t = t
|
| 11 |
+
|
| 12 |
+
@classmethod
|
| 13 |
+
def from_calibration_file(cls, path: str):
|
| 14 |
+
with h5py.File(path, "r") as f:
|
| 15 |
+
K = torch.tensor(np.array(f["K"]), dtype=torch.float32)
|
| 16 |
+
R = torch.tensor(np.array(f["R"]), dtype=torch.float32)
|
| 17 |
+
T = torch.tensor(np.array(f["T"]), dtype=torch.float32)
|
| 18 |
+
|
| 19 |
+
return cls(K, R, T)
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def K_inv(self):
|
| 23 |
+
return self.K.inverse()
|
| 24 |
+
|
| 25 |
+
def to_cameradict(self):
|
| 26 |
+
fx = self.K[0, 0].item()
|
| 27 |
+
fy = self.K[1, 1].item()
|
| 28 |
+
cx = self.K[0, 2].item()
|
| 29 |
+
cy = self.K[1, 2].item()
|
| 30 |
+
|
| 31 |
+
params = {
|
| 32 |
+
"model": "PINHOLE",
|
| 33 |
+
"width": int(cx * 2),
|
| 34 |
+
"height": int(cy * 2),
|
| 35 |
+
"params": [fx, fy, cx, cy],
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return params
|
| 39 |
+
|
| 40 |
+
def __repr__(self):
|
| 41 |
+
return f"ImageData(K={self.K}, R={self.R}, t={self.t})"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def cameras2F(cam1: Camera, cam2: Camera) -> torch.Tensor:
|
| 45 |
+
E = cameras2E(cam1, cam2)
|
| 46 |
+
return cam2.K_inv.T @ E @ cam1.K_inv
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def cameras2E(cam1: Camera, cam2: Camera) -> torch.Tensor:
|
| 50 |
+
R = cam2.R @ cam1.R.T
|
| 51 |
+
T = cam2.t - R @ cam1.t
|
| 52 |
+
return cross_product_matrix(T) @ R
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def cross_product_matrix(v) -> torch.Tensor:
|
| 56 |
+
"""Following en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication."""
|
| 57 |
+
|
| 58 |
+
return torch.tensor(
|
| 59 |
+
[[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]],
|
| 60 |
+
dtype=v.dtype,
|
| 61 |
+
device=v.device,
|
| 62 |
+
)
|
ripe/utils/pose_error.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mostly from: https://github.com/cvg/glue-factory/blob/main/gluefactory/geometry/epipolar.py
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def angle_error_mat(R1, R2):
|
| 8 |
+
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
|
| 9 |
+
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
|
| 10 |
+
return torch.rad2deg(torch.abs(torch.arccos(cos)))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def angle_error_vec(v1, v2, eps=1e-10):
|
| 14 |
+
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
|
| 15 |
+
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
|
| 16 |
+
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def relative_pose_error(R_gt, t_gt, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
|
| 20 |
+
# angle error between 2 vectors
|
| 21 |
+
t_err = angle_error_vec(t, t_gt, eps)
|
| 22 |
+
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
|
| 23 |
+
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
|
| 24 |
+
t_err = torch.zeros_like(t_err)
|
| 25 |
+
|
| 26 |
+
# angle error between 2 rotation matrices
|
| 27 |
+
r_err = angle_error_mat(R, R_gt)
|
| 28 |
+
|
| 29 |
+
return t_err, r_err
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cal_error_auc(errors, thresholds):
|
| 33 |
+
sort_idx = np.argsort(errors)
|
| 34 |
+
errors = np.array(errors.copy())[sort_idx]
|
| 35 |
+
recall = (np.arange(len(errors)) + 1) / len(errors)
|
| 36 |
+
errors = np.r_[0.0, errors]
|
| 37 |
+
recall = np.r_[0.0, recall]
|
| 38 |
+
aucs = []
|
| 39 |
+
for t in thresholds:
|
| 40 |
+
last_index = np.searchsorted(errors, t)
|
| 41 |
+
r = np.r_[recall[:last_index], recall[last_index - 1]]
|
| 42 |
+
e = np.r_[errors[:last_index], t]
|
| 43 |
+
aucs.append(np.round((np.trapz(r, x=e) / t), 4))
|
| 44 |
+
return aucs
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AUCMetric:
|
| 48 |
+
def __init__(self, thresholds, elements=None):
|
| 49 |
+
self._elements = elements
|
| 50 |
+
self.thresholds = thresholds
|
| 51 |
+
if not isinstance(thresholds, list):
|
| 52 |
+
self.thresholds = [thresholds]
|
| 53 |
+
|
| 54 |
+
def update(self, tensor):
|
| 55 |
+
assert tensor.dim() == 1
|
| 56 |
+
self._elements += tensor.cpu().numpy().tolist()
|
| 57 |
+
|
| 58 |
+
def compute(self):
|
| 59 |
+
if len(self._elements) == 0:
|
| 60 |
+
return np.nan
|
| 61 |
+
else:
|
| 62 |
+
return cal_error_auc(self._elements, self.thresholds)
|
ripe/utils/pylogger.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
# from pytorch_lightning.utilities import rank_zero_only
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_base_pylogger():
|
| 7 |
+
"""Initializes base python command line logger."""
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(
|
| 10 |
+
level=logging.WARNING,
|
| 11 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 12 |
+
handlers=[logging.StreamHandler()],
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_pylogger(name=__name__) -> logging.Logger:
|
| 17 |
+
"""Initializes multi-GPU-friendly python command line logger."""
|
| 18 |
+
|
| 19 |
+
if not logging.root.handlers:
|
| 20 |
+
init_base_pylogger()
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(name)
|
| 23 |
+
|
| 24 |
+
logger.setLevel(logging.DEBUG)
|
| 25 |
+
|
| 26 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 27 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 28 |
+
# logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
| 29 |
+
# for level in logging_levels:
|
| 30 |
+
# setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
| 31 |
+
|
| 32 |
+
return logger
|
ripe/utils/utils.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms.functional import resize
|
| 8 |
+
|
| 9 |
+
from ripe import utils
|
| 10 |
+
|
| 11 |
+
log = utils.get_pylogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def gridify(x, window_size):
|
| 15 |
+
"""Turn a tensor of BxCxHxW into a tensor of
|
| 16 |
+
BxCx(H//window_size)x(W//window_size)x(window_size**2)
|
| 17 |
+
|
| 18 |
+
Params:
|
| 19 |
+
x: Input tensor of shape BxCxHxW
|
| 20 |
+
window_size: Size of the window
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
x: Output tensor of shape BxCx(H//window_size)x(W//window_size)x(window_size**2)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
assert x.dim() == 4, "Input tensor x must have 4 dimensions"
|
| 27 |
+
|
| 28 |
+
B, C, H, W = x.shape
|
| 29 |
+
x = (
|
| 30 |
+
x.unfold(2, window_size, window_size)
|
| 31 |
+
.unfold(3, window_size, window_size)
|
| 32 |
+
.reshape(B, C, H // window_size, W // window_size, window_size**2)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_grid(B, H, W, device):
|
| 39 |
+
x1_n = torch.meshgrid(*[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)])
|
| 40 |
+
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
|
| 41 |
+
return x1_n
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def cv2_matches_from_kornia(match_dists: torch.Tensor, match_idxs: torch.Tensor) -> List[cv2.DMatch]:
|
| 45 |
+
return [cv2.DMatch(idx[0].item(), idx[1].item(), d.item()) for idx, d in zip(match_idxs, match_dists)]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def to_cv_kpts(kpts, scores):
|
| 49 |
+
kp = kpts.cpu().numpy().astype(np.int16)
|
| 50 |
+
s = scores.cpu().numpy()
|
| 51 |
+
|
| 52 |
+
cv_kp = [cv2.KeyPoint(kp[i][0], kp[i][1], 6, 0, s[i]) for i in range(len(kp))]
|
| 53 |
+
|
| 54 |
+
return cv_kp
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def resize_image(image, min_size=512, max_size=768):
|
| 58 |
+
"""Resize image to a new size while maintaining the aspect ratio.
|
| 59 |
+
|
| 60 |
+
Params:
|
| 61 |
+
image (torch.tensor): Image to be resized.
|
| 62 |
+
min_size (int): Minimum size of the smaller dimension.
|
| 63 |
+
max_size (int): Maximum size of the larger dimension.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
image: Resized image.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
h, w = image.shape[-2:]
|
| 70 |
+
|
| 71 |
+
aspect_ratio = w / h
|
| 72 |
+
|
| 73 |
+
if w > h:
|
| 74 |
+
new_w = max(min_size, min(max_size, w))
|
| 75 |
+
new_h = int(new_w / aspect_ratio)
|
| 76 |
+
else:
|
| 77 |
+
new_h = max(min_size, min(max_size, h))
|
| 78 |
+
new_w = int(new_h * aspect_ratio)
|
| 79 |
+
|
| 80 |
+
new_size = (new_h, new_w)
|
| 81 |
+
|
| 82 |
+
image = resize(image, new_size)
|
| 83 |
+
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_rewards(
|
| 88 |
+
reward,
|
| 89 |
+
kps1,
|
| 90 |
+
kps2,
|
| 91 |
+
selected_mask1,
|
| 92 |
+
selected_mask2,
|
| 93 |
+
padding_mask1,
|
| 94 |
+
padding_mask2,
|
| 95 |
+
rel_idx_matches,
|
| 96 |
+
abs_idx_matches,
|
| 97 |
+
ransac_inliers,
|
| 98 |
+
label,
|
| 99 |
+
penalty=0.0,
|
| 100 |
+
use_whitening=False,
|
| 101 |
+
selected_only=False,
|
| 102 |
+
filter_mode=None,
|
| 103 |
+
):
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
reward *= 1.0 if label else -1.0
|
| 106 |
+
|
| 107 |
+
dense_returns = torch.zeros((len(kps1), len(kps2)), device=kps1.device)
|
| 108 |
+
|
| 109 |
+
if filter_mode == "ignore":
|
| 110 |
+
dense_returns[
|
| 111 |
+
abs_idx_matches[:, 0][ransac_inliers],
|
| 112 |
+
abs_idx_matches[:, 1][ransac_inliers],
|
| 113 |
+
] = reward
|
| 114 |
+
elif filter_mode == "punish":
|
| 115 |
+
in_padding_area = (
|
| 116 |
+
padding_mask1[abs_idx_matches[:, 0]] & padding_mask2[abs_idx_matches[:, 1]]
|
| 117 |
+
) # both in the image area (not in padding area)
|
| 118 |
+
|
| 119 |
+
dense_returns[
|
| 120 |
+
abs_idx_matches[:, 0][ransac_inliers & in_padding_area],
|
| 121 |
+
abs_idx_matches[:, 1][ransac_inliers & in_padding_area],
|
| 122 |
+
] = reward
|
| 123 |
+
dense_returns[
|
| 124 |
+
abs_idx_matches[:, 0][ransac_inliers & ~in_padding_area],
|
| 125 |
+
abs_idx_matches[:, 1][ransac_inliers & ~in_padding_area],
|
| 126 |
+
] = -1.0
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Unknown filter mode: {filter_mode}")
|
| 129 |
+
|
| 130 |
+
if selected_only:
|
| 131 |
+
dense_returns = dense_returns[selected_mask1, :][:, selected_mask2]
|
| 132 |
+
if filter_mode == "ignore" and not selected_only:
|
| 133 |
+
dense_returns = dense_returns[padding_mask1, :][:, padding_mask2]
|
| 134 |
+
|
| 135 |
+
if penalty != 0.0:
|
| 136 |
+
# pos. pair: small penalty for not finding a match
|
| 137 |
+
# neg. pair: small reward for not finding a match
|
| 138 |
+
penalty_val = penalty if label else -penalty
|
| 139 |
+
|
| 140 |
+
dense_returns[dense_returns == 0.0] = penalty_val
|
| 141 |
+
|
| 142 |
+
if use_whitening:
|
| 143 |
+
dense_returns = (dense_returns - dense_returns.mean()) / (dense_returns.std() + 1e-6)
|
| 144 |
+
|
| 145 |
+
return dense_returns
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_other_random_id(idx: int, len_dataset: int, min_dist: int = 20):
|
| 149 |
+
for _ in range(10):
|
| 150 |
+
tgt_id = random.randint(0, len_dataset - 1)
|
| 151 |
+
if abs(idx - tgt_id) >= min_dist:
|
| 152 |
+
return tgt_id
|
| 153 |
+
|
| 154 |
+
raise ValueError(f"Could not find target image with distance >= {min_dist} from source image {idx}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def cv_resize_and_pad_to_shape(image, new_shape, padding_color=(0, 0, 0)):
|
| 158 |
+
"""Resizes image to new_shape with maintaining the aspect ratio and pads with padding_color if
|
| 159 |
+
needed.
|
| 160 |
+
|
| 161 |
+
Params:
|
| 162 |
+
image: Image to be resized.
|
| 163 |
+
new_shape: Expected (height, width) of new image.
|
| 164 |
+
padding_color: Tuple in BGR of padding color
|
| 165 |
+
Returns:
|
| 166 |
+
image: Resized image with padding
|
| 167 |
+
"""
|
| 168 |
+
h, w = image.shape[:2]
|
| 169 |
+
|
| 170 |
+
scale_h = new_shape[0] / h
|
| 171 |
+
scale_w = new_shape[1] / w
|
| 172 |
+
|
| 173 |
+
scale = None
|
| 174 |
+
if scale_w * h > new_shape[0]:
|
| 175 |
+
scale = scale_h
|
| 176 |
+
elif scale_h * w > new_shape[1]:
|
| 177 |
+
scale = scale_w
|
| 178 |
+
else:
|
| 179 |
+
scale = max(scale_h, scale_w)
|
| 180 |
+
|
| 181 |
+
new_w, new_h = int(round(w * scale)), int(round(h * scale))
|
| 182 |
+
|
| 183 |
+
image = cv2.resize(image, (new_w, new_h))
|
| 184 |
+
|
| 185 |
+
missing_h = new_shape[0] - new_h
|
| 186 |
+
missing_w = new_shape[1] - new_w
|
| 187 |
+
|
| 188 |
+
top, bottom = missing_h // 2, missing_h - (missing_h // 2)
|
| 189 |
+
left, right = missing_w // 2, missing_w - (missing_w // 2)
|
| 190 |
+
|
| 191 |
+
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
|
| 192 |
+
return image
|
ripe/utils/wandb_utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import omegaconf
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_flattened_wandb_cfg(conf_dict):
|
| 5 |
+
flattened = {}
|
| 6 |
+
|
| 7 |
+
def _flatten(cfg, prefix=""):
|
| 8 |
+
for k, v in cfg.items():
|
| 9 |
+
new_key = f"{prefix}.{k}" if prefix else k
|
| 10 |
+
if isinstance(v, omegaconf.dictconfig.DictConfig):
|
| 11 |
+
_flatten(v, new_key)
|
| 12 |
+
else:
|
| 13 |
+
flattened[new_key] = v
|
| 14 |
+
|
| 15 |
+
_flatten(conf_dict)
|
| 16 |
+
return flattened
|