diff --git a/app.py b/app.py index fa9289672a72747a831206e782646df4e8fa37f9..96c6dac4ed89a6bd1783c81b6c145fb4315aa4c1 100644 --- a/app.py +++ b/app.py @@ -9,9 +9,10 @@ from extra_utils.utils import ( match_features, get_model, get_feature_model, - display_matches + display_matches, ) + def run_matching( match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1 ): @@ -277,7 +278,7 @@ def run(config): matcher_info, ] button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs) - + app.launch(share=False) diff --git a/third_party/ALIKE/alike.py b/third_party/ALIKE/alike.py index 303616d52581efce0ae0eb86af70f5ea8984909d..b975f806f3e0f593a3564ae52d9d08187f514b34 100644 --- a/third_party/ALIKE/alike.py +++ b/third_party/ALIKE/alike.py @@ -12,46 +12,89 @@ from soft_detect import DKD import time configs = { - 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2, - 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth')}, - 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2, - 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth')}, - 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2, - 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth')}, - 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2, - 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')}, + "alike-t": { + "c1": 8, + "c2": 16, + "c3": 32, + "c4": 64, + "dim": 64, + "single_head": True, + "radius": 2, + "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-t.pth"), + }, + "alike-s": { + "c1": 8, + "c2": 16, + "c3": 48, + "c4": 96, + "dim": 96, + "single_head": True, + "radius": 2, + "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-s.pth"), + }, + "alike-n": { + "c1": 16, + "c2": 32, + "c3": 64, + "c4": 128, + "dim": 128, + "single_head": True, + "radius": 2, + "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-n.pth"), + }, + "alike-l": { + "c1": 32, + "c2": 64, + "c3": 128, + "c4": 128, + "dim": 128, + "single_head": False, + "radius": 2, + "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-l.pth"), + }, } class ALike(ALNet): - def __init__(self, - # ================================== feature encoder - c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128, - single_head: bool = False, - # ================================== detect parameters - radius: int = 2, - top_k: int = 500, scores_th: float = 0.5, - n_limit: int = 5000, - device: str = 'cpu', - model_path: str = '' - ): + def __init__( + self, + # ================================== feature encoder + c1: int = 32, + c2: int = 64, + c3: int = 128, + c4: int = 128, + dim: int = 128, + single_head: bool = False, + # ================================== detect parameters + radius: int = 2, + top_k: int = 500, + scores_th: float = 0.5, + n_limit: int = 5000, + device: str = "cpu", + model_path: str = "", + ): super().__init__(c1, c2, c3, c4, dim, single_head) self.radius = radius self.top_k = top_k self.n_limit = n_limit self.scores_th = scores_th - self.dkd = DKD(radius=self.radius, top_k=self.top_k, - scores_th=self.scores_th, n_limit=self.n_limit) + self.dkd = DKD( + radius=self.radius, + top_k=self.top_k, + scores_th=self.scores_th, + n_limit=self.n_limit, + ) self.device = device - if model_path != '': + if model_path != "": state_dict = torch.load(model_path, self.device) self.load_state_dict(state_dict) self.to(self.device) self.eval() - logging.info(f'Loaded model parameters from {model_path}') + logging.info(f"Loaded model parameters from {model_path}") logging.info( - f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB") + f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB" + ) def extract_dense_map(self, image, ret_dict=False): # ==================================================== @@ -81,7 +124,10 @@ class ALike(ALNet): descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1) if ret_dict: - return {'descriptor_map': descriptor_map, 'scores_map': scores_map, } + return { + "descriptor_map": descriptor_map, + "scores_map": scores_map, + } else: return descriptor_map, scores_map @@ -104,15 +150,22 @@ class ALike(ALNet): image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio) # ==================== convert image to tensor - image = torch.from_numpy(image).to(self.device).to(torch.float32).permute(2, 0, 1)[None] / 255.0 + image = ( + torch.from_numpy(image) + .to(self.device) + .to(torch.float32) + .permute(2, 0, 1)[None] + / 255.0 + ) # ==================== extract keypoints start = time.time() with torch.no_grad(): descriptor_map, scores_map = self.extract_dense_map(image) - keypoints, descriptors, scores, _ = self.dkd(scores_map, descriptor_map, - sub_pixel=sub_pixel) + keypoints, descriptors, scores, _ = self.dkd( + scores_map, descriptor_map, sub_pixel=sub_pixel + ) keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0] keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]]) @@ -124,14 +177,16 @@ class ALike(ALNet): end = time.time() - return {'keypoints': keypoints.cpu().numpy(), - 'descriptors': descriptors.cpu().numpy(), - 'scores': scores.cpu().numpy(), - 'scores_map': scores_map.cpu().numpy(), - 'time': end - start, } + return { + "keypoints": keypoints.cpu().numpy(), + "descriptors": descriptors.cpu().numpy(), + "scores": scores.cpu().numpy(), + "scores_map": scores_map.cpu().numpy(), + "time": end - start, + } -if __name__ == '__main__': +if __name__ == "__main__": import numpy as np from thop import profile @@ -139,5 +194,5 @@ if __name__ == '__main__': image = np.random.random((640, 480, 3)).astype(np.float32) flops, params = profile(net, inputs=(image, 9999, False), verbose=False) - print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9)) - print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3)) + print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9)) + print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3)) diff --git a/third_party/ALIKE/alnet.py b/third_party/ALIKE/alnet.py index 53127063233660c7b96aa15e89aa4a8a1a340dd1..91cb7ee55e502895e7b0037f2add1a35a613cd40 100644 --- a/third_party/ALIKE/alnet.py +++ b/third_party/ALIKE/alnet.py @@ -5,9 +5,13 @@ from typing import Optional, Callable class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, - gate: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None): + def __init__( + self, + in_channels, + out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): super().__init__() if gate is None: self.gate = nn.ReLU(inplace=True) @@ -31,16 +35,16 @@ class ResBlock(nn.Module): expansion: int = 1 def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - gate: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResBlock, self).__init__() if gate is None: @@ -50,7 +54,7 @@ class ResBlock(nn.Module): if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('ResBlock only supports groups=1 and base_width=64') + raise ValueError("ResBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in ResBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -81,9 +85,15 @@ class ResBlock(nn.Module): class ALNet(nn.Module): - def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128, - single_head: bool = True, - ): + def __init__( + self, + c1: int = 32, + c2: int = 64, + c3: int = 128, + c4: int = 128, + dim: int = 128, + single_head: bool = True, + ): super().__init__() self.gate = nn.ReLU(inplace=True) @@ -93,28 +103,48 @@ class ALNet(nn.Module): self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d) - self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1, - downsample=nn.Conv2d(c1, c2, 1), - gate=self.gate, - norm_layer=nn.BatchNorm2d) - self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1, - downsample=nn.Conv2d(c2, c3, 1), - gate=self.gate, - norm_layer=nn.BatchNorm2d) - self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1, - downsample=nn.Conv2d(c3, c4, 1), - gate=self.gate, - norm_layer=nn.BatchNorm2d) + self.block2 = ResBlock( + inplanes=c1, + planes=c2, + stride=1, + downsample=nn.Conv2d(c1, c2, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d, + ) + self.block3 = ResBlock( + inplanes=c2, + planes=c3, + stride=1, + downsample=nn.Conv2d(c2, c3, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d, + ) + self.block4 = ResBlock( + inplanes=c3, + planes=c4, + stride=1, + downsample=nn.Conv2d(c3, c4, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d, + ) # ================================== feature aggregation self.conv1 = resnet.conv1x1(c1, dim // 4) self.conv2 = resnet.conv1x1(c2, dim // 4) self.conv3 = resnet.conv1x1(c3, dim // 4) self.conv4 = resnet.conv1x1(dim, dim // 4) - self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) - self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) - self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) + self.upsample2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.upsample4 = nn.Upsample( + scale_factor=4, mode="bilinear", align_corners=True + ) + self.upsample8 = nn.Upsample( + scale_factor=8, mode="bilinear", align_corners=True + ) + self.upsample32 = nn.Upsample( + scale_factor=32, mode="bilinear", align_corners=True + ) # ================================== detector and descriptor head self.single_head = single_head @@ -153,12 +183,12 @@ class ALNet(nn.Module): return scores_map, descriptor_map -if __name__ == '__main__': +if __name__ == "__main__": from thop import profile net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True) image = torch.randn(1, 3, 640, 480) flops, params = profile(net, inputs=(image,), verbose=False) - print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9)) - print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3)) + print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9)) + print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3)) diff --git a/third_party/ALIKE/demo.py b/third_party/ALIKE/demo.py index 9bfbefdd26cfeceefc75f90d1c44a7f922c624a5..a3f5130eea283404412b374c678ba3a1ae6d1c04 100644 --- a/third_party/ALIKE/demo.py +++ b/third_party/ALIKE/demo.py @@ -12,13 +12,13 @@ from alike import ALike, configs class ImageLoader(object): def __init__(self, filepath: str): self.N = 3000 - if filepath.startswith('camera'): + if filepath.startswith("camera"): camera = int(filepath[6:]) self.cap = cv2.VideoCapture(camera) if not self.cap.isOpened(): raise IOError(f"Can't open camera {camera}!") - logging.info(f'Opened camera {camera}') - self.mode = 'camera' + logging.info(f"Opened camera {camera}") + self.mode = "camera" elif os.path.exists(filepath): if os.path.isfile(filepath): self.cap = cv2.VideoCapture(filepath) @@ -27,34 +27,38 @@ class ImageLoader(object): rate = self.cap.get(cv2.CAP_PROP_FPS) self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 duration = self.N / rate - logging.info(f'Opened video {filepath}') - logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s') - self.mode = 'video' + logging.info(f"Opened video {filepath}") + logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s") + self.mode = "video" else: - self.images = glob.glob(os.path.join(filepath, '*.png')) + \ - glob.glob(os.path.join(filepath, '*.jpg')) + \ - glob.glob(os.path.join(filepath, '*.ppm')) + self.images = ( + glob.glob(os.path.join(filepath, "*.png")) + + glob.glob(os.path.join(filepath, "*.jpg")) + + glob.glob(os.path.join(filepath, "*.ppm")) + ) self.images.sort() self.N = len(self.images) - logging.info(f'Loading {self.N} images') - self.mode = 'images' + logging.info(f"Loading {self.N} images") + self.mode = "images" else: - raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath) + raise IOError( + "Error filepath (camerax/path of images/path of videos): ", filepath + ) def __getitem__(self, item): - if self.mode == 'camera' or self.mode == 'video': + if self.mode == "camera" or self.mode == "video": if item > self.N: return None ret, img = self.cap.read() if not ret: raise "Can't read image from camera" - if self.mode == 'video': + if self.mode == "video": self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) - elif self.mode == 'images': + elif self.mode == "images": filename = self.images[item] img = cv2.imread(filename) if img is None: - raise Exception('Error reading image %s' % filename) + raise Exception("Error reading image %s" % filename) return img def __len__(self): @@ -99,38 +103,68 @@ class SimpleTracker(object): nn12 = np.argmax(sim, axis=1) nn21 = np.argmax(sim, axis=0) ids1 = np.arange(0, sim.shape[0]) - mask = (ids1 == nn21[nn12]) + mask = ids1 == nn21[nn12] matches = np.stack([ids1[mask], nn12[mask]]) return matches.transpose() -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ALike Demo.') - parser.add_argument('input', type=str, default='', - help='Image directory or movie file or "camera0" (for webcam0).') - parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t", - help="The model configuration") - parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).") - parser.add_argument('--top_k', type=int, default=-1, - help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)') - parser.add_argument('--scores_th', type=float, default=0.2, - help='Detector score threshold (default: 0.2).') - parser.add_argument('--n_limit', type=int, default=5000, - help='Maximum number of keypoints to be detected (default: 5000).') - parser.add_argument('--no_display', action='store_true', - help='Do not display images to screen. Useful if running remotely (default: False).') - parser.add_argument('--no_sub_pixel', action='store_true', - help='Do not detect sub-pixel keypoints (default: False).') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ALike Demo.") + parser.add_argument( + "input", + type=str, + default="", + help='Image directory or movie file or "camera0" (for webcam0).', + ) + parser.add_argument( + "--model", + choices=["alike-t", "alike-s", "alike-n", "alike-l"], + default="alike-t", + help="The model configuration", + ) + parser.add_argument( + "--device", type=str, default="cuda", help="Running device (default: cuda)." + ) + parser.add_argument( + "--top_k", + type=int, + default=-1, + help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)", + ) + parser.add_argument( + "--scores_th", + type=float, + default=0.2, + help="Detector score threshold (default: 0.2).", + ) + parser.add_argument( + "--n_limit", + type=int, + default=5000, + help="Maximum number of keypoints to be detected (default: 5000).", + ) + parser.add_argument( + "--no_display", + action="store_true", + help="Do not display images to screen. Useful if running remotely (default: False).", + ) + parser.add_argument( + "--no_sub_pixel", + action="store_true", + help="Do not detect sub-pixel keypoints (default: False).", + ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) image_loader = ImageLoader(args.input) - model = ALike(**configs[args.model], - device=args.device, - top_k=args.top_k, - scores_th=args.scores_th, - n_limit=args.n_limit) + model = ALike( + **configs[args.model], + device=args.device, + top_k=args.top_k, + scores_th=args.scores_th, + n_limit=args.n_limit, + ) tracker = SimpleTracker() if not args.no_display: @@ -142,26 +176,26 @@ if __name__ == '__main__': for img in progress_bar: if img is None: break - + img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) - kpts = pred['keypoints'] - desc = pred['descriptors'] - runtime.append(pred['time']) + kpts = pred["keypoints"] + desc = pred["descriptors"] + runtime.append(pred["time"]) out, N_matches = tracker.update(img, kpts, desc) - ave_fps = (1. / np.stack(runtime)).mean() + ave_fps = (1.0 / np.stack(runtime)).mean() status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" progress_bar.set_description(status) if not args.no_display: - cv2.setWindowTitle(args.model, args.model + ': ' + status) + cv2.setWindowTitle(args.model, args.model + ": " + status) cv2.imshow(args.model, out) - if cv2.waitKey(1) == ord('q'): + if cv2.waitKey(1) == ord("q"): break - logging.info('Finished!') + logging.info("Finished!") if not args.no_display: - logging.info('Press any key to exit!') + logging.info("Press any key to exit!") cv2.waitKey() diff --git a/third_party/ALIKE/hseq/eval.py b/third_party/ALIKE/hseq/eval.py index abca625044013a0cd34a518223c32d3ec8abb8a3..1d91398740e5dee9d2968fb418fcb45febd015ba 100644 --- a/third_party/ALIKE/hseq/eval.py +++ b/third_party/ALIKE/hseq/eval.py @@ -6,29 +6,53 @@ import numpy as np from extract import extract_method use_cuda = torch.cuda.is_available() -device = torch.device('cuda' if use_cuda else 'cpu') - -methods = ['d2', 'lfnet', 'superpoint', 'r2d2', 'aslfeat', 'disk', - 'alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms'] -names = ['D2-Net(MS)', 'LF-Net(MS)', 'SuperPoint', 'R2D2(MS)', 'ASLFeat(MS)', 'DISK', - 'ALike-N', 'ALike-L', 'ALike-N(MS)', 'ALike-L(MS)'] +device = torch.device("cuda" if use_cuda else "cpu") + +methods = [ + "d2", + "lfnet", + "superpoint", + "r2d2", + "aslfeat", + "disk", + "alike-n", + "alike-l", + "alike-n-ms", + "alike-l-ms", +] +names = [ + "D2-Net(MS)", + "LF-Net(MS)", + "SuperPoint", + "R2D2(MS)", + "ASLFeat(MS)", + "DISK", + "ALike-N", + "ALike-L", + "ALike-N(MS)", + "ALike-L(MS)", +] top_k = None n_i = 52 n_v = 56 -cache_dir = 'hseq/cache' -dataset_path = 'hseq/hpatches-sequences-release' +cache_dir = "hseq/cache" +dataset_path = "hseq/hpatches-sequences-release" -def generate_read_function(method, extension='ppm'): +def generate_read_function(method, extension="ppm"): def read_function(seq_name, im_idx): - aux = np.load(os.path.join(dataset_path, seq_name, '%d.%s.%s' % (im_idx, extension, method))) + aux = np.load( + os.path.join( + dataset_path, seq_name, "%d.%s.%s" % (im_idx, extension, method) + ) + ) if top_k is None: - return aux['keypoints'], aux['descriptors'] + return aux["keypoints"], aux["descriptors"] else: - assert ('scores' in aux) - ids = np.argsort(aux['scores'])[-top_k:] - return aux['keypoints'][ids, :], aux['descriptors'][ids, :] + assert "scores" in aux + ids = np.argsort(aux["scores"])[-top_k:] + return aux["keypoints"][ids, :], aux["descriptors"][ids, :] return read_function @@ -39,7 +63,7 @@ def mnn_matcher(descriptors_a, descriptors_b): nn12 = torch.max(sim, dim=1)[1] nn21 = torch.max(sim, dim=0)[1] ids1 = torch.arange(0, sim.shape[0], device=device) - mask = (ids1 == nn21[nn12]) + mask = ids1 == nn21[nn12] matches = torch.stack([ids1[mask], nn12[mask]]) return matches.t().data.cpu().numpy() @@ -73,7 +97,7 @@ def benchmark_features(read_feats): n_feats.append(keypoints_a.shape[0]) # =========== compute homography - ref_img = cv2.imread(os.path.join(dataset_path, seq_name, '1.ppm')) + ref_img = cv2.imread(os.path.join(dataset_path, seq_name, "1.ppm")) ref_img_shape = ref_img.shape for im_idx in range(2, 7): @@ -82,17 +106,19 @@ def benchmark_features(read_feats): matches = mnn_matcher( torch.from_numpy(descriptors_a).to(device=device), - torch.from_numpy(descriptors_b).to(device=device) + torch.from_numpy(descriptors_b).to(device=device), ) - homography = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))) + homography = np.loadtxt( + os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx)) + ) - pos_a = keypoints_a[matches[:, 0], : 2] + pos_a = keypoints_a[matches[:, 0], :2] pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1) pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h))) - pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2:] + pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:] - pos_b = keypoints_b[matches[:, 1], : 2] + pos_b = keypoints_b[matches[:, 1], :2] dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1)) @@ -103,28 +129,37 @@ def benchmark_features(read_feats): dist = np.array([float("inf")]) for thr in rng: - if seq_name[0] == 'i': + if seq_name[0] == "i": i_err[thr] += np.mean(dist <= thr) else: v_err[thr] += np.mean(dist <= thr) # =========== compute homography gt_homo = homography - pred_homo, _ = cv2.findHomography(keypoints_a[matches[:, 0], : 2], keypoints_b[matches[:, 1], : 2], - cv2.RANSAC) + pred_homo, _ = cv2.findHomography( + keypoints_a[matches[:, 0], :2], + keypoints_b[matches[:, 1], :2], + cv2.RANSAC, + ) if pred_homo is None: homo_dist = np.array([float("inf")]) else: - corners = np.array([[0, 0], - [ref_img_shape[1] - 1, 0], - [0, ref_img_shape[0] - 1], - [ref_img_shape[1] - 1, ref_img_shape[0] - 1]]) + corners = np.array( + [ + [0, 0], + [ref_img_shape[1] - 1, 0], + [0, ref_img_shape[0] - 1], + [ref_img_shape[1] - 1, ref_img_shape[0] - 1], + ] + ) real_warped_corners = homo_trans(corners, gt_homo) warped_corners = homo_trans(corners, pred_homo) - homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) + homo_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) for thr in rng: - if seq_name[0] == 'i': + if seq_name[0] == "i": i_err_homo[thr] += np.mean(homo_dist <= thr) else: v_err_homo[thr] += np.mean(homo_dist <= thr) @@ -136,10 +171,10 @@ def benchmark_features(read_feats): return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches] -if __name__ == '__main__': +if __name__ == "__main__": errors = {} for method in methods: - output_file = os.path.join(cache_dir, method + '.npy') + output_file = os.path.join(cache_dir, method + ".npy") read_function = generate_read_function(method) if os.path.exists(output_file): errors[method] = np.load(output_file, allow_pickle=True) @@ -152,11 +187,11 @@ if __name__ == '__main__': i_err, v_err, i_err_hom, v_err_hom, _ = errors[method] print(f"====={name}=====") - print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end='') + print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end="") for thr in range(1, 4): err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5) - print(f"{err * 100:.2f}%", end=' ') + print(f"{err * 100:.2f}%", end=" ") for thr in range(1, 4): err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5) - print(f"{err_hom * 100:.2f}%", end=' ') - print('') + print(f"{err_hom * 100:.2f}%", end=" ") + print("") diff --git a/third_party/ALIKE/hseq/extract.py b/third_party/ALIKE/hseq/extract.py index 1342e40dd2d0e1d1986e90f995c95b17972ec4e1..df16ae246bf360b529f0640cab5ae79f495e4f61 100644 --- a/third_party/ALIKE/hseq/extract.py +++ b/third_party/ALIKE/hseq/extract.py @@ -9,23 +9,23 @@ from tqdm import tqdm from copy import deepcopy from torchvision.transforms import ToTensor -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from alike import ALike, configs -dataset_root = 'hseq/hpatches-sequences-release' +dataset_root = "hseq/hpatches-sequences-release" use_cuda = torch.cuda.is_available() -device = 'cuda' if use_cuda else 'cpu' -methods = ['alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms'] +device = "cuda" if use_cuda else "cpu" +methods = ["alike-n", "alike-l", "alike-n-ms", "alike-l-ms"] class HPatchesDataset(data.Dataset): - def __init__(self, root: str = dataset_root, alteration: str = 'all'): + def __init__(self, root: str = dataset_root, alteration: str = "all"): """ Args: root: dataset root path alteration: # 'all', 'i' for illumination or 'v' for viewpoint """ - assert (Path(root).exists()), f"Dataset root path {root} dose not exist!" + assert Path(root).exists(), f"Dataset root path {root} dose not exist!" self.root = root # get all image file name @@ -35,15 +35,15 @@ class HPatchesDataset(data.Dataset): folders = [x for x in Path(self.root).iterdir() if x.is_dir()] self.seqs = [] for folder in folders: - if alteration == 'i' and folder.stem[0] != 'i': + if alteration == "i" and folder.stem[0] != "i": continue - if alteration == 'v' and folder.stem[0] != 'v': + if alteration == "v" and folder.stem[0] != "v": continue self.seqs.append(folder) self.len = len(self.seqs) - assert (self.len > 0), f'Can not find PatchDataset in path {self.root}' + assert self.len > 0, f"Can not find PatchDataset in path {self.root}" def __getitem__(self, item): folder = self.seqs[item] @@ -51,12 +51,12 @@ class HPatchesDataset(data.Dataset): imgs = [] homos = [] for i in range(1, 7): - img = cv2.imread(str(folder / f'{i}.ppm'), cv2.IMREAD_COLOR) + img = cv2.imread(str(folder / f"{i}.ppm"), cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC imgs.append(img) if i != 1: - homo = np.loadtxt(str(folder / f'H_1_{i}')).astype('float32') + homo = np.loadtxt(str(folder / f"H_1_{i}")).astype("float32") homos.append(homo) return imgs, homos, folder.stem @@ -68,11 +68,18 @@ class HPatchesDataset(data.Dataset): return self.__class__ -def extract_multiscale(model, img, scale_f=2 ** 0.5, - min_scale=1., max_scale=1., - min_size=0., max_size=99999., - image_size_max=99999, - n_k=0, sort=False): +def extract_multiscale( + model, + img, + scale_f=2**0.5, + min_scale=1.0, + max_scale=1.0, + min_size=0.0, + max_size=99999.0, + image_size_max=99999, + n_k=0, + sort=False, +): H_, W_, three = img.shape assert three == 3, "input image shape should be [HxWx3]" @@ -100,7 +107,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5, # extract descriptors with torch.no_grad(): descriptor_map, scores_map = model.extract_dense_map(image) - keypoints_, descriptors_, scores_, _ = model.dkd(scores_map, descriptor_map) + keypoints_, descriptors_, scores_, _ = model.dkd( + scores_map, descriptor_map + ) keypoints.append(keypoints_[0]) descriptors.append(descriptors_[0]) @@ -110,7 +119,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5, # down-scale the image for next iteration nh, nw = round(H * s), round(W * s) - image = torch.nn.functional.interpolate(image, (nh, nw), mode='bilinear', align_corners=False) + image = torch.nn.functional.interpolate( + image, (nh, nw), mode="bilinear", align_corners=False + ) # restore value torch.backends.cudnn.benchmark = old_bm @@ -131,29 +142,34 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5, descriptors = descriptors[0:n_k] scores = scores[0:n_k] - return {'keypoints': keypoints, 'descriptors': descriptors, 'scores': scores} + return {"keypoints": keypoints, "descriptors": descriptors, "scores": scores} def extract_method(m): - hpatches = HPatchesDataset(root=dataset_root, alteration='all') + hpatches = HPatchesDataset(root=dataset_root, alteration="all") model = m[:7] - min_scale = 0.3 if m[8:] == 'ms' else 1.0 + min_scale = 0.3 if m[8:] == "ms" else 1.0 model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000) - progbar = tqdm(hpatches, desc='Extracting for {}'.format(m)) + progbar = tqdm(hpatches, desc="Extracting for {}".format(m)) for imgs, homos, seq_name in progbar: for i in range(1, 7): img = imgs[i - 1] - pred = extract_multiscale(model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000) - kpts, descs, scores = pred['keypoints'], pred['descriptors'], pred['scores'] + pred = extract_multiscale( + model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000 + ) + kpts, descs, scores = pred["keypoints"], pred["descriptors"], pred["scores"] - with open(os.path.join(dataset_root, seq_name, f'{i}.ppm.{m}'), 'wb') as f: - np.savez(f, keypoints=kpts.cpu().numpy(), - scores=scores.cpu().numpy(), - descriptors=descs.cpu().numpy()) + with open(os.path.join(dataset_root, seq_name, f"{i}.ppm.{m}"), "wb") as f: + np.savez( + f, + keypoints=kpts.cpu().numpy(), + scores=scores.cpu().numpy(), + descriptors=descs.cpu().numpy(), + ) -if __name__ == '__main__': +if __name__ == "__main__": for method in methods: extract_method(method) diff --git a/third_party/ALIKE/soft_detect.py b/third_party/ALIKE/soft_detect.py index 2d23cd13b8a7db9b0398fdc1b235564222d30c90..636ba11d0584c513631fffce31ba2d71be3e6c74 100644 --- a/third_party/ALIKE/soft_detect.py +++ b/third_party/ALIKE/soft_detect.py @@ -17,13 +17,15 @@ import torch.nn.functional as F # v # [ y: range=-1.0~1.0; h: range=0~H ] + def simple_nms(scores, nms_radius: int): - """ Fast Non-maximum suppression to remove nearby points """ - assert (nms_radius >= 0) + """Fast Non-maximum suppression to remove nearby points""" + assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool2d( - x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) @@ -50,8 +52,14 @@ def sample_descriptor(descriptor_map, kpts, bilinear_interp=False): kptsi = kpts[index] # Nx2,(x,y) if bilinear_interp: - descriptors_ = torch.nn.functional.grid_sample(descriptor_map[index].unsqueeze(0), kptsi.view(1, 1, -1, 2), - mode='bilinear', align_corners=True)[0, :, 0, :] # CxN + descriptors_ = torch.nn.functional.grid_sample( + descriptor_map[index].unsqueeze(0), + kptsi.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, :, 0, : + ] # CxN else: kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]]) kptsi = kptsi.long() @@ -94,10 +102,10 @@ class DKD(nn.Module): nms_scores = simple_nms(scores_nograd, 2) # remove border - nms_scores[:, :, :self.radius + 1, :] = 0 - nms_scores[:, :, :, :self.radius + 1] = 0 - nms_scores[:, :, h - self.radius:, :] = 0 - nms_scores[:, :, :, w - self.radius:] = 0 + nms_scores[:, :, : self.radius + 1, :] = 0 + nms_scores[:, :, :, : self.radius + 1] = 0 + nms_scores[:, :, h - self.radius :, :] = 0 + nms_scores[:, :, :, w - self.radius :] = 0 # detect keypoints without grad if self.top_k > 0: @@ -121,7 +129,7 @@ class DKD(nn.Module): if len(indices) > self.n_limit: kpts_sc = scores[indices] sort_idx = kpts_sc.sort(descending=True)[1] - sel_idx = sort_idx[:self.n_limit] + sel_idx = sort_idx[: self.n_limit] indices = indices[sel_idx] indices_keypoints.append(indices) @@ -134,42 +142,73 @@ class DKD(nn.Module): self.hw_grid = self.hw_grid.to(patches) # to device for b_idx in range(b): patch = patches[b_idx].t() # (H*W) x (kernel**2) - indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M patch_scores = patch[indices_kpt] # M x (kernel**2) # max is detached to prevent undesired backprop loops in the graph max_v = patch_scores.max(dim=1).values.detach()[:, None] - x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1] + x_exp = ( + (patch_scores - max_v) / self.temperature + ).exp() # M * (kernel**2), in [0, 1] # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } - xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2 - - hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius, - dim=-1) ** 2 + xy_residual = ( + x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] + ) # Soft-argmax, Mx2 + + hw_grid_dist2 = ( + torch.norm( + (self.hw_grid[None, :, :] - xy_residual[:, None, :]) + / self.radius, + dim=-1, + ) + ** 2 + ) scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) # compute result keypoints - keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2 + keypoints_xy_nms = torch.stack( + [indices_kpt % w, indices_kpt // w], dim=1 + ) # Mx2 keypoints_xy = keypoints_xy_nms + xy_residual - keypoints_xy = keypoints_xy / keypoints_xy.new_tensor( - [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1) - - kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), - keypoints_xy.view(1, 1, -1, 2), - mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN + keypoints_xy = ( + keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1 + ) # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN keypoints.append(keypoints_xy) scoredispersitys.append(scoredispersity) kptscores.append(kptscore) else: for b_idx in range(b): - indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M - keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2 - keypoints_xy = keypoints_xy_nms / keypoints_xy_nms.new_tensor( - [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1) - kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), - keypoints_xy.view(1, 1, -1, 2), - mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + keypoints_xy_nms = torch.stack( + [indices_kpt % w, indices_kpt // w], dim=1 + ) # Mx2 + keypoints_xy = ( + keypoints_xy_nms / keypoints_xy_nms.new_tensor([w - 1, h - 1]) * 2 + - 1 + ) # (w,h) -> (-1~1,-1~1) + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN keypoints.append(keypoints_xy) scoredispersitys.append(None) kptscores.append(kptscore) @@ -183,8 +222,9 @@ class DKD(nn.Module): :param sub_pixel: whether to use sub-pixel keypoint detection :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0 """ - keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map, - sub_pixel) + keypoints, scoredispersitys, kptscores = self.detect_keypoints( + scores_map, sub_pixel + ) descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel) diff --git a/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py b/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py index fc2b44807696ec280672c8f40650fd04fa4d8a36..00ea16cd35dc4362d0d9a294ad8a1762427bc382 100644 --- a/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py +++ b/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py @@ -1,10 +1,11 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent / '../../../')) + +sys.path.append(str(Path(__file__).parent / "../../../")) from src.config.default import _CN as cfg -cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0 -cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20] -cfg.ASPAN.COARSE.TRAIN_RES = [480,640] +cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20] +cfg.ASPAN.COARSE.TRAIN_RES = [480, 640] diff --git a/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py b/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py index 886d10d8f55533c8021bcca8395b5a2897fb8734..854132e8c8af3b3c9c85fa797a79a149aff545ef 100644 --- a/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py +++ b/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py @@ -1,10 +1,11 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent / '../../../')) + +sys.path.append(str(Path(__file__).parent / "../../../")) from src.config.default import _CN as cfg -cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20] -cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0 diff --git a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py index f0b9c04cbf3f466e413b345272afe7d7fe4274ea..e2ff53d7a1943f4149c43cdb6f2547c2290651aa 100644 --- a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py +++ b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py @@ -1,12 +1,13 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent / '../../../')) + +sys.path.append(str(Path(__file__).parent / "../../../")) from src.config.default import _CN as cfg -cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36] -cfg.ASPAN.COARSE.TRAIN_RES = [832,832] -cfg.ASPAN.COARSE.TEST_RES = [1152,1152] -cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.ASPAN.COARSE.COARSEST_LEVEL = [36, 36] +cfg.ASPAN.COARSE.TRAIN_RES = [832, 832] +cfg.ASPAN.COARSE.TEST_RES = [1152, 1152] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" cfg.TRAINER.CANONICAL_LR = 8e-3 cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs diff --git a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py index 1202080b234562d8cc65d924d7cccf0336b9f7c0..b226243478579ba2f1d4f45d8c90c02fb347d7ff 100644 --- a/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py +++ b/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py @@ -1,10 +1,11 @@ import sys from pathlib import Path -sys.path.append(str(Path(__file__).parent / '../../../')) + +sys.path.append(str(Path(__file__).parent / "../../../")) from src.config.default import _CN as cfg -cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26] -cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.ASPAN.COARSE.COARSEST_LEVEL = [26, 26] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False cfg.TRAINER.CANONICAL_LR = 8e-3 diff --git a/third_party/ASpanFormer/configs/data/base.py b/third_party/ASpanFormer/configs/data/base.py index 03aab160fa4137ccc04380f94854a56fbb549074..2621621cd3caf2edb11b41a96b11aa6a63afba92 100644 --- a/third_party/ASpanFormer/configs/data/base.py +++ b/third_party/ASpanFormer/configs/data/base.py @@ -4,6 +4,7 @@ Setups in data configs will override all existed setups! """ from yacs.config import CfgNode as CN + _CN = CN() _CN.DATASET = CN() _CN.TRAINER = CN() diff --git a/third_party/ASpanFormer/configs/data/megadepth_test_1500.py b/third_party/ASpanFormer/configs/data/megadepth_test_1500.py index 9616432f52a693ed84f3f12b9b85470b23410eee..a8d07aafd1944188cec525043c775d268b01be1f 100644 --- a/third_party/ASpanFormer/configs/data/megadepth_test_1500.py +++ b/third_party/ASpanFormer/configs/data/megadepth_test_1500.py @@ -8,6 +8,6 @@ cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" cfg.DATASET.MGDPT_IMG_RESIZE = 1152 -cfg.DATASET.MGDPT_IMG_PAD=True -cfg.DATASET.MGDPT_DF =8 -cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 \ No newline at end of file +cfg.DATASET.MGDPT_IMG_PAD = True +cfg.DATASET.MGDPT_DF = 8 +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 diff --git a/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py b/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py index 8f9b01fdaed254e10b3d55980499b88a00060f04..48b9bd095d64c681d0e64ee9416fb63fbd1f27b5 100644 --- a/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py +++ b/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py @@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 TEST_BASE_PATH = "data/megadepth/index" cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" -cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" -cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" -cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val +cfg.DATASET.VAL_NPZ_ROOT = ( + cfg.DATASET.TEST_NPZ_ROOT +) = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = ( + cfg.DATASET.TEST_LIST_PATH +) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val # 368 scenes in total for MegaDepth # (with difficulty balanced (further split each scene to 3 sub-scenes)) diff --git a/third_party/ASpanFormer/configs/data/scannet_trainval.py b/third_party/ASpanFormer/configs/data/scannet_trainval.py index c38d6440e2b4ec349e5f168909c7f8c367408813..a9a5b8a332e012a2891bbf7ec8842523b67e7599 100644 --- a/third_party/ASpanFormer/configs/data/scannet_trainval.py +++ b/third_party/ASpanFormer/configs/data/scannet_trainval.py @@ -12,6 +12,10 @@ TEST_BASE_PATH = "assets/scannet_test_1500" cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH -cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" -cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" -cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val +cfg.DATASET.VAL_LIST_PATH = ( + cfg.DATASET.TEST_LIST_PATH +) = f"{TEST_BASE_PATH}/scannet_test.txt" +cfg.DATASET.VAL_INTRINSIC_PATH = ( + cfg.DATASET.TEST_INTRINSIC_PATH +) = f"{TEST_BASE_PATH}/intrinsics.npz" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val diff --git a/third_party/ASpanFormer/demo/demo.py b/third_party/ASpanFormer/demo/demo.py index f3d95b10dc3166c18ad8493be7a3d36a25d8fc3b..dceb13523faec756063b40fd586bcd81f483e274 100644 --- a/third_party/ASpanFormer/demo/demo.py +++ b/third_party/ASpanFormer/demo/demo.py @@ -1,63 +1,91 @@ import os import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) -from src.ASpanFormer.aspanformer import ASpanFormer +from src.ASpanFormer.aspanformer import ASpanFormer from src.config.default import get_cfg_defaults from src.utils.misc import lower_config -import demo_utils +import demo_utils import cv2 import torch import numpy as np import argparse + parser = argparse.ArgumentParser() -parser.add_argument('--config_path', type=str, default='../configs/aspan/outdoor/aspan_test.py', - help='path for config file.') -parser.add_argument('--img0_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg', - help='path for image0.') -parser.add_argument('--img1_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg', - help='path for image1.') -parser.add_argument('--weights_path', type=str, default='../weights/outdoor.ckpt', - help='path for model weights.') -parser.add_argument('--long_dim0', type=int, default=1024, - help='resize for longest dim of image0.') -parser.add_argument('--long_dim1', type=int, default=1024, - help='resize for longest dim of image1.') +parser.add_argument( + "--config_path", + type=str, + default="../configs/aspan/outdoor/aspan_test.py", + help="path for config file.", +) +parser.add_argument( + "--img0_path", + type=str, + default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg", + help="path for image0.", +) +parser.add_argument( + "--img1_path", + type=str, + default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg", + help="path for image1.", +) +parser.add_argument( + "--weights_path", + type=str, + default="../weights/outdoor.ckpt", + help="path for model weights.", +) +parser.add_argument( + "--long_dim0", type=int, default=1024, help="resize for longest dim of image0." +) +parser.add_argument( + "--long_dim1", type=int, default=1024, help="resize for longest dim of image1." +) args = parser.parse_args() -if __name__=='__main__': +if __name__ == "__main__": config = get_cfg_defaults() config.merge_from_file(args.config_path) _config = lower_config(config) - matcher = ASpanFormer(config=_config['aspan']) - state_dict = torch.load(args.weights_path, map_location='cpu')['state_dict'] - matcher.load_state_dict(state_dict,strict=False) - matcher.cuda(),matcher.eval() - - img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path) - img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0) - img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(img1,args.long_dim1) - img0_g,img1_g=demo_utils.resize(img0_g,args.long_dim0),demo_utils.resize(img1_g,args.long_dim1) - data={'image0':torch.from_numpy(img0_g/255.)[None,None].cuda().float(), - 'image1':torch.from_numpy(img1_g/255.)[None,None].cuda().float()} - with torch.no_grad(): - matcher(data,online_resize=True) - corr0,corr1=data['mkpts0_f'].cpu().numpy(),data['mkpts1_f'].cpu().numpy() - - F_hat,mask_F=cv2.findFundamentalMat(corr0,corr1,method=cv2.FM_RANSAC,ransacReprojThreshold=1) + matcher = ASpanFormer(config=_config["aspan"]) + state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"] + matcher.load_state_dict(state_dict, strict=False) + matcher.cuda(), matcher.eval() + + img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path) + img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0) + img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize( + img1, args.long_dim1 + ) + img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize( + img1_g, args.long_dim1 + ) + data = { + "image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(), + "image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(), + } + with torch.no_grad(): + matcher(data, online_resize=True) + corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy() + + F_hat, mask_F = cv2.findFundamentalMat( + corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1 + ) if mask_F is not None: - mask_F=mask_F[:,0].astype(bool) + mask_F = mask_F[:, 0].astype(bool) else: - mask_F=np.zeros_like(corr0[:,0]).astype(bool) - - #visualize match - display=demo_utils.draw_match(img0,img1,corr0,corr1) - display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F]) - cv2.imwrite('match.png',display) - cv2.imwrite('match_ransac.png',display_ransac) - print(len(corr1),len(corr1[mask_F])) \ No newline at end of file + mask_F = np.zeros_like(corr0[:, 0]).astype(bool) + + # visualize match + display = demo_utils.draw_match(img0, img1, corr0, corr1) + display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F]) + cv2.imwrite("match.png", display) + cv2.imwrite("match_ransac.png", display_ransac) + print(len(corr1), len(corr1[mask_F])) diff --git a/third_party/ASpanFormer/demo/demo_utils.py b/third_party/ASpanFormer/demo/demo_utils.py index a104e25d3f5ee8b7efb6cc5fa0dc27378e22c83f..fcc8f71e02406fef4ac97fef2d0fec7c9196ad57 100644 --- a/third_party/ASpanFormer/demo/demo_utils.py +++ b/third_party/ASpanFormer/demo/demo_utils.py @@ -1,44 +1,88 @@ import cv2 import numpy as np -def resize(image,long_dim): - h,w=image.shape[0],image.shape[1] - image=cv2.resize(image,(int(w*long_dim/max(h,w)),int(h*long_dim/max(h,w)))) + +def resize(image, long_dim): + h, w = image.shape[0], image.shape[1] + image = cv2.resize( + image, (int(w * long_dim / max(h, w)), int(h * long_dim / max(h, w))) + ) return image -def draw_points(img,points,color=(0,255,0),radius=3): + +def draw_points(img, points, color=(0, 255, 0), radius=3): dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] for i in range(points.shape[0]): - cv2.circle(img, dp[i],radius=radius,color=color) + cv2.circle(img, dp[i], radius=radius, color=color) return img - -def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + +def draw_match( + img1, + img2, + corr1, + corr2, + inlier=[True], + color=None, + radius1=1, + radius2=1, + resize=None, +): if resize is not None: - scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] - img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) - corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] - corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] - corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [ + img2.shape[1] / resize[0], + img2.shape[0] / resize[1], + ] + img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize( + img2, resize, interpolation=cv2.INTER_AREA + ) + corr1, corr2 = ( + corr1 / np.asarray(scale1)[np.newaxis], + corr2 / np.asarray(scale2)[np.newaxis], + ) + corr1_key = [ + cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0]) + ] + corr2_key = [ + cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0]) + ] assert len(corr1) == len(corr2) draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] if color is None: - color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] - if len(color)==1: - display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, - matchColor=color[0], - singlePointColor=color[0], - flags=4 - ) + color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier] + if len(color) == 1: + display = cv2.drawMatches( + img1, + corr1_key, + img2, + corr2_key, + draw_matches, + None, + matchColor=color[0], + singlePointColor=color[0], + flags=4, + ) else: - height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] - display=np.zeros([height,width,3],np.uint8) - display[:img1.shape[0],:img1.shape[1]]=img1 - display[:img2.shape[0],img1.shape[1]:]=img2 + height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1] + display = np.zeros([height, width, 3], np.uint8) + display[: img1.shape[0], : img1.shape[1]] = img1 + display[: img2.shape[0], img1.shape[1] :] = img2 for i in range(len(corr1)): - left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) - cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) - cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) - return display \ No newline at end of file + left_x, left_y, right_x, right_y = ( + int(corr1[i][0]), + int(corr1[i][1]), + int(corr2[i][0] + img1.shape[1]), + int(corr2[i][1]), + ) + cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2])) + cv2.line( + display, + (left_x, left_y), + (right_x, right_y), + cur_color, + 1, + lineType=cv2.LINE_AA, + ) + return display diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py index dff6704976cbe9e916c6de6af9e3b755dfbd20bf..0603d4088cd41dc4669ff60368fd1547000c161f 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py @@ -1,3 +1,3 @@ from .transformer import LocalFeatureTransformer_Flow -from .loftr import LocalFeatureTransformer +from .loftr import LocalFeatureTransformer from .fine_preprocess import FinePreprocess diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py index 632dd22077806d2b53f66a09d0567925a30d1523..984b0df8b6bc8783b6ade4e9dbdf39b8a5673850 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py @@ -4,39 +4,59 @@ import torch.nn as nn from itertools import product from torch.nn import functional as F + class layernorm2d(nn.Module): - - def __init__(self,dim) : - super().__init__() - self.dim=dim - self.affine=nn.parameter.Parameter(torch.ones(dim), requires_grad=True) - self.bias=nn.parameter.Parameter(torch.zeros(dim), requires_grad=True) - - def forward(self,x): - #x: B*C*H*W - mean,std=x.mean(dim=1,keepdim=True),x.std(dim=1,keepdim=True) - return self.affine[None,:,None,None]*(x-mean)/(std+1e-6)+self.bias[None,:,None,None] + def __init__(self, dim): + super().__init__() + self.dim = dim + self.affine = nn.parameter.Parameter(torch.ones(dim), requires_grad=True) + self.bias = nn.parameter.Parameter(torch.zeros(dim), requires_grad=True) + + def forward(self, x): + # x: B*C*H*W + mean, std = x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True) + return ( + self.affine[None, :, None, None] * (x - mean) / (std + 1e-6) + + self.bias[None, :, None, None] + ) class HierachicalAttention(Module): - def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3): + def __init__(self, d_model, nhead, nsample, radius_scale, nlevel=3): super().__init__() - self.d_model=d_model - self.nhead=nhead - self.nsample=nsample - self.nlevel=nlevel - self.radius_scale=radius_scale + self.d_model = d_model + self.nhead = nhead + self.nsample = nsample + self.nlevel = nlevel + self.radius_scale = radius_scale self.merge_head = nn.Sequential( - nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False), + nn.Conv1d(d_model * 3, d_model, kernel_size=1, bias=False), nn.ReLU(True), - nn.Conv1d(d_model, d_model, kernel_size=1,bias=False), + nn.Conv1d(d_model, d_model, kernel_size=1, bias=False), ) - self.fullattention=FullAttention(d_model,nhead) - self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True) - sample_offset=torch.tensor([[pos[0]-nsample[1]/2+0.5, pos[1]-nsample[1]/2+0.5] for pos in product(range(nsample[1]), range(nsample[1]))]) #r^2*2 - self.sample_offset=nn.parameter.Parameter(sample_offset,requires_grad=False) + self.fullattention = FullAttention(d_model, nhead) + self.temp = nn.parameter.Parameter(torch.tensor(1.0), requires_grad=True) + sample_offset = torch.tensor( + [ + [pos[0] - nsample[1] / 2 + 0.5, pos[1] - nsample[1] / 2 + 0.5] + for pos in product(range(nsample[1]), range(nsample[1])) + ] + ) # r^2*2 + self.sample_offset = nn.parameter.Parameter(sample_offset, requires_grad=False) - def forward(self,query,key,value,flow,size_q,size_kv,mask0=None, mask1=None,ds0=[4,4],ds1=[4,4]): + def forward( + self, + query, + key, + value, + flow, + size_q, + size_kv, + mask0=None, + mask1=None, + ds0=[4, 4], + ds1=[4, 4], + ): """ Args: q,k,v (torch.Tensor): [B, C, L] @@ -45,123 +65,217 @@ class HierachicalAttention(Module): Return: all_message (torch.Tensor): [B, C, H, W] """ - - variance=flow[:,:,:,2:] - offset=flow[:,:,:,:2] #B*H*W*2 - bs=query.shape[0] - h0,w0=size_q[0],size_q[1] - h1,w1=size_kv[0],size_kv[1] - variance=torch.exp(0.5*variance)*self.radius_scale #b*h*w*2(pixel scale) - span_scale=torch.clamp((variance*2/self.nsample[1]),min=1) #b*h*w*2 - - sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1] - q_list=[F.avg_pool2d(query.view(bs,-1,h0,w0),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0] - k_list=[F.avg_pool2d(key.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] - v_list=[F.avg_pool2d(value.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] #n_level - - offset_list=[F.avg_pool2d(offset.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1)/sub_size for sub_size in sub_sample0[1:]] #n_level-1 - span_list=[F.avg_pool2d(span_scale.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1) for sub_size in sub_sample0[1:]] #n_level-1 + + variance = flow[:, :, :, 2:] + offset = flow[:, :, :, :2] # B*H*W*2 + bs = query.shape[0] + h0, w0 = size_q[0], size_q[1] + h1, w1 = size_kv[0], size_kv[1] + variance = torch.exp(0.5 * variance) * self.radius_scale # b*h*w*2(pixel scale) + span_scale = torch.clamp((variance * 2 / self.nsample[1]), min=1) # b*h*w*2 + + sub_sample0, sub_sample1 = [ds0, 2, 1], [ds1, 2, 1] + q_list = [ + F.avg_pool2d( + query.view(bs, -1, h0, w0), kernel_size=sub_size, stride=sub_size + ) + for sub_size in sub_sample0 + ] + k_list = [ + F.avg_pool2d( + key.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size + ) + for sub_size in sub_sample1 + ] + v_list = [ + F.avg_pool2d( + value.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size + ) + for sub_size in sub_sample1 + ] # n_level + + offset_list = [ + F.avg_pool2d( + offset.permute(0, 3, 1, 2), + kernel_size=sub_size * self.nsample[0], + stride=sub_size * self.nsample[0], + ).permute(0, 2, 3, 1) + / sub_size + for sub_size in sub_sample0[1:] + ] # n_level-1 + span_list = [ + F.avg_pool2d( + span_scale.permute(0, 3, 1, 2), + kernel_size=sub_size * self.nsample[0], + stride=sub_size * self.nsample[0], + ).permute(0, 2, 3, 1) + for sub_size in sub_sample0[1:] + ] # n_level-1 if mask0 is not None: - mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1) - mask0_list=[-F.max_pool2d(-mask0,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0] - mask1_list=[-F.max_pool2d(-mask1,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] + mask0, mask1 = mask0.view(bs, 1, h0, w0), mask1.view(bs, 1, h1, w1) + mask0_list = [ + -F.max_pool2d(-mask0, kernel_size=sub_size, stride=sub_size) + for sub_size in sub_sample0 + ] + mask1_list = [ + -F.max_pool2d(-mask1, kernel_size=sub_size, stride=sub_size) + for sub_size in sub_sample1 + ] else: - mask0_list=mask1_list=[None,None,None] - - message_list=[] - #full attention at coarse scale - mask0_flatten=mask0_list[0].view(bs,-1) if mask0 is not None else None - mask1_flatten=mask1_list[0].view(bs,-1) if mask1 is not None else None - message_list.append(self.fullattention(q_list[0],k_list[0],v_list[0],mask0_flatten,mask1_flatten,self.temp).view(bs,self.d_model,h0//ds0[0],w0//ds0[1])) - - for index in range(1,self.nlevel): - q,k,v=q_list[index],k_list[index],v_list[index] - mask0,mask1=mask0_list[index],mask1_list[index] - s,o=span_list[index-1],offset_list[index-1] #B*h*w(*2) - q,k,v,sample_pixel,mask_sample=self.partition_token(q,k,v,o,s,mask0) #B*Head*D*G*N(G*N=H*W for q) - message_list.append(self.group_attention(q,k,v,1,mask_sample).view(bs,self.d_model,h0//sub_sample0[index],w0//sub_sample0[index])) - #fuse - all_message=torch.cat([F.upsample(message_list[idx],scale_factor=sub_sample0[idx],mode='nearest') \ - for idx in range(self.nlevel)],dim=1).view(bs,-1,h0*w0) #b*3d*H*W - - all_message=self.merge_head(all_message).view(bs,-1,h0,w0) #b*d*H*W + mask0_list = mask1_list = [None, None, None] + + message_list = [] + # full attention at coarse scale + mask0_flatten = mask0_list[0].view(bs, -1) if mask0 is not None else None + mask1_flatten = mask1_list[0].view(bs, -1) if mask1 is not None else None + message_list.append( + self.fullattention( + q_list[0], k_list[0], v_list[0], mask0_flatten, mask1_flatten, self.temp + ).view(bs, self.d_model, h0 // ds0[0], w0 // ds0[1]) + ) + + for index in range(1, self.nlevel): + q, k, v = q_list[index], k_list[index], v_list[index] + mask0, mask1 = mask0_list[index], mask1_list[index] + s, o = span_list[index - 1], offset_list[index - 1] # B*h*w(*2) + q, k, v, sample_pixel, mask_sample = self.partition_token( + q, k, v, o, s, mask0 + ) # B*Head*D*G*N(G*N=H*W for q) + message_list.append( + self.group_attention(q, k, v, 1, mask_sample).view( + bs, self.d_model, h0 // sub_sample0[index], w0 // sub_sample0[index] + ) + ) + # fuse + all_message = torch.cat( + [ + F.upsample( + message_list[idx], scale_factor=sub_sample0[idx], mode="nearest" + ) + for idx in range(self.nlevel) + ], + dim=1, + ).view( + bs, -1, h0 * w0 + ) # b*3d*H*W + + all_message = self.merge_head(all_message).view(bs, -1, h0, w0) # b*d*H*W return all_message - - def partition_token(self,q,k,v,offset,span_scale,maskv): - #q,k,v: B*C*H*W - #o: B*H/2*W/2*2 - #span_scale:B*H*W - bs=q.shape[0] - h,w=q.shape[2],q.shape[3] - hk,wk=k.shape[2],k.shape[3] - offset=offset.view(bs,-1,2) - span_scale=span_scale.view(bs,-1,1,2) - #B*G*2 - offset_sample=self.sample_offset[None,None]*span_scale - sample_pixel=offset[:,:,None]+offset_sample#B*G*r^2*2 - sample_norm=sample_pixel/torch.tensor([wk/2,hk/2]).cuda()[None,None,None]-1 - - q = q.view(bs, -1 , h // self.nsample[0], self.nsample[0], w // self.nsample[0], self.nsample[0]).\ - permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, self.nhead,self.d_model//self.nhead, -1,self.nsample[0]**2)#B*head*D*G*N(G*N=H*W for q) - #sample token - k=F.grid_sample(k, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2 - v=F.grid_sample(v, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2 - #import pdb;pdb.set_trace() + + def partition_token(self, q, k, v, offset, span_scale, maskv): + # q,k,v: B*C*H*W + # o: B*H/2*W/2*2 + # span_scale:B*H*W + bs = q.shape[0] + h, w = q.shape[2], q.shape[3] + hk, wk = k.shape[2], k.shape[3] + offset = offset.view(bs, -1, 2) + span_scale = span_scale.view(bs, -1, 1, 2) + # B*G*2 + offset_sample = self.sample_offset[None, None] * span_scale + sample_pixel = offset[:, :, None] + offset_sample # B*G*r^2*2 + sample_norm = ( + sample_pixel / torch.tensor([wk / 2, hk / 2]).cuda()[None, None, None] - 1 + ) + + q = ( + q.view( + bs, + -1, + h // self.nsample[0], + self.nsample[0], + w // self.nsample[0], + self.nsample[0], + ) + .permute(0, 1, 2, 4, 3, 5) + .contiguous() + .view(bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[0] ** 2) + ) # B*head*D*G*N(G*N=H*W for q) + # sample token + k = F.grid_sample(k, grid=sample_norm).view( + bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2 + ) # B*head*D*G*r^2 + v = F.grid_sample(v, grid=sample_norm).view( + bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2 + ) # B*head*D*G*r^2 + # import pdb;pdb.set_trace() if maskv is not None: - mask_sample=F.grid_sample(maskv.view(bs,-1,h,w).float(),grid=sample_norm,mode='nearest')==1 #B*1*G*r^2 + mask_sample = ( + F.grid_sample( + maskv.view(bs, -1, h, w).float(), grid=sample_norm, mode="nearest" + ) + == 1 + ) # B*1*G*r^2 else: - mask_sample=None - return q,k,v,sample_pixel,mask_sample - + mask_sample = None + return q, k, v, sample_pixel, mask_sample - def group_attention(self,query,key,value,temp,mask_sample=None): - #q,k,v: B*Head*D*G*N(G*N=H*W for q) - bs=query.shape[0] - #import pdb;pdb.set_trace() + def group_attention(self, query, key, value, temp, mask_sample=None): + # q,k,v: B*Head*D*G*N(G*N=H*W for q) + bs = query.shape[0] + # import pdb;pdb.set_trace() QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key) if mask_sample is not None: - num_head,number_n=QK.shape[1],QK.shape[3] - QK.masked_fill_(~(mask_sample[:,:,:,None]).expand(-1,num_head,-1,number_n,-1).bool(), float(-1e8)) + num_head, number_n = QK.shape[1], QK.shape[3] + QK.masked_fill_( + ~(mask_sample[:, :, :, None]) + .expand(-1, num_head, -1, number_n, -1) + .bool(), + float(-1e8), + ) # Compute the attention and the weighted average - softmax_temp = temp / query.size(2)**.5 # sqrt(D) + softmax_temp = temp / query.size(2) ** 0.5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=-1) - queried_values = torch.einsum("bhgnm,bhdgm->bhdgn", A, value).contiguous().view(bs,self.d_model,-1) + queried_values = ( + torch.einsum("bhgnm,bhdgm->bhdgn", A, value) + .contiguous() + .view(bs, self.d_model, -1) + ) return queried_values - class FullAttention(Module): - def __init__(self,d_model,nhead): + def __init__(self, d_model, nhead): super().__init__() - self.d_model=d_model - self.nhead=nhead + self.d_model = d_model + self.nhead = nhead - def forward(self, q, k,v , mask0=None, mask1=None, temp=1): - """ Multi-head scaled dot-product attention, a.k.a full attention. + def forward(self, q, k, v, mask0=None, mask1=None, temp=1): + """Multi-head scaled dot-product attention, a.k.a full attention. Args: q,k,v: [N, D, L] mask: [N, L] Returns: msg: [N,L] """ - bs=q.shape[0] - q,k,v=q.view(bs,self.nhead,self.d_model//self.nhead,-1),k.view(bs,self.nhead,self.d_model//self.nhead,-1),v.view(bs,self.nhead,self.d_model//self.nhead,-1) + bs = q.shape[0] + q, k, v = ( + q.view(bs, self.nhead, self.d_model // self.nhead, -1), + k.view(bs, self.nhead, self.d_model // self.nhead, -1), + v.view(bs, self.nhead, self.d_model // self.nhead, -1), + ) # Compute the unnormalized attention and apply the masks QK = torch.einsum("nhdl,nhds->nhls", q, k) if mask0 is not None: - QK.masked_fill_(~(mask0[:,None, :, None] * mask1[:, None, None]).bool(), float(-1e8)) + QK.masked_fill_( + ~(mask0[:, None, :, None] * mask1[:, None, None]).bool(), float(-1e8) + ) # Compute the attention and the weighted average - softmax_temp = temp / q.size(2)**.5 # sqrt(D) + softmax_temp = temp / q.size(2) ** 0.5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=-1) - queried_values = torch.einsum("nhls,nhds->nhdl", A, v).contiguous().view(bs,self.d_model,-1) + queried_values = ( + torch.einsum("nhls,nhds->nhdl", A, v) + .contiguous() + .view(bs, self.d_model, -1) + ) return queried_values - - + def elu_feature_map(x): return F.elu(x) + 1 + class LinearAttention(Module): def __init__(self, eps=1e-6): super().__init__() @@ -169,7 +283,7 @@ class LinearAttention(Module): self.eps = eps def forward(self, queries, keys, values, q_mask=None, kv_mask=None): - """ Multi-Head linear attention proposed in "Transformers are RNNs" + """Multi-Head linear attention proposed in "Transformers are RNNs" Args: queries: [N, L, H, D] keys: [N, S, H, D] @@ -195,4 +309,4 @@ class LinearAttention(Module): Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length - return queried_values.contiguous() \ No newline at end of file + return queried_values.contiguous() diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py index 5bb8eefd362240a9901a335f0e6e07770ff04567..6c37f76c3d5735508f950bb1239f5e93039b27ff 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py @@ -9,15 +9,15 @@ class FinePreprocess(nn.Module): super().__init__() self.config = config - self.cat_c_feat = config['fine_concat_coarse_feat'] - self.W = self.config['fine_window_size'] + self.cat_c_feat = config["fine_concat_coarse_feat"] + self.W = self.config["fine_window_size"] - d_model_c = self.config['coarse']['d_model'] - d_model_f = self.config['fine']['d_model'] + d_model_c = self.config["coarse"]["d_model"] + d_model_f = self.config["fine"]["d_model"] self.d_model_f = d_model_f if self.cat_c_feat: self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) - self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True) self._reset_parameters() @@ -28,32 +28,48 @@ class FinePreprocess(nn.Module): def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): W = self.W - stride = data['hw0_f'][0] // data['hw0_c'][0] + stride = data["hw0_f"][0] // data["hw0_c"][0] - data.update({'W': W}) - if data['b_ids'].shape[0] == 0: + data.update({"W": W}) + if data["b_ids"].shape[0] == 0: feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) return feat0, feat1 # 1. unfold(crop) all local windows - feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) - feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) - feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) - feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f0_unfold = F.unfold( + feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2 + ) + feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2) + feat_f1_unfold = F.unfold( + feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2 + ) + feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2) # 2. select only the predicted matches - feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] - feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]] # option: use coarse-level loftr feature as context: concat and linear if self.cat_c_feat: - feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], - feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] - feat_cf_win = self.merge_feat(torch.cat([ - torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] - repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] - ], -1)) + feat_c_win = self.down_proj( + torch.cat( + [ + feat_c0[data["b_ids"], data["i_ids"]], + feat_c1[data["b_ids"], data["j_ids"]], + ], + 0, + ) + ) # [2n, c] + feat_cf_win = self.merge_feat( + torch.cat( + [ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, "n c -> n ww c", ww=W**2), # [2n, ww, cf] + ], + -1, + ) + ) feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) return feat_f0_unfold, feat_f1_unfold diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py index 7dcebaa7beee978b9b8abcec8bb1bd2cc6b60870..eaad9fdac1fbfc7a77f2db7c98c67bc41e335945 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py @@ -3,11 +3,9 @@ import torch import torch.nn as nn from .attention import LinearAttention + class LoFTREncoderLayer(nn.Module): - def __init__(self, - d_model, - nhead, - attention='linear'): + def __init__(self, d_model, nhead, attention="linear"): super(LoFTREncoderLayer, self).__init__() self.dim = d_model // nhead @@ -22,9 +20,9 @@ class LoFTREncoderLayer(nn.Module): # feed-forward network self.mlp = nn.Sequential( - nn.Linear(d_model*2, d_model*2, bias=False), + nn.Linear(d_model * 2, d_model * 2, bias=False), nn.ReLU(True), - nn.Linear(d_model*2, d_model, bias=False), + nn.Linear(d_model * 2, d_model, bias=False), ) # norm and dropout @@ -43,16 +41,14 @@ class LoFTREncoderLayer(nn.Module): query, key, value = x, source, source # multi-head attention - query = self.q_proj(query).view( - bs, -1, self.nhead, self.dim) # [N, L, (H, D)] - key = self.k_proj(key).view(bs, -1, self.nhead, - self.dim) # [N, S, (H, D)] + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) message = self.attention( - query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] - message = self.merge(message.view( - bs, -1, self.nhead*self.dim)) # [N, L, C] + query, key, value, q_mask=x_mask, kv_mask=source_mask + ) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] message = self.norm1(message) # feed-forward network @@ -69,13 +65,15 @@ class LocalFeatureTransformer(nn.Module): super(LocalFeatureTransformer, self).__init__() self.config = config - self.d_model = config['d_model'] - self.nhead = config['nhead'] - self.layer_names = config['layer_names'] + self.d_model = config["d_model"] + self.nhead = config["nhead"] + self.layer_names = config["layer_names"] encoder_layer = LoFTREncoderLayer( - config['d_model'], config['nhead'], config['attention']) + config["d_model"], config["nhead"], config["attention"] + ) self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))] + ) self._reset_parameters() def _reset_parameters(self): @@ -93,20 +91,18 @@ class LocalFeatureTransformer(nn.Module): """ assert self.d_model == feat0.size( - 2), "the feature number of src and transformer must be equal" + 2 + ), "the feature number of src and transformer must be equal" index = 0 for layer, name in zip(self.layers, self.layer_names): - if name == 'self': - feat0 = layer(feat0, feat0, mask0, mask0, - type='self', index=index) + if name == "self": + feat0 = layer(feat0, feat0, mask0, mask0, type="self", index=index) feat1 = layer(feat1, feat1, mask1, mask1) - elif name == 'cross': + elif name == "cross": feat0 = layer(feat0, feat1, mask0, mask1) - feat1 = layer(feat1, feat0, mask1, mask0, - type='cross', index=index) + feat1 = layer(feat1, feat0, mask1, mask0, type="cross", index=index) index += 1 else: raise KeyError return feat0, feat1 - diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py index c398f770833bf2066cda60a7ff546ec29640d433..125f555f93874af74c6e2595a360939f2f3bbce2 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py @@ -2,44 +2,42 @@ import copy import torch import torch.nn as nn import torch.nn.functional as F -from .attention import FullAttention, HierachicalAttention ,layernorm2d +from .attention import FullAttention, HierachicalAttention, layernorm2d class messageLayer_ini(nn.Module): - - def __init__(self, d_model, d_flow,d_value, nhead): + def __init__(self, d_model, d_flow, d_value, nhead): super().__init__() super(messageLayer_ini, self).__init__() self.d_model = d_model self.d_flow = d_flow - self.d_value=d_value + self.d_value = d_value self.nhead = nhead - self.attention = FullAttention(d_model,nhead) + self.attention = FullAttention(d_model, nhead) - self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) - self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) - self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False) - self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False) + self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False) + self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False) + self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False) + self.merge_head = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False) - self.merge_f= self.merge_f = nn.Sequential( - nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False), + self.merge_f = self.merge_f = nn.Sequential( + nn.Conv2d(d_model * 2, d_model * 2, kernel_size=1, bias=False), nn.ReLU(True), - nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False), + nn.Conv2d(d_model * 2, d_model, kernel_size=1, bias=False), ) self.norm1 = layernorm2d(d_model) self.norm2 = layernorm2d(d_model) + def forward(self, x0, x1, pos0, pos1, mask0=None, mask1=None): + # x1,x2: b*d*L + x0, x1 = self.update(x0, x1, pos1, mask0, mask1), self.update( + x1, x0, pos0, mask1, mask0 + ) + return x0, x1 - def forward(self, x0, x1,pos0,pos1,mask0=None,mask1=None): - #x1,x2: b*d*L - x0,x1=self.update(x0,x1,pos1,mask0,mask1),\ - self.update(x1,x0,pos0,mask1,mask0) - return x0,x1 - - - def update(self,f0,f1,pos1,mask0,mask1): + def update(self, f0, f1, pos1, mask0, mask1): """ Args: f0: [N, D, H, W] @@ -47,53 +45,77 @@ class messageLayer_ini(nn.Module): Returns: f0_new: (N, d, h, w) """ - bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3] + bs, h, w = f0.shape[0], f0.shape[2], f0.shape[3] - f0_flatten,f1_flatten=f0.view(bs,self.d_model,-1),f1.view(bs,self.d_model,-1) - pos1_flatten=pos1.view(bs,self.d_value-self.d_model,-1) - f1_flatten_v=torch.cat([f1_flatten,pos1_flatten],dim=1) + f0_flatten, f1_flatten = f0.view(bs, self.d_model, -1), f1.view( + bs, self.d_model, -1 + ) + pos1_flatten = pos1.view(bs, self.d_value - self.d_model, -1) + f1_flatten_v = torch.cat([f1_flatten, pos1_flatten], dim=1) - queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten) - values=self.v_proj(f1_flatten_v).view(bs,self.nhead,self.d_model//self.nhead,-1) - - queried_values=self.attention(queries,keys,values,mask0,mask1) - msg=self.merge_head(queried_values).view(bs,-1,h,w) - msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1))) - return f0+msg + queries, keys = self.q_proj(f0_flatten), self.k_proj(f1_flatten) + values = self.v_proj(f1_flatten_v).view( + bs, self.nhead, self.d_model // self.nhead, -1 + ) + queried_values = self.attention(queries, keys, values, mask0, mask1) + msg = self.merge_head(queried_values).view(bs, -1, h, w) + msg = self.norm2(self.merge_f(torch.cat([f0, self.norm1(msg)], dim=1))) + return f0 + msg class messageLayer_gla(nn.Module): - - def __init__(self,d_model,d_flow,d_value, - nhead,radius_scale,nsample,update_flow=True): + def __init__( + self, d_model, d_flow, d_value, nhead, radius_scale, nsample, update_flow=True + ): super().__init__() self.d_model = d_model - self.d_flow=d_flow - self.d_value=d_value + self.d_flow = d_flow + self.d_value = d_value self.nhead = nhead - self.radius_scale=radius_scale - self.update_flow=update_flow - self.flow_decoder=nn.Sequential( - nn.Conv1d(d_flow, d_flow//2, kernel_size=1, bias=False), - nn.ReLU(True), - nn.Conv1d(d_flow//2, 4, kernel_size=1, bias=False)) - self.attention=HierachicalAttention(d_model,nhead,nsample,radius_scale) - - self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) - self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) - self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False) - - d_extra=d_flow if update_flow else 0 - self.merge_f=nn.Sequential( - nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False), - nn.ReLU(True), - nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False), - ) - self.norm1 = layernorm2d(d_model) - self.norm2 = layernorm2d(d_model+d_extra) + self.radius_scale = radius_scale + self.update_flow = update_flow + self.flow_decoder = nn.Sequential( + nn.Conv1d(d_flow, d_flow // 2, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv1d(d_flow // 2, 4, kernel_size=1, bias=False), + ) + self.attention = HierachicalAttention(d_model, nhead, nsample, radius_scale) - def forward(self, x0, x1, flow_feature0,flow_feature1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False) + self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False) + self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False) + + d_extra = d_flow if update_flow else 0 + self.merge_f = nn.Sequential( + nn.Conv2d( + d_model * 2 + d_extra, d_model + d_flow, kernel_size=1, bias=False + ), + nn.ReLU(True), + nn.Conv2d( + d_model + d_flow, + d_model + d_extra, + kernel_size=3, + padding=1, + bias=False, + ), + ) + self.norm1 = layernorm2d(d_model) + self.norm2 = layernorm2d(d_model + d_extra) + + def forward( + self, + x0, + x1, + flow_feature0, + flow_feature1, + pos0, + pos1, + mask0=None, + mask1=None, + ds0=[4, 4], + ds1=[4, 4], + ): """ Args: x0 (torch.Tensor): [B, C, H, W] @@ -101,88 +123,135 @@ class messageLayer_gla(nn.Module): flow_feature0 (torch.Tensor): [B, C', H, W] flow_feature1 (torch.Tensor): [B, C', H, W] """ - flow0,flow1=self.decode_flow(flow_feature0,flow_feature1.shape[2:]),self.decode_flow(flow_feature1,flow_feature0.shape[2:]) - x0_new,flow_feature0_new=self.update(x0,x1,flow0.detach(),flow_feature0,pos1,mask0,mask1,ds0,ds1) - x1_new,flow_feature1_new=self.update(x1,x0,flow1.detach(),flow_feature1,pos0,mask1,mask0,ds1,ds0) - return x0_new,x1_new,flow_feature0_new,flow_feature1_new,flow0,flow1 - - def update(self,x0,x1,flow0,flow_feature0,pos1,mask0,mask1,ds0,ds1): - bs=x0.shape[0] - queries,keys=self.q_proj(x0.view(bs,self.d_model,-1)),self.k_proj(x1.view(bs,self.d_model,-1)) - x1_pos=torch.cat([x1,pos1],dim=1) - values=self.v_proj(x1_pos.view(bs,self.d_value,-1)) - msg=self.attention(queries,keys,values,flow0,x0.shape[2:],x1.shape[2:],mask0,mask1,ds0,ds1) + flow0, flow1 = self.decode_flow( + flow_feature0, flow_feature1.shape[2:] + ), self.decode_flow(flow_feature1, flow_feature0.shape[2:]) + x0_new, flow_feature0_new = self.update( + x0, x1, flow0.detach(), flow_feature0, pos1, mask0, mask1, ds0, ds1 + ) + x1_new, flow_feature1_new = self.update( + x1, x0, flow1.detach(), flow_feature1, pos0, mask1, mask0, ds1, ds0 + ) + return x0_new, x1_new, flow_feature0_new, flow_feature1_new, flow0, flow1 + + def update(self, x0, x1, flow0, flow_feature0, pos1, mask0, mask1, ds0, ds1): + bs = x0.shape[0] + queries, keys = self.q_proj(x0.view(bs, self.d_model, -1)), self.k_proj( + x1.view(bs, self.d_model, -1) + ) + x1_pos = torch.cat([x1, pos1], dim=1) + values = self.v_proj(x1_pos.view(bs, self.d_value, -1)) + msg = self.attention( + queries, + keys, + values, + flow0, + x0.shape[2:], + x1.shape[2:], + mask0, + mask1, + ds0, + ds1, + ) if self.update_flow: - update_feature=torch.cat([x0,flow_feature0],dim=1) + update_feature = torch.cat([x0, flow_feature0], dim=1) else: - update_feature=x0 - msg=self.norm2(self.merge_f(torch.cat([update_feature,self.norm1(msg)],dim=1))) - update_feature=update_feature+msg - - x0_new,flow_feature0_new=update_feature[:,:self.d_model],update_feature[:,self.d_model:] - return x0_new,flow_feature0_new - - def decode_flow(self,flow_feature,kshape): - bs,h,w=flow_feature.shape[0],flow_feature.shape[2],flow_feature.shape[3] - scale_factor=torch.tensor([kshape[1],kshape[0]]).cuda()[None,None,None] - flow=self.flow_decoder(flow_feature.view(bs,-1,h*w)).permute(0,2,1).view(bs,h,w,4) - flow_coordinates=torch.sigmoid(flow[:,:,:,:2])*scale_factor - flow_var=flow[:,:,:,2:] - flow=torch.cat([flow_coordinates,flow_var],dim=-1) #B*H*W*4 + update_feature = x0 + msg = self.norm2( + self.merge_f(torch.cat([update_feature, self.norm1(msg)], dim=1)) + ) + update_feature = update_feature + msg + + x0_new, flow_feature0_new = ( + update_feature[:, : self.d_model], + update_feature[:, self.d_model :], + ) + return x0_new, flow_feature0_new + + def decode_flow(self, flow_feature, kshape): + bs, h, w = flow_feature.shape[0], flow_feature.shape[2], flow_feature.shape[3] + scale_factor = torch.tensor([kshape[1], kshape[0]]).cuda()[None, None, None] + flow = ( + self.flow_decoder(flow_feature.view(bs, -1, h * w)) + .permute(0, 2, 1) + .view(bs, h, w, 4) + ) + flow_coordinates = torch.sigmoid(flow[:, :, :, :2]) * scale_factor + flow_var = flow[:, :, :, 2:] + flow = torch.cat([flow_coordinates, flow_var], dim=-1) # B*H*W*4 return flow class flow_initializer(nn.Module): - def __init__(self, dim, dim_flow, nhead, layer_num): super().__init__() - self.layer_num= layer_num + self.layer_num = layer_num self.dim = dim self.dim_flow = dim_flow - encoder_layer = messageLayer_ini( - dim ,dim_flow,dim+dim_flow , nhead) + encoder_layer = messageLayer_ini(dim, dim_flow, dim + dim_flow, nhead) self.layers_coarse = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(layer_num)]) - self.decoupler = nn.Conv2d( - self.dim, self.dim+self.dim_flow, kernel_size=1) - self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1) + [copy.deepcopy(encoder_layer) for _ in range(layer_num)] + ) + self.decoupler = nn.Conv2d(self.dim, self.dim + self.dim_flow, kernel_size=1) + self.up_merge = nn.Conv2d(2 * dim, dim, kernel_size=1) - def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + def forward( + self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4] + ): # feat0: [B, C, H0, W0] # feat1: [B, C, H1, W1] # use low-res MHA to initialize flow feature bs = feat0.size(0) - h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3] + h0, w0, h1, w1 = feat0.shape[2], feat0.shape[3], feat1.shape[2], feat1.shape[3] # coarse level - sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), \ - F.avg_pool2d(feat1, ds1, stride=ds1) + sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), F.avg_pool2d( + feat1, ds1, stride=ds1 + ) + + sub_pos0, sub_pos1 = F.avg_pool2d(pos0, ds0, stride=ds0), F.avg_pool2d( + pos1, ds1, stride=ds1 + ) - sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \ - F.avg_pool2d(pos1, ds1, stride=ds1) - if mask0 is not None: - mask0,mask1=-F.max_pool2d(-mask0.view(bs,1,h0,w0),ds0,stride=ds0).view(bs,-1),\ - -F.max_pool2d(-mask1.view(bs,1,h1,w1),ds1,stride=ds1).view(bs,-1) - + mask0, mask1 = -F.max_pool2d( + -mask0.view(bs, 1, h0, w0), ds0, stride=ds0 + ).view(bs, -1), -F.max_pool2d( + -mask1.view(bs, 1, h1, w1), ds1, stride=ds1 + ).view( + bs, -1 + ) + for layer in self.layers_coarse: - sub_feat0, sub_feat1 = layer(sub_feat0, sub_feat1,sub_pos0,sub_pos1,mask0,mask1) + sub_feat0, sub_feat1 = layer( + sub_feat0, sub_feat1, sub_pos0, sub_pos1, mask0, mask1 + ) # decouple flow and visual features - decoupled_feature0, decoupled_feature1 = self.decoupler(sub_feat0),self.decoupler(sub_feat1) + decoupled_feature0, decoupled_feature1 = self.decoupler( + sub_feat0 + ), self.decoupler(sub_feat1) + + sub_feat0, sub_flow_feature0 = ( + decoupled_feature0[:, : self.dim], + decoupled_feature0[:, self.dim :], + ) + sub_feat1, sub_flow_feature1 = ( + decoupled_feature1[:, : self.dim], + decoupled_feature1[:, self.dim :], + ) + update_feat0, flow_feature0 = F.upsample( + sub_feat0, scale_factor=ds0, mode="bilinear" + ), F.upsample(sub_flow_feature0, scale_factor=ds0, mode="bilinear") + update_feat1, flow_feature1 = F.upsample( + sub_feat1, scale_factor=ds1, mode="bilinear" + ), F.upsample(sub_flow_feature1, scale_factor=ds1, mode="bilinear") - sub_feat0, sub_flow_feature0 = decoupled_feature0[:,:self.dim], decoupled_feature0[:, self.dim:] - sub_feat1, sub_flow_feature1 = decoupled_feature1[:,:self.dim], decoupled_feature1[:, self.dim:] - update_feat0, flow_feature0 = F.upsample(sub_feat0, scale_factor=ds0, mode='bilinear'),\ - F.upsample(sub_flow_feature0, scale_factor=ds0, mode='bilinear') - update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\ - F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear') - - feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1)) - feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1)) - - return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w + feat0 = feat0 + self.up_merge(torch.cat([feat0, update_feat0], dim=1)) + feat1 = feat1 + self.up_merge(torch.cat([feat1, update_feat1], dim=1)) + + return feat0, feat1, flow_feature0, flow_feature1 # b*c*h*w class LocalFeatureTransformer_Flow(nn.Module): @@ -192,27 +261,49 @@ class LocalFeatureTransformer_Flow(nn.Module): super(LocalFeatureTransformer_Flow, self).__init__() self.config = config - self.d_model = config['d_model'] - self.nhead = config['nhead'] + self.d_model = config["d_model"] + self.nhead = config["nhead"] + + self.pos_transform = nn.Conv2d( + config["d_model"], config["d_flow"], kernel_size=1, bias=False + ) + self.ini_layer = flow_initializer( + self.d_model, config["d_flow"], config["nhead"], config["ini_layer_num"] + ) - self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False) - self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num']) - encoder_layer = messageLayer_gla( - config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample']) - encoder_layer_last=messageLayer_gla( - config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'],update_flow=False) - self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(config['layer_num']-1)]+[encoder_layer_last]) + config["d_model"], + config["d_flow"], + config["d_flow"] + config["d_model"], + config["nhead"], + config["radius_scale"], + config["nsample"], + ) + encoder_layer_last = messageLayer_gla( + config["d_model"], + config["d_flow"], + config["d_flow"] + config["d_model"], + config["nhead"], + config["radius_scale"], + config["nsample"], + update_flow=False, + ) + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(config["layer_num"] - 1)] + + [encoder_layer_last] + ) self._reset_parameters() - + def _reset_parameters(self): - for name,p in self.named_parameters(): - if 'temp' in name or 'sample_offset' in name: + for name, p in self.named_parameters(): + if "temp" in name or "sample_offset" in name: continue if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + def forward( + self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4] + ): """ Args: feat0 (torch.Tensor): [N, C, H, W] @@ -224,21 +315,37 @@ class LocalFeatureTransformer_Flow(nn.Module): flow_list: [L,N,H,W,4]*1(2) """ bs = feat0.size(0) - - pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1) - pos0,pos1=pos0.expand(bs,-1,-1,-1),pos1.expand(bs,-1,-1,-1) + + pos0, pos1 = self.pos_transform(pos0), self.pos_transform(pos1) + pos0, pos1 = pos0.expand(bs, -1, -1, -1), pos1.expand(bs, -1, -1, -1) assert self.d_model == feat0.size( - 1), "the feature number of src and transformer must be equal" - - flow_list=[[],[]]# [px,py,sx,sy] + 1 + ), "the feature number of src and transformer must be equal" + + flow_list = [[], []] # [px,py,sx,sy] if mask0 is not None: - mask0,mask1=mask0[:,None].float(),mask1[:,None].float() - feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(feat0, feat1,pos0,pos1,mask0,mask1,ds0,ds1) + mask0, mask1 = mask0[:, None].float(), mask1[:, None].float() + feat0, feat1, flow_feature0, flow_feature1 = self.ini_layer( + feat0, feat1, pos0, pos1, mask0, mask1, ds0, ds1 + ) for layer in self.layers: - feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(feat0,feat1,flow_feature0,flow_feature1,pos0,pos1,mask0,mask1,ds0,ds1) + feat0, feat1, flow_feature0, flow_feature1, flow0, flow1 = layer( + feat0, + feat1, + flow_feature0, + flow_feature1, + pos0, + pos1, + mask0, + mask1, + ds0, + ds1, + ) flow_list[0].append(flow0) flow_list[1].append(flow1) - flow_list[0]=torch.stack(flow_list[0],dim=0) - flow_list[1]=torch.stack(flow_list[1],dim=0) - feat0, feat1 = feat0.permute(0, 2, 3, 1).view(bs, -1, self.d_model), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model) - return feat0, feat1, flow_list \ No newline at end of file + flow_list[0] = torch.stack(flow_list[0], dim=0) + flow_list[1] = torch.stack(flow_list[1], dim=0) + feat0, feat1 = feat0.permute(0, 2, 3, 1).view( + bs, -1, self.d_model + ), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model) + return feat0, feat1, flow_list diff --git a/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py b/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py index 01b797a420cf5ccea5b53fee3ceda8b5e157573f..113e912bf219ff6fcbc7a1642454ac08b455fd0d 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py +++ b/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py @@ -5,7 +5,11 @@ from einops.einops import rearrange from .backbone import build_backbone from .utils.position_encoding import PositionEncodingSine -from .aspan_module import LocalFeatureTransformer_Flow, LocalFeatureTransformer, FinePreprocess +from .aspan_module import ( + LocalFeatureTransformer_Flow, + LocalFeatureTransformer, + FinePreprocess, +) from .utils.coarse_matching import CoarseMatching from .utils.fine_matching import FineMatching @@ -19,16 +23,18 @@ class ASpanFormer(nn.Module): # Modules self.backbone = build_backbone(config) self.pos_encoding = PositionEncodingSine( - config['coarse']['d_model'],pre_scaling=[config['coarse']['train_res'],config['coarse']['test_res']]) - self.loftr_coarse = LocalFeatureTransformer_Flow(config['coarse']) - self.coarse_matching = CoarseMatching(config['match_coarse']) + config["coarse"]["d_model"], + pre_scaling=[config["coarse"]["train_res"], config["coarse"]["test_res"]], + ) + self.loftr_coarse = LocalFeatureTransformer_Flow(config["coarse"]) + self.coarse_matching = CoarseMatching(config["match_coarse"]) self.fine_preprocess = FinePreprocess(config) self.loftr_fine = LocalFeatureTransformer(config["fine"]) self.fine_matching = FineMatching() - self.coarsest_level=config['coarse']['coarsest_level'] + self.coarsest_level = config["coarse"]["coarsest_level"] def forward(self, data, online_resize=False): - """ + """ Update: data (dict): { 'image0': (torch.Tensor): (N, 1, H, W) @@ -38,96 +44,135 @@ class ASpanFormer(nn.Module): } """ if online_resize: - assert data['image0'].shape[0]==1 and data['image1'].shape[1]==1 - self.resize_input(data,self.config['coarse']['train_res']) + assert data["image0"].shape[0] == 1 and data["image1"].shape[1] == 1 + self.resize_input(data, self.config["coarse"]["train_res"]) else: - data['pos_scale0'],data['pos_scale1']=None,None + data["pos_scale0"], data["pos_scale1"] = None, None # 1. Local Feature CNN - data.update({ - 'bs': data['image0'].size(0), - 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] - }) - - if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + data.update( + { + "bs": data["image0"].size(0), + "hw0_i": data["image0"].shape[2:], + "hw1_i": data["image1"].shape[2:], + } + ) + + if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence feats_c, feats_f = self.backbone( - torch.cat([data['image0'], data['image1']], dim=0)) + torch.cat([data["image0"], data["image1"]], dim=0) + ) (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split( - data['bs']), feats_f.split(data['bs']) + data["bs"] + ), feats_f.split(data["bs"]) else: # handle different input shapes (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone( - data['image0']), self.backbone(data['image1']) + data["image0"] + ), self.backbone(data["image1"]) - data.update({ - 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], - 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] - }) + data.update( + { + "hw0_c": feat_c0.shape[2:], + "hw1_c": feat_c1.shape[2:], + "hw0_f": feat_f0.shape[2:], + "hw1_f": feat_f1.shape[2:], + } + ) # 2. coarse-level loftr module # add featmap with positional encoding, then flatten it to sequence [N, HW, C] - [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(feat_c0,data['pos_scale0']), self.pos_encoding(feat_c1,data['pos_scale1']) - feat_c0 = rearrange(feat_c0, 'n c h w -> n c h w ') - feat_c1 = rearrange(feat_c1, 'n c h w -> n c h w ') + [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding( + feat_c0, data["pos_scale0"] + ), self.pos_encoding(feat_c1, data["pos_scale1"]) + feat_c0 = rearrange(feat_c0, "n c h w -> n c h w ") + feat_c1 = rearrange(feat_c1, "n c h w -> n c h w ") - #TODO:adjust ds - ds0=[int(data['hw0_c'][0]/self.coarsest_level[0]),int(data['hw0_c'][1]/self.coarsest_level[1])] - ds1=[int(data['hw1_c'][0]/self.coarsest_level[0]),int(data['hw1_c'][1]/self.coarsest_level[1])] + # TODO:adjust ds + ds0 = [ + int(data["hw0_c"][0] / self.coarsest_level[0]), + int(data["hw0_c"][1] / self.coarsest_level[1]), + ] + ds1 = [ + int(data["hw1_c"][0] / self.coarsest_level[0]), + int(data["hw1_c"][1] / self.coarsest_level[1]), + ] if online_resize: - ds0,ds1=[4,4],[4,4] + ds0, ds1 = [4, 4], [4, 4] mask_c0 = mask_c1 = None # mask is useful in training - if 'mask0' in data: - mask_c0, mask_c1 = data['mask0'].flatten( - -2), data['mask1'].flatten(-2) + if "mask0" in data: + mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2) feat_c0, feat_c1, flow_list = self.loftr_coarse( - feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1) + feat_c0, feat_c1, pos_encoding0, pos_encoding1, mask_c0, mask_c1, ds0, ds1 + ) # 3. match coarse-level and register predicted offset - self.coarse_matching(feat_c0, feat_c1, flow_list,data, - mask_c0=mask_c0, mask_c1=mask_c1) + self.coarse_matching( + feat_c0, feat_c1, flow_list, data, mask_c0=mask_c0, mask_c1=mask_c1 + ) # 4. fine-level refinement feat_f0_unfold, feat_f1_unfold = self.fine_preprocess( - feat_f0, feat_f1, feat_c0, feat_c1, data) + feat_f0, feat_f1, feat_c0, feat_c1, data + ) if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted feat_f0_unfold, feat_f1_unfold = self.loftr_fine( - feat_f0_unfold, feat_f1_unfold) + feat_f0_unfold, feat_f1_unfold + ) # 5. match fine-level self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) # 6. resize match coordinates back to input resolution if online_resize: - data['mkpts0_f']*=data['online_resize_scale0'] - data['mkpts1_f']*=data['online_resize_scale1'] - + data["mkpts0_f"] *= data["online_resize_scale0"] + data["mkpts1_f"] *= data["online_resize_scale1"] + def load_state_dict(self, state_dict, *args, **kwargs): for k in list(state_dict.keys()): - if k.startswith('matcher.'): - if 'sample_offset' in k: + if k.startswith("matcher."): + if "sample_offset" in k: state_dict.pop(k) else: - state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k) return super().load_state_dict(state_dict, *args, **kwargs) - - def resize_input(self,data,train_res,df=32): - h0,w0,h1,w1=data['image0'].shape[2],data['image0'].shape[3],data['image1'].shape[2],data['image1'].shape[3] - data['image0'],data['image1']=self.resize_df(data['image0'],df),self.resize_df(data['image1'],df) - - if len(train_res)==1: - train_res_h=train_res_w=train_res + + def resize_input(self, data, train_res, df=32): + h0, w0, h1, w1 = ( + data["image0"].shape[2], + data["image0"].shape[3], + data["image1"].shape[2], + data["image1"].shape[3], + ) + data["image0"], data["image1"] = self.resize_df( + data["image0"], df + ), self.resize_df(data["image1"], df) + + if len(train_res) == 1: + train_res_h = train_res_w = train_res else: - train_res_h,train_res_w=train_res[0],train_res[1] - data['pos_scale0'],data['pos_scale1']=[train_res_h/data['image0'].shape[2],train_res_w/data['image0'].shape[3]],\ - [train_res_h/data['image1'].shape[2],train_res_w/data['image1'].shape[3]] - data['online_resize_scale0'],data['online_resize_scale1']=torch.tensor([w0/data['image0'].shape[3],h0/data['image0'].shape[2]])[None].cuda(),\ - torch.tensor([w1/data['image1'].shape[3],h1/data['image1'].shape[2]])[None].cuda() - - def resize_df(self,image,df=32): - h,w=image.shape[2],image.shape[3] - h_new,w_new=h//df*df,w//df*df - if h!=h_new or w!=w_new: - img_new=transforms.Resize([h_new,w_new]).forward(image) + train_res_h, train_res_w = train_res[0], train_res[1] + data["pos_scale0"], data["pos_scale1"] = [ + train_res_h / data["image0"].shape[2], + train_res_w / data["image0"].shape[3], + ], [ + train_res_h / data["image1"].shape[2], + train_res_w / data["image1"].shape[3], + ] + data["online_resize_scale0"], data["online_resize_scale1"] = ( + torch.tensor([w0 / data["image0"].shape[3], h0 / data["image0"].shape[2]])[ + None + ].cuda(), + torch.tensor([w1 / data["image1"].shape[3], h1 / data["image1"].shape[2]])[ + None + ].cuda(), + ) + + def resize_df(self, image, df=32): + h, w = image.shape[2], image.shape[3] + h_new, w_new = h // df * df, w // df * df + if h != h_new or w != w_new: + img_new = transforms.Resize([h_new, w_new]).forward(image) else: - img_new=image + img_new = image return img_new diff --git a/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py b/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py index b6e731b3f53ab367c89ef0ea8e1cbffb0d990775..ae8593230b281e960ece68c04dcf214769e50f08 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py +++ b/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py @@ -2,10 +2,12 @@ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 def build_backbone(config): - if config['backbone_type'] == 'ResNetFPN': - if config['resolution'] == (8, 2): - return ResNetFPN_8_2(config['resnetfpn']) - elif config['resolution'] == (16, 4): - return ResNetFPN_16_4(config['resnetfpn']) + if config["backbone_type"] == "ResNetFPN": + if config["resolution"] == (8, 2): + return ResNetFPN_8_2(config["resnetfpn"]) + elif config["resolution"] == (16, 4): + return ResNetFPN_16_4(config["resnetfpn"]) else: - raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + raise ValueError( + f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported." + ) diff --git a/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py b/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py index 985e5b3f273a51e51447a8025ca3aadbe46752eb..948c72940ab00e5741e2788eea841d124333c8ed 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py +++ b/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py @@ -4,12 +4,16 @@ import torch.nn.functional as F def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution without padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False + ) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) class BasicBlock(nn.Module): @@ -25,8 +29,7 @@ class BasicBlock(nn.Module): self.downsample = None else: self.downsample = nn.Sequential( - conv1x1(in_planes, planes, stride=stride), - nn.BatchNorm2d(planes) + conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes) ) def forward(self, x): @@ -37,7 +40,7 @@ class BasicBlock(nn.Module): if self.downsample is not None: x = self.downsample(x) - return self.relu(x+y) + return self.relu(x + y) class ResNetFPN_8_2(nn.Module): @@ -50,14 +53,16 @@ class ResNetFPN_8_2(nn.Module): super().__init__() # Config block = BasicBlock - initial_dim = config['initial_dim'] - block_dims = config['block_dims'] + initial_dim = config["initial_dim"] + block_dims = config["block_dims"] # Class Variable self.in_planes = initial_dim # Networks - self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + 1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = nn.BatchNorm2d(initial_dim) self.relu = nn.ReLU(inplace=True) @@ -84,7 +89,7 @@ class ResNetFPN_8_2(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -107,13 +112,17 @@ class ResNetFPN_8_2(nn.Module): # FPN x3_out = self.layer3_outconv(x3) - x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out_2x = F.interpolate( + x3_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x2_out = self.layer2_outconv(x2) - x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + x2_out = self.layer2_outconv2(x2_out + x3_out_2x) - x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out_2x = F.interpolate( + x2_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x1_out = self.layer1_outconv(x1) - x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + x1_out = self.layer1_outconv2(x1_out + x2_out_2x) return [x3_out, x1_out] @@ -128,14 +137,16 @@ class ResNetFPN_16_4(nn.Module): super().__init__() # Config block = BasicBlock - initial_dim = config['initial_dim'] - block_dims = config['block_dims'] + initial_dim = config["initial_dim"] + block_dims = config["block_dims"] # Class Variable self.in_planes = initial_dim # Networks - self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + 1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = nn.BatchNorm2d(initial_dim) self.relu = nn.ReLU(inplace=True) @@ -164,7 +175,7 @@ class ResNetFPN_16_4(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -188,12 +199,16 @@ class ResNetFPN_16_4(nn.Module): # FPN x4_out = self.layer4_outconv(x4) - x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x4_out_2x = F.interpolate( + x4_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x3_out = self.layer3_outconv(x3) - x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + x3_out = self.layer3_outconv2(x3_out + x4_out_2x) - x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out_2x = F.interpolate( + x3_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x2_out = self.layer2_outconv(x2) - x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + x2_out = self.layer2_outconv2(x2_out + x3_out_2x) return [x4_out, x2_out] diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py b/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py index 953ee55a09144a4ce0099e709f3a992d021aa0ab..c506479a978c3ebb20c6736ed30f0ef0a351d4b9 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py @@ -7,8 +7,9 @@ from time import time INF = 1e9 + def mask_border(m, b: int, v): - """ Mask borders with value + """Mask borders with value Args: m (torch.Tensor): [N, H0, W0, H1, W1] b (int) @@ -39,22 +40,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1): h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): - m[b_idx, h0 - bd:] = v - m[b_idx, :, w0 - bd:] = v - m[b_idx, :, :, h1 - bd:] = v - m[b_idx, :, :, :, w1 - bd:] = v + m[b_idx, h0 - bd :] = v + m[b_idx, :, w0 - bd :] = v + m[b_idx, :, :, h1 - bd :] = v + m[b_idx, :, :, :, w1 - bd :] = v def compute_max_candidates(p_m0, p_m1): """Compute the max candidates of all pairs within a batch - + Args: p_m0, p_m1 (torch.Tensor): padded masks """ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] - max_cand = torch.sum( - torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) return max_cand @@ -63,29 +63,32 @@ class CoarseMatching(nn.Module): super().__init__() self.config = config # general config - self.thr = config['thr'] - self.border_rm = config['border_rm'] + self.thr = config["thr"] + self.border_rm = config["border_rm"] # -- # for trainig fine-level LoFTR - self.train_coarse_percent = config['train_coarse_percent'] - self.train_pad_num_gt_min = config['train_pad_num_gt_min'] - + self.train_coarse_percent = config["train_coarse_percent"] + self.train_pad_num_gt_min = config["train_pad_num_gt_min"] + # we provide 2 options for differentiable matching - self.match_type = config['match_type'] - if self.match_type == 'dual_softmax': - self.temperature=nn.parameter.Parameter(torch.tensor(10.), requires_grad=True) - elif self.match_type == 'sinkhorn': + self.match_type = config["match_type"] + if self.match_type == "dual_softmax": + self.temperature = nn.parameter.Parameter( + torch.tensor(10.0), requires_grad=True + ) + elif self.match_type == "sinkhorn": try: from .superglue import log_optimal_transport except ImportError: raise ImportError("download superglue.py first!") self.log_optimal_transport = log_optimal_transport self.bin_score = nn.Parameter( - torch.tensor(config['skh_init_bin_score'], requires_grad=True)) - self.skh_iters = config['skh_iters'] - self.skh_prefilter = config['skh_prefilter'] + torch.tensor(config["skh_init_bin_score"], requires_grad=True) + ) + self.skh_iters = config["skh_iters"] + self.skh_prefilter = config["skh_prefilter"] else: raise NotImplementedError() - + def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None): """ Args: @@ -108,29 +111,32 @@ class CoarseMatching(nn.Module): """ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) # normalize - feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, - [feat_c0, feat_c1]) - - if self.match_type == 'dual_softmax': - sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, - feat_c1) * self.temperature + feat_c0, feat_c1 = map( + lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1] + ) + + if self.match_type == "dual_softmax": + sim_matrix = ( + torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) * self.temperature + ) if mask_c0 is not None: sim_matrix.masked_fill_( - ~(mask_c0[..., None] * mask_c1[:, None]).bool(), - -INF) + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF + ) conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) - - elif self.match_type == 'sinkhorn': + + elif self.match_type == "sinkhorn": # sinkhorn, dustbin included sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) if mask_c0 is not None: sim_matrix[:, :L, :S].masked_fill_( - ~(mask_c0[..., None] * mask_c1[:, None]).bool(), - -INF) + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF + ) # build uniform prior & use sinkhorn log_assign_matrix = self.log_optimal_transport( - sim_matrix, self.bin_score, self.skh_iters) + sim_matrix, self.bin_score, self.skh_iters + ) assign_matrix = log_assign_matrix.exp() conf_matrix = assign_matrix[:, :-1, :-1] @@ -141,18 +147,21 @@ class CoarseMatching(nn.Module): conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 - if self.config['sparse_spvs']: - data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + if self.config["sparse_spvs"]: + data.update({"conf_matrix_with_bin": assign_matrix.clone()}) - data.update({'conf_matrix': conf_matrix}) + data.update({"conf_matrix": conf_matrix}) # predict coarse matches from conf_matrix data.update(**self.get_coarse_match(conf_matrix, data)) - #update predicted offset - if flow_list[0].shape[2]==flow_list[1].shape[2] and flow_list[0].shape[3]==flow_list[1].shape[3]: - flow_list=torch.stack(flow_list,dim=0) - data.update({'predict_flow':flow_list}) #[2*L*B*H*W*4] - self.get_offset_match(flow_list,data,mask_c0,mask_c1) + # update predicted offset + if ( + flow_list[0].shape[2] == flow_list[1].shape[2] + and flow_list[0].shape[3] == flow_list[1].shape[3] + ): + flow_list = torch.stack(flow_list, dim=0) + data.update({"predict_flow": flow_list}) # [2*L*B*H*W*4] + self.get_offset_match(flow_list, data, mask_c0, mask_c1) @torch.no_grad() def get_coarse_match(self, conf_matrix, data): @@ -172,28 +181,33 @@ class CoarseMatching(nn.Module): 'mconf' (torch.Tensor): [M]} """ axes_lengths = { - 'h0c': data['hw0_c'][0], - 'w0c': data['hw0_c'][1], - 'h1c': data['hw1_c'][0], - 'w1c': data['hw1_c'][1] + "h0c": data["hw0_c"][0], + "w0c": data["hw0_c"][1], + "h1c": data["hw1_c"][0], + "w1c": data["hw1_c"][1], } _device = conf_matrix.device # 1. confidence thresholding mask = conf_matrix > self.thr - mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', - **axes_lengths) - if 'mask0' not in data: + mask = rearrange( + mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths + ) + if "mask0" not in data: mask_border(mask, self.border_rm, False) else: - mask_border_with_padding(mask, self.border_rm, False, - data['mask0'], data['mask1']) - mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', - **axes_lengths) + mask_border_with_padding( + mask, self.border_rm, False, data["mask0"], data["mask1"] + ) + mask = rearrange( + mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths + ) # 2. mutual nearest - mask = mask \ - * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + mask = ( + mask + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + ) # 3. find all valid coarse matches # this only works when at most one `True` in each row @@ -208,67 +222,79 @@ class CoarseMatching(nn.Module): # NOTE: # The sampling is performed across all pairs in a batch without manually balancing # #samples for fine-level increases w.r.t. batch_size - if 'mask0' not in data: - num_candidates_max = mask.size(0) * max( - mask.size(1), mask.size(2)) + if "mask0" not in data: + num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2)) else: num_candidates_max = compute_max_candidates( - data['mask0'], data['mask1']) - num_matches_train = int(num_candidates_max * - self.train_coarse_percent) + data["mask0"], data["mask1"] + ) + num_matches_train = int(num_candidates_max * self.train_coarse_percent) num_matches_pred = len(b_ids) - assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" - + assert ( + self.train_pad_num_gt_min < num_matches_train + ), "min-num-gt-pad should be less than num-train-matches" + # pred_indices is to select from prediction if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: pred_indices = torch.arange(num_matches_pred, device=_device) else: pred_indices = torch.randint( num_matches_pred, - (num_matches_train - self.train_pad_num_gt_min, ), - device=_device) + (num_matches_train - self.train_pad_num_gt_min,), + device=_device, + ) # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) gt_pad_indices = torch.randint( - len(data['spv_b_ids']), - (max(num_matches_train - num_matches_pred, - self.train_pad_num_gt_min), ), - device=_device) - mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + len(data["spv_b_ids"]), + (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),), + device=_device, + ) + mconf_gt = torch.zeros( + len(data["spv_b_ids"]), device=_device + ) # set conf of gt paddings to all zero b_ids, i_ids, j_ids, mconf = map( - lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], - dim=0), - *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], - [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0), + *zip( + [b_ids, data["spv_b_ids"]], + [i_ids, data["spv_i_ids"]], + [j_ids, data["spv_j_ids"]], + [mconf, mconf_gt], + ) + ) # These matches select patches that feed into fine-level network - coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids} # 4. Update with matches in original image resolution - scale = data['hw0_i'][0] / data['hw0_c'][0] - scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale - scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale - mkpts0_c = torch.stack( - [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], - dim=1) * scale0 - mkpts1_c = torch.stack( - [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], - dim=1) * scale1 + scale = data["hw0_i"][0] / data["hw0_c"][0] + scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale + scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale + mkpts0_c = ( + torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1) + * scale0 + ) + mkpts1_c = ( + torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1) + * scale1 + ) # These matches is the current prediction (for visualization) - coarse_matches.update({ - 'gt_mask': mconf == 0, - 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches - 'mkpts0_c': mkpts0_c[mconf != 0], - 'mkpts1_c': mkpts1_c[mconf != 0], - 'mconf': mconf[mconf != 0] - }) + coarse_matches.update( + { + "gt_mask": mconf == 0, + "m_bids": b_ids[mconf != 0], # mconf == 0 => gt matches + "mkpts0_c": mkpts0_c[mconf != 0], + "mkpts1_c": mkpts1_c[mconf != 0], + "mconf": mconf[mconf != 0], + } + ) return coarse_matches @torch.no_grad() - def get_offset_match(self, flow_list, data,mask1,mask2): + def get_offset_match(self, flow_list, data, mask1, mask2): """ Args: offset (torch.Tensor): [L, B, H, W, 2] @@ -280,52 +306,62 @@ class CoarseMatching(nn.Module): 'mkpts1_c' (torch.Tensor): [M, 2], 'mconf' (torch.Tensor): [M]} """ - offset1=flow_list[0] - bs,layer_num=offset1.shape[1],offset1.shape[0] - - #left side - offset1=offset1.view(layer_num,bs,-1,4) - conf1=offset1[:,:,:,2:].mean(dim=-1) + offset1 = flow_list[0] + bs, layer_num = offset1.shape[1], offset1.shape[0] + + # left side + offset1 = offset1.view(layer_num, bs, -1, 4) + conf1 = offset1[:, :, :, 2:].mean(dim=-1) if mask1 is not None: - conf1.masked_fill_(~mask1.bool()[None].expand(layer_num,-1,-1),100) - offset1=offset1[:,:,:,:2] - self.get_offset_match_work(offset1,conf1,data,'left') - - #rihgt side - if len(flow_list)==2: - offset2=flow_list[1].view(layer_num,bs,-1,4) - conf2=offset2[:,:,:,2:].mean(dim=-1) + conf1.masked_fill_(~mask1.bool()[None].expand(layer_num, -1, -1), 100) + offset1 = offset1[:, :, :, :2] + self.get_offset_match_work(offset1, conf1, data, "left") + + # rihgt side + if len(flow_list) == 2: + offset2 = flow_list[1].view(layer_num, bs, -1, 4) + conf2 = offset2[:, :, :, 2:].mean(dim=-1) if mask2 is not None: - conf2.masked_fill_(~mask2.bool()[None].expand(layer_num,-1,-1),100) - offset2=offset2[:,:,:,:2] - self.get_offset_match_work(offset2,conf2,data,'right') - + conf2.masked_fill_(~mask2.bool()[None].expand(layer_num, -1, -1), 100) + offset2 = offset2[:, :, :, :2] + self.get_offset_match_work(offset2, conf2, data, "right") @torch.no_grad() - def get_offset_match_work(self, offset,conf, data,side): - bs,layer_num=offset.shape[1],offset.shape[0] + def get_offset_match_work(self, offset, conf, data, side): + bs, layer_num = offset.shape[1], offset.shape[0] # 1. confidence thresholding - mask_conf= conf<2 + mask_conf = conf < 2 for index in range(bs): - mask_conf[:,index,0]=True #safe guard in case that no match survives + mask_conf[:, index, 0] = True # safe guard in case that no match survives # 3. find offset matches - scale = data['hw0_i'][0] / data['hw0_c'][0] - l_ids,b_ids,i_ids = torch.where(mask_conf) - j_coor=offset[l_ids,b_ids,i_ids,:2] *scale#[N,2] - i_coor=torch.stack([i_ids%data['hw0_c'][1],i_ids//data['hw0_c'][1]],dim=1)*scale - #i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2] + scale = data["hw0_i"][0] / data["hw0_c"][0] + l_ids, b_ids, i_ids = torch.where(mask_conf) + j_coor = offset[l_ids, b_ids, i_ids, :2] * scale # [N,2] + i_coor = ( + torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1) + * scale + ) + # i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2] # These matches is the current prediction (for visualization) - data.update({ - 'offset_bids_'+side: b_ids, # mconf == 0 => gt matches - 'offset_lids_'+side: l_ids, - 'conf'+side: conf[mask_conf] - }) - - if side=='right': - data.update({'offset_kpts0_f_'+side: j_coor.detach(), - 'offset_kpts1_f_'+side: i_coor}) + data.update( + { + "offset_bids_" + side: b_ids, # mconf == 0 => gt matches + "offset_lids_" + side: l_ids, + "conf" + side: conf[mask_conf], + } + ) + + if side == "right": + data.update( + { + "offset_kpts0_f_" + side: j_coor.detach(), + "offset_kpts1_f_" + side: i_coor, + } + ) else: - data.update({'offset_kpts0_f_'+side: i_coor, - 'offset_kpts1_f_'+side: j_coor.detach()}) - - + data.update( + { + "offset_kpts0_f_" + side: i_coor, + "offset_kpts1_f_" + side: j_coor.detach(), + } + ) diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py b/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py index fdc57e84936c805cb387b6239ca4a5ff6154e22e..1ffe9c067b1fb95a75dd102c5947c82d03dbea89 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py @@ -8,7 +8,7 @@ def lower_config(yacs_cfg): _CN = CN() -_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.BACKBONE_TYPE = "ResNetFPN" _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd _CN.FINE_CONCAT_COARSE_FEAT = True @@ -23,15 +23,15 @@ _CN.COARSE = CN() _CN.COARSE.D_MODEL = 256 _CN.COARSE.D_FFN = 256 _CN.COARSE.NHEAD = 8 -_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 -_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.LAYER_NAMES = ["self", "cross"] * 4 +_CN.COARSE.ATTENTION = "linear" # options: ['linear', 'full'] _CN.COARSE.TEMP_BUG_FIX = False # 3. Coarse-Matching config _CN.MATCH_COARSE = CN() _CN.MATCH_COARSE.THR = 0.1 _CN.MATCH_COARSE.BORDER_RM = 2 -_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" # options: ['dual_softmax, 'sinkhorn'] _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 _CN.MATCH_COARSE.SKH_ITERS = 3 _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 @@ -44,7 +44,7 @@ _CN.FINE = CN() _CN.FINE.D_MODEL = 128 _CN.FINE.D_FFN = 128 _CN.FINE.NHEAD = 8 -_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 -_CN.FINE.ATTENTION = 'linear' +_CN.FINE.LAYER_NAMES = ["self", "cross"] * 1 +_CN.FINE.ATTENTION = "linear" default_cfg = lower_config(_CN) diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py b/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py index 6e77aded52e1eb5c01e22c2738104f3b09d6922a..3f41b1db96016efb58888381284f86d448839ff0 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py @@ -26,35 +26,46 @@ class FineMatching(nn.Module): """ M, WW, C = feat_f0.shape W = int(math.sqrt(WW)) - scale = data['hw0_i'][0] / data['hw0_f'][0] + scale = data["hw0_i"][0] / data["hw0_f"][0] self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale # corner case: if no coarse matches found if M == 0: - assert self.training == False, "M is always >0, when training, see coarse_matching.py" + assert ( + self.training == False + ), "M is always >0, when training, see coarse_matching.py" # logger.warning('No matches found in coarse-level.') - data.update({ - 'expec_f': torch.empty(0, 3, device=feat_f0.device), - 'mkpts0_f': data['mkpts0_c'], - 'mkpts1_f': data['mkpts1_c'], - }) + data.update( + { + "expec_f": torch.empty(0, 3, device=feat_f0.device), + "mkpts0_f": data["mkpts0_c"], + "mkpts1_f": data["mkpts1_c"], + } + ) return - feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] - sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) - softmax_temp = 1. / C**.5 + feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :] + sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1) + softmax_temp = 1.0 / C**0.5 heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) # compute coordinates from heatmap coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] - grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape( + 1, -1, 2 + ) # [1, WW, 2] # compute std over - var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] - std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability - + var = ( + torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) + - coords_normalized**2 + ) # [M, 2] + std = torch.sum( + torch.sqrt(torch.clamp(var, min=1e-10)), -1 + ) # [M] clamp needed for numerical stability + # for fine-level supervision - data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + data.update({"expec_f": torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) # compute absolute kpt coords self.get_fine_match(coords_normalized, data) @@ -64,11 +75,10 @@ class FineMatching(nn.Module): W, WW, C, scale = self.W, self.WW, self.C, self.scale # mkpts0_f and mkpts1_f - mkpts0_f = data['mkpts0_c'] - scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale - mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + mkpts0_f = data["mkpts0_c"] + scale1 = scale * data["scale1"][data["b_ids"]] if "scale0" in data else scale + mkpts1_f = ( + data["mkpts1_c"] + (coords_normed * (W // 2) * scale1)[: len(data["mconf"])] + ) - data.update({ - "mkpts0_f": mkpts0_f, - "mkpts1_f": mkpts1_f - }) + data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f}) diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py b/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py index f95cdb65b48324c4f4ceb20231b1bed992b41116..6101f738f2b2b7ee014fcb53a4032391939ed8cd 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py @@ -3,10 +3,10 @@ import torch @torch.no_grad() def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): - """ Warp kpts0 from I0 to I1 with depth, K and Rt + """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). - + Args: kpts0 (torch.Tensor): [N, L, 2] - , depth0 (torch.Tensor): [N, H, W], @@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): # Sample depth, get calculable_mask on depth != 0 kpts0_depth = torch.stack( - [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + [ + depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] + for i in range(kpts0.shape[0]) + ], + dim=0, ) # (N, L) nonzero_mask = kpts0_depth != 0 # Unproject - kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) # Rigid Transform - w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] # Project w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) - w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth # Covisible Check h, w = depth1.shape[1:3] - covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ - (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) w_kpts0_long = w_kpts0.long() w_kpts0_long[~covisible_mask, :] = 0 w_kpts0_depth = torch.stack( - [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + [ + depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] + for i in range(w_kpts0_long.shape[0]) + ], + dim=0, ) # (N, L) - consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + consistent_mask = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() < 0.2 valid_mask = nonzero_mask * covisible_mask * consistent_mask return valid_mask, w_kpts0 diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py b/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py index 07d384ae18370acb99ef00a788f628c967249ace..1da77ecef628e3e263b56fb501b6a6313f05c060 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py @@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module): This is a sinusoidal position encoding that generalized to 2-dimensional images """ - def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None): + def __init__(self, d_model, max_shape=(256, 256), pre_scaling=None): """ Args: max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels @@ -18,44 +18,63 @@ class PositionEncodingSine(nn.Module): We will remove the buggy impl after re-training all variants of our released models. """ super().__init__() - self.d_model=d_model - self.max_shape=max_shape - self.pre_scaling=pre_scaling + self.d_model = d_model + self.max_shape = max_shape + self.pre_scaling = pre_scaling pe = torch.zeros((d_model, *max_shape)) y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) if pre_scaling[0] is not None and pre_scaling[1] is not None: - train_res,test_res=pre_scaling[0],pre_scaling[1] - x_position,y_position=x_position*train_res[1]/test_res[1],y_position*train_res[0]/test_res[0] + train_res, test_res = pre_scaling[0], pre_scaling[1] + x_position, y_position = ( + x_position * train_res[1] / test_res[1], + y_position * train_res[0] / test_res[0], + ) - div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + div_term = torch.exp( + torch.arange(0, d_model // 2, 2).float() + * (-math.log(10000.0) / (d_model // 2)) + ) div_term = div_term[:, None, None] # [C//4, 1, 1] pe[0::4, :, :] = torch.sin(x_position * div_term) pe[1::4, :, :] = torch.cos(x_position * div_term) pe[2::4, :, :] = torch.sin(y_position * div_term) pe[3::4, :, :] = torch.cos(y_position * div_term) - self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] - def forward(self, x,scaling=None): + def forward(self, x, scaling=None): """ Args: x: [N, C, H, W] """ - if scaling is None: #onliner scaling overwrites pre_scaling - return x + self.pe[:, :, :x.size(2), :x.size(3)],self.pe[:, :, :x.size(2), :x.size(3)] + if scaling is None: # onliner scaling overwrites pre_scaling + return ( + x + self.pe[:, :, : x.size(2), : x.size(3)], + self.pe[:, :, : x.size(2), : x.size(3)], + ) else: pe = torch.zeros((self.d_model, *self.max_shape)) - y_position = torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0)*scaling[0] - x_position = torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0)*scaling[1] - - div_term = torch.exp(torch.arange(0, self.d_model//2, 2).float() * (-math.log(10000.0) / (self.d_model//2))) + y_position = ( + torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0) * scaling[0] + ) + x_position = ( + torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0) * scaling[1] + ) + + div_term = torch.exp( + torch.arange(0, self.d_model // 2, 2).float() + * (-math.log(10000.0) / (self.d_model // 2)) + ) div_term = div_term[:, None, None] # [C//4, 1, 1] pe[0::4, :, :] = torch.sin(x_position * div_term) pe[1::4, :, :] = torch.cos(x_position * div_term) pe[2::4, :, :] = torch.sin(y_position * div_term) pe[3::4, :, :] = torch.cos(y_position * div_term) - pe=pe.unsqueeze(0).to(x.device) - return x + pe[:, :, :x.size(2), :x.size(3)],pe[:, :, :x.size(2), :x.size(3)] \ No newline at end of file + pe = pe.unsqueeze(0).to(x.device) + return ( + x + pe[:, :, : x.size(2), : x.size(3)], + pe[:, :, : x.size(2), : x.size(3)], + ) diff --git a/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py b/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py index 5cef3a7968413136f6dc9f52b6a1ec87192b006b..16c468d8ee1425be0d4518477263f377bd09873a 100644 --- a/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py +++ b/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py @@ -13,7 +13,7 @@ from .geometry import warp_kpts @torch.no_grad() def mask_pts_at_padded_regions(grid_pt, mask): """For megadepth dataset, zero-padding exists in images""" - mask = repeat(mask, 'n h w -> n (h w) c', c=2) + mask = repeat(mask, "n h w -> n (h w) c", c=2) grid_pt[~mask.bool()] = 0 return grid_pt @@ -30,37 +30,55 @@ def spvs_coarse(data, config): 'spv_w_pt0_i': [N, hw0, 2], in original image resolution 'spv_pt1_i': [N, hw1, 2], in original image resolution } - + NOTE: - for scannet dataset, there're 3 kinds of resolution {i, c, f} - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} """ # 1. misc - device = data['image0'].device - N, _, H0, W0 = data['image0'].shape - _, _, H1, W1 = data['image1'].shape - scale = config['ASPAN']['RESOLUTION'][0] - scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale - scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + device = data["image0"].device + N, _, H0, W0 = data["image0"].shape + _, _, H1, W1 = data["image1"].shape + scale = config["ASPAN"]["RESOLUTION"][0] + scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale + scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) # 2. warp grids # create kpts in meshgrid and resize them to image resolution - grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_c = ( + create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1) + ) # [N, hw, 2] grid_pt0_i = scale0 * grid_pt0_c - grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_c = ( + create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1) + ) grid_pt1_i = scale1 * grid_pt1_c # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt - if 'mask0' in data: - grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) - grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + if "mask0" in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"]) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"]) # warp kpts bi-directionally and resize them to coarse-level resolution # (no depth consistency check, since it leads to worse results experimentally) # (unhandled edge case: points with 0-depth will be warped to the left-up corner) - _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) - _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + _, w_pt0_i = warp_kpts( + grid_pt0_i, + data["depth0"], + data["depth1"], + data["T_0to1"], + data["K0"], + data["K1"], + ) + _, w_pt1_i = warp_kpts( + grid_pt1_i, + data["depth1"], + data["depth0"], + data["T_1to0"], + data["K1"], + data["K0"], + ) w_pt0_c = w_pt0_i / scale1 w_pt1_c = w_pt1_i / scale0 @@ -72,21 +90,26 @@ def spvs_coarse(data, config): # corner case: out of boundary def out_bound_mask(pt, w, h): - return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + return ( + (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + ) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 - loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) - correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + loop_back = torch.stack( + [nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0 + ) + correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1) correct_0to1[:, 0] = False # ignore the top-left corner # 4. construct a gt conf_matrix - conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device) b_ids, i_ids = torch.where(correct_0to1 != 0) j_ids = nearest_index1[b_ids, i_ids] conf_matrix_gt[b_ids, i_ids, j_ids] = 1 - data.update({'conf_matrix_gt': conf_matrix_gt}) + data.update({"conf_matrix_gt": conf_matrix_gt}) # 5. save coarse matches(gt) for training fine level if len(b_ids) == 0: @@ -96,30 +119,26 @@ def spvs_coarse(data, config): i_ids = torch.tensor([0], device=device) j_ids = torch.tensor([0], device=device) - data.update({ - 'spv_b_ids': b_ids, - 'spv_i_ids': i_ids, - 'spv_j_ids': j_ids - }) + data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids}) # 6. save intermediate results (for fast fine-level computation) - data.update({ - 'spv_w_pt0_i': w_pt0_i, - 'spv_pt1_i': grid_pt1_i - }) + data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i}) def compute_supervision_coarse(data, config): - assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" - data_source = data['dataset_name'][0] - if data_source.lower() in ['scannet', 'megadepth']: + assert ( + len(set(data["dataset_name"])) == 1 + ), "Do not support mixed datasets training!" + data_source = data["dataset_name"][0] + if data_source.lower() in ["scannet", "megadepth"]: spvs_coarse(data, config) else: - raise ValueError(f'Unknown data source: {data_source}') + raise ValueError(f"Unknown data source: {data_source}") ############## ↓ Fine-Level supervision ↓ ############## + @torch.no_grad() def spvs_fine(data, config): """ @@ -129,23 +148,25 @@ def spvs_fine(data, config): """ # 1. misc # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') - w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] - scale = config['ASPAN']['RESOLUTION'][1] - radius = config['ASPAN']['FINE_WINDOW_SIZE'] // 2 + w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"] + scale = config["ASPAN"]["RESOLUTION"][1] + radius = config["ASPAN"]["FINE_WINDOW_SIZE"] // 2 # 2. get coarse prediction - b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"] # 3. compute gt - scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + scale = scale * data["scale1"][b_ids] if "scale0" in data else scale # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later - expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + expec_f_gt = ( + (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius + ) # [M, 2] data.update({"expec_f_gt": expec_f_gt}) def compute_supervision_fine(data, config): - data_source = data['dataset_name'][0] - if data_source.lower() in ['scannet', 'megadepth']: + data_source = data["dataset_name"][0] + if data_source.lower() in ["scannet", "megadepth"]: spvs_fine(data, config) else: raise NotImplementedError diff --git a/third_party/ASpanFormer/src/config/default.py b/third_party/ASpanFormer/src/config/default.py index 40abd51c3f28ea6dee3c4e9fcee6efac5c080a2f..2850199cfb4d403fe4ec7aa5d61a7de524e4183c 100644 --- a/third_party/ASpanFormer/src/config/default.py +++ b/third_party/ASpanFormer/src/config/default.py @@ -1,9 +1,10 @@ from yacs.config import CfgNode as CN + _CN = CN() ############## ↓ ASPAN Pipeline ↓ ############## _CN.ASPAN = CN() -_CN.ASPAN.BACKBONE_TYPE = 'ResNetFPN' +_CN.ASPAN.BACKBONE_TYPE = "ResNetFPN" _CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] _CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd _CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True @@ -17,14 +18,14 @@ _CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 _CN.ASPAN.COARSE = CN() _CN.ASPAN.COARSE.D_MODEL = 256 _CN.ASPAN.COARSE.D_FFN = 256 -_CN.ASPAN.COARSE.D_FLOW= 128 +_CN.ASPAN.COARSE.D_FLOW = 128 _CN.ASPAN.COARSE.NHEAD = 8 -_CN.ASPAN.COARSE.NLEVEL= 3 -_CN.ASPAN.COARSE.INI_LAYER_NUM = 2 -_CN.ASPAN.COARSE.LAYER_NUM = 4 -_CN.ASPAN.COARSE.NSAMPLE = [2,8] -_CN.ASPAN.COARSE.RADIUS_SCALE= 5 -_CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26] +_CN.ASPAN.COARSE.NLEVEL = 3 +_CN.ASPAN.COARSE.INI_LAYER_NUM = 2 +_CN.ASPAN.COARSE.LAYER_NUM = 4 +_CN.ASPAN.COARSE.NSAMPLE = [2, 8] +_CN.ASPAN.COARSE.RADIUS_SCALE = 5 +_CN.ASPAN.COARSE.COARSEST_LEVEL = [26, 26] _CN.ASPAN.COARSE.TRAIN_RES = None _CN.ASPAN.COARSE.TEST_RES = None @@ -32,7 +33,9 @@ _CN.ASPAN.COARSE.TEST_RES = None _CN.ASPAN.MATCH_COARSE = CN() _CN.ASPAN.MATCH_COARSE.THR = 0.2 _CN.ASPAN.MATCH_COARSE.BORDER_RM = 2 -_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = ( + "dual_softmax" # options: ['dual_softmax, 'sinkhorn'] +) _CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3 _CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 _CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False @@ -46,13 +49,13 @@ _CN.ASPAN.FINE = CN() _CN.ASPAN.FINE.D_MODEL = 128 _CN.ASPAN.FINE.D_FFN = 128 _CN.ASPAN.FINE.NHEAD = 8 -_CN.ASPAN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 -_CN.ASPAN.FINE.ATTENTION = 'linear' +_CN.ASPAN.FINE.LAYER_NAMES = ["self", "cross"] * 1 +_CN.ASPAN.FINE.ATTENTION = "linear" # 5. ASPAN Losses # -- # coarse-level _CN.ASPAN.LOSS = CN() -_CN.ASPAN.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.ASPAN.LOSS.COARSE_TYPE = "focal" # ['focal', 'cross_entropy'] _CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0 # _CN.ASPAN.LOSS.SPARSE_SPVS = False # -- - -- # focal loss (coarse) @@ -64,7 +67,7 @@ _CN.ASPAN.LOSS.NEG_WEIGHT = 1.0 # use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE` # -- # fine-level -_CN.ASPAN.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.ASPAN.LOSS.FINE_TYPE = "l2_with_std" # ['l2_with_std', 'l2'] _CN.ASPAN.LOSS.FINE_WEIGHT = 1.0 _CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) @@ -85,24 +88,32 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None _CN.DATASET.VAL_DATA_ROOT = None _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) _CN.DATASET.VAL_NPZ_ROOT = None -_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_LIST_PATH = ( + None # None if val data from all scenes are bundled into a single npz file +) _CN.DATASET.VAL_INTRINSIC_PATH = None # testing _CN.DATASET.TEST_DATA_SOURCE = None _CN.DATASET.TEST_DATA_ROOT = None _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) _CN.DATASET.TEST_NPZ_ROOT = None -_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_LIST_PATH = ( + None # None if test data from all scenes are bundled into a single npz file +) _CN.DATASET.TEST_INTRINSIC_PATH = None # 2. dataset config # general options -_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = ( + 0.4 # discard data with overlap_score < min_overlap_score +) _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] # MegaDepth options -_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_RESIZE = ( + 640 # resize the longer side, zero-pad bottom-right to square. +) _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 _CN.DATASET.MGDPT_DF = 8 @@ -118,17 +129,17 @@ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning # optimizer _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime -_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam _CN.TRAINER.ADAMW_DECAY = 0.1 # step-based warm-up -_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] -_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0.0 _CN.TRAINER.WARMUP_STEP = 4800 # learning rate scheduler -_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] -_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.SCHEDULER = "MultiStepLR" # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step] _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR _CN.TRAINER.MSLR_GAMMA = 0.5 _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing @@ -136,25 +147,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' in # plotting related _CN.TRAINER.ENABLE_PLOTTING = True -_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting -_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] -_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic" # geometric metrics and pose solver -_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) -_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] -_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.EPI_ERR_THR = ( + 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +) +_CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, DEGENSAC, MAGSAC] _CN.TRAINER.RANSAC_PIXEL_THR = 0.5 _CN.TRAINER.RANSAC_CONF = 0.99999 _CN.TRAINER.RANSAC_MAX_ITERS = 10000 _CN.TRAINER.USE_MAGSACPP = False # data sampler for train_dataloader -_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +_CN.TRAINER.DATA_SAMPLER = ( + "scene_balance" # options: ['scene_balance', 'random', 'normal'] +) # 'scene_balance' config _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 -_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not -_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = ( + True # whether sample each scene with replacement or not +) +_CN.TRAINER.SB_SUBSET_SHUFFLE = ( + True # after sampling from scenes, whether shuffle within the epoch or not +) _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data # 'random' config _CN.TRAINER.RDM_REPLACEMENT = True diff --git a/third_party/ASpanFormer/src/datasets/__init__.py b/third_party/ASpanFormer/src/datasets/__init__.py index 1860e3ae060a26e4625925861cecdc355f2b08b7..4feb648440e6c8db60de3aa475cd82ce460dcc1c 100644 --- a/third_party/ASpanFormer/src/datasets/__init__.py +++ b/third_party/ASpanFormer/src/datasets/__init__.py @@ -1,3 +1,2 @@ from .scannet import ScanNetDataset from .megadepth import MegaDepthDataset - diff --git a/third_party/ASpanFormer/src/datasets/megadepth.py b/third_party/ASpanFormer/src/datasets/megadepth.py index a70ac715a3f807e37bc5b87ae9446ddd2aa4fc86..7cbf95962df705c14d11483838f13bfd5e036166 100644 --- a/third_party/ASpanFormer/src/datasets/megadepth.py +++ b/third_party/ASpanFormer/src/datasets/megadepth.py @@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth class MegaDepthDataset(Dataset): - def __init__(self, - root_dir, - npz_path, - mode='train', - min_overlap_score=0.4, - img_resize=None, - df=None, - img_padding=False, - depth_padding=False, - augment_fn=None, - **kwargs): + def __init__( + self, + root_dir, + npz_path, + mode="train", + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + **kwargs + ): """ Manage one scene(npz_path) of MegaDepth dataset. - + Args: root_dir (str): megadepth root directory that has `phoenix`. npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. @@ -38,28 +40,36 @@ class MegaDepthDataset(Dataset): super().__init__() self.root_dir = root_dir self.mode = mode - self.scene_id = npz_path.split('.')[0] + self.scene_id = npz_path.split(".")[0] # prepare scene_info and pair_info - if mode == 'test' and min_overlap_score != 0: - logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + if mode == "test" and min_overlap_score != 0: + logger.warning( + "You are using `min_overlap_score`!=0 in test mode. Set to 0." + ) min_overlap_score = 0 self.scene_info = np.load(npz_path, allow_pickle=True) - self.pair_infos = self.scene_info['pair_infos'].copy() - del self.scene_info['pair_infos'] - self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + self.pair_infos = self.scene_info["pair_infos"].copy() + del self.scene_info["pair_infos"] + self.pair_infos = [ + pair_info + for pair_info in self.pair_infos + if pair_info[1] > min_overlap_score + ] # parameters for image resizing, padding and depthmap padding - if mode == 'train': + if mode == "train": assert img_resize is not None and img_padding and depth_padding self.img_resize = img_resize self.df = df self.img_padding = img_padding - self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + self.depth_max_size = ( + 2000 if depth_padding else None + ) # the upperbound of depthmaps size in megadepth. # for training LoFTR - self.augment_fn = augment_fn if mode == 'train' else None - self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + self.augment_fn = augment_fn if mode == "train" else None + self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125) def __len__(self): return len(self.pair_infos) @@ -68,60 +78,77 @@ class MegaDepthDataset(Dataset): (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] # read grayscale image and mask. (1, h, w) and (h, w) - img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) - img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) - + img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1]) + # TODO: Support augmentation & handle seeds for each worker correctly. image0, mask0, scale0 = read_megadepth_gray( - img_name0, self.img_resize, self.df, self.img_padding, None) - # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + img_name0, self.img_resize, self.df, self.img_padding, None + ) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) image1, mask1, scale1 = read_megadepth_gray( - img_name1, self.img_resize, self.df, self.img_padding, None) - # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + img_name1, self.img_resize, self.df, self.img_padding, None + ) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # read depth. shape: (h, w) - if self.mode in ['train', 'val']: + if self.mode in ["train", "val"]: depth0 = read_megadepth_depth( - osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]), + pad_to=self.depth_max_size, + ) depth1 = read_megadepth_depth( - osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]), + pad_to=self.depth_max_size, + ) else: depth0 = depth1 = torch.tensor([]) # read intrinsics of original size - K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) - K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + K_0 = torch.tensor( + self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float + ).reshape(3, 3) + K_1 = torch.tensor( + self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float + ).reshape(3, 3) # read and compute relative poses - T0 = self.scene_info['poses'][idx0] - T1 = self.scene_info['poses'][idx1] - T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T0 = self.scene_info["poses"][idx0] + T1 = self.scene_info["poses"][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) T_1to0 = T_0to1.inverse() data = { - 'image0': image0, # (1, h, w) - 'depth0': depth0, # (h, w) - 'image1': image1, - 'depth1': depth1, - 'T_0to1': T_0to1, # (4, 4) - 'T_1to0': T_1to0, - 'K0': K_0, # (3, 3) - 'K1': K_1, - 'scale0': scale0, # [scale_w, scale_h] - 'scale1': scale1, - 'dataset_name': 'MegaDepth', - 'scene_id': self.scene_id, - 'pair_id': idx, - 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + "image0": image0, # (1, h, w) + "depth0": depth0, # (h, w) + "image1": image1, + "depth1": depth1, + "T_0to1": T_0to1, # (4, 4) + "T_1to0": T_1to0, + "K0": K_0, # (3, 3) + "K1": K_1, + "scale0": scale0, # [scale_w, scale_h] + "scale1": scale1, + "dataset_name": "MegaDepth", + "scene_id": self.scene_id, + "pair_id": idx, + "pair_names": ( + self.scene_info["image_paths"][idx0], + self.scene_info["image_paths"][idx1], + ), } # for LoFTR training if mask0 is not None: # img_padding is True if self.coarse_scale: - [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), - scale_factor=self.coarse_scale, - mode='nearest', - recompute_scale_factor=False)[0].bool() - data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + [ts_mask_0, ts_mask_1] = F.interpolate( + torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode="nearest", + recompute_scale_factor=False, + )[0].bool() + data.update({"mask0": ts_mask_0, "mask1": ts_mask_1}) return data diff --git a/third_party/ASpanFormer/src/datasets/sampler.py b/third_party/ASpanFormer/src/datasets/sampler.py index 81b6f435645632a013476f9a665a0861ab7fcb61..131111c4cf69cd8770058dfac2be717aa183978e 100644 --- a/third_party/ASpanFormer/src/datasets/sampler.py +++ b/third_party/ASpanFormer/src/datasets/sampler.py @@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset class RandomConcatSampler(Sampler): - """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. - + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. Args: shuffle (bool): shuffle the random sampled indices across all sub-datsets. @@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler): TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 """ - def __init__(self, - data_source: ConcatDataset, - n_samples_per_subset: int, - subset_replacement: bool=True, - shuffle: bool=True, - repeat: int=1, - seed: int=None): + + def __init__( + self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool = True, + shuffle: bool = True, + repeat: int = 1, + seed: int = None, + ): if not isinstance(data_source, ConcatDataset): raise TypeError("data_source should be torch.utils.data.ConcatDataset") - + self.data_source = data_source self.n_subset = len(self.data_source.datasets) self.n_samples_per_subset = n_samples_per_subset @@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler): self.shuffle = shuffle self.generator = torch.manual_seed(seed) assert self.repeat >= 1 - + def __len__(self): return self.n_samples - + def __iter__(self): indices = [] # sample from each sub-dataset for d_idx in range(self.n_subset): - low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1] high = self.data_source.cumulative_sizes[d_idx] if self.subset_replacement: - rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), - generator=self.generator, dtype=torch.int64) + rand_tensor = torch.randint( + low, + high, + (self.n_samples_per_subset,), + generator=self.generator, + dtype=torch.int64, + ) else: # sample without replacement len_subset = len(self.data_source.datasets[d_idx]) rand_tensor = torch.randperm(len_subset, generator=self.generator) + low if len_subset >= self.n_samples_per_subset: - rand_tensor = rand_tensor[:self.n_samples_per_subset] - else: # padding with replacement - rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), - generator=self.generator, dtype=torch.int64) + rand_tensor = rand_tensor[: self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint( + low, + high, + (self.n_samples_per_subset - len_subset,), + generator=self.generator, + dtype=torch.int64, + ) rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) indices.append(rand_tensor) indices = torch.cat(indices) @@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler): _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] repeat_indices = map(_choice, repeat_indices) indices = torch.cat([indices, *repeat_indices], 0) - + assert indices.shape[0] == self.n_samples return iter(indices.tolist()) diff --git a/third_party/ASpanFormer/src/datasets/scannet.py b/third_party/ASpanFormer/src/datasets/scannet.py index 3520d34c0f08a784ddbf923846a7cb2a847b1787..615e98409b92713ab241aa8658c74cf7b2f8baae 100644 --- a/third_party/ASpanFormer/src/datasets/scannet.py +++ b/third_party/ASpanFormer/src/datasets/scannet.py @@ -10,20 +10,22 @@ from src.utils.dataset import ( read_scannet_gray, read_scannet_depth, read_scannet_pose, - read_scannet_intrinsic + read_scannet_intrinsic, ) class ScanNetDataset(utils.data.Dataset): - def __init__(self, - root_dir, - npz_path, - intrinsic_path, - mode='train', - min_overlap_score=0.4, - augment_fn=None, - pose_dir=None, - **kwargs): + def __init__( + self, + root_dir, + npz_path, + intrinsic_path, + mode="train", + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + **kwargs, + ): """Manage one scene of ScanNet Dataset. Args: root_dir (str): ScanNet root directory that contains scene folders. @@ -41,73 +43,81 @@ class ScanNetDataset(utils.data.Dataset): # prepare data_names, intrinsics and extrinsics(T) with np.load(npz_path) as data: - self.data_names = data['name'] - if 'score' in data.keys() and mode not in ['val' or 'test']: - kept_mask = data['score'] > min_overlap_score + self.data_names = data["name"] + if "score" in data.keys() and mode not in ["val" or "test"]: + kept_mask = data["score"] > min_overlap_score self.data_names = self.data_names[kept_mask] self.intrinsics = dict(np.load(intrinsic_path)) # for training LoFTR - self.augment_fn = augment_fn if mode == 'train' else None + self.augment_fn = augment_fn if mode == "train" else None def __len__(self): return len(self.data_names) def _read_abs_pose(self, scene_name, name): - pth = osp.join(self.pose_dir, - scene_name, - 'pose', f'{name}.txt') + pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt") return read_scannet_pose(pth) def _compute_rel_pose(self, scene_name, name0, name1): pose0 = self._read_abs_pose(scene_name, name0) pose1 = self._read_abs_pose(scene_name, name1) - + return np.matmul(pose1, inv(pose0)) # (4, 4) def __getitem__(self, idx): data_name = self.data_names[idx] scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name - scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" # read the grayscale image which will be resized to (1, 480, 640) - img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') - img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') + img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg") + img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg") # TODO: Support augmentation & handle seeds for each worker correctly. image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) - # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) - # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # read the depthmap which is stored as (480, 640) - if self.mode in ['train', 'val']: - depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) - depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + if self.mode in ["train", "val"]: + depth0 = read_scannet_depth( + osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png") + ) + depth1 = read_scannet_depth( + osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png") + ) else: depth0 = depth1 = torch.tensor([]) # read the intrinsic of depthmap - K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + K_0 = K_1 = torch.tensor( + self.intrinsics[scene_name].copy(), dtype=torch.float + ).reshape(3, 3) # read and compute relative poses - T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), - dtype=torch.float32) + T_0to1 = torch.tensor( + self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32, + ) T_1to0 = T_0to1.inverse() data = { - 'image0': image0, # (1, h, w) - 'depth0': depth0, # (h, w) - 'image1': image1, - 'depth1': depth1, - 'T_0to1': T_0to1, # (4, 4) - 'T_1to0': T_1to0, - 'K0': K_0, # (3, 3) - 'K1': K_1, - 'dataset_name': 'ScanNet', - 'scene_id': scene_name, - 'pair_id': idx, - 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), - osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + "image0": image0, # (1, h, w) + "depth0": depth0, # (h, w) + "image1": image1, + "depth1": depth1, + "T_0to1": T_0to1, # (4, 4) + "T_1to0": T_1to0, + "K0": K_0, # (3, 3) + "K1": K_1, + "dataset_name": "ScanNet", + "scene_id": scene_name, + "pair_id": idx, + "pair_names": ( + osp.join(scene_name, "color", f"{stem_name_0}.jpg"), + osp.join(scene_name, "color", f"{stem_name_1}.jpg"), + ), } return data diff --git a/third_party/ASpanFormer/src/lightning/data.py b/third_party/ASpanFormer/src/lightning/data.py index 73db514b8924d647814e6c5def919c23393d3ccf..9877df5980c73e9bfb5a1e6ec301e1a84a97ca56 100644 --- a/third_party/ASpanFormer/src/lightning/data.py +++ b/third_party/ASpanFormer/src/lightning/data.py @@ -16,7 +16,7 @@ from torch.utils.data import ( ConcatDataset, DistributedSampler, RandomSampler, - dataloader + dataloader, ) from src.utils.augment import build_augmentor @@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler class MultiSceneDataModule(pl.LightningDataModule): - """ + """ For distributed training, each training process is assgined only a part of the training scenes to reduce memory overhead. """ + def __init__(self, args, config): super().__init__() @@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule): # 2. dataset config # general options - self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_test = ( + config.DATASET.MIN_OVERLAP_SCORE_TEST + ) # 0.4, omit data with overlap_score < min_overlap_score self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN - self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + self.augment_fn = build_augmentor( + config.DATASET.AUGMENTATION_TYPE + ) # None, options: [None, 'dark', 'mobile'] # MegaDepth options self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 - self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True - self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True self.mgdpt_df = config.DATASET.MGDPT_DF # 8 self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr. # 3.loader parameters self.train_loader_params = { - 'batch_size': args.batch_size, - 'num_workers': args.num_workers, - 'pin_memory': getattr(args, 'pin_memory', True) + "batch_size": args.batch_size, + "num_workers": args.num_workers, + "pin_memory": getattr(args, "pin_memory", True), } self.val_loader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': args.num_workers, - 'pin_memory': getattr(args, 'pin_memory', True) + "batch_size": 1, + "shuffle": False, + "num_workers": args.num_workers, + "pin_memory": getattr(args, "pin_memory", True), } self.test_loader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': args.num_workers, - 'pin_memory': True + "batch_size": 1, + "shuffle": False, + "num_workers": args.num_workers, + "pin_memory": True, } - + # 4. sampler self.data_sampler = config.TRAINER.DATA_SAMPLER self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE self.repeat = config.TRAINER.SB_REPEAT - + # (optional) RandomSampler for debugging # misc configurations - self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.parallel_load_data = getattr(args, "parallel_load_data", False) self.seed = config.TRAINER.SEED # 66 def setup(self, stage=None): @@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule): stage (str): 'fit' in training phase, and 'test' in testing phase. """ - assert stage in ['fit', 'test'], "stage must be either fit or test" + assert stage in ["fit", "test"], "stage must be either fit or test" try: self.world_size = dist.get_world_size() @@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule): self.rank = 0 logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") - if stage == 'fit': + if stage == "fit": self.train_dataset = self._setup_dataset( self.train_data_root, self.train_npz_root, self.train_list_path, self.train_intrinsic_path, - mode='train', + mode="train", min_overlap_score=self.min_overlap_score_train, - pose_dir=self.train_pose_root) + pose_dir=self.train_pose_root, + ) # setup multiple (optional) validation subsets if isinstance(self.val_list_path, (list, tuple)): self.val_dataset = [] if not isinstance(self.val_npz_root, (list, tuple)): - self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + self.val_npz_root = [ + self.val_npz_root for _ in range(len(self.val_list_path)) + ] for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): - self.val_dataset.append(self._setup_dataset( - self.val_data_root, - npz_root, - npz_list, - self.val_intrinsic_path, - mode='val', - min_overlap_score=self.min_overlap_score_test, - pose_dir=self.val_pose_root)) + self.val_dataset.append( + self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode="val", + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root, + ) + ) else: self.val_dataset = self._setup_dataset( self.val_data_root, self.val_npz_root, self.val_list_path, self.val_intrinsic_path, - mode='val', + mode="val", min_overlap_score=self.min_overlap_score_test, - pose_dir=self.val_pose_root) - logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + pose_dir=self.val_pose_root, + ) + logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!") else: # stage == 'test self.test_dataset = self._setup_dataset( self.test_data_root, self.test_npz_root, self.test_list_path, self.test_intrinsic_path, - mode='test', + mode="test", min_overlap_score=self.min_overlap_score_test, - pose_dir=self.test_pose_root) - logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + pose_dir=self.test_pose_root, + ) + logger.info(f"[rank:{self.rank}]: Test Dataset loaded!") - def _setup_dataset(self, - data_root, - split_npz_root, - scene_list_path, - intri_path, - mode='train', - min_overlap_score=0., - pose_dir=None): - """ Setup train / val / test set""" - with open(scene_list_path, 'r') as f: + def _setup_dataset( + self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode="train", + min_overlap_score=0.0, + pose_dir=None, + ): + """Setup train / val / test set""" + with open(scene_list_path, "r") as f: npz_names = [name.split()[0] for name in f.readlines()] - if mode == 'train': - local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + if mode == "train": + local_npz_names = get_local_split( + npz_names, self.world_size, self.rank, self.seed + ) else: local_npz_names = npz_names - logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') - - dataset_builder = self._build_concat_dataset_parallel \ - if self.parallel_load_data \ - else self._build_concat_dataset - return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, - mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.") + + dataset_builder = ( + self._build_concat_dataset_parallel + if self.parallel_load_data + else self._build_concat_dataset + ) + return dataset_builder( + data_root, + local_npz_names, + split_npz_root, + intri_path, + mode=mode, + min_overlap_score=min_overlap_score, + pose_dir=pose_dir, + ) def _build_concat_dataset( self, @@ -196,49 +222,61 @@ class MultiSceneDataModule(pl.LightningDataModule): npz_dir, intrinsic_path, mode, - min_overlap_score=0., - pose_dir=None + min_overlap_score=0.0, + pose_dir=None, ): datasets = [] - augment_fn = self.augment_fn if mode == 'train' else None - data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source - if data_source=='GL3D' and mode=='val': - data_source='MegaDepth' - if str(data_source).lower() == 'megadepth': - npz_names = [f'{n}.npz' for n in npz_names] - if str(data_source).lower() == 'gl3d': - npz_names = [f'{n}.txt' for n in npz_names] - #npz_names=npz_names[:8] - for npz_name in tqdm(npz_names, - desc=f'[rank:{self.rank}] loading {mode} datasets', - disable=int(self.rank) != 0): + augment_fn = self.augment_fn if mode == "train" else None + data_source = ( + self.trainval_data_source + if mode in ["train", "val"] + else self.test_data_source + ) + if data_source == "GL3D" and mode == "val": + data_source = "MegaDepth" + if str(data_source).lower() == "megadepth": + npz_names = [f"{n}.npz" for n in npz_names] + if str(data_source).lower() == "gl3d": + npz_names = [f"{n}.txt" for n in npz_names] + # npz_names=npz_names[:8] + for npz_name in tqdm( + npz_names, + desc=f"[rank:{self.rank}] loading {mode} datasets", + disable=int(self.rank) != 0, + ): # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. npz_path = osp.join(npz_dir, npz_name) - if data_source == 'ScanNet': + if data_source == "ScanNet": datasets.append( - ScanNetDataset(data_root, - npz_path, - intrinsic_path, - mode=mode, - min_overlap_score=min_overlap_score, - augment_fn=augment_fn, - pose_dir=pose_dir)) - elif data_source == 'MegaDepth': + ScanNetDataset( + data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir, + ) + ) + elif data_source == "MegaDepth": datasets.append( - MegaDepthDataset(data_root, - npz_path, - mode=mode, - min_overlap_score=min_overlap_score, - img_resize=self.mgdpt_img_resize, - df=self.mgdpt_df, - img_padding=self.mgdpt_img_pad, - depth_padding=self.mgdpt_depth_pad, - augment_fn=augment_fn, - coarse_scale=self.coarse_scale)) + MegaDepthDataset( + data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale, + ) + ) else: raise NotImplementedError() return ConcatDataset(datasets) - + def _build_concat_dataset_parallel( self, data_root, @@ -246,78 +284,119 @@ class MultiSceneDataModule(pl.LightningDataModule): npz_dir, intrinsic_path, mode, - min_overlap_score=0., + min_overlap_score=0.0, pose_dir=None, ): - augment_fn = self.augment_fn if mode == 'train' else None - data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source - if str(data_source).lower() == 'megadepth': - npz_names = [f'{n}.npz' for n in npz_names] - #npz_names=npz_names[:8] - with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', - total=len(npz_names), disable=int(self.rank) != 0)): - if data_source == 'ScanNet': - datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( - delayed(lambda x: _build_dataset( - ScanNetDataset, - data_root, - osp.join(npz_dir, x), - intrinsic_path, - mode=mode, - min_overlap_score=min_overlap_score, - augment_fn=augment_fn, - pose_dir=pose_dir))(name) - for name in npz_names) - elif data_source == 'MegaDepth': + augment_fn = self.augment_fn if mode == "train" else None + data_source = ( + self.trainval_data_source + if mode in ["train", "val"] + else self.test_data_source + ) + if str(data_source).lower() == "megadepth": + npz_names = [f"{n}.npz" for n in npz_names] + # npz_names=npz_names[:8] + with tqdm_joblib( + tqdm( + desc=f"[rank:{self.rank}] loading {mode} datasets", + total=len(npz_names), + disable=int(self.rank) != 0, + ) + ): + if data_source == "ScanNet": + datasets = Parallel( + n_jobs=math.floor( + len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size() + ) + )( + delayed( + lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir, + ) + )(name) + for name in npz_names + ) + elif data_source == "MegaDepth": # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. raise NotImplementedError() - datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( - delayed(lambda x: _build_dataset( - MegaDepthDataset, - data_root, - osp.join(npz_dir, x), - mode=mode, - min_overlap_score=min_overlap_score, - img_resize=self.mgdpt_img_resize, - df=self.mgdpt_df, - img_padding=self.mgdpt_img_pad, - depth_padding=self.mgdpt_depth_pad, - augment_fn=augment_fn, - coarse_scale=self.coarse_scale))(name) - for name in npz_names) + datasets = Parallel( + n_jobs=math.floor( + len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size() + ) + )( + delayed( + lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale, + ) + )(name) + for name in npz_names + ) else: - raise ValueError(f'Unknown dataset: {data_source}') + raise ValueError(f"Unknown dataset: {data_source}") return ConcatDataset(datasets) def train_dataloader(self): - """ Build training dataloader for ScanNet / MegaDepth. """ - assert self.data_sampler in ['scene_balance'] - logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') - if self.data_sampler == 'scene_balance': - sampler = RandomConcatSampler(self.train_dataset, - self.n_samples_per_subset, - self.subset_replacement, - self.shuffle, self.repeat, self.seed) + """Build training dataloader for ScanNet / MegaDepth.""" + assert self.data_sampler in ["scene_balance"] + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)." + ) + if self.data_sampler == "scene_balance": + sampler = RandomConcatSampler( + self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, + self.repeat, + self.seed, + ) else: sampler = None - dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + dataloader = DataLoader( + self.train_dataset, sampler=sampler, **self.train_loader_params + ) return dataloader - + def val_dataloader(self): - """ Build validation dataloader for ScanNet / MegaDepth. """ - logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + """Build validation dataloader for ScanNet / MegaDepth.""" + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init." + ) if not isinstance(self.val_dataset, abc.Sequence): sampler = DistributedSampler(self.val_dataset, shuffle=False) - return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + return DataLoader( + self.val_dataset, sampler=sampler, **self.val_loader_params + ) else: dataloaders = [] for dataset in self.val_dataset: sampler = DistributedSampler(dataset, shuffle=False) - dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + dataloaders.append( + DataLoader(dataset, sampler=sampler, **self.val_loader_params) + ) return dataloaders def test_dataloader(self, *args, **kwargs): - logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init." + ) sampler = DistributedSampler(self.test_dataset, shuffle=False) return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) diff --git a/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py b/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py index ee20cbec4628b73c08358ebf1e1906fb2c0ac13c..9b34b7b7485d4419390614e3fe0174ccc53ac7a9 100644 --- a/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py +++ b/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py @@ -1,4 +1,3 @@ - from collections import defaultdict import pprint from loguru import logger @@ -10,15 +9,19 @@ import pytorch_lightning as pl from matplotlib import pyplot as plt from src.ASpanFormer.aspanformer import ASpanFormer -from src.ASpanFormer.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.ASpanFormer.utils.supervision import ( + compute_supervision_coarse, + compute_supervision_fine, +) from src.losses.aspan_loss import ASpanLoss from src.optimizers import build_optimizer, build_scheduler from src.utils.metrics import ( - compute_symmetrical_epipolar_errors,compute_symmetrical_epipolar_errors_offset_bidirectional, + compute_symmetrical_epipolar_errors, + compute_symmetrical_epipolar_errors_offset_bidirectional, compute_pose_errors, - aggregate_metrics + aggregate_metrics, ) -from src.utils.plotting import make_matching_figures,make_matching_figures_offset +from src.utils.plotting import make_matching_figures, make_matching_figures_offset from src.utils.comm import gather, all_gather from src.utils.misc import lower_config, flattenList from src.utils.profiler import PassThroughProfiler @@ -34,200 +37,288 @@ class PL_ASpanFormer(pl.LightningModule): # Misc self.config = config # full config _config = lower_config(self.config) - self.loftr_cfg = lower_config(_config['aspan']) + self.loftr_cfg = lower_config(_config["aspan"]) self.profiler = profiler or PassThroughProfiler() - self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + self.n_vals_plot = max( + config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1 + ) # Matcher: LoFTR - self.matcher = ASpanFormer(config=_config['aspan']) + self.matcher = ASpanFormer(config=_config["aspan"]) self.loss = ASpanLoss(_config) # Pretrained weights print(pretrained_ckpt) if pretrained_ckpt: - print('load') - state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] - msg=self.matcher.load_state_dict(state_dict, strict=False) + print("load") + state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"] + msg = self.matcher.load_state_dict(state_dict, strict=False) print(msg) - logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") - + logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint") + # Testing self.dump_dir = dump_dir - + def configure_optimizers(self): # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` optimizer = build_optimizer(self, self.config) scheduler = build_scheduler(self.config, optimizer) return [optimizer], [scheduler] - + def optimizer_step( - self, epoch, batch_idx, optimizer, optimizer_idx, - optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu, + using_native_amp, + using_lbfgs, + ): # learning rate warm up warmup_step = self.config.TRAINER.WARMUP_STEP if self.trainer.global_step < warmup_step: - if self.config.TRAINER.WARMUP_TYPE == 'linear': + if self.config.TRAINER.WARMUP_TYPE == "linear": base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR - lr = base_lr + \ - (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ - abs(self.config.TRAINER.TRUE_LR - base_lr) + lr = base_lr + ( + self.trainer.global_step / self.config.TRAINER.WARMUP_STEP + ) * abs(self.config.TRAINER.TRUE_LR - base_lr) for pg in optimizer.param_groups: - pg['lr'] = lr - elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pg["lr"] = lr + elif self.config.TRAINER.WARMUP_TYPE == "constant": pass else: - raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + raise ValueError( + f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}" + ) # update params optimizer.step(closure=optimizer_closure) optimizer.zero_grad() - + def _trainval_inference(self, batch): with self.profiler.profile("Compute coarse supervision"): - compute_supervision_coarse(batch, self.config) - + compute_supervision_coarse(batch, self.config) + with self.profiler.profile("LoFTR"): - self.matcher(batch) - + self.matcher(batch) + with self.profiler.profile("Compute fine supervision"): - compute_supervision_fine(batch, self.config) - + compute_supervision_fine(batch, self.config) + with self.profiler.profile("Compute losses"): - self.loss(batch) - + self.loss(batch) + def _compute_metrics(self, batch): with self.profiler.profile("Copmute metrics"): - compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match - compute_symmetrical_epipolar_errors_offset_bidirectional(batch) # compute epi_errs for offset match - compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + compute_symmetrical_epipolar_errors( + batch + ) # compute epi_errs for each match + compute_symmetrical_epipolar_errors_offset_bidirectional( + batch + ) # compute epi_errs for offset match + compute_pose_errors( + batch, self.config + ) # compute R_errs, t_errs, pose_errs for each pair - rel_pair_names = list(zip(*batch['pair_names'])) - bs = batch['image0'].size(0) + rel_pair_names = list(zip(*batch["pair_names"])) + bs = batch["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler - 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], - 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], - 'epi_errs_offset': [batch['epi_errs_offset_left'][batch['offset_bids_left'] == b].cpu().numpy() for b in range(bs)], #only consider left side - 'R_errs': batch['R_errs'], - 't_errs': batch['t_errs'], - 'inliers': batch['inliers']} - ret_dict = {'metrics': metrics} + "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], + "epi_errs": [ + batch["epi_errs"][batch["m_bids"] == b].cpu().numpy() + for b in range(bs) + ], + "epi_errs_offset": [ + batch["epi_errs_offset_left"][batch["offset_bids_left"] == b] + .cpu() + .numpy() + for b in range(bs) + ], # only consider left side + "R_errs": batch["R_errs"], + "t_errs": batch["t_errs"], + "inliers": batch["inliers"], + } + ret_dict = {"metrics": metrics} return ret_dict, rel_pair_names - - + def training_step(self, batch, batch_idx): self._trainval_inference(batch) - + # logging - if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + if ( + self.trainer.global_rank == 0 + and self.global_step % self.trainer.log_every_n_steps == 0 + ): # scalars - for k, v in batch['loss_scalars'].items(): - if not k.startswith('loss_flow') and not k.startswith('conf_'): - self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) - - #log offset_loss and conf for each layer and level - layer_num=self.loftr_cfg['coarse']['layer_num'] + for k, v in batch["loss_scalars"].items(): + if not k.startswith("loss_flow") and not k.startswith("conf_"): + self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step) + + # log offset_loss and conf for each layer and level + layer_num = self.loftr_cfg["coarse"]["layer_num"] for layer_index in range(layer_num): - log_title='layer_'+str(layer_index) - self.logger.experiment.add_scalar(log_title+'/offset_loss', batch['loss_scalars']['loss_flow_'+str(layer_index)], self.global_step) - self.logger.experiment.add_scalar(log_title+'/conf_', batch['loss_scalars']['conf_'+str(layer_index)],self.global_step) - + log_title = "layer_" + str(layer_index) + self.logger.experiment.add_scalar( + log_title + "/offset_loss", + batch["loss_scalars"]["loss_flow_" + str(layer_index)], + self.global_step, + ) + self.logger.experiment.add_scalar( + log_title + "/conf_", + batch["loss_scalars"]["conf_" + str(layer_index)], + self.global_step, + ) + # net-params - if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == 'sinkhorn': + if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == "sinkhorn": self.logger.experiment.add_scalar( - f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step) + f"skh_bin_score", + self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, + self.global_step, + ) # figures if self.config.TRAINER.ENABLE_PLOTTING: - compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match - figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + compute_symmetrical_epipolar_errors( + batch + ) # compute epi_errs for each match + figures = make_matching_figures( + batch, self.config, self.config.TRAINER.PLOT_MODE + ) for k, v in figures.items(): - self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + self.logger.experiment.add_figure( + f"train_match/{k}", v, self.global_step + ) - #plot offset - if self.global_step%200==0: + # plot offset + if self.global_step % 200 == 0: compute_symmetrical_epipolar_errors_offset_bidirectional(batch) - figures_left = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_left') - figures_right = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right') + figures_left = make_matching_figures_offset( + batch, self.config, self.config.TRAINER.PLOT_MODE, side="_left" + ) + figures_right = make_matching_figures_offset( + batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right" + ) for k, v in figures_left.items(): - self.logger.experiment.add_figure(f'train_offset/{k}'+'_left', v, self.global_step) - figures = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right') + self.logger.experiment.add_figure( + f"train_offset/{k}" + "_left", v, self.global_step + ) + figures = make_matching_figures_offset( + batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right" + ) for k, v in figures_right.items(): - self.logger.experiment.add_figure(f'train_offset/{k}'+'_right', v, self.global_step) - - return {'loss': batch['loss']} + self.logger.experiment.add_figure( + f"train_offset/{k}" + "_right", v, self.global_step + ) + + return {"loss": batch["loss"]} def training_epoch_end(self, outputs): - avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() if self.trainer.global_rank == 0: self.logger.experiment.add_scalar( - 'train/avg_loss_on_epoch', avg_loss, - global_step=self.current_epoch) - + "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch + ) + def validation_step(self, batch, batch_idx): self._trainval_inference(batch) - - ret_dict, _ = self._compute_metrics(batch) #this func also compute the epi_errors - + + ret_dict, _ = self._compute_metrics( + batch + ) # this func also compute the epi_errors + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) figures = {self.config.TRAINER.PLOT_MODE: []} figures_offset = {self.config.TRAINER.PLOT_MODE: []} if batch_idx % val_plot_interval == 0: - figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) - figures_offset=make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,'_left') + figures = make_matching_figures( + batch, self.config, mode=self.config.TRAINER.PLOT_MODE + ) + figures_offset = make_matching_figures_offset( + batch, self.config, self.config.TRAINER.PLOT_MODE, "_left" + ) return { **ret_dict, - 'loss_scalars': batch['loss_scalars'], - 'figures': figures, - 'figures_offset_left':figures_offset + "loss_scalars": batch["loss_scalars"], + "figures": figures, + "figures_offset_left": figures_offset, } - + def validation_epoch_end(self, outputs): # handle multiple validation sets - multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_outputs = ( + [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + ) multi_val_metrics = defaultdict(list) - + for valset_idx, outputs in enumerate(multi_outputs): # since pl performs sanity_check at the very begining of the training cur_epoch = self.trainer.current_epoch - if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + if ( + not self.trainer.resume_from_checkpoint + and self.trainer.running_sanity_check + ): cur_epoch = -1 # 1. loss_scalars: dict of list, on cpu - _loss_scalars = [o['loss_scalars'] for o in outputs] - loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + _loss_scalars = [o["loss_scalars"] for o in outputs] + loss_scalars = { + k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) + for k in _loss_scalars[0] + } # 2. val metrics: dict of list, numpy - _metrics = [o['metrics'] for o in outputs] - metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} - # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 - val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + _metrics = [o["metrics"] for o in outputs] + metrics = { + k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) + for k in _metrics[0] + } + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics( + metrics, self.config.TRAINER.EPI_ERR_THR + ) for thr in [5, 10, 20]: - multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) - + multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"]) + # 3. figures - _figures = [o['figures'] for o in outputs] - figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + _figures = [o["figures"] for o in outputs] + figures = { + k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) + for k in _figures[0] + } # tensorboard records only on rank 0 if self.trainer.global_rank == 0: for k, v in loss_scalars.items(): mean_v = torch.stack(v).mean() - self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + self.logger.experiment.add_scalar( + f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch + ) for k, v in val_metrics_4tb.items(): - self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) - + self.logger.experiment.add_scalar( + f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch + ) + for k, v in figures.items(): if self.trainer.global_rank == 0: for plot_idx, fig in enumerate(v): self.logger.experiment.add_figure( - f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) - plt.close('all') + f"val_match_{valset_idx}/{k}/pair-{plot_idx}", + fig, + cur_epoch, + close=True, + ) + plt.close("all") for thr in [5, 10, 20]: # log on all ranks for ModelCheckpoint callback to work properly - self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + self.log( + f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"])) + ) # ckpt monitors on this def test_step(self, batch, batch_idx): with self.profiler.profile("LoFTR"): @@ -238,39 +329,46 @@ class PL_ASpanFormer(pl.LightningModule): with self.profiler.profile("dump_results"): if self.dump_dir is not None: # dump results for further analysis - keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} - pair_names = list(zip(*batch['pair_names'])) - bs = batch['image0'].shape[0] + keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"} + pair_names = list(zip(*batch["pair_names"])) + bs = batch["image0"].shape[0] dumps = [] for b_id in range(bs): item = {} - mask = batch['m_bids'] == b_id - item['pair_names'] = pair_names[b_id] - item['identifier'] = '#'.join(rel_pair_names[b_id]) + mask = batch["m_bids"] == b_id + item["pair_names"] = pair_names[b_id] + item["identifier"] = "#".join(rel_pair_names[b_id]) for key in keys_to_save: item[key] = batch[key][mask].cpu().numpy() - for key in ['R_errs', 't_errs', 'inliers']: + for key in ["R_errs", "t_errs", "inliers"]: item[key] = batch[key][b_id] dumps.append(item) - ret_dict['dumps'] = dumps + ret_dict["dumps"] = dumps return ret_dict def test_epoch_end(self, outputs): # metrics: dict of list, numpy - _metrics = [o['metrics'] for o in outputs] - metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + _metrics = [o["metrics"] for o in outputs] + metrics = { + k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) + for k in _metrics[0] + } # [{key: [{...}, *#bs]}, *#batch] if self.dump_dir is not None: Path(self.dump_dir).mkdir(parents=True, exist_ok=True) - _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + _dumps = flattenList([o["dumps"] for o in outputs]) # [{...}, #bs*#batch] dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] - logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + logger.info( + f"Prediction and evaluation results will be saved to: {self.dump_dir}" + ) if self.trainer.global_rank == 0: print(self.profiler.summary()) - val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) - logger.info('\n' + pprint.pformat(val_metrics_4tb)) + val_metrics_4tb = aggregate_metrics( + metrics, self.config.TRAINER.EPI_ERR_THR + ) + logger.info("\n" + pprint.pformat(val_metrics_4tb)) if self.dump_dir is not None: - np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) + np.save(Path(self.dump_dir) / "LoFTR_pred_eval", dumps) diff --git a/third_party/ASpanFormer/src/losses/aspan_loss.py b/third_party/ASpanFormer/src/losses/aspan_loss.py index 0cca52b36fc997415937969f26caba8c41ac2b8e..dc0f33391b95b6f4f39f673ebc07f6991a00491f 100644 --- a/third_party/ASpanFormer/src/losses/aspan_loss.py +++ b/third_party/ASpanFormer/src/losses/aspan_loss.py @@ -3,48 +3,55 @@ from loguru import logger import torch import torch.nn as nn + class ASpanLoss(nn.Module): def __init__(self, config): super().__init__() self.config = config # config under the global namespace - self.loss_config = config['aspan']['loss'] - self.match_type = self.config['aspan']['match_coarse']['match_type'] - self.sparse_spvs = self.config['aspan']['match_coarse']['sparse_spvs'] - self.flow_weight=self.config['aspan']['loss']['flow_weight'] + self.loss_config = config["aspan"]["loss"] + self.match_type = self.config["aspan"]["match_coarse"]["match_type"] + self.sparse_spvs = self.config["aspan"]["match_coarse"]["sparse_spvs"] + self.flow_weight = self.config["aspan"]["loss"]["flow_weight"] # coarse-level - self.correct_thr = self.loss_config['fine_correct_thr'] - self.c_pos_w = self.loss_config['pos_weight'] - self.c_neg_w = self.loss_config['neg_weight'] + self.correct_thr = self.loss_config["fine_correct_thr"] + self.c_pos_w = self.loss_config["pos_weight"] + self.c_neg_w = self.loss_config["neg_weight"] # fine-level - self.fine_type = self.loss_config['fine_type'] - - def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1): - #coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]] - #flow_list: [L,B,H,W,4] - loss1=self.flow_loss_worker(flow_list[0],coarse_corr_gt[0],coarse_corr_gt[1],coarse_corr_gt[2],w1) - loss2=self.flow_loss_worker(flow_list[1],coarse_corr_gt[0],coarse_corr_gt[2],coarse_corr_gt[1],w0) - total_loss=(loss1+loss2)/2 + self.fine_type = self.loss_config["fine_type"] + + def compute_flow_loss(self, coarse_corr_gt, flow_list, h0, w0, h1, w1): + # coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]] + # flow_list: [L,B,H,W,4] + loss1 = self.flow_loss_worker( + flow_list[0], coarse_corr_gt[0], coarse_corr_gt[1], coarse_corr_gt[2], w1 + ) + loss2 = self.flow_loss_worker( + flow_list[1], coarse_corr_gt[0], coarse_corr_gt[2], coarse_corr_gt[1], w0 + ) + total_loss = (loss1 + loss2) / 2 return total_loss - def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w): - bs,layer_num=flow.shape[1],flow.shape[0] - flow=flow.view(layer_num,bs,-1,4) - gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1) + def flow_loss_worker(self, flow, batch_indicies, self_indicies, cross_indicies, w): + bs, layer_num = flow.shape[1], flow.shape[0] + flow = flow.view(layer_num, bs, -1, 4) + gt_flow = torch.stack([cross_indicies % w, cross_indicies // w], dim=1) - total_loss_list=[] + total_loss_list = [] for layer_index in range(layer_num): - cur_flow_list=flow[layer_index] - spv_flow=cur_flow_list[batch_indicies,self_indicies][:,:2] - spv_conf=cur_flow_list[batch_indicies,self_indicies][:,2:]#[#coarse,2] - l2_flow_dis=((gt_flow-spv_flow)**2) #[#coarse,2] - total_loss=(spv_conf+torch.exp(-spv_conf)*l2_flow_dis) #[#coarse,2] + cur_flow_list = flow[layer_index] + spv_flow = cur_flow_list[batch_indicies, self_indicies][:, :2] + spv_conf = cur_flow_list[batch_indicies, self_indicies][ + :, 2: + ] # [#coarse,2] + l2_flow_dis = (gt_flow - spv_flow) ** 2 # [#coarse,2] + total_loss = spv_conf + torch.exp(-spv_conf) * l2_flow_dis # [#coarse,2] total_loss_list.append(total_loss.mean()) - total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight + total_loss = torch.stack(total_loss_list, dim=-1) * self.flow_weight return total_loss - + def compute_coarse_loss(self, conf, conf_gt, weight=None): - """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + """Point-wise CE / Focal Loss with 0 / 1 confidence as gt. Args: conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) conf_gt (torch.Tensor): (N, HW0, HW1) @@ -56,38 +63,44 @@ class ASpanLoss(nn.Module): if not pos_mask.any(): # assign a wrong gt pos_mask[0, 0, 0] = True if weight is not None: - weight[0, 0, 0] = 0. - c_pos_w = 0. + weight[0, 0, 0] = 0.0 + c_pos_w = 0.0 if not neg_mask.any(): neg_mask[0, 0, 0] = True if weight is not None: - weight[0, 0, 0] = 0. - c_neg_w = 0. - - if self.loss_config['coarse_type'] == 'cross_entropy': - assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!' - conf = torch.clamp(conf, 1e-6, 1-1e-6) - loss_pos = - torch.log(conf[pos_mask]) - loss_neg = - torch.log(1 - conf[neg_mask]) + weight[0, 0, 0] = 0.0 + c_neg_w = 0.0 + + if self.loss_config["coarse_type"] == "cross_entropy": + assert ( + not self.sparse_spvs + ), "Sparse Supervision for cross-entropy not implemented!" + conf = torch.clamp(conf, 1e-6, 1 - 1e-6) + loss_pos = -torch.log(conf[pos_mask]) + loss_neg = -torch.log(1 - conf[neg_mask]) if weight is not None: loss_pos = loss_pos * weight[pos_mask] loss_neg = loss_neg * weight[neg_mask] return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() - elif self.loss_config['coarse_type'] == 'focal': - conf = torch.clamp(conf, 1e-6, 1-1e-6) - alpha = self.loss_config['focal_alpha'] - gamma = self.loss_config['focal_gamma'] - + elif self.loss_config["coarse_type"] == "focal": + conf = torch.clamp(conf, 1e-6, 1 - 1e-6) + alpha = self.loss_config["focal_alpha"] + gamma = self.loss_config["focal_gamma"] + if self.sparse_spvs: - pos_conf = conf[:, :-1, :-1][pos_mask] \ - if self.match_type == 'sinkhorn' \ - else conf[pos_mask] - loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() + pos_conf = ( + conf[:, :-1, :-1][pos_mask] + if self.match_type == "sinkhorn" + else conf[pos_mask] + ) + loss_pos = -alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() # calculate losses for negative samples - if self.match_type == 'sinkhorn': + if self.match_type == "sinkhorn": neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0 - neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0) - loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log() + neg_conf = torch.cat( + [conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0 + ) + loss_neg = -alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log() else: # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision. # we could also add 'pseudo negtive-samples' @@ -97,32 +110,46 @@ class ASpanLoss(nn.Module): # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, # but only through manually setting corresponding regions in sim_matrix to '-inf'. loss_pos = loss_pos * weight[pos_mask] - if self.match_type == 'sinkhorn': + if self.match_type == "sinkhorn": neg_w0 = (weight.sum(-1) != 0)[neg0] neg_w1 = (weight.sum(1) != 0)[neg1] neg_mask = torch.cat([neg_w0, neg_w1], 0) loss_neg = loss_neg[neg_mask] - - loss = c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \ - if self.match_type == 'sinkhorn' \ - else c_pos_w * loss_pos.mean() + + loss = ( + c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + if self.match_type == "sinkhorn" + else c_pos_w * loss_pos.mean() + ) return loss # positive and negative elements occupy similar propotions. => more balanced loss weights needed else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.) - loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() - loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() + loss_pos = ( + -alpha + * torch.pow(1 - conf[pos_mask], gamma) + * (conf[pos_mask]).log() + ) + loss_neg = ( + -alpha + * torch.pow(conf[neg_mask], gamma) + * (1 - conf[neg_mask]).log() + ) if weight is not None: loss_pos = loss_pos * weight[pos_mask] loss_neg = loss_neg * weight[neg_mask] return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed else: - raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type'])) - + raise ValueError( + "Unknown coarse loss: {type}".format( + type=self.loss_config["coarse_type"] + ) + ) + def compute_fine_loss(self, expec_f, expec_f_gt): - if self.fine_type == 'l2_with_std': + if self.fine_type == "l2_with_std": return self._compute_fine_loss_l2_std(expec_f, expec_f_gt) - elif self.fine_type == 'l2': + elif self.fine_type == "l2": return self._compute_fine_loss_l2(expec_f, expec_f_gt) else: raise NotImplementedError() @@ -133,9 +160,13 @@ class ASpanLoss(nn.Module): expec_f (torch.Tensor): [M, 2] expec_f_gt (torch.Tensor): [M, 2] """ - correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + correct_mask = ( + torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr + ) if correct_mask.sum() == 0: - if self.training: # this seldomly happen when training, since we pad prediction with gt + if ( + self.training + ): # this seldomly happen when training, since we pad prediction with gt logger.warning("assign a false supervision to avoid ddp deadlock") correct_mask[0] = True else: @@ -150,20 +181,26 @@ class ASpanLoss(nn.Module): expec_f_gt (torch.Tensor): [M, 2] """ # correct_mask tells you which pair to compute fine-loss - correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + correct_mask = ( + torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr + ) # use std as weight that measures uncertainty std = expec_f[:, 2] - inverse_std = 1. / torch.clamp(std, min=1e-10) - weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std + inverse_std = 1.0 / torch.clamp(std, min=1e-10) + weight = ( + inverse_std / torch.mean(inverse_std) + ).detach() # avoid minizing loss through increase std # corner case: no correct coarse match found if not correct_mask.any(): - if self.training: # this seldomly happen during training, since we pad prediction with gt - # sometimes there is not coarse-level gt at all. + if ( + self.training + ): # this seldomly happen during training, since we pad prediction with gt + # sometimes there is not coarse-level gt at all. logger.warning("assign a false supervision to avoid ddp deadlock") correct_mask[0] = True - weight[0] = 0. + weight[0] = 0.0 else: return None @@ -172,12 +209,15 @@ class ASpanLoss(nn.Module): loss = (flow_l2 * weight[correct_mask]).mean() return loss - + @torch.no_grad() def compute_c_weight(self, data): - """ compute element-wise weights for computing coarse-level loss. """ - if 'mask0' in data: - c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float() + """compute element-wise weights for computing coarse-level loss.""" + if "mask0" in data: + c_weight = ( + data["mask0"].flatten(-2)[..., None] + * data["mask1"].flatten(-2)[:, None] + ).float() else: c_weight = None return c_weight @@ -196,36 +236,54 @@ class ASpanLoss(nn.Module): # 1. coarse-level loss loss_c = self.compute_coarse_loss( - data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ - else data['conf_matrix'], - data['conf_matrix_gt'], - weight=c_weight) - loss = loss_c * self.loss_config['coarse_weight'] + data["conf_matrix_with_bin"] + if self.sparse_spvs and self.match_type == "sinkhorn" + else data["conf_matrix"], + data["conf_matrix_gt"], + weight=c_weight, + ) + loss = loss_c * self.loss_config["coarse_weight"] loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) # 2. fine-level loss - loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt']) + loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"]) if loss_f is not None: - loss += loss_f * self.loss_config['fine_weight'] - loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) + loss += loss_f * self.loss_config["fine_weight"] + loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) else: assert self.training is False - loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound - + loss_scalars.update({"loss_f": torch.tensor(1.0)}) # 1 is the upper bound + # 3. flow loss - coarse_corr=[data['spv_b_ids'],data['spv_i_ids'],data['spv_j_ids']] - loss_flow = self.compute_flow_loss(coarse_corr,data['predict_flow'],\ - data['hw0_c'][0],data['hw0_c'][1],data['hw1_c'][0],data['hw1_c'][1]) - loss_flow=loss_flow*self.flow_weight - for index,loss_off in enumerate(loss_flow): - loss_scalars.update({'loss_flow_'+str(index): loss_off.clone().detach().cpu()}) # 1 is the upper bound - conf=data['predict_flow'][0][:,:,:,:,2:] - layer_num=conf.shape[0] + coarse_corr = [data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]] + loss_flow = self.compute_flow_loss( + coarse_corr, + data["predict_flow"], + data["hw0_c"][0], + data["hw0_c"][1], + data["hw1_c"][0], + data["hw1_c"][1], + ) + loss_flow = loss_flow * self.flow_weight + for index, loss_off in enumerate(loss_flow): + loss_scalars.update( + {"loss_flow_" + str(index): loss_off.clone().detach().cpu()} + ) # 1 is the upper bound + conf = data["predict_flow"][0][:, :, :, :, 2:] + layer_num = conf.shape[0] for layer_index in range(layer_num): - loss_scalars.update({'conf_'+str(layer_index): conf[layer_index].mean().clone().detach().cpu()}) # 1 is the upper bound - - - loss+=loss_flow.sum() - #print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data) - loss_scalars.update({'loss': loss.clone().detach().cpu()}) + loss_scalars.update( + { + "conf_" + + str(layer_index): conf[layer_index] + .mean() + .clone() + .detach() + .cpu() + } + ) # 1 is the upper bound + + loss += loss_flow.sum() + # print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data) + loss_scalars.update({"loss": loss.clone().detach().cpu()}) data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/third_party/ASpanFormer/src/optimizers/__init__.py b/third_party/ASpanFormer/src/optimizers/__init__.py index e1db2285352586c250912bdd2c4ae5029620ab5f..e4e36c22e00217deccacd589f8924b2f74589456 100644 --- a/third_party/ASpanFormer/src/optimizers/__init__.py +++ b/third_party/ASpanFormer/src/optimizers/__init__.py @@ -7,9 +7,13 @@ def build_optimizer(model, config): lr = config.TRAINER.TRUE_LR if name == "adam": - return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + return torch.optim.Adam( + model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY + ) elif name == "adamw": - return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY + ) else: raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") @@ -24,18 +28,27 @@ def build_scheduler(config, optimizer): 'frequency': x, (optional) } """ - scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL} name = config.TRAINER.SCHEDULER - if name == 'MultiStepLR': + if name == "MultiStepLR": scheduler.update( - {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) - elif name == 'CosineAnnealing': + { + "scheduler": MultiStepLR( + optimizer, + config.TRAINER.MSLR_MILESTONES, + gamma=config.TRAINER.MSLR_GAMMA, + ) + } + ) + elif name == "CosineAnnealing": scheduler.update( - {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) - elif name == 'ExponentialLR': + {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)} + ) + elif name == "ExponentialLR": scheduler.update( - {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)} + ) else: raise NotImplementedError() diff --git a/third_party/ASpanFormer/src/utils/augment.py b/third_party/ASpanFormer/src/utils/augment.py index d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d..068751c6c07091bbaed76debd43a73155f61b9bd 100644 --- a/third_party/ASpanFormer/src/utils/augment.py +++ b/third_party/ASpanFormer/src/utils/augment.py @@ -7,16 +7,21 @@ class DarkAug(object): """ def __init__(self) -> None: - self.augmentor = A.Compose([ - A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), - A.Blur(p=0.1, blur_limit=(3, 9)), - A.MotionBlur(p=0.2, blur_limit=(3, 25)), - A.RandomGamma(p=0.1, gamma_limit=(15, 65)), - A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) - ], p=0.75) + self.augmentor = A.Compose( + [ + A.RandomBrightnessContrast( + p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3) + ), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)), + ], + p=0.75, + ) def __call__(self, x): - return self.augmentor(image=x)['image'] + return self.augmentor(image=x)["image"] class MobileAug(object): @@ -25,31 +30,36 @@ class MobileAug(object): """ def __init__(self): - self.augmentor = A.Compose([ - A.MotionBlur(p=0.25), - A.ColorJitter(p=0.5), - A.RandomRain(p=0.1), # random occlusion - A.RandomSunFlare(p=0.1), - A.JpegCompression(p=0.25), - A.ISONoise(p=0.25) - ], p=1.0) + self.augmentor = A.Compose( + [ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25), + ], + p=1.0, + ) def __call__(self, x): - return self.augmentor(image=x)['image'] + return self.augmentor(image=x)["image"] def build_augmentor(method=None, **kwargs): if method is not None: - raise NotImplementedError('Using of augmentation functions are not supported yet!') - if method == 'dark': + raise NotImplementedError( + "Using of augmentation functions are not supported yet!" + ) + if method == "dark": return DarkAug() - elif method == 'mobile': + elif method == "mobile": return MobileAug() elif method is None: return None else: - raise ValueError(f'Invalid augmentation method: {method}') + raise ValueError(f"Invalid augmentation method: {method}") -if __name__ == '__main__': - augmentor = build_augmentor('FDA') +if __name__ == "__main__": + augmentor = build_augmentor("FDA") diff --git a/third_party/ASpanFormer/src/utils/comm.py b/third_party/ASpanFormer/src/utils/comm.py index 26ec9517cc47e224430106d8ae9aa99a3fe49167..9f578cda8933cc358934c645fcf413c63ab4d79d 100644 --- a/third_party/ASpanFormer/src/utils/comm.py +++ b/third_party/ASpanFormer/src/utils/comm.py @@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group): device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) - if len(buffer) > 1024 ** 3: + if len(buffer) > 1024**3: logger = logging.getLogger(__name__) logger.warning( "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( - get_rank(), len(buffer) / (1024 ** 3), device + get_rank(), len(buffer) / (1024**3), device ) ) storage = torch.ByteStorage.from_buffer(buffer) @@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group): ), "comm.gather/all_gather must be called from ranks within the given group!" local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) size_list = [ - torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) ] dist.all_gather(size_list, local_size, group=group) @@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group): # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: - padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor @@ -164,7 +167,8 @@ def all_gather(data, group=None): # receiving Tensor from all ranks tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list ] dist.all_gather(tensor_list, tensor, group=group) @@ -205,7 +209,8 @@ def gather(data, dst=0, group=None): if rank == dst: max_size = max(size_list) tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list ] dist.gather(tensor, tensor_list, dst=dst, group=group) @@ -228,7 +233,7 @@ def shared_random_seed(): All workers must call this function, otherwise it will deadlock. """ - ints = np.random.randint(2 ** 31) + ints = np.random.randint(2**31) all_ints = all_gather(ints) return all_ints[0] diff --git a/third_party/ASpanFormer/src/utils/dataloader.py b/third_party/ASpanFormer/src/utils/dataloader.py index 6da37b880a290c2bb3ebb028d0c8dab592acc5c1..b980dfd344714870ecdacd9e7a9742f51c3ee14d 100644 --- a/third_party/ASpanFormer/src/utils/dataloader.py +++ b/third_party/ASpanFormer/src/utils/dataloader.py @@ -3,21 +3,22 @@ import numpy as np # --- PL-DATAMODULE --- + def get_local_split(items: list, world_size: int, rank: int, seed: int): - """ The local rank only loads a split of the dataset. """ + """The local rank only loads a split of the dataset.""" n_items = len(items) items_permute = np.random.RandomState(seed).permutation(items) if n_items % world_size == 0: padded_items = items_permute else: padding = np.random.RandomState(seed).choice( - items, - world_size - (n_items % world_size), - replace=True) + items, world_size - (n_items % world_size), replace=True + ) padded_items = np.concatenate([items_permute, padding]) - assert len(padded_items) % world_size == 0, \ - f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + assert ( + len(padded_items) % world_size == 0 + ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}" n_per_rank = len(padded_items) // world_size - local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)] return local_items diff --git a/third_party/ASpanFormer/src/utils/dataset.py b/third_party/ASpanFormer/src/utils/dataset.py index 209bf554acc20e33ea89eb9e7024ba68d0b3a30b..1881446fd69aedb520ae669100cd2a3c2d143a18 100644 --- a/third_party/ASpanFormer/src/utils/dataset.py +++ b/third_party/ASpanFormer/src/utils/dataset.py @@ -15,8 +15,11 @@ except Exception: # --- DATA IO --- + def load_array_from_s3( - path, client, cv_type, + path, + client, + cv_type, use_h5py=False, ): byte_str = client.Get(path) @@ -26,7 +29,7 @@ def load_array_from_s3( data = cv2.imdecode(raw_array, cv_type) else: f = io.BytesIO(byte_str) - data = np.array(h5py.File(f, 'r')['/depth']) + data = np.array(h5py.File(f, "r")["/depth"]) except Exception as ex: print(f"==> Data loading failure: {path}") raise ex @@ -36,9 +39,8 @@ def load_array_from_s3( def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): - cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ - else cv2.IMREAD_COLOR - if str(path).startswith('s3://'): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR + if str(path).startswith("s3://"): image = load_array_from_s3(str(path), client, cv_type) else: image = cv2.imread(str(path), cv_type) @@ -54,7 +56,7 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): def get_resized_wh(w, h, resize=None): if resize is not None: # resize the longer edge scale = resize / max(h, w) - w_new, h_new = int(round(w*scale)), int(round(h*scale)) + w_new, h_new = int(round(w * scale)), int(round(h * scale)) else: w_new, h_new = w, h return w_new, h_new @@ -69,20 +71,22 @@ def get_divisible_wh(w, h, df=None): def pad_bottom_right(inp, pad_size, ret_mask=False): - assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + assert isinstance(pad_size, int) and pad_size >= max( + inp.shape[-2:] + ), f"{pad_size} < {max(inp.shape[-2:])}" mask = None if inp.ndim == 2: padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) - padded[:inp.shape[0], :inp.shape[1]] = inp + padded[: inp.shape[0], : inp.shape[1]] = inp if ret_mask: mask = np.zeros((pad_size, pad_size), dtype=bool) - mask[:inp.shape[0], :inp.shape[1]] = True + mask[: inp.shape[0], : inp.shape[1]] = True elif inp.ndim == 3: padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) - padded[:, :inp.shape[1], :inp.shape[2]] = inp + padded[:, : inp.shape[1], : inp.shape[2]] = inp if ret_mask: mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) - mask[:, :inp.shape[1], :inp.shape[2]] = True + mask[:, : inp.shape[1], : inp.shape[2]] = True else: raise NotImplementedError() return padded, mask @@ -90,6 +94,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False): # --- MEGADEPTH --- + def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): """ Args: @@ -99,7 +104,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No Returns: image (torch.tensor): (1, h, w) mask (torch.tensor): (h, w) - scale (torch.tensor): [w/w_new, h/h_new] + scale (torch.tensor): [w/w_new, h/h_new] """ # read image image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) @@ -110,7 +115,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No w_new, h_new = get_divisible_wh(w_new, h_new, df) image = cv2.resize(image, (w_new, h_new)) - scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) if padding: # padding pad_to = max(h_new, w_new) @@ -118,7 +123,9 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No else: mask = None - image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + image = ( + torch.from_numpy(image).float()[None] / 255 + ) # (h, w) -> (1, h, w) and normalized if mask is not None: mask = torch.from_numpy(mask) @@ -126,10 +133,10 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No def read_megadepth_depth(path, pad_to=None): - if str(path).startswith('s3://'): + if str(path).startswith("s3://"): depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) else: - depth = np.array(h5py.File(path, 'r')['depth']) + depth = np.array(h5py.File(path, "r")["depth"]) if pad_to is not None: depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) depth = torch.from_numpy(depth).float() # (h, w) @@ -138,6 +145,7 @@ def read_megadepth_depth(path, pad_to=None): # --- ScanNet --- + def read_scannet_gray(path, resize=(640, 480), augment_fn=None): """ Args: @@ -146,7 +154,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None): Returns: image (torch.tensor): (1, h, w) mask (torch.tensor): (h, w) - scale (torch.tensor): [w/w_new, h/h_new] + scale (torch.tensor): [w/w_new, h/h_new] """ # read and resize image image = imread_gray(path, augment_fn) @@ -158,7 +166,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None): def read_scannet_depth(path): - if str(path).startswith('s3://'): + if str(path).startswith("s3://"): depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) else: depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) @@ -168,55 +176,57 @@ def read_scannet_depth(path): def read_scannet_pose(path): - """ Read ScanNet's Camera2World pose and transform it to World2Camera. - + """Read ScanNet's Camera2World pose and transform it to World2Camera. + Returns: pose_w2c (np.ndarray): (4, 4) """ - cam2world = np.loadtxt(path, delimiter=' ') + cam2world = np.loadtxt(path, delimiter=" ") world2cam = inv(cam2world) return world2cam def read_scannet_intrinsic(path): - """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. - """ - intrinsic = np.loadtxt(path, delimiter=' ') + """Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" + intrinsic = np.loadtxt(path, delimiter=" ") return intrinsic[:-1, :-1] -def read_gl3d_gray(path,resize): - img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize))) - img = torch.from_numpy(img).float()[None] / 255 # (h, w) -> (1, h, w) and normalized +def read_gl3d_gray(path, resize): + img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize))) + img = ( + torch.from_numpy(img).float()[None] / 255 + ) # (h, w) -> (1, h, w) and normalized return img + def read_gl3d_depth(file_path): - with open(file_path, 'rb') as fin: + with open(file_path, "rb") as fin: color = None width = None height = None scale = None data_type = None - header = str(fin.readline().decode('UTF-8')).rstrip() - if header == 'PF': + header = str(fin.readline().decode("UTF-8")).rstrip() + if header == "PF": color = True - elif header == 'Pf': + elif header == "Pf": color = False else: - raise Exception('Not a PFM file.') - dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8')) + raise Exception("Not a PFM file.") + dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8")) if dim_match: width, height = map(int, dim_match.groups()) else: - raise Exception('Malformed PFM header.') - scale = float((fin.readline().decode('UTF-8')).rstrip()) + raise Exception("Malformed PFM header.") + scale = float((fin.readline().decode("UTF-8")).rstrip()) if scale < 0: # little-endian - data_type = ' 0 else 0) precs.append(np.mean(prec_) if len(prec_) > 0 else 0) if ret_dict: - return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} if not offset else {f'prec_flow@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + return ( + {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)} + if not offset + else {f"prec_flow@{t:.0e}": prec for t, prec in zip(thresholds, precs)} + ) else: return precs def aggregate_metrics(metrics, epi_err_thr=5e-4): - """ Aggregate metrics for the whole dataset: + """Aggregate metrics for the whole dataset: (This method should be called once per dataset) 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) """ # filter duplicates - unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"])) unq_ids = list(unq_ids.values()) - logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...") # pose auc angular_thresholds = [5, 10, 20] - pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[ + unq_ids + ] aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) # matching precision dist_thresholds = [epi_err_thr] - precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) - - #offset precision + precs = epidist_prec( + np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True + ) # (prec@err_thr) + + # offset precision try: - precs_offset = epidist_prec(np.array(metrics['epi_errs_offset'], dtype=object)[unq_ids], [2e-3], True,offset=True) - return {**aucs, **precs,**precs_offset} + precs_offset = epidist_prec( + np.array(metrics["epi_errs_offset"], dtype=object)[unq_ids], + [2e-3], + True, + offset=True, + ) + return {**aucs, **precs, **precs_offset} except: return {**aucs, **precs} diff --git a/third_party/ASpanFormer/src/utils/misc.py b/third_party/ASpanFormer/src/utils/misc.py index 25e4433f5ffa41adc4c0435cfe2b5696e43b58b3..d9b6a4a5f5920cde89bdecbf2a444aaea8ff51f3 100644 --- a/third_party/ASpanFormer/src/utils/misc.py +++ b/third_party/ASpanFormer/src/utils/misc.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_only import cv2 import numpy as np + def lower_config(yacs_cfg): if not isinstance(yacs_cfg, CN): return yacs_cfg @@ -25,7 +26,7 @@ def upper_config(dict_cfg): def log_on(condition, message, level): if condition: - assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"] logger.log(level, message) @@ -35,32 +36,35 @@ def get_rank_zero_only_logger(logger: _Logger): else: for _level in logger._core.levels.keys(): level = _level.lower() - setattr(logger, level, - lambda x: None) + setattr(logger, level, lambda x: None) logger._log = lambda x: None return logger def setup_gpus(gpus: Union[str, int]) -> int: - """ A temporary fix for pytorch-lighting 1.3.x """ + """A temporary fix for pytorch-lighting 1.3.x""" gpus = str(gpus) gpu_ids = [] - - if ',' not in gpus: + + if "," not in gpus: n_gpus = int(gpus) return n_gpus if n_gpus != -1 else torch.cuda.device_count() else: - gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] - + gpu_ids = [i.strip() for i in gpus.split(",") if i != ""] + # setup environment variables - visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") if visible_devices is None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) - visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') - logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids) + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + logger.warning( + f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}" + ) else: - logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + logger.warning( + "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process." + ) return len(gpu_ids) @@ -71,11 +75,11 @@ def flattenList(x): @contextlib.contextmanager def tqdm_joblib(tqdm_object): """Context manager to patch joblib to report into tqdm progress bar given as argument - + Usage: with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) - + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) ret_vals = Parallel(n_jobs=args.world_size)( delayed(lambda x: _compute_cov_score(pid, *x))(param) @@ -84,6 +88,7 @@ def tqdm_joblib(tqdm_object): total=len(image_ids)*(len(image_ids)-1)/2)) Src: https://stackoverflow.com/a/58936697 """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -101,39 +106,79 @@ def tqdm_joblib(tqdm_object): tqdm_object.close() -def draw_points(img,points,color=(0,255,0),radius=3): +def draw_points(img, points, color=(0, 255, 0), radius=3): dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] for i in range(points.shape[0]): - cv2.circle(img, dp[i],radius=radius,color=color) + cv2.circle(img, dp[i], radius=radius, color=color) return img - -def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + +def draw_match( + img1, + img2, + corr1, + corr2, + inlier=[True], + color=None, + radius1=1, + radius2=1, + resize=None, +): if resize is not None: - scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] - img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) - corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] - corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] - corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [ + img2.shape[1] / resize[0], + img2.shape[0] / resize[1], + ] + img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize( + img2, resize, interpolation=cv2.INTER_AREA + ) + corr1, corr2 = ( + corr1 / np.asarray(scale1)[np.newaxis], + corr2 / np.asarray(scale2)[np.newaxis], + ) + corr1_key = [ + cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0]) + ] + corr2_key = [ + cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0]) + ] assert len(corr1) == len(corr2) draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] if color is None: - color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] - if len(color)==1: - display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, - matchColor=color[0], - singlePointColor=color[0], - flags=4 - ) + color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier] + if len(color) == 1: + display = cv2.drawMatches( + img1, + corr1_key, + img2, + corr2_key, + draw_matches, + None, + matchColor=color[0], + singlePointColor=color[0], + flags=4, + ) else: - height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] - display=np.zeros([height,width,3],np.uint8) - display[:img1.shape[0],:img1.shape[1]]=img1 - display[:img2.shape[0],img1.shape[1]:]=img2 + height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1] + display = np.zeros([height, width, 3], np.uint8) + display[: img1.shape[0], : img1.shape[1]] = img1 + display[: img2.shape[0], img1.shape[1] :] = img2 for i in range(len(corr1)): - left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) - cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) - cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) + left_x, left_y, right_x, right_y = ( + int(corr1[i][0]), + int(corr1[i][1]), + int(corr2[i][0] + img1.shape[1]), + int(corr2[i][1]), + ) + cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2])) + cv2.line( + display, + (left_x, left_y), + (right_x, right_y), + cur_color, + 1, + lineType=cv2.LINE_AA, + ) return display diff --git a/third_party/ASpanFormer/src/utils/plotting.py b/third_party/ASpanFormer/src/utils/plotting.py index 8696880237b6ad9fe48d3c1fc44ed13b691a6c4d..0ca3ef0a336a652e7ca910a5584227da043ac019 100644 --- a/third_party/ASpanFormer/src/utils/plotting.py +++ b/third_party/ASpanFormer/src/utils/plotting.py @@ -4,38 +4,51 @@ import matplotlib.pyplot as plt import matplotlib from copy import deepcopy + def _compute_conf_thresh(data): - dataset_name = data['dataset_name'][0].lower() - if dataset_name == 'scannet': + dataset_name = data["dataset_name"][0].lower() + if dataset_name == "scannet": thr = 5e-4 - elif dataset_name == 'megadepth' or dataset_name=='gl3d': + elif dataset_name == "megadepth" or dataset_name == "gl3d": thr = 1e-4 else: - raise ValueError(f'Unknown dataset: {dataset_name}') + raise ValueError(f"Unknown dataset: {dataset_name}") return thr # --- VISUALIZATION --- # + def make_matching_figure( - img0, img1, mkpts0, mkpts1, color, - kpts0=None, kpts1=None, text=[], dpi=75, path=None): + img0, + img1, + mkpts0, + mkpts1, + color, + kpts0=None, + kpts1=None, + text=[], + dpi=75, + path=None, +): # draw image pair - assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + assert ( + mkpts0.shape[0] == mkpts1.shape[0] + ), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}" fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) - axes[0].imshow(img0, cmap='gray') - axes[1].imshow(img1, cmap='gray') - for i in range(2): # clear all frames + axes[0].imshow(img0, cmap="gray") + axes[1].imshow(img1, cmap="gray") + for i in range(2): # clear all frames axes[i].get_yaxis().set_ticks([]) axes[i].get_xaxis().set_ticks([]) for spine in axes[i].spines.values(): spine.set_visible(False) plt.tight_layout(pad=1) - + if kpts0 is not None: assert kpts1 is not None - axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) - axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=2) # draw matches if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: @@ -43,164 +56,181 @@ def make_matching_figure( transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) - fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), - (fkpts0[i, 1], fkpts1[i, 1]), - transform=fig.transFigure, c=color[i], linewidth=1) - for i in range(len(mkpts0))] - + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=1, + ) + for i in range(len(mkpts0)) + ] + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) # put txts - txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" fig.text( - 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, - fontsize=15, va='top', ha='left', color=txt_color) + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) # save or return figure if path: - plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) plt.close() else: return fig -def _make_evaluation_figure(data, b_id, alpha='dynamic'): - b_mask = data['m_bids'] == b_id +def _make_evaluation_figure(data, b_id, alpha="dynamic"): + b_mask = data["m_bids"] == b_id conf_thr = _compute_conf_thresh(data) - - img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() - kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() - + + img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data["mkpts0_f"][b_mask].cpu().numpy() + kpts1 = data["mkpts1_f"][b_mask].cpu().numpy() + # for megadepth, we visualize matches on the resized image - if 'scale0' in data: - kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] - kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] - epi_errs = data['epi_errs'][b_mask].cpu().numpy() + if "scale0" in data: + kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]] + epi_errs = data["epi_errs"][b_mask].cpu().numpy() correct_mask = epi_errs < conf_thr precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 n_correct = np.sum(correct_mask) - n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu()) recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) # recall might be larger than 1, since the calculation of conf_matrix_gt # uses groundtruth depths and camera poses, but epipolar distance is used here. # matching info - if alpha == 'dynamic': + if alpha == "dynamic": alpha = dynamic_alpha(len(correct_mask)) color = error_colormap(epi_errs, conf_thr, alpha=alpha) - + text = [ - f'#Matches {len(kpts0)}', - f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', - f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + f"#Matches {len(kpts0)}", + f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}", + f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}", ] - + # make the figure - figure = make_matching_figure(img0, img1, kpts0, kpts1, - color, text=text) + figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text) return figure -def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''): - layer_num=data['predict_flow'][0].shape[0] - b_mask = data['offset_bids'+side] == b_id - conf_thr = 2e-3 #hardcode for scannet(coarse level) - img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - - figure_list=[] - #draw offset matches in different layers +def _make_evaluation_figure_offset(data, b_id, alpha="dynamic", side=""): + layer_num = data["predict_flow"][0].shape[0] + + b_mask = data["offset_bids" + side] == b_id + conf_thr = 2e-3 # hardcode for scannet(coarse level) + img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + + figure_list = [] + # draw offset matches in different layers for layer_index in range(layer_num): - l_mask=data['offset_lids'+side]==layer_index - mask=l_mask&b_mask - kpts0 = data['offset_kpts0_f'+side][mask].cpu().numpy() - kpts1 = data['offset_kpts1_f'+side][mask].cpu().numpy() - - epi_errs = data['epi_errs_offset'+side][mask].cpu().numpy() + l_mask = data["offset_lids" + side] == layer_index + mask = l_mask & b_mask + kpts0 = data["offset_kpts0_f" + side][mask].cpu().numpy() + kpts1 = data["offset_kpts1_f" + side][mask].cpu().numpy() + + epi_errs = data["epi_errs_offset" + side][mask].cpu().numpy() correct_mask = epi_errs < conf_thr - + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 n_correct = np.sum(correct_mask) - n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu()) recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) # recall might be larger than 1, since the calculation of conf_matrix_gt # uses groundtruth depths and camera poses, but epipolar distance is used here. # matching info - if alpha == 'dynamic': + if alpha == "dynamic": alpha = dynamic_alpha(len(correct_mask)) color = error_colormap(epi_errs, conf_thr, alpha=alpha) - + text = [ - f'#Matches {len(kpts0)}', - f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', - f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + f"#Matches {len(kpts0)}", + f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}", + f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}", ] - + # make the figure - #import pdb;pdb.set_trace() - figure = make_matching_figure(deepcopy(img0), deepcopy(img1) , kpts0, kpts1, - color, text=text) + # import pdb;pdb.set_trace() + figure = make_matching_figure( + deepcopy(img0), deepcopy(img1), kpts0, kpts1, color, text=text + ) figure_list.append(figure) return figure + def _make_confidence_figure(data, b_id): # TODO: Implement confidence figure raise NotImplementedError() -def make_matching_figures(data, config, mode='evaluation'): - """ Make matching figures for a batch. - +def make_matching_figures(data, config, mode="evaluation"): + """Make matching figures for a batch. + Args: data (Dict): a batch updated by PL_LoFTR. config (Dict): matcher config Returns: figures (Dict[str, List[plt.figure]] """ - assert mode in ['evaluation', 'confidence'] # 'confidence' + assert mode in ["evaluation", "confidence"] # 'confidence' figures = {mode: []} - for b_id in range(data['image0'].size(0)): - if mode == 'evaluation': + for b_id in range(data["image0"].size(0)): + if mode == "evaluation": fig = _make_evaluation_figure( - data, b_id, - alpha=config.TRAINER.PLOT_MATCHES_ALPHA) - elif mode == 'confidence': + data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA + ) + elif mode == "confidence": fig = _make_confidence_figure(data, b_id) else: - raise ValueError(f'Unknown plot mode: {mode}') + raise ValueError(f"Unknown plot mode: {mode}") figures[mode].append(fig) return figures -def make_matching_figures_offset(data, config, mode='evaluation',side=''): - """ Make matching figures for a batch. - + +def make_matching_figures_offset(data, config, mode="evaluation", side=""): + """Make matching figures for a batch. + Args: data (Dict): a batch updated by PL_LoFTR. config (Dict): matcher config Returns: figures (Dict[str, List[plt.figure]] """ - assert mode in ['evaluation', 'confidence'] # 'confidence' + assert mode in ["evaluation", "confidence"] # 'confidence' figures = {mode: []} - for b_id in range(data['image0'].size(0)): - if mode == 'evaluation': + for b_id in range(data["image0"].size(0)): + if mode == "evaluation": fig = _make_evaluation_figure_offset( - data, b_id, - alpha=config.TRAINER.PLOT_MATCHES_ALPHA,side=side) - elif mode == 'confidence': + data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA, side=side + ) + elif mode == "confidence": fig = _make_evaluation_figure_offset(data, b_id) else: - raise ValueError(f'Unknown plot mode: {mode}') + raise ValueError(f"Unknown plot mode: {mode}") figures[mode].append(fig) return figures -def dynamic_alpha(n_matches, - milestones=[0, 300, 1000, 2000], - alphas=[1.0, 0.8, 0.4, 0.2]): + +def dynamic_alpha( + n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2] +): if n_matches == 0: return 1.0 ranges = list(zip(alphas, alphas[1:] + [None])) @@ -209,11 +239,15 @@ def dynamic_alpha(n_matches, if _range[1] is None: return _range[0] return _range[1] + (milestones[loc + 1] - n_matches) / ( - milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + milestones[loc + 1] - milestones[loc] + ) * (_range[0] - _range[1]) def error_colormap(err, thr, alpha=1.0): assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" x = 1 - np.clip(err / (thr * 2), 0, 1) return np.clip( - np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1), + 0, + 1, + ) diff --git a/third_party/ASpanFormer/src/utils/profiler.py b/third_party/ASpanFormer/src/utils/profiler.py index 6d21ed79fb506ef09c75483355402c48a195aaa9..0275ea34e3eb9cceb4ed809bebeda209749f5bc5 100644 --- a/third_party/ASpanFormer/src/utils/profiler.py +++ b/third_party/ASpanFormer/src/utils/profiler.py @@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only class InferenceProfiler(SimpleProfiler): """ This profiler records duration of actions with cuda.synchronize() - Use this in test time. + Use this in test time. """ def __init__(self): @@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler): def build_profiler(name): - if name == 'inference': + if name == "inference": return InferenceProfiler() - elif name == 'pytorch': + elif name == "pytorch": from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) elif name is None: return PassThroughProfiler() else: - raise ValueError(f'Invalid profiler: {name}') + raise ValueError(f"Invalid profiler: {name}") diff --git a/third_party/ASpanFormer/test.py b/third_party/ASpanFormer/test.py index 541ce84662ab4888c6fece30403c5c9983118637..bed3060d931d2f9e5d60ef3b0eb6a9016322fa0f 100644 --- a/third_party/ASpanFormer/test.py +++ b/third_party/ASpanFormer/test.py @@ -10,33 +10,52 @@ from src.lightning.data import MultiSceneDataModule from src.lightning.lightning_aspanformer import PL_ASpanFormer import torch + def parse_args(): # init a costum parser which will be added into pl.Trainer parser # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - 'data_cfg_path', type=str, help='data config path') - parser.add_argument( - 'main_cfg_path', type=str, help='main config path') - parser.add_argument( - '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') - parser.add_argument( - '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("data_cfg_path", type=str, help="data config path") + parser.add_argument("main_cfg_path", type=str, help="main config path") parser.add_argument( - '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') + "--ckpt_path", + type=str, + default="weights/indoor_ds.ckpt", + help="path to the checkpoint", + ) parser.add_argument( - '--batch_size', type=int, default=1, help='batch_size per gpu') + "--dump_dir", + type=str, + default=None, + help="if set, the matching results will be dump to dump_dir", + ) parser.add_argument( - '--num_workers', type=int, default=2) + "--profiler_name", + type=str, + default=None, + help="options: [inference, pytorch], or leave it unset", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu") + parser.add_argument("--num_workers", type=int, default=2) parser.add_argument( - '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + "--thr", + type=float, + default=None, + help="modify the coarse-level matching threshold.", + ) parser.add_argument( - '--mode', type=str, default='vanilla', help='modify the coarse-level matching threshold.') + "--mode", + type=str, + default="vanilla", + help="modify the coarse-level matching threshold.", + ) parser = pl.Trainer.add_argparse_args(parser) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": # parse arguments args = parse_args() pprint.pprint(vars(args)) @@ -55,7 +74,12 @@ if __name__ == '__main__': # lightning module profiler = build_profiler(args.profiler_name) - model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + model = PL_ASpanFormer( + config, + pretrained_ckpt=args.ckpt_path, + profiler=profiler, + dump_dir=args.dump_dir, + ) loguru_logger.info(f"ASpanFormer-lightning initialized!") # lightning data @@ -63,7 +87,9 @@ if __name__ == '__main__': loguru_logger.info(f"DataModule initialized!") # lightning trainer - trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + trainer = pl.Trainer.from_argparse_args( + args, replace_sampler_ddp=False, logger=False + ) loguru_logger.info(f"Start testing!") trainer.test(model, datamodule=data_module, verbose=False) diff --git a/third_party/ASpanFormer/tools/extract.py b/third_party/ASpanFormer/tools/extract.py index 12f55e2f94120d5765f124f8eec867f1d82e0aa7..b3dea56a14f6c100b2c53978678bab69a656cdeb 100644 --- a/third_party/ASpanFormer/tools/extract.py +++ b/third_party/ASpanFormer/tools/extract.py @@ -5,43 +5,77 @@ from tqdm import tqdm from multiprocessing import Pool from functools import partial -scannet_dir='/root/data/ScanNet-v2-1.0.0/data/raw' -dump_dir='/root/data/scannet_dump' -num_process=32 - -def extract(seq,scannet_dir,split,dump_dir): - assert split=='train' or split=='test' - if not os.path.exists(os.path.join(dump_dir,split,seq)): - os.mkdir(os.path.join(dump_dir,split,seq)) - cmd='python reader.py --filename '+os.path.join(scannet_dir,'scans' if split=='train' else 'scans_test',seq,seq+'.sens')+' --output_path '+os.path.join(dump_dir,split,seq)+\ - ' --export_depth_images --export_color_images --export_poses --export_intrinsics' +scannet_dir = "/root/data/ScanNet-v2-1.0.0/data/raw" +dump_dir = "/root/data/scannet_dump" +num_process = 32 + + +def extract(seq, scannet_dir, split, dump_dir): + assert split == "train" or split == "test" + if not os.path.exists(os.path.join(dump_dir, split, seq)): + os.mkdir(os.path.join(dump_dir, split, seq)) + cmd = ( + "python reader.py --filename " + + os.path.join( + scannet_dir, + "scans" if split == "train" else "scans_test", + seq, + seq + ".sens", + ) + + " --output_path " + + os.path.join(dump_dir, split, seq) + + " --export_depth_images --export_color_images --export_poses --export_intrinsics" + ) os.system(cmd) -if __name__=='__main__': + +if __name__ == "__main__": if not os.path.exists(dump_dir): os.mkdir(dump_dir) - os.mkdir(os.path.join(dump_dir,'train')) - os.mkdir(os.path.join(dump_dir,'test')) + os.mkdir(os.path.join(dump_dir, "train")) + os.mkdir(os.path.join(dump_dir, "test")) - train_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans','scene*'))] - test_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans_test','scene*'))] + train_seq_list = [ + seq.split("/")[-1] + for seq in glob.glob(os.path.join(scannet_dir, "scans", "scene*")) + ] + test_seq_list = [ + seq.split("/")[-1] + for seq in glob.glob(os.path.join(scannet_dir, "scans_test", "scene*")) + ] - extract_train=partial(extract,scannet_dir=scannet_dir,split='train',dump_dir=dump_dir) - extract_test=partial(extract,scannet_dir=scannet_dir,split='test',dump_dir=dump_dir) + extract_train = partial( + extract, scannet_dir=scannet_dir, split="train", dump_dir=dump_dir + ) + extract_test = partial( + extract, scannet_dir=scannet_dir, split="test", dump_dir=dump_dir + ) - num_train_iter=len(train_seq_list)//num_process if len(train_seq_list)%num_process==0 else len(train_seq_list)//num_process+1 - num_test_iter=len(test_seq_list)//num_process if len(test_seq_list)%num_process==0 else len(test_seq_list)//num_process+1 + num_train_iter = ( + len(train_seq_list) // num_process + if len(train_seq_list) % num_process == 0 + else len(train_seq_list) // num_process + 1 + ) + num_test_iter = ( + len(test_seq_list) // num_process + if len(test_seq_list) % num_process == 0 + else len(test_seq_list) // num_process + 1 + ) pool = Pool(num_process) for index in tqdm(range(num_train_iter)): - seq_list=train_seq_list[index*num_process:min((index+1)*num_process,len(train_seq_list))] - pool.map(extract_train,seq_list) + seq_list = train_seq_list[ + index * num_process : min((index + 1) * num_process, len(train_seq_list)) + ] + pool.map(extract_train, seq_list) pool.close() pool.join() - + pool = Pool(num_process) for index in tqdm(range(num_test_iter)): - seq_list=test_seq_list[index*num_process:min((index+1)*num_process,len(test_seq_list))] - pool.map(extract_test,seq_list) + seq_list = test_seq_list[ + index * num_process : min((index + 1) * num_process, len(test_seq_list)) + ] + pool.map(extract_test, seq_list) pool.close() - pool.join() \ No newline at end of file + pool.join() diff --git a/third_party/ASpanFormer/tools/preprocess_scene.py b/third_party/ASpanFormer/tools/preprocess_scene.py index d20c0d070243519d67bbd25668ff5eb1657474be..5364058829b7e45eabd61a32a591711645fc1ded 100644 --- a/third_party/ASpanFormer/tools/preprocess_scene.py +++ b/third_party/ASpanFormer/tools/preprocess_scene.py @@ -6,78 +6,63 @@ import numpy as np import os -parser = argparse.ArgumentParser(description='MegaDepth preprocessing script') +parser = argparse.ArgumentParser(description="MegaDepth preprocessing script") -parser.add_argument( - '--base_path', type=str, required=True, - help='path to MegaDepth' -) -parser.add_argument( - '--scene_id', type=str, required=True, - help='scene ID' -) +parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth") +parser.add_argument("--scene_id", type=str, required=True, help="scene ID") parser.add_argument( - '--output_path', type=str, required=True, - help='path to the output directory' + "--output_path", type=str, required=True, help="path to the output directory" ) args = parser.parse_args() base_path = args.base_path # Remove the trailing / if need be. -if base_path[-1] in ['/', '\\']: - base_path = base_path[: - 1] +if base_path[-1] in ["/", "\\"]: + base_path = base_path[:-1] scene_id = args.scene_id -base_depth_path = os.path.join( - base_path, 'phoenix/S6/zl548/MegaDepth_v1' -) -base_undistorted_sfm_path = os.path.join( - base_path, 'Undistorted_SfM' -) +base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1") +base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM") undistorted_sparse_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'sparse-txt' + base_undistorted_sfm_path, scene_id, "sparse-txt" ) if not os.path.exists(undistorted_sparse_path): exit() -depths_path = os.path.join( - base_depth_path, scene_id, 'dense0', 'depths' -) +depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths") if not os.path.exists(depths_path): exit() -images_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'images' -) +images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images") if not os.path.exists(images_path): exit() # Process cameras.txt -with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header camera_intrinsics = {} for camera in raw: - camera = camera.split(' ') - camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + camera = camera.split(" ") + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]] # Process points3D.txt -with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header points3D = {} for point3D in raw: - point3D = point3D.split(' ') - points3D[int(point3D[0])] = np.array([ - float(point3D[1]), float(point3D[2]), float(point3D[3]) - ]) - + point3D = point3D.split(" ") + points3D[int(point3D[0])] = np.array( + [float(point3D[1]), float(point3D[2]), float(point3D[3])] + ) + # Process images.txt -with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: - raw = f.readlines()[4 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f: + raw = f.readlines()[4:] # skip the header image_id_to_idx = {} image_names = [] @@ -85,19 +70,19 @@ raw_pose = [] camera = [] points3D_id_to_2D = [] n_points3D = [] -for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): - image = image.split(' ') - points = points.split(' ') +for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])): + image = image.split(" ") + points = points.split(" ") image_id_to_idx[int(image[0])] = idx - image_name = image[-1].strip('\n') + image_name = image[-1].strip("\n") image_names.append(image_name) - raw_pose.append([float(elem) for elem in image[1 : -2]]) + raw_pose.append([float(elem) for elem in image[1:-2]]) camera.append(int(image[-2])) current_points3D_id_to_2D = {} - for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]): + for x, y, point3D_id in zip(points[::3], points[1::3], points[2::3]): if int(point3D_id) == -1: continue current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)] @@ -110,12 +95,10 @@ image_paths = [] depth_paths = [] for image_name in image_names: image_path = os.path.join(images_path, image_name) - + # Path to the depth file - depth_path = os.path.join( - depths_path, '%s.h5' % os.path.splitext(image_name)[0] - ) - + depth_path = os.path.join(depths_path, "%s.h5" % os.path.splitext(image_name)[0]) + if os.path.exists(depth_path): # Check if depth map or background / foreground mask file_size = os.stat(depth_path).st_size @@ -152,32 +135,22 @@ for idx, image_name in enumerate(image_names): intrinsics.append(K) image_pose = raw_pose[idx] - qvec = image_pose[: 4] + qvec = image_pose[:4] qvec = qvec / np.linalg.norm(qvec) w, x, y, z = qvec - R = np.array([ - [ - 1 - 2 * y * y - 2 * z * z, - 2 * x * y - 2 * z * w, - 2 * x * z + 2 * y * w - ], + R = np.array( [ - 2 * x * y + 2 * z * w, - 1 - 2 * x * x - 2 * z * z, - 2 * y * z - 2 * x * w - ], - [ - 2 * x * z - 2 * y * w, - 2 * y * z + 2 * x * w, - 1 - 2 * x * x - 2 * y * y + [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w], + [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w], + [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y], ] - ]) + ) principal_axis.append(R[2, :]) - t = image_pose[4 : 7] + t = image_pose[4:7] # World-to-Camera pose current_pose = np.zeros([4, 4]) - current_pose[: 3, : 3] = R - current_pose[: 3, 3] = t + current_pose[:3, :3] = R + current_pose[:3, 3] = t current_pose[3, 3] = 1 # Camera-to-World pose # pose = np.zeros([4, 4]) @@ -185,38 +158,38 @@ for idx, image_name in enumerate(image_names): # pose[: 3, 3] = -np.matmul(np.transpose(R), t) # pose[3, 3] = 1 poses.append(current_pose) - + current_points3D_id_to_ndepth = {} for point3D_id in points3D_id_to_2D[idx].keys(): p3d = points3D[point3D_id] - current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) + current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / ( + 0.5 * (K[0, 0] + K[1, 1]) + ) points3D_id_to_ndepth.append(current_points3D_id_to_ndepth) principal_axis = np.array(principal_axis) -angles = np.rad2deg(np.arccos( - np.clip( - np.dot(principal_axis, np.transpose(principal_axis)), - -1, 1 - ) -)) +angles = np.rad2deg( + np.arccos(np.clip(np.dot(principal_axis, np.transpose(principal_axis)), -1, 1)) +) # Compute overlap score -overlap_matrix = np.full([n_images, n_images], -1.) -scale_ratio_matrix = np.full([n_images, n_images], -1.) +overlap_matrix = np.full([n_images, n_images], -1.0) +scale_ratio_matrix = np.full([n_images, n_images], -1.0) for idx1 in range(n_images): if image_paths[idx1] is None or depth_paths[idx1] is None: continue for idx2 in range(idx1 + 1, n_images): if image_paths[idx2] is None or depth_paths[idx2] is None: continue - matches = ( - points3D_id_to_2D[idx1].keys() & - points3D_id_to_2D[idx2].keys() - ) + matches = points3D_id_to_2D[idx1].keys() & points3D_id_to_2D[idx2].keys() min_num_points3D = min( len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2]) ) - overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D - overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D + overlap_matrix[idx1, idx2] = len(matches) / len( + points3D_id_to_2D[idx1] + ) # min_num_points3D + overlap_matrix[idx2, idx1] = len(matches) / len( + points3D_id_to_2D[idx2] + ) # min_num_points3D if len(matches) == 0: continue points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1] @@ -228,7 +201,7 @@ for idx1 in range(n_images): scale_ratio_matrix[idx2, idx1] = min_scale_ratio np.savez( - os.path.join(args.output_path, '%s.npz' % scene_id), + os.path.join(args.output_path, "%s.npz" % scene_id), image_paths=image_paths, depth_paths=depth_paths, intrinsics=intrinsics, @@ -238,5 +211,5 @@ np.savez( angles=angles, n_points3D=n_points3D, points3D_id_to_2D=points3D_id_to_2D, - points3D_id_to_ndepth=points3D_id_to_ndepth -) \ No newline at end of file + points3D_id_to_ndepth=points3D_id_to_ndepth, +) diff --git a/third_party/ASpanFormer/tools/reader.py b/third_party/ASpanFormer/tools/reader.py index f419fbaa8a099fcfede1cea51fcf95a2c1589160..2734a7796ef8235bdbc1be317b6618f3d3185319 100644 --- a/third_party/ASpanFormer/tools/reader.py +++ b/third_party/ASpanFormer/tools/reader.py @@ -6,34 +6,45 @@ from SensorData import SensorData # params parser = argparse.ArgumentParser() # data paths -parser.add_argument('--filename', required=True, help='path to sens file to read') -parser.add_argument('--output_path', required=True, help='path to output folder') -parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') -parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') -parser.add_argument('--export_poses', dest='export_poses', action='store_true') -parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') -parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False) +parser.add_argument("--filename", required=True, help="path to sens file to read") +parser.add_argument("--output_path", required=True, help="path to output folder") +parser.add_argument( + "--export_depth_images", dest="export_depth_images", action="store_true" +) +parser.add_argument( + "--export_color_images", dest="export_color_images", action="store_true" +) +parser.add_argument("--export_poses", dest="export_poses", action="store_true") +parser.add_argument( + "--export_intrinsics", dest="export_intrinsics", action="store_true" +) +parser.set_defaults( + export_depth_images=False, + export_color_images=False, + export_poses=False, + export_intrinsics=False, +) opt = parser.parse_args() print(opt) def main(): - if not os.path.exists(opt.output_path): - os.makedirs(opt.output_path) - # load the data - sys.stdout.write('loading %s...' % opt.filename) - sd = SensorData(opt.filename) - sys.stdout.write('loaded!\n') - if opt.export_depth_images: - sd.export_depth_images(os.path.join(opt.output_path, 'depth')) - if opt.export_color_images: - sd.export_color_images(os.path.join(opt.output_path, 'color')) - if opt.export_poses: - sd.export_poses(os.path.join(opt.output_path, 'pose')) - if opt.export_intrinsics: - sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic')) + if not os.path.exists(opt.output_path): + os.makedirs(opt.output_path) + # load the data + sys.stdout.write("loading %s..." % opt.filename) + sd = SensorData(opt.filename) + sys.stdout.write("loaded!\n") + if opt.export_depth_images: + sd.export_depth_images(os.path.join(opt.output_path, "depth")) + if opt.export_color_images: + sd.export_color_images(os.path.join(opt.output_path, "color")) + if opt.export_poses: + sd.export_poses(os.path.join(opt.output_path, "pose")) + if opt.export_intrinsics: + sd.export_intrinsics(os.path.join(opt.output_path, "intrinsic")) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/third_party/ASpanFormer/tools/undistort_mega.py b/third_party/ASpanFormer/tools/undistort_mega.py index 68798ff30e6afa37a0f98571ecfd3f05751868c8..fcd5ff2d77cd45dc9e5cebc48d7a173e31e68caf 100644 --- a/third_party/ASpanFormer/tools/undistort_mega.py +++ b/third_party/ASpanFormer/tools/undistort_mega.py @@ -6,28 +6,20 @@ import os import subprocess -parser = argparse.ArgumentParser(description='MegaDepth Undistortion') +parser = argparse.ArgumentParser(description="MegaDepth Undistortion") parser.add_argument( - '--colmap_path', type=str,default='/usr/bin/', - help='path to colmap executable' + "--colmap_path", type=str, default="/usr/bin/", help="path to colmap executable" ) parser.add_argument( - '--base_path', type=str,default='/root/MegaDepth', - help='path to MegaDepth' + "--base_path", type=str, default="/root/MegaDepth", help="path to MegaDepth" ) args = parser.parse_args() -sfm_path = os.path.join( - args.base_path, 'MegaDepth_v1_SfM' -) -base_depth_path = os.path.join( - args.base_path, 'phoenix/S6/zl548/MegaDepth_v1' -) -output_path = os.path.join( - args.base_path, 'Undistorted_SfM' -) +sfm_path = os.path.join(args.base_path, "MegaDepth_v1_SfM") +base_depth_path = os.path.join(args.base_path, "phoenix/S6/zl548/MegaDepth_v1") +output_path = os.path.join(args.base_path, "Undistorted_SfM") os.mkdir(output_path) @@ -35,35 +27,45 @@ for scene_name in os.listdir(base_depth_path): current_output_path = os.path.join(output_path, scene_name) os.mkdir(current_output_path) - image_path = os.path.join( - base_depth_path, scene_name, 'dense0', 'imgs' - ) + image_path = os.path.join(base_depth_path, scene_name, "dense0", "imgs") if not os.path.exists(image_path): continue - + # Find the maximum image size in scene. max_image_size = 0 for image_name in os.listdir(image_path): max_image_size = max( - max_image_size, - max(imagesize.get(os.path.join(image_path, image_name))) + max_image_size, max(imagesize.get(os.path.join(image_path, image_name))) ) # Undistort the images and update the reconstruction. - subprocess.call([ - os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', - '--image_path', os.path.join(sfm_path, scene_name, 'images'), - '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'), - '--output_path', current_output_path, - '--max_image_size', str(max_image_size) - ]) + subprocess.call( + [ + os.path.join(args.colmap_path, "colmap"), + "image_undistorter", + "--image_path", + os.path.join(sfm_path, scene_name, "images"), + "--input_path", + os.path.join(sfm_path, scene_name, "sparse", "manhattan", "0"), + "--output_path", + current_output_path, + "--max_image_size", + str(max_image_size), + ] + ) # Transform the reconstruction to raw text format. - sparse_txt_path = os.path.join(current_output_path, 'sparse-txt') + sparse_txt_path = os.path.join(current_output_path, "sparse-txt") os.mkdir(sparse_txt_path) - subprocess.call([ - os.path.join(args.colmap_path, 'colmap'), 'model_converter', - '--input_path', os.path.join(current_output_path, 'sparse'), - '--output_path', sparse_txt_path, - '--output_type', 'TXT' - ]) \ No newline at end of file + subprocess.call( + [ + os.path.join(args.colmap_path, "colmap"), + "model_converter", + "--input_path", + os.path.join(current_output_path, "sparse"), + "--output_path", + sparse_txt_path, + "--output_type", + "TXT", + ] + ) diff --git a/third_party/ASpanFormer/train.py b/third_party/ASpanFormer/train.py index 21f644763711481e84863ed5d861ec57d95f2d5c..f1aeb79f630932b539500544d4249b1237d06605 100644 --- a/third_party/ASpanFormer/train.py +++ b/third_party/ASpanFormer/train.py @@ -23,41 +23,58 @@ loguru_logger = get_rank_zero_only_logger(loguru_logger) def parse_args(): def str2bool(v): return v.lower() in ("true", "1") + # init a costum parser which will be added into pl.Trainer parser # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - 'data_cfg_path', type=str, help='data config path') - parser.add_argument( - 'main_cfg_path', type=str, help='main config path') - parser.add_argument( - '--exp_name', type=str, default='default_exp_name') - parser.add_argument( - '--batch_size', type=int, default=4, help='batch_size per gpu') - parser.add_argument( - '--num_workers', type=int, default=4) + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("data_cfg_path", type=str, help="data config path") + parser.add_argument("main_cfg_path", type=str, help="main config path") + parser.add_argument("--exp_name", type=str, default="default_exp_name") + parser.add_argument("--batch_size", type=int, default=4, help="batch_size per gpu") + parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( - '--pin_memory', type=lambda x: bool(strtobool(x)), - nargs='?', default=True, help='whether loading data to pinned memory or not') + "--pin_memory", + type=lambda x: bool(strtobool(x)), + nargs="?", + default=True, + help="whether loading data to pinned memory or not", + ) parser.add_argument( - '--ckpt_path', type=str, default=None, - help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + "--ckpt_path", + type=str, + default=None, + help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer", + ) parser.add_argument( - '--disable_ckpt', action='store_true', - help='disable checkpoint saving (useful for debugging).') + "--disable_ckpt", + action="store_true", + help="disable checkpoint saving (useful for debugging).", + ) parser.add_argument( - '--profiler_name', type=str, default=None, - help='options: [inference, pytorch], or leave it unset') + "--profiler_name", + type=str, + default=None, + help="options: [inference, pytorch], or leave it unset", + ) parser.add_argument( - '--parallel_load_data', action='store_true', - help='load datasets in with multiple processes.') + "--parallel_load_data", + action="store_true", + help="load datasets in with multiple processes.", + ) parser.add_argument( - '--mode', type=str, default='vanilla', - help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + "--mode", + type=str, + default="vanilla", + help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer", + ) parser.add_argument( - '--ini', type=str2bool, default=False, - help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + "--ini", + type=str2bool, + default=False, + help="pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer", + ) parser = pl.Trainer.add_argparse_args(parser) return parser.parse_args() @@ -83,8 +100,7 @@ def main(): _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS config.TRAINER.SCALING = _scaling config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling - config.TRAINER.WARMUP_STEP = math.floor( - config.TRAINER.WARMUP_STEP / _scaling) + config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) # lightning module profiler = build_profiler(args.profiler_name) @@ -97,16 +113,22 @@ def main(): # TensorBoard Logger logger = TensorBoardLogger( - save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) - ckpt_dir = Path(logger.log_dir) / 'checkpoints' + save_dir="logs/tb_logs", name=args.exp_name, default_hp_metric=False + ) + ckpt_dir = Path(logger.log_dir) / "checkpoints" # Callbacks # TODO: update ModelCheckpoint to monitor multiple metrics - ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max', - save_last=True, - dirpath=str(ckpt_dir), - filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') - lr_monitor = LearningRateMonitor(logging_interval='step') + ckpt_callback = ModelCheckpoint( + monitor="auc@10", + verbose=True, + save_top_k=5, + mode="max", + save_last=True, + dirpath=str(ckpt_dir), + filename="{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}", + ) + lr_monitor = LearningRateMonitor(logging_interval="step") callbacks = [lr_monitor] if not args.disable_ckpt: callbacks.append(ckpt_callback) @@ -114,21 +136,24 @@ def main(): # Lightning Trainer trainer = pl.Trainer.from_argparse_args( args, - plugins=DDPPlugin(find_unused_parameters=False, - num_nodes=args.num_nodes, - sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + plugins=DDPPlugin( + find_unused_parameters=False, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + ), gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, callbacks=callbacks, logger=logger, sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, replace_sampler_ddp=False, # use custom sampler reload_dataloaders_every_epoch=False, # avoid repeated samples! - weights_summary='full', - profiler=profiler) + weights_summary="full", + profiler=profiler, + ) loguru_logger.info(f"Trainer initialized!") loguru_logger.info(f"Start training!") trainer.fit(model, datamodule=data_module) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/third_party/DKM/demo/demo_fundamental.py b/third_party/DKM/demo/demo_fundamental.py index e19766d5d3ce1abf0d18483cbbce71b2696983be..643ae3d62d3d4a09d1eb6f7b351ea23f2095b725 100644 --- a/third_party/DKM/demo/demo_fundamental.py +++ b/third_party/DKM/demo/demo_fundamental.py @@ -6,11 +6,12 @@ from dkm.utils.utils import tensor_to_pil import cv2 from dkm import DKMv3_outdoor -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == "__main__": from argparse import ArgumentParser + parser = ArgumentParser() parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) @@ -22,7 +23,6 @@ if __name__ == "__main__": # Create model dkm_model = DKMv3_outdoor(device=device) - W_A, H_A = Image.open(im1_path).size W_B, H_B = Image.open(im2_path).size @@ -30,8 +30,13 @@ if __name__ == "__main__": warp, certainty = dkm_model.match(im1_path, im2_path, device=device) # Sample matches for estimation matches, certainty = dkm_model.sample(warp, certainty) - kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) F, mask = cv2.findFundamentalMat( - kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + ransacReprojThreshold=0.2, + method=cv2.USAC_MAGSAC, + confidence=0.999999, + maxIters=10000, ) - # TODO: some better visualization \ No newline at end of file + # TODO: some better visualization diff --git a/third_party/DKM/demo/demo_match.py b/third_party/DKM/demo/demo_match.py index fb901894d8654a884819162d3b9bb8094529e034..aef324e1b19a76498dc0476714149534546e0218 100644 --- a/third_party/DKM/demo/demo_match.py +++ b/third_party/DKM/demo/demo_match.py @@ -6,15 +6,18 @@ from dkm.utils.utils import tensor_to_pil from dkm import DKMv3_outdoor -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == "__main__": from argparse import ArgumentParser + parser = ArgumentParser() parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) - parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str) + parser.add_argument( + "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str + ) args, _ = parser.parse_known_args() im1_path = args.im_A_path @@ -37,12 +40,12 @@ if __name__ == "__main__": x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) im2_transfer_rgb = F.grid_sample( - x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False )[0] im1_transfer_rgb = F.grid_sample( - x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False )[0] - warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) - white_im = torch.ones((H,2*W),device=device) + warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2) + white_im = torch.ones((H, 2 * W), device=device) vis_im = certainty * warp_im + (1 - certainty) * white_im tensor_to_pil(vis_im, unnormalize=False).save(save_path) diff --git a/third_party/DKM/dkm/__init__.py b/third_party/DKM/dkm/__init__.py index a9b47632780acc7762bcccc348e2025fe99f3726..27099047d713e61a103bd0f439f292245ad720a3 100644 --- a/third_party/DKM/dkm/__init__.py +++ b/third_party/DKM/dkm/__init__.py @@ -1,4 +1,4 @@ from .models import ( DKMv3_outdoor, DKMv3_indoor, - ) +) diff --git a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py index 9c3febe5ca9e3a683bc7122cec635c4f54b66f7c..719e298726528754c3f826d6d2f2fe2ce9b3b903 100644 --- a/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py +++ b/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py @@ -53,7 +53,7 @@ class HpatchesHomogBenchmark: ) return query_coords, query_to_support - def benchmark(self, model, model_name = None): + def benchmark(self, model, model_name=None): n_matches = [] homog_dists = [] for seq_idx, seq_name in tqdm( @@ -71,9 +71,7 @@ class HpatchesHomogBenchmark: H = np.loadtxt( os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) ) - dense_matches, dense_certainty = model.match( - im1_path, im2_path - ) + dense_matches, dense_certainty = model.match(im1_path, im2_path) good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) pos_a, pos_b = self.convert_coordinates( good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 @@ -82,9 +80,9 @@ class HpatchesHomogBenchmark: H_pred, inliers = cv2.findHomography( pos_a, pos_b, - method = cv2.RANSAC, - confidence = 0.99999, - ransacReprojThreshold = 3 * min(w2, h2) / 480, + method=cv2.RANSAC, + confidence=0.99999, + ransacReprojThreshold=3 * min(w2, h2) / 480, ) except: H_pred = None diff --git a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py index 6b1193745ff18d239165aeb3376642fb17033874..d9499f1e92fd4df3ad6fe59c37b6c881d5322a51 100644 --- a/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py +++ b/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py @@ -5,8 +5,9 @@ from PIL import Image from tqdm import tqdm import torch.nn.functional as F + class Megadepth1500Benchmark: - def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + def __init__(self, data_root="data/megadepth", scene_names=None) -> None: if scene_names is None: self.scene_names = [ "0015_0.1_0.3.npz", @@ -56,28 +57,24 @@ class Megadepth1500Benchmark: K1[:2] = K1[:2] * scale1 K2[:2] = K2[:2] * scale2 dense_matches, dense_certainty = model.match(im1_path, im2_path) - sparse_matches,_ = model.sample( + sparse_matches, _ = model.sample( dense_matches, dense_certainty, 5000 ) kpts1 = sparse_matches[:, :2] - kpts1 = ( - torch.stack( - ( - w1 * (kpts1[:, 0] + 1) / 2, - h1 * (kpts1[:, 1] + 1) / 2, - ), - axis=-1, - ) + kpts1 = torch.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, ) kpts2 = sparse_matches[:, 2:] - kpts2 = ( - torch.stack( - ( - w2 * (kpts2[:, 0] + 1) / 2, - h2 * (kpts2[:, 1] + 1) / 2, - ), - axis=-1, - ) + kpts2 = torch.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, ) for _ in range(5): shuffling = np.random.permutation(np.arange(len(kpts1))) @@ -85,7 +82,9 @@ class Megadepth1500Benchmark: kpts2 = kpts2[shuffling] try: norm_threshold = 0.5 / ( - np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + np.mean(np.abs(K1[:2, :2])) + + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1.cpu().numpy(), kpts2.cpu().numpy(), diff --git a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py index 0b370644497efd62563105e68e692e10ff339669..5e8d597760a82349d043055f5ca867f1f79fc55a 100644 --- a/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py +++ b/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py @@ -7,14 +7,16 @@ from torch.utils.data import ConcatDataset class MegadepthDenseBenchmark: - def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None: + def __init__( + self, data_root="data/megadepth", h=384, w=512, num_samples=2000, device=None + ) -> None: mega = MegadepthBuilder(data_root=data_root) self.dataset = ConcatDataset( mega.build_scenes(split="test_loftr", ht=h, wt=w) ) # fixed resolution of 384,512 self.num_samples = num_samples if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): @@ -54,7 +56,9 @@ class MegadepthDenseBenchmark: pck_3_tot = 0.0 pck_5_tot = 0.0 sampler = torch.utils.data.WeightedRandomSampler( - torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + torch.ones(len(self.dataset)), + replacement=False, + num_samples=self.num_samples, ) dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler diff --git a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py index ca938cb462c351845ce035f8be0714cf81214452..1ad659f887d3863812a368dcb210fbd7bbadb04e 100644 --- a/third_party/DKM/dkm/benchmarks/scannet_benchmark.py +++ b/third_party/DKM/dkm/benchmarks/scannet_benchmark.py @@ -10,7 +10,7 @@ class ScanNetBenchmark: def __init__(self, data_root="data/scannet") -> None: self.data_root = data_root - def benchmark(self, model, model_name = None): + def benchmark(self, model, model_name=None): model.train(False) with torch.no_grad(): data_root = self.data_root @@ -24,20 +24,20 @@ class ScanNetBenchmark: scene = pairs[pairind] scene_name = f"scene0{scene[0]}_00" im1_path = osp.join( - self.data_root, - "scans_test", - scene_name, - "color", - f"{scene[2]}.jpg", - ) + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) im1 = Image.open(im1_path) im2_path = osp.join( - self.data_root, - "scans_test", - scene_name, - "color", - f"{scene[3]}.jpg", - ) + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) im2 = Image.open(im2_path) T_gt = rel_pose[pairind].reshape(3, 4) R, t = T_gt[:3, :3], T_gt[:3, 3] @@ -76,24 +76,20 @@ class ScanNetBenchmark: offset = 0.5 kpts1 = sparse_matches[:, :2] - kpts1 = ( - np.stack( - ( - w1 * (kpts1[:, 0] + 1) / 2 - offset, - h1 * (kpts1[:, 1] + 1) / 2 - offset, - ), - axis=-1, - ) + kpts1 = np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, ) kpts2 = sparse_matches[:, 2:] - kpts2 = ( - np.stack( - ( - w2 * (kpts2[:, 0] + 1) / 2 - offset, - h2 * (kpts2[:, 1] + 1) / 2 - offset, - ), - axis=-1, - ) + kpts2 = np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, ) for _ in range(5): shuffling = np.random.permutation(np.arange(len(kpts1))) @@ -101,7 +97,8 @@ class ScanNetBenchmark: kpts2 = kpts2[shuffling] try: norm_threshold = 0.5 / ( - np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1, kpts2, diff --git a/third_party/DKM/dkm/datasets/scannet.py b/third_party/DKM/dkm/datasets/scannet.py index 6ac39b41480f7585c4755cc30e0677ef74ed5e0c..fc24263c771f5fbb5d1e676257e9ad484a03ae31 100644 --- a/third_party/DKM/dkm/datasets/scannet.py +++ b/third_party/DKM/dkm/datasets/scannet.py @@ -5,10 +5,7 @@ import cv2 import h5py import numpy as np import torch -from torch.utils.data import ( - Dataset, - DataLoader, - ConcatDataset) +from torch.utils.data import Dataset, DataLoader, ConcatDataset import torchvision.transforms.functional as tvf import kornia.augmentation as K @@ -19,21 +16,35 @@ from dkm.utils.transforms import GeometricSequential from tqdm import tqdm + class ScanNetScene: - def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None: - self.scene_root = osp.join(data_root,"scans","scans_train") - self.data_names = scene_info['name'] - self.overlaps = scene_info['score'] + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + shake_t=0, + rot_prob=0.0, + ) -> None: + self.scene_root = osp.join(data_root, "scans", "scans_train") + self.data_names = scene_info["name"] + self.overlaps = scene_info["score"] # Only sample 10s - valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0 self.overlaps = self.overlaps[valid] self.data_names = self.data_names[valid] if len(self.data_names) > 10000: - pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + pairinds = np.random.choice( + np.arange(0, len(self.data_names)), 10000, replace=False + ) self.data_names = self.data_names[pairinds] self.overlaps = self.overlaps[pairinds] self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) - self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt), normalize=False + ) self.wt, self.ht = wt, ht self.shake_t = shake_t self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) @@ -41,7 +52,7 @@ class ScanNetScene: def load_im(self, im_ref, crop=None): im = Image.open(im_ref) return im - + def load_depth(self, depth_ref, crop=None): depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) depth = depth / 1000 @@ -50,55 +61,61 @@ class ScanNetScene: def __len__(self): return len(self.data_names) - + def scale_intrinsic(self, K, wi, hi): - sx, sy = self.wt / wi, self.ht / hi - sK = torch.tensor([[sx, 0, 0], - [0, sy, 0], - [0, 0, 1]]) - return sK@K - - def read_scannet_pose(self,path): - """ Read ScanNet's Camera2World pose and transform it to World2Camera. - + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def read_scannet_pose(self, path): + """Read ScanNet's Camera2World pose and transform it to World2Camera. + Returns: pose_w2c (np.ndarray): (4, 4) """ - cam2world = np.loadtxt(path, delimiter=' ') + cam2world = np.loadtxt(path, delimiter=" ") world2cam = np.linalg.inv(cam2world) return world2cam - - def read_scannet_intrinsic(self,path): - """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. - """ - intrinsic = np.loadtxt(path, delimiter=' ') + def read_scannet_intrinsic(self, path): + """Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" + intrinsic = np.loadtxt(path, delimiter=" ") return intrinsic[:-1, :-1] def __getitem__(self, pair_idx): # read intrinsics of original size data_name = self.data_names[pair_idx] scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name - scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' - + scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" + # read the intrinsic of depthmap - K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, - scene_name, - 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + K1 = K2 = self.read_scannet_intrinsic( + osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt") + ) # the depth K is not the same, but doesnt really matter # read and compute relative poses - T1 = self.read_scannet_pose(osp.join(self.scene_root, - scene_name, - 'pose', f'{stem_name_1}.txt')) - T2 = self.read_scannet_pose(osp.join(self.scene_root, - scene_name, - 'pose', f'{stem_name_2}.txt')) - T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + T1 = self.read_scannet_pose( + osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt") + ) + T2 = self.read_scannet_pose( + osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt") + ) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) # Load positive pair data - im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') - im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') - depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') - depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + im_src_ref = os.path.join( + self.scene_root, scene_name, "color", f"{stem_name_1}.jpg" + ) + im_pos_ref = os.path.join( + self.scene_root, scene_name, "color", f"{stem_name_2}.jpg" + ) + depth_src_ref = os.path.join( + self.scene_root, scene_name, "depth", f"{stem_name_1}.png" + ) + depth_pos_ref = os.path.join( + self.scene_root, scene_name, "depth", f"{stem_name_2}.png" + ) im_src = self.load_im(im_src_ref) im_pos = self.load_im(im_pos_ref) @@ -110,42 +127,53 @@ class ScanNetScene: K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) # Process images im_src, im_pos = self.im_transform_ops((im_src, im_pos)) - depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None])) - - data_dict = {'query': im_src, - 'support': im_pos, - 'query_depth': depth_src[0,0], - 'support_depth': depth_pos[0,0], - 'K1': K1, - 'K2': K2, - 'T_1to2':T_1to2, - } + depth_src, depth_pos = self.depth_transform_ops( + (depth_src[None, None], depth_pos[None, None]) + ) + + data_dict = { + "query": im_src, + "support": im_pos, + "query_depth": depth_src[0, 0], + "support_depth": depth_pos[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + } return data_dict class ScanNetBuilder: - def __init__(self, data_root = 'data/scannet') -> None: + def __init__(self, data_root="data/scannet") -> None: self.data_root = data_root - self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.scene_info_root = os.path.join(data_root, "scannet_indices") self.all_scenes = os.listdir(self.scene_info_root) - - def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + + def build_scenes(self, split="train", min_overlap=0.0, **kwargs): # Note: split doesn't matter here as we always use same scannet_train scenes scene_names = self.all_scenes scenes = [] for scene_name in tqdm(scene_names): - scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) - scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ) + scenes.append( + ScanNetScene( + self.data_root, scene_info, min_overlap=min_overlap, **kwargs + ) + ) return scenes - - def weight_scenes(self, concat_dataset, alpha=.5): + + def weight_scenes(self, concat_dataset, alpha=0.5): ns = [] for d in concat_dataset.datasets: ns.append(len(d)) - ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) return ws if __name__ == "__main__": - mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train')) - mega_test[0] \ No newline at end of file + mega_test = ConcatDataset( + ScanNetBuilder("data/scannet").build_scenes(split="train") + ) + mega_test[0] diff --git a/third_party/DKM/dkm/models/deprecated/build_model.py b/third_party/DKM/dkm/models/deprecated/build_model.py index dd28335f3e348ab6c90b26ba91b95e864b0bbbb9..6b4f6608296c21387b19242681e6e49160c0887e 100644 --- a/third_party/DKM/dkm/models/deprecated/build_model.py +++ b/third_party/DKM/dkm/models/deprecated/build_model.py @@ -10,16 +10,16 @@ dkm_pretrained_urls = { "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth", "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth", }, - "DKMv2":{ + "DKMv2": { "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth", "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth", - } + }, } def DKM(pretrained=True, version="mega_synthetic", device=None): if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") gp_dim = 256 dfn_dim = 384 feat_dim = 256 @@ -150,7 +150,8 @@ def DKM(pretrained=True, version="mega_synthetic", device=None): matcher.load_state_dict(weights) return matcher -def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): + +def DKMv2(pretrained=True, version="outdoor", resolution="low", **kwargs): gp_dim = 256 dfn_dim = 384 feat_dim = 256 @@ -200,8 +201,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): conv_refiner = nn.ModuleDict( { "16": ConvRefiner( - 2 * 512+128, - 1024+128, + 2 * 512 + 128, + 1024 + 128, 3, kernel_size=kernel_size, dw=dw, @@ -210,8 +211,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): displacement_emb_dim=128, ), "8": ConvRefiner( - 2 * 512+64, - 1024+64, + 2 * 512 + 64, + 1024 + 64, 3, kernel_size=kernel_size, dw=dw, @@ -220,8 +221,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): displacement_emb_dim=64, ), "4": ConvRefiner( - 2 * 256+32, - 512+32, + 2 * 256 + 32, + 512 + 32, 3, kernel_size=kernel_size, dw=dw, @@ -230,8 +231,8 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): displacement_emb_dim=32, ), "2": ConvRefiner( - 2 * 64+16, - 128+16, + 2 * 64 + 16, + 128 + 16, 3, kernel_size=kernel_size, dw=dw, @@ -240,7 +241,7 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): displacement_emb_dim=16, ), "1": ConvRefiner( - 2 * 3+6, + 2 * 3 + 6, 24, 3, kernel_size=kernel_size, @@ -287,16 +288,14 @@ def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): encoder = Encoder( tv_resnet.resnet50(pretrained=not pretrained), ) # only load pretrained weights if not loading a pretrained matcher ;) - matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, **kwargs).to(device) if pretrained: try: weights = torch.hub.load_state_dict_from_url( dkm_pretrained_urls["DKMv2"][version] ) except: - weights = torch.load( - dkm_pretrained_urls["DKMv2"][version] - ) + weights = torch.load(dkm_pretrained_urls["DKMv2"][version]) matcher.load_state_dict(weights) return matcher diff --git a/third_party/DKM/dkm/models/deprecated/local_corr.py b/third_party/DKM/dkm/models/deprecated/local_corr.py index 681fe4c0079561fa7a4c44e82a8879a4a27273a1..227d73b00be7efd7f64c32936b3dcdd7e5b4d123 100644 --- a/third_party/DKM/dkm/models/deprecated/local_corr.py +++ b/third_party/DKM/dkm/models/deprecated/local_corr.py @@ -10,8 +10,8 @@ from ..dkm import ConvRefiner class Stream: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if device == 'cuda': + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "cuda": stream = torch.cuda.current_stream(device=device).cuda_stream else: stream = None @@ -622,7 +622,7 @@ class LocalCorr(ConvRefiner): if __name__ == "__main__": - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = torch.randn(2, 128, 32, 32).to(device) y = torch.randn(2, 128, 32, 32).to(device) local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4) diff --git a/third_party/DKM/dkm/models/dkm.py b/third_party/DKM/dkm/models/dkm.py index 27c3f6d59ad3a8e976e3d719868908ddf443883e..58462e5d14cf9cac6e1fa551298f9fc82f93fcab 100644 --- a/third_party/DKM/dkm/models/dkm.py +++ b/third_party/DKM/dkm/models/dkm.py @@ -19,11 +19,11 @@ class ConvRefiner(nn.Module): dw=False, kernel_size=5, hidden_blocks=3, - displacement_emb = None, - displacement_emb_dim = None, - local_corr_radius = None, - corr_in_other = None, - no_support_fm = False, + displacement_emb=None, + displacement_emb_dim=None, + local_corr_radius=None, + corr_in_other=None, + no_support_fm=False, ): super().__init__() self.block1 = self.create_block( @@ -43,12 +43,13 @@ class ConvRefiner(nn.Module): self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) if displacement_emb: self.has_displacement_emb = True - self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) else: self.has_displacement_emb = False self.local_corr_radius = local_corr_radius self.corr_in_other = corr_in_other self.no_support_fm = no_support_fm + def create_block( self, in_dim, @@ -86,29 +87,35 @@ class ConvRefiner(nn.Module): [type]: [description] """ device = x.device - b,c,hs,ws = x.shape + b, c, hs, ws = x.shape with torch.no_grad(): x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) if self.has_displacement_emb: query_coords = torch.meshgrid( - ( - torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), - torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), - ) + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) ) query_coords = torch.stack((query_coords[1], query_coords[0])) query_coords = query_coords[None].expand(b, 2, hs, ws) - in_displacement = flow-query_coords + in_displacement = flow - query_coords emb_in_displacement = self.disp_emb(in_displacement) if self.local_corr_radius: - #TODO: should corr have gradient? + # TODO: should corr have gradient? if self.corr_in_other: # Corr in other means take a kxk grid around the predicted coordinate in other image - local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow) + local_corr = local_correlation( + x, y, local_radius=self.local_corr_radius, flow=flow + ) else: # Otherwise we use the warp to sample in the first image # This is actually different operations, especially for large viewpoint changes - local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,) + local_corr = local_correlation( + x, + x_hat, + local_radius=self.local_corr_radius, + ) if self.no_support_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) @@ -269,7 +276,7 @@ class GP(nn.Module): only_nearest_neighbour=False, sigma_noise=0.1, no_cov=False, - predict_features = False, + predict_features=False, ): super().__init__() self.K = kernel(T=T, learn_temperature=learn_temperature) @@ -344,9 +351,9 @@ class GP(nn.Module): b, c, h2, w2 = y.shape f = self.get_pos_enc(y) if self.predict_features: - f = f + y[:,:self.dim] # Stupid way to predict features + f = f + y[:, : self.dim] # Stupid way to predict features b, d, h2, w2 = f.shape - #assert x.shape == y.shape + # assert x.shape == y.shape x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) K_xx = self.K(x, x) K_yy = self.K(y, y) @@ -355,7 +362,12 @@ class GP(nn.Module): sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large if len(K_yy[0]) > 2000: - K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)]) + K_yy_inv = torch.cat( + [ + torch.linalg.inv(K_yy[k : k + 1] + sigma_noise[k : k + 1]) + for k in range(b) + ] + ) else: K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) @@ -363,7 +375,9 @@ class GP(nn.Module): mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) if not self.no_cov: cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) - cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + cov_x = rearrange( + cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1 + ) local_cov_x = self.get_local_cov(cov_x) local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") gp_feats = torch.cat((mu_x, local_cov_x), dim=1) @@ -376,6 +390,7 @@ class Encoder(nn.Module): def __init__(self, resnet): super().__init__() self.resnet = resnet + def forward(self, x): x0 = x b, c, h, w = x.shape @@ -404,7 +419,15 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__( - self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None, + self, + embedding_decoder, + gps, + proj, + conv_refiner, + transformers=None, + detach=False, + scales="all", + pos_embeddings=None, ): super().__init__() self.embedding_decoder = embedding_decoder @@ -424,17 +447,15 @@ class Decoder(nn.Module): certainty = F.interpolate( certainty, size=(h, w), align_corners=False, mode="bilinear" ) - flow = F.interpolate( - flow, size=(h, w), align_corners=False, mode="bilinear" - ) + flow = F.interpolate(flow, size=(h, w), align_corners=False, mode="bilinear") delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) flow = torch.stack( - ( - flow[:, 0] + delta_flow[:, 0] / (4 * w), - flow[:, 1] + delta_flow[:, 1] / (4 * h), - ), - dim=1, - ) + ( + flow[:, 0] + delta_flow[:, 0] / (4 * w), + flow[:, 1] + delta_flow[:, 1] / (4 * h), + ), + dim=1, + ) flow = flow.permute(0, 2, 3, 1) certainty = certainty + delta_certainty return flow, certainty @@ -452,8 +473,7 @@ class Decoder(nn.Module): coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") return coarse_coords - - def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None): + def forward(self, f1, f2, upsample=False, dense_flow=None, dense_certainty=None): coarse_scales = self.embedding_decoder.scales() all_scales = self.scales if not upsample else ["8", "4", "2", "1"] sizes = {scale: f1[scale].shape[-2:] for scale in f1} @@ -462,7 +482,10 @@ class Decoder(nn.Module): device = f1[1].device coarsest_scale = int(all_scales[0]) old_stuff = torch.zeros( - b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + b, + self.embedding_decoder.internal_dim, + *sizes[coarsest_scale], + device=f1[coarsest_scale].device ) dense_corresps = {} if not upsample: @@ -470,17 +493,17 @@ class Decoder(nn.Module): dense_certainty = 0.0 else: dense_flow = F.interpolate( - dense_flow, - size=sizes[coarsest_scale], - align_corners=False, - mode="bilinear", - ) + dense_flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) dense_certainty = F.interpolate( - dense_certainty, - size=sizes[coarsest_scale], - align_corners=False, - mode="bilinear", - ) + dense_certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) for new_scale in all_scales: ins = int(new_scale) f1_s, f2_s = f1[ins], f2[ins] @@ -543,14 +566,14 @@ class RegressionMatcher(nn.Module): decoder, h=384, w=512, - use_contrastive_loss = False, - alpha = 1, - beta = 0, - sample_mode = "threshold", - upsample_preds = False, - symmetric = False, - name = None, - use_soft_mutual_nearest_neighbours = False, + use_contrastive_loss=False, + alpha=1, + beta=0, + sample_mode="threshold", + upsample_preds=False, + symmetric=False, + name=None, + use_soft_mutual_nearest_neighbours=False, ): super().__init__() self.encoder = encoder @@ -566,13 +589,13 @@ class RegressionMatcher(nn.Module): self.symmetric = symmetric self.name = name self.sample_thresh = 0.05 - self.upsample_res = (864,1152) + self.upsample_res = (864, 1152) if use_soft_mutual_nearest_neighbours: assert symmetric, "MNS requires symmetric inference" self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours - - def extract_backbone_features(self, batch, batched = True, upsample = True): - #TODO: only extract stride [1,2,4,8] for upsample = True + + def extract_backbone_features(self, batch, batched=True, upsample=True): + # TODO: only extract stride [1,2,4,8] for upsample = True x_q = batch["query"] x_s = batch["support"] if batched: @@ -593,7 +616,7 @@ class RegressionMatcher(nn.Module): dense_certainty = dense_certainty.clone() dense_certainty[dense_certainty > upper_thresh] = 1 elif "pow" in self.sample_mode: - dense_certainty = dense_certainty**(1/3) + dense_certainty = dense_certainty ** (1 / 3) elif "naive" in self.sample_mode: dense_certainty = torch.ones_like(dense_certainty) matches, certainty = ( @@ -601,23 +624,28 @@ class RegressionMatcher(nn.Module): dense_certainty.reshape(-1), ) expansion_factor = 4 if "balanced" in self.sample_mode else 1 - good_samples = torch.multinomial(certainty, - num_samples = min(expansion_factor*num, len(certainty)), - replacement=False) + good_samples = torch.multinomial( + certainty, + num_samples=min(expansion_factor * num, len(certainty)), + replacement=False, + ) good_matches, good_certainty = matches[good_samples], certainty[good_samples] if "balanced" not in self.sample_mode: return good_matches, good_certainty from ..utils.kde import kde + density = kde(good_matches, std=0.1) - p = 1 / (density+1) - p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones - balanced_samples = torch.multinomial(p, - num_samples = min(num,len(good_certainty)), - replacement=False) + p = 1 / (density + 1) + p[ + density < 10 + ] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial( + p, num_samples=min(num, len(good_certainty)), replacement=False + ) return good_matches[balanced_samples], good_certainty[balanced_samples] - def forward(self, batch, batched = True): + def forward(self, batch, batched=True): feature_pyramid = self.extract_backbone_features(batch, batched=batched) if batched: f_q_pyramid = { @@ -634,37 +662,43 @@ class RegressionMatcher(nn.Module): else: return dense_corresps - def forward_symmetric(self, batch, upsample = False, batched = True): - feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched) + def forward_symmetric(self, batch, upsample=False, batched=True): + feature_pyramid = self.extract_backbone_features( + batch, upsample=upsample, batched=batched + ) f_q_pyramid = feature_pyramid f_s_pyramid = { scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) for scale, f_scale in feature_pyramid.items() } - dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {})) + dense_corresps = self.decoder( + f_q_pyramid, + f_s_pyramid, + upsample=upsample, + **(batch["corresps"] if "corresps" in batch else {}) + ) return dense_corresps - + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): - kpts_A, kpts_B = matches[...,:2], matches[...,2:] - kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) - kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + kpts_A, kpts_B = matches[..., :2], matches[..., 2:] + kpts_A = torch.stack( + (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1 + ) + kpts_B = torch.stack( + (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1 + ) return kpts_A, kpts_B - - def match( - self, - im1_path, - im2_path, - *args, - batched=False, - device = None - ): - assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " + + def match(self, im1_path, im2_path, *args, batched=False, device=None): + assert not ( + batched and self.upsample_preds + ), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " if isinstance(im1_path, (str, os.PathLike)): im1, im2 = Image.open(im1_path), Image.open(im2_path) - else: # assume it is a PIL Image + else: # assume it is a PIL Image im1, im2 = im1_path, im2_path if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") symmetric = self.symmetric self.train(False) with torch.no_grad(): @@ -680,7 +714,10 @@ class RegressionMatcher(nn.Module): resize=(hs, ws), normalize=True ) query, support = test_transform((im1, im2)) - batch = {"query": query[None].to(device), "support": support[None].to(device)} + batch = { + "query": query[None].to(device), + "support": support[None].to(device), + } else: b, c, h, w = im1.shape b, c, h2, w2 = im2.shape @@ -690,38 +727,47 @@ class RegressionMatcher(nn.Module): finest_scale = 1 # Run matcher if symmetric: - dense_corresps = self.forward_symmetric(batch, batched = True) + dense_corresps = self.forward_symmetric(batch, batched=True) else: - dense_corresps = self.forward(batch, batched = True) - + dense_corresps = self.forward(batch, batched=True) + if self.upsample_preds: hs, ws = self.upsample_res low_res_certainty = F.interpolate( - dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + dense_corresps[16]["dense_certainty"], + size=(hs, ws), + align_corners=False, + mode="bilinear", ) cert_clamp = 0 factor = 0.5 - low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + low_res_certainty = ( + factor * low_res_certainty * (low_res_certainty < cert_clamp) + ) - if self.upsample_preds: + if self.upsample_preds: test_transform = get_tuple_transform_ops( resize=(hs, ws), normalize=True ) query, support = test_transform((im1, im2)) query, support = query[None].to(device), support[None].to(device) - batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]} + batch = { + "query": query, + "support": support, + "corresps": dense_corresps[finest_scale], + } if symmetric: - dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True) + dense_corresps = self.forward_symmetric( + batch, upsample=True, batched=True + ) else: - dense_corresps = self.forward(batch, batched = True, upsample=True) + dense_corresps = self.forward(batch, batched=True, upsample=True) query_to_support = dense_corresps[finest_scale]["dense_flow"] dense_certainty = dense_corresps[finest_scale]["dense_certainty"] - + # Get certainty interpolation dense_certainty = dense_certainty - low_res_certainty - query_to_support = query_to_support.permute( - 0, 2, 3, 1 - ) + query_to_support = query_to_support.permute(0, 2, 3, 1) # Create im1 meshgrid query_coords = torch.meshgrid( ( @@ -735,23 +781,20 @@ class RegressionMatcher(nn.Module): query_coords = query_coords.permute(0, 2, 3, 1) if (query_to_support.abs() > 1).any() and True: wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 - dense_certainty[wrong[:,None]] = 0 - + dense_certainty[wrong[:, None]] = 0 + query_to_support = torch.clamp(query_to_support, -1, 1) if symmetric: support_coords = query_coords - qts, stq = query_to_support.chunk(2) + qts, stq = query_to_support.chunk(2) q_warp = torch.cat((query_coords, qts), dim=-1) s_warp = torch.cat((stq, support_coords), dim=-1) - warp = torch.cat((q_warp, s_warp),dim=2) - dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0] + warp = torch.cat((q_warp, s_warp), dim=2) + dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:, 0] else: warp = torch.cat((query_coords, query_to_support), dim=-1) if batched: - return ( - warp, - dense_certainty - ) + return (warp, dense_certainty) else: return ( warp[0], diff --git a/third_party/DKM/dkm/models/encoders.py b/third_party/DKM/dkm/models/encoders.py index 29077e1797196611e9b59a753130a5b153e0aa05..29fe93443933cf7bbf5c542d8732aabc8c771604 100644 --- a/third_party/DKM/dkm/models/encoders.py +++ b/third_party/DKM/dkm/models/encoders.py @@ -3,10 +3,12 @@ import torch.nn as nn import torch.nn.functional as F import torchvision.models as tvm + class ResNet18(nn.Module): def __init__(self, pretrained=False) -> None: super().__init__() self.net = tvm.resnet18(pretrained=pretrained) + def forward(self, x): self = self.net x1 = x @@ -18,7 +20,7 @@ class ResNet18(nn.Module): x8 = self.layer2(x4) x16 = self.layer3(x8) x32 = self.layer4(x16) - return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1} + return {32: x32, 16: x16, 8: x8, 4: x4, 2: x2, 1: x1} def train(self, mode=True): super().train(mode) @@ -27,33 +29,47 @@ class ResNet18(nn.Module): m.eval() pass + class ResNet50(nn.Module): - def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None: + def __init__( + self, + pretrained=False, + high_res=False, + weights=None, + dilation=None, + freeze_bn=True, + anti_aliased=False, + ) -> None: super().__init__() if dilation is None: - dilation = [False,False,False] + dilation = [False, False, False] if anti_aliased: pass else: if weights is not None: - self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + self.net = tvm.resnet50( + weights=weights, replace_stride_with_dilation=dilation + ) else: - self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) - + self.net = tvm.resnet50( + pretrained=pretrained, replace_stride_with_dilation=dilation + ) + self.high_res = high_res self.freeze_bn = freeze_bn + def forward(self, x): net = self.net - feats = {1:x} + feats = {1: x} x = net.conv1(x) x = net.bn1(x) x = net.relu(x) - feats[2] = x + feats[2] = x x = net.maxpool(x) x = net.layer1(x) - feats[4] = x + feats[4] = x x = net.layer2(x) - feats[8] = x + feats[8] = x x = net.layer3(x) feats[16] = x x = net.layer4(x) @@ -69,36 +85,65 @@ class ResNet50(nn.Module): pass - - class ResNet101(nn.Module): - def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + def __init__(self, pretrained=False, high_res=False, weights=None) -> None: super().__init__() if weights is not None: - self.net = tvm.resnet101(weights = weights) + self.net = tvm.resnet101(weights=weights) else: self.net = tvm.resnet101(pretrained=pretrained) self.high_res = high_res self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): net = self.net - feats = {1:x} + feats = {1: x} sf = self.scale_factor if self.high_res: x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") x = net.conv1(x) x = net.bn1(x) x = net.relu(x) - feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[2] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.maxpool(x) x = net.layer1(x) - feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[4] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer2(x) - feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[8] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer3(x) - feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[16] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer4(x) - feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[32] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) return feats def train(self, mode=True): @@ -110,33 +155,64 @@ class ResNet101(nn.Module): class WideResNet50(nn.Module): - def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + def __init__(self, pretrained=False, high_res=False, weights=None) -> None: super().__init__() if weights is not None: - self.net = tvm.wide_resnet50_2(weights = weights) + self.net = tvm.wide_resnet50_2(weights=weights) else: self.net = tvm.wide_resnet50_2(pretrained=pretrained) self.high_res = high_res self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): net = self.net - feats = {1:x} + feats = {1: x} sf = self.scale_factor if self.high_res: x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") x = net.conv1(x) x = net.bn1(x) x = net.relu(x) - feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[2] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.maxpool(x) x = net.layer1(x) - feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[4] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer2(x) - feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[8] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer3(x) - feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[16] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) x = net.layer4(x) - feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + feats[32] = ( + x + if not self.high_res + else F.interpolate( + x, scale_factor=1 / sf, align_corners=False, mode="bilinear" + ) + ) return feats def train(self, mode=True): @@ -144,4 +220,4 @@ class WideResNet50(nn.Module): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() - pass \ No newline at end of file + pass diff --git a/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/third_party/DKM/dkm/models/model_zoo/DKMv3.py index 6f4c9ede3863d778f679a033d8d2287b8776e894..fe41ab8b6400a4e57b8b08aab556bcba535e384a 100644 --- a/third_party/DKM/dkm/models/model_zoo/DKMv3.py +++ b/third_party/DKM/dkm/models/model_zoo/DKMv3.py @@ -5,9 +5,17 @@ from ..dkm import * from ..encoders import * -def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs): +def DKMv3( + weights, + h, + w, + symmetric=True, + sample_mode="threshold_balanced", + device=None, + **kwargs +): if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") gp_dim = 256 dfn_dim = 384 feat_dim = 256 @@ -57,44 +65,44 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de conv_refiner = nn.ModuleDict( { "16": ConvRefiner( - 2 * 512+128+(2*7+1)**2, - 2 * 512+128+(2*7+1)**2, + 2 * 512 + 128 + (2 * 7 + 1) ** 2, + 2 * 512 + 128 + (2 * 7 + 1) ** 2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=128, - local_corr_radius = 7, - corr_in_other = True, + local_corr_radius=7, + corr_in_other=True, ), "8": ConvRefiner( - 2 * 512+64+(2*3+1)**2, - 2 * 512+64+(2*3+1)**2, + 2 * 512 + 64 + (2 * 3 + 1) ** 2, + 2 * 512 + 64 + (2 * 3 + 1) ** 2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=64, - local_corr_radius = 3, - corr_in_other = True, + local_corr_radius=3, + corr_in_other=True, ), "4": ConvRefiner( - 2 * 256+32+(2*2+1)**2, - 2 * 256+32+(2*2+1)**2, + 2 * 256 + 32 + (2 * 2 + 1) ** 2, + 2 * 256 + 32 + (2 * 2 + 1) ** 2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=32, - local_corr_radius = 2, - corr_in_other = True, + local_corr_radius=2, + corr_in_other=True, ), "2": ConvRefiner( - 2 * 64+16, - 128+16, + 2 * 64 + 16, + 128 + 16, 3, kernel_size=kernel_size, dw=dw, @@ -103,7 +111,7 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de displacement_emb_dim=16, ), "1": ConvRefiner( - 2 * 3+6, + 2 * 3 + 6, 24, 3, kernel_size=kernel_size, @@ -144,7 +152,16 @@ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", de ) decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) - encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False) - matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device) + encoder = ResNet50(pretrained=False, high_res=False, freeze_bn=False) + matcher = RegressionMatcher( + encoder, + decoder, + h=h, + w=w, + name="DKMv3", + sample_mode=sample_mode, + symmetric=symmetric, + **kwargs + ).to(device) res = matcher.load_state_dict(weights) return matcher diff --git a/third_party/DKM/dkm/models/model_zoo/__init__.py b/third_party/DKM/dkm/models/model_zoo/__init__.py index c85da2920c1acfac140ada2d87623203607d42ca..78901ad4f67e152933af8bb56c5478e3d561f30d 100644 --- a/third_party/DKM/dkm/models/model_zoo/__init__.py +++ b/third_party/DKM/dkm/models/model_zoo/__init__.py @@ -8,7 +8,7 @@ import torch from .DKMv3 import DKMv3 -def DKMv3_outdoor(path_to_weights = None, device=None): +def DKMv3_outdoor(path_to_weights=None, device=None): """ Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default resolution can be changed by setting model.h_resized, model.w_resized later. @@ -16,24 +16,27 @@ def DKMv3_outdoor(path_to_weights = None, device=None): can be turned off by model.upsample_preds = False """ if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if path_to_weights is not None: - weights = torch.load(path_to_weights, map_location='cpu') + weights = torch.load(path_to_weights, map_location="cpu") else: - weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"], - map_location='cpu') - return DKMv3(weights, 540, 720, upsample_preds = True, device=device) + weights = torch.hub.load_state_dict_from_url( + weight_urls["DKMv3"]["outdoor"], map_location="cpu" + ) + return DKMv3(weights, 540, 720, upsample_preds=True, device=device) -def DKMv3_indoor(path_to_weights = None, device=None): + +def DKMv3_indoor(path_to_weights=None, device=None): """ Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default Resolution can be changed by setting model.h_resized, model.w_resized later. """ if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if path_to_weights is not None: weights = torch.load(path_to_weights, map_location=device) else: - weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"], - map_location=device) - return DKMv3(weights, 480, 640, upsample_preds = False, device=device) + weights = torch.hub.load_state_dict_from_url( + weight_urls["DKMv3"]["indoor"], map_location=device + ) + return DKMv3(weights, 480, 640, upsample_preds=False, device=device) diff --git a/third_party/DKM/dkm/utils/kde.py b/third_party/DKM/dkm/utils/kde.py index fa392455e70fda4c9c77c28bda76bcb7ef9045b0..286a531cede3fe1b46fbb8915bb8ad140b2cb79a 100644 --- a/third_party/DKM/dkm/utils/kde.py +++ b/third_party/DKM/dkm/utils/kde.py @@ -2,25 +2,28 @@ import torch import torch.nn.functional as F import numpy as np -def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1): + +def fast_kde(x, std=0.1, kernel_size=9, dilation=3, padding=9 // 2, stride=1): raise NotImplementedError("WIP, use at your own risk.") # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid - x = x.permute(0,3,1,2) - B,C,H,W = x.shape - K = kernel_size ** 2 - unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W) - scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp() + x = x.permute(0, 3, 1, 2) + B, C, H, W = x.shape + K = kernel_size**2 + unfolded_x = F.unfold( + x, kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride + ).reshape(B, C, K, H, W) + scores = (-(unfolded_x - x[:, :, None]).sum(dim=1) ** 2 / (2 * std**2)).exp() density = scores.sum(dim=1) return density -def kde(x, std = 0.1, device=None): +def kde(x, std=0.1, device=None): if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(x, np.ndarray): x = torch.from_numpy(x) # use a gaussian kernel to estimate density x = x.to(device) - scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp() density = scores.sum(dim=-1) return density diff --git a/third_party/DKM/dkm/utils/local_correlation.py b/third_party/DKM/dkm/utils/local_correlation.py index c0c1c06291d0b760376a2b2162bcf49d6eb1303c..08f7f04881bb9610edf3bd8bdcbda4e32d6e4c54 100644 --- a/third_party/DKM/dkm/utils/local_correlation.py +++ b/third_party/DKM/dkm/utils/local_correlation.py @@ -3,38 +3,42 @@ import torch.nn.functional as F def local_correlation( - feature0, - feature1, - local_radius, - padding_mode="zeros", - flow = None + feature0, feature1, local_radius, padding_mode="zeros", flow=None ): device = feature0.device b, c, h, w = feature0.size() if flow is None: # If flow is None, assume feature0 and feature1 are aligned coords = torch.meshgrid( - ( - torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), - torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), - )) - coords = torch.stack((coords[1], coords[0]), dim=-1)[ - None - ].expand(b, h, w, 2) + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(b, h, w, 2) else: - coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + coords = flow.permute(0, 2, 3, 1) # If using flow, sample around flow target. r = local_radius local_window = torch.meshgrid( - ( - torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device), - torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device), - )) - local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ - None - ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2) - coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2) + ( + torch.linspace( + -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device + ), + torch.linspace( + -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device + ), + ) + ) + local_window = ( + torch.stack((local_window[1], local_window[0]), dim=-1)[None] + .expand(b, 2 * r + 1, 2 * r + 1, 2) + .reshape(b, (2 * r + 1) ** 2, 2) + ) + coords = (coords[:, :, :, None] + local_window[:, None, None]).reshape( + b, h, w * (2 * r + 1) ** 2, 2 + ) window_feature = F.grid_sample( feature1, coords, padding_mode=padding_mode, align_corners=False - )[...,None].reshape(b,c,h,w,(2*r+1)**2) - corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5) + )[..., None].reshape(b, c, h, w, (2 * r + 1) ** 2) + corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature) / (c**0.5) return corr diff --git a/third_party/DKM/dkm/utils/utils.py b/third_party/DKM/dkm/utils/utils.py index 46bbe60260930aed184c6fa5907c837c0177b304..ca5ca11da35d2c201d3351d33798a04cd7781b4f 100644 --- a/third_party/DKM/dkm/utils/utils.py +++ b/third_party/DKM/dkm/utils/utils.py @@ -6,18 +6,18 @@ from torchvision.transforms.functional import InterpolationMode import torch.nn.functional as F from PIL import Image -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py # --- GEOMETRY --- def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC diff --git a/third_party/DarkFeat/darkfeat.py b/third_party/DarkFeat/darkfeat.py index e78ad2604aafb759a6241365ac93fd1ef38f76f3..710962a2a8853689b5b0b764ce817d23aa0537ac 100644 --- a/third_party/DarkFeat/darkfeat.py +++ b/third_party/DarkFeat/darkfeat.py @@ -16,11 +16,11 @@ def gather_nd(params, indices): out_shape = orig_shape[:-1] + list(params.shape)[m:] else: raise ValueError( - f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}' + f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}" ) indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() - output = params[indices] # (num_samples, ...) + output = params[indices] # (num_samples, ...) return output.reshape(out_shape).contiguous() @@ -59,11 +59,13 @@ def interpolate(pos, inputs, nd=True): w_bottom_right = w_bottom_right[..., None] interpolated_val = ( - w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + - w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + - w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + - w_bottom_right * - gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right + * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left + * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right + * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) ) return interpolated_val @@ -73,24 +75,29 @@ def edge_mask(inputs, n_channel, dilation=1, edge_thld=5): b, c, h, w = inputs.size() device = inputs.device - dii_filter = torch.tensor( - [[0, 1., 0], [0, -2., 0], [0, 1., 0]] - ).view(1, 1, 3, 3) + dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3) dij_filter = 0.25 * torch.tensor( - [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] - ).view(1, 1, 3, 3) - djj_filter = torch.tensor( - [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] + [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]] ).view(1, 1, 3, 3) + djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3) dii = F.conv2d( - inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + dii_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) dij = F.conv2d( - inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + dij_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) djj = F.conv2d( - inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + djj_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) det = dii * djj - dij * dij @@ -111,11 +118,17 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si mask = score_map > score_thld if nms_size > 0: - nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2) + nms_mask = F.max_pool2d( + score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2 + ) nms_mask = torch.eq(score_map, nms_mask) mask = torch.logical_and(nms_mask, mask) if eof_size > 0: - eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device) + eof_mask = torch.ones( + (1, 1, h - 2 * eof_size, w - 2 * eof_size), + dtype=torch.float32, + device=score_map.device, + ) eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0) eof_mask = eof_mask.bool() mask = torch.logical_and(eof_mask, mask) @@ -157,23 +170,20 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si # output: [batch_size, C, H, W], [batch_size, C, H, W] def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): inputs = inputs / moving_instance_max - + batch_size, C, H, W = inputs.shape pad_size = ksize // 2 + (dilation - 1) kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize) - - pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect') + + pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect") avg_spatial_inputs = F.conv2d( - pad_inputs, - kernel, - stride=1, - dilation=dilation, - padding=0, - groups=C + pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C ) - avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1 + avg_channel_inputs = torch.mean( + inputs, axis=1, keepdim=True + ) # channel dimension is 1 # print(avg_spatial_inputs.shape) alpha = F.softplus(inputs - avg_spatial_inputs) @@ -184,23 +194,36 @@ def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): class DarkFeat(nn.Module): default_config = { - 'model_path': '', - 'input_type': 'raw-demosaic', - 'kpt_n': 5000, - 'kpt_refinement': True, - 'score_thld': 0.5, - 'edge_thld': 10, - 'multi_scale': False, - 'multi_level': True, - 'nms_size': 3, - 'eof_size': 5, - 'need_norm': True, - 'use_peakiness': True + "model_path": "", + "input_type": "raw-demosaic", + "kpt_n": 5000, + "kpt_refinement": True, + "score_thld": 0.5, + "edge_thld": 10, + "multi_scale": False, + "multi_level": True, + "nms_size": 3, + "eof_size": 5, + "need_norm": True, + "use_peakiness": True, } - def __init__(self, model_path='', inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): + def __init__( + self, + model_path="", + inchan=3, + dilated=True, + dilation=1, + bn=True, + bn_affine=False, + ): super(DarkFeat, self).__init__() - inchan = 3 if self.default_config['input_type'] == 'rgb' or self.default_config['input_type'] == 'raw-demosaic' else 1 + inchan = ( + 3 + if self.default_config["input_type"] == "rgb" + or self.default_config["input_type"] == "raw-demosaic" + else 1 + ) self.config = {**self.default_config} self.inchan = inchan @@ -209,60 +232,81 @@ class DarkFeat(nn.Module): self.dilation = dilation self.bn = bn self.bn_affine = bn_affine - self.config['model_path'] = model_path + self.config["model_path"] = model_path dim = 128 mchan = 4 - self.conv0 = self._add_conv( 8*mchan) - self.conv1 = self._add_conv( 8*mchan, bn=False) - self.bn1 = self._make_bn(8*mchan) - self.conv2 = self._add_conv( 16*mchan, stride=2) - self.conv3 = self._add_conv( 16*mchan, bn=False) - self.bn3 = self._make_bn(16*mchan) - self.conv4 = self._add_conv( 32*mchan, stride=2) - self.conv5 = self._add_conv( 32*mchan) + self.conv0 = self._add_conv(8 * mchan) + self.conv1 = self._add_conv(8 * mchan, bn=False) + self.bn1 = self._make_bn(8 * mchan) + self.conv2 = self._add_conv(16 * mchan, stride=2) + self.conv3 = self._add_conv(16 * mchan, bn=False) + self.bn3 = self._make_bn(16 * mchan) + self.conv4 = self._add_conv(32 * mchan, stride=2) + self.conv5 = self._add_conv(32 * mchan) # replace last 8x8 convolution with 3 3x3 convolutions - self.conv6_0 = self._add_conv( 32*mchan) - self.conv6_1 = self._add_conv( 32*mchan) + self.conv6_0 = self._add_conv(32 * mchan) + self.conv6_1 = self._add_conv(32 * mchan) self.conv6_2 = self._add_conv(dim, bn=False, relu=False) self.out_dim = dim - self.moving_avg_params = nn.ParameterList([ - Parameter(torch.tensor(1.), requires_grad=False), - Parameter(torch.tensor(1.), requires_grad=False), - Parameter(torch.tensor(1.), requires_grad=False) - ]) + self.moving_avg_params = nn.ParameterList( + [ + Parameter(torch.tensor(1.0), requires_grad=False), + Parameter(torch.tensor(1.0), requires_grad=False), + Parameter(torch.tensor(1.0), requires_grad=False), + ] + ) self.clf = nn.Conv2d(128, 2, kernel_size=1) state_dict = torch.load(self.config["model_path"]) new_state_dict = {} - + for key in state_dict: - if 'running_mean' not in key and 'running_var' not in key and 'num_batches_tracked' not in key: + if ( + "running_mean" not in key + and "running_var" not in key + and "num_batches_tracked" not in key + ): new_state_dict[key] = state_dict[key] self.load_state_dict(new_state_dict) - print('Loaded DarkFeat model') - + print("Loaded DarkFeat model") + def _make_bn(self, outd): return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False) - def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False): + def _add_conv( + self, + outd, + k=3, + stride=1, + dilation=1, + bn=True, + relu=True, + k_pool=1, + pool_type="max", + bias=False, + ): d = self.dilation * dilation - conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias) + conv_params = dict( + padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias + ) ops = nn.ModuleList([]) - ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) - if bn and self.bn: ops.append( self._make_bn(outd) ) - if relu: ops.append( nn.ReLU(inplace=True) ) + ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) + if bn and self.bn: + ops.append(self._make_bn(outd)) + if relu: + ops.append(nn.ReLU(inplace=True)) self.curchan = outd - + if k_pool > 1: - if pool_type == 'avg': + if pool_type == "avg": ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) - elif pool_type == 'max': + elif pool_type == "max": ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) else: print(f"Error, unknown pooling type {pool_type}...") @@ -270,32 +314,32 @@ class DarkFeat(nn.Module): return nn.Sequential(*ops) def forward(self, input): - """ Compute keypoints, scores, descriptors for image """ - data = input['image'] + """Compute keypoints, scores, descriptors for image""" + data = input["image"] H, W = data.shape[2:] - if self.config['input_type'] == 'rgb': + if self.config["input_type"] == "rgb": # 3-channel rgb RGB_mean = [0.485, 0.456, 0.406] - RGB_std = [0.229, 0.224, 0.225] + RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) data = norm_RGB(data) - elif self.config['input_type'] == 'gray': + elif self.config["input_type"] == "gray": # 1-channel data = torch.mean(data, dim=1, keepdim=True) norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std()) data = norm_gray0(data) - elif self.config['input_type'] == 'raw': + elif self.config["input_type"] == "raw": # 4-channel pass - elif self.config['input_type'] == 'raw-demosaic': + elif self.config["input_type"] == "raw-demosaic": # 3-channel pass else: raise NotImplementedError() - + # x: [N, C, H, W] x0 = self.conv0(data) x1 = self.conv1(x0) @@ -309,16 +353,20 @@ class DarkFeat(nn.Module): x6_1 = self.conv6_1(x6_0) x6_2 = self.conv6_2(x6_1) - comb_weights = torch.tensor([1., 2., 3.], device=data.device) + comb_weights = torch.tensor([1.0, 2.0, 3.0], device=data.device) comb_weights /= torch.sum(comb_weights) ksize = [3, 2, 1] det_score_maps = [] for idx, xx in enumerate([x1, x3, x6_2]): - alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]) + alpha, beta = peakiness_score( + xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx] + ) score_vol = alpha * beta det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] - det_score_map = F.interpolate(det_score_map, size=data.shape[2:], mode='bilinear', align_corners=True) + det_score_map = F.interpolate( + det_score_map, size=data.shape[2:], mode="bilinear", align_corners=True + ) det_score_map = comb_weights[idx] * det_score_map det_score_maps.append(det_score_map) @@ -326,34 +374,42 @@ class DarkFeat(nn.Module): desc = x6_2 score_map = det_score_map - conf = F.softmax(self.clf((desc)**2), dim=1)[:,1:2] - score_map = score_map * F.interpolate(conf, size=score_map.shape[2:], mode='bilinear', align_corners=True) + conf = F.softmax(self.clf((desc) ** 2), dim=1)[:, 1:2] + score_map = score_map * F.interpolate( + conf, size=score_map.shape[2:], mode="bilinear", align_corners=True + ) kpt_inds, kpt_score = extract_kpts( score_map, - k=self.config['kpt_n'], - score_thld=self.config['score_thld'], - nms_size=self.config['nms_size'], - eof_size=self.config['eof_size'], - edge_thld=self.config['edge_thld'] + k=self.config["kpt_n"], + score_thld=self.config["score_thld"], + nms_size=self.config["nms_size"], + eof_size=self.config["eof_size"], + edge_thld=self.config["edge_thld"], ) - descs = F.normalize( - interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)), - p=2, - dim=-1 - ).detach().cpu().numpy(), - kpts = np.squeeze(torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0) \ - * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32) + descs = ( + F.normalize( + interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)), + p=2, + dim=-1, + ) + .detach() + .cpu() + .numpy(), + ) + kpts = np.squeeze( + torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0 + ) * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32) scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0) - idxs = np.negative(scores).argsort()[0:self.config['kpt_n']] + idxs = np.negative(scores).argsort()[0 : self.config["kpt_n"]] descs = descs[0][idxs] kpts = kpts[idxs] scores = scores[idxs] return { - 'keypoints': kpts, - 'scores': torch.from_numpy(scores), - 'descriptors': torch.from_numpy(descs.T), + "keypoints": kpts, + "scores": torch.from_numpy(scores), + "descriptors": torch.from_numpy(descs.T), } diff --git a/third_party/DarkFeat/datasets/InvISP/cal_metrics.py b/third_party/DarkFeat/datasets/InvISP/cal_metrics.py index cc3e501664487de4c08ab8c89328dd266fba2868..28811368c5be5a362e8907ec4963a1de7aaa260b 100644 --- a/third_party/DarkFeat/datasets/InvISP/cal_metrics.py +++ b/third_party/DarkFeat/datasets/InvISP/cal_metrics.py @@ -1,8 +1,9 @@ import cv2 import numpy as np import math + # from skimage.metrics import structural_similarity as ssim -from skimage.measure import compare_ssim +from skimage.measure import compare_ssim from scipy.misc import imread from glob import glob @@ -14,30 +15,34 @@ parser.add_argument("--path", type=str, help="Path to evaluate images.") args = parser.parse_args() + def psnr(img1, img2): - mse = np.mean( (img1/255. - img2/255.) ** 2 ) - if mse < 1.0e-10: - return 100 - PIXEL_MAX = 1 - return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2) + if mse < 1.0e-10: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + def psnr_raw(img1, img2): - mse = np.mean( (img1 - img2) ** 2 ) - if mse < 1.0e-10: - return 100 - PIXEL_MAX = 1 - return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + mse = np.mean((img1 - img2) ** 2) + if mse < 1.0e-10: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) def my_ssim(img1, img2): - return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), multichannel=True) + return compare_ssim( + img1, img2, data_range=img1.max() - img1.min(), multichannel=True + ) def quan_eval(path, suffix="jpg"): # path: /disk2/yazhou/projects/IISP/exps/test_final_unet_globalEDV2/ # ours - gt_imgs = sorted(glob(path+"tar*.%s"%suffix)) - pred_imgs = sorted(glob(path+"pred*.%s"%suffix)) + gt_imgs = sorted(glob(path + "tar*.%s" % suffix)) + pred_imgs = sorted(glob(path + "pred*.%s" % suffix)) # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb: # gt_imgs = [line.rstrip() for line in f_gt.readlines()] @@ -45,8 +50,8 @@ def quan_eval(path, suffix="jpg"): assert len(gt_imgs) == len(pred_imgs) - psnr_avg = 0. - ssim_avg = 0. + psnr_avg = 0.0 + ssim_avg = 0.0 for i in range(len(gt_imgs)): gt = imread(gt_imgs[i]) pred = imread(pred_imgs[i]) @@ -66,21 +71,23 @@ def quan_eval(path, suffix="jpg"): return psnr_avg, ssim_avg + def mse(gt, pred): - return np.mean((gt-pred)**2) + return np.mean((gt - pred) ** 2) + def mse_raw(path, suffix="npy"): - gt_imgs = sorted(glob(path+"raw_tar*.%s"%suffix)) - pred_imgs = sorted(glob(path+"raw_pred*.%s"%suffix)) + gt_imgs = sorted(glob(path + "raw_tar*.%s" % suffix)) + pred_imgs = sorted(glob(path + "raw_pred*.%s" % suffix)) # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb: # gt_imgs = [line.rstrip() for line in f_gt.readlines()] # pred_imgs = [line.rstrip() for line in f_rgb.readlines()] - + assert len(gt_imgs) == len(pred_imgs) - mse_avg = 0. - psnr_avg = 0. + mse_avg = 0.0 + psnr_avg = 0.0 for i in range(len(gt_imgs)): gt = np.load(gt_imgs[i]) pred = np.load(pred_imgs[i]) @@ -100,6 +107,7 @@ def mse_raw(path, suffix="npy"): return mse_avg, psnr_avg + test_full = False # if test_full: @@ -107,8 +115,10 @@ test_full = False # mse_avg, psnr_avg_raw = mse_raw(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt)) # else: psnr_avg, ssim_avg = quan_eval(args.path, "jpg") -mse_avg, psnr_avg_raw = mse_raw(args.path) - -print("pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format(psnr_avg, ssim_avg, mse_avg, psnr_avg_raw)) - +mse_avg, psnr_avg_raw = mse_raw(args.path) +print( + "pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format( + psnr_avg, ssim_avg, mse_avg, psnr_avg_raw + ) +) diff --git a/third_party/DarkFeat/datasets/InvISP/config/config.py b/third_party/DarkFeat/datasets/InvISP/config/config.py index dc42182ecf7464cc85ed5c77b7aeb9ee4e3ecd74..d0b041cd724db5d8edf629fd56dfba10b83ea6c0 100644 --- a/third_party/DarkFeat/datasets/InvISP/config/config.py +++ b/third_party/DarkFeat/datasets/InvISP/config/config.py @@ -5,17 +5,37 @@ BATCH_SIZE = 1 DATA_PATH = "./data/" - def get_arguments(): parser = argparse.ArgumentParser(description="training codes") - + parser.add_argument("--task", type=str, help="Name of this training") - parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Dataset root path.") - parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. ") - parser.add_argument("--debug_mode", dest='debug_mode', action='store_true', help="If debug mode, load less data.") - parser.add_argument("--gamma", dest='gamma', action='store_true', help="Use gamma compression for raw data.") - parser.add_argument("--camera", type=str, default="NIKON_D700", choices=["NIKON_D700", "Canon_EOS_5D"], help="Choose which camera to use. ") - parser.add_argument("--rgb_weight", type=float, default=1, help="Weight for rgb loss. ") - - + parser.add_argument( + "--data_path", type=str, default=DATA_PATH, help="Dataset root path." + ) + parser.add_argument( + "--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. " + ) + parser.add_argument( + "--debug_mode", + dest="debug_mode", + action="store_true", + help="If debug mode, load less data.", + ) + parser.add_argument( + "--gamma", + dest="gamma", + action="store_true", + help="Use gamma compression for raw data.", + ) + parser.add_argument( + "--camera", + type=str, + default="NIKON_D700", + choices=["NIKON_D700", "Canon_EOS_5D"], + help="Choose which camera to use. ", + ) + parser.add_argument( + "--rgb_weight", type=float, default=1, help="Weight for rgb loss. " + ) + return parser diff --git a/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py b/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py index 62271771a17a4863b730136d49f2a23aed0e49b2..3445a409b756b5f2ae6f0f4d1da2c589268635e1 100644 --- a/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py +++ b/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py @@ -10,22 +10,27 @@ import scipy.io as scio parser = argparse.ArgumentParser(description="data preprocess") parser.add_argument("--camera", type=str, default="NIKON_D700", help="Camera Name") -parser.add_argument("--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW") -parser.add_argument("--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth.") +parser.add_argument( + "--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW" +) +parser.add_argument( + "--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth." +) args = parser.parse_args() camera_name = args.camera Bayer_Pattern = args.Bayer_Pattern JPEG_Quality = args.JPEG_Quality -dng_path = sorted(glob.glob('/mnt/nvme2n1/hyz/data/' + camera_name + '/DNG/*.cr2')) -rgb_target_path = '/mnt/nvme2n1/hyz/data/'+ camera_name + '/RGB/' -raw_input_path = '/mnt/nvme2n1/hyz/data/' + camera_name + '/RAW/' +dng_path = sorted(glob.glob("/mnt/nvme2n1/hyz/data/" + camera_name + "/DNG/*.cr2")) +rgb_target_path = "/mnt/nvme2n1/hyz/data/" + camera_name + "/RGB/" +raw_input_path = "/mnt/nvme2n1/hyz/data/" + camera_name + "/RAW/" if not os.path.isdir(rgb_target_path): os.mkdir(rgb_target_path) if not os.path.isdir(raw_input_path): os.mkdir(raw_input_path) - + + def flip(raw_img, flip): if flip == 3: raw_img = np.rot90(raw_img, k=2) @@ -38,19 +43,19 @@ def flip(raw_img, flip): return raw_img - for path in dng_path: print("Start Processing %s" % os.path.basename(path)) raw = rawpy.imread(path) - file_name = path.split('/')[-1].split('.')[0] - im = raw.postprocess(use_camera_wb=True,no_auto_bright=True) + file_name = path.split("/")[-1].split(".")[0] + im = raw.postprocess(use_camera_wb=True, no_auto_bright=True) flip_val = raw.sizes.flip cwb = raw.camera_whitebalance raw_img = raw.raw_image_visible - if camera_name == 'Canon_EOS_5D': + if camera_name == "Canon_EOS_5D": raw_img = np.maximum(raw_img - 127.0, 0) de_raw = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw_img, Bayer_Pattern) de_raw = flip(de_raw, flip_val) - rgb_img = PILImage.fromarray(im).save(rgb_target_path + file_name + '.jpg', quality = JPEG_Quality, subsampling = 1) - np.savez(raw_input_path + file_name + '.npz', raw=de_raw, wb=cwb) - + rgb_img = PILImage.fromarray(im).save( + rgb_target_path + file_name + ".jpg", quality=JPEG_Quality, subsampling=1 + ) + np.savez(raw_input_path + file_name + ".npz", raw=de_raw, wb=cwb) diff --git a/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py b/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py index 4c71bd3b4162bd21761983deef6b94fa46a364f6..9f0106b9f5175c8cd003cbdcab21f6c9c71e262d 100644 --- a/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py +++ b/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py @@ -14,119 +14,147 @@ from .base_dataset import BaseDataset class FiveKDatasetTrain(BaseDataset): def __init__(self, opt): - super().__init__(opt=opt) + super().__init__(opt=opt) self.patch_size = 256 input_RAWs_WBs, target_RGBs = self.load(is_train=True) - assert len(input_RAWs_WBs) == len(target_RGBs) - self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} + assert len(input_RAWs_WBs) == len(target_RGBs) + self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} def random_flip(self, input_raw, target_rgb): idx = np.random.randint(2) - input_raw = np.flip(input_raw,axis=idx).copy() - target_rgb = np.flip(target_rgb,axis=idx).copy() - + input_raw = np.flip(input_raw, axis=idx).copy() + target_rgb = np.flip(target_rgb, axis=idx).copy() + return input_raw, target_rgb def random_rotate(self, input_raw, target_rgb): idx = np.random.randint(4) - input_raw = np.rot90(input_raw,k=idx) - target_rgb = np.rot90(target_rgb,k=idx) + input_raw = np.rot90(input_raw, k=idx) + target_rgb = np.rot90(target_rgb, k=idx) return input_raw, target_rgb - def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False): + def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False): H, W, _ = input_raw.shape rnd_h = random.randint(0, max(0, H - patch_size)) rnd_w = random.randint(0, max(0, W - patch_size)) - patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] + patch_input_raw = input_raw[ + rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : + ] if flow or demos: - patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] + patch_target_rgb = target_rgb[ + rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : + ] else: - patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :] + patch_target_rgb = target_rgb[ + rnd_h * 2 : rnd_h * 2 + patch_size * 2, + rnd_w * 2 : rnd_w * 2 + patch_size * 2, + :, + ] return patch_input_raw, patch_target_rgb - + def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): - input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos) - input_raw, target_rgb = self.random_rotate(input_raw,target_rgb) - input_raw, target_rgb = self.random_flip(input_raw,target_rgb) - + input_raw, target_rgb = self.random_crop( + patch_size, input_raw, target_rgb, flow=flow, demos=demos + ) + input_raw, target_rgb = self.random_rotate(input_raw, target_rgb) + input_raw, target_rgb = self.random_flip(input_raw, target_rgb) + return input_raw, target_rgb def __len__(self): - return len(self.data['input_RAWs_WBs']) + return len(self.data["input_RAWs_WBs"]) + + def __getitem__(self, idx): + input_raw_wb_path = self.data["input_RAWs_WBs"][idx] + target_rgb_path = self.data["target_RGBs"][idx] - def __getitem__(self, idx): - input_raw_wb_path = self.data['input_RAWs_WBs'][idx] - target_rgb_path = self.data['target_RGBs'][idx] - target_rgb_img = imread(target_rgb_path) input_raw_wb = np.load(input_raw_wb_path) - input_raw_img = input_raw_wb['raw'] - wb = input_raw_wb['wb'] - wb = wb / wb.max() - input_raw_img = input_raw_img * wb[:-1] + input_raw_img = input_raw_wb["raw"] + wb = input_raw_wb["wb"] + wb = wb / wb.max() + input_raw_img = input_raw_img * wb[:-1] self.patch_size = 256 - input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True) - - if self.gamma: - norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) - input_raw_img = np.power(input_raw_img, 1/2.2) + input_raw_img, target_rgb_img = self.aug( + self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True + ) + + if self.gamma: + norm_value = ( + np.power(4095, 1 / 2.2) + if self.camera_name == "Canon_EOS_5D" + else np.power(16383, 1 / 2.2) + ) + input_raw_img = np.power(input_raw_img, 1 / 2.2) else: - norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 + norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 target_rgb_img = self.norm_img(target_rgb_img, max_value=255) - input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) + input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) target_raw_img = input_raw_img.copy() input_raw_img = self.np2tensor(input_raw_img).float() target_rgb_img = self.np2tensor(target_rgb_img).float() target_raw_img = self.np2tensor(target_raw_img).float() - - sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, - 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} + + sample = { + "input_raw": input_raw_img, + "target_rgb": target_rgb_img, + "target_raw": target_raw_img, + "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], + } return sample + class FiveKDatasetTest(BaseDataset): def __init__(self, opt): super().__init__(opt=opt) self.patch_size = 256 - + input_RAWs_WBs, target_RGBs = self.load(is_train=False) - assert len(input_RAWs_WBs) == len(target_RGBs) - self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} + assert len(input_RAWs_WBs) == len(target_RGBs) + self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} def __len__(self): - return len(self.data['input_RAWs_WBs']) + return len(self.data["input_RAWs_WBs"]) + + def __getitem__(self, idx): + input_raw_wb_path = self.data["input_RAWs_WBs"][idx] + target_rgb_path = self.data["target_RGBs"][idx] - def __getitem__(self, idx): - input_raw_wb_path = self.data['input_RAWs_WBs'][idx] - target_rgb_path = self.data['target_RGBs'][idx] - target_rgb_img = imread(target_rgb_path) input_raw_wb = np.load(input_raw_wb_path) - input_raw_img = input_raw_wb['raw'] - wb = input_raw_wb['wb'] - wb = wb / wb.max() - input_raw_img = input_raw_img * wb[:-1] - - if self.gamma: - norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) - input_raw_img = np.power(input_raw_img, 1/2.2) + input_raw_img = input_raw_wb["raw"] + wb = input_raw_wb["wb"] + wb = wb / wb.max() + input_raw_img = input_raw_img * wb[:-1] + + if self.gamma: + norm_value = ( + np.power(4095, 1 / 2.2) + if self.camera_name == "Canon_EOS_5D" + else np.power(16383, 1 / 2.2) + ) + input_raw_img = np.power(input_raw_img, 1 / 2.2) else: - norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 + norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 target_rgb_img = self.norm_img(target_rgb_img, max_value=255) - input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) + input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) target_raw_img = input_raw_img.copy() input_raw_img = self.np2tensor(input_raw_img).float() target_rgb_img = self.np2tensor(target_rgb_img).float() target_raw_img = self.np2tensor(target_raw_img).float() - - sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, - 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} - return sample + sample = { + "input_raw": input_raw_img, + "target_rgb": target_rgb_img, + "target_raw": target_raw_img, + "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], + } + return sample diff --git a/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py b/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py index 34c5de9f75dbfb5323c2cdad532cb0a42c09df22..1ec55b4edd7663c8323a9b197e938083c6ed2497 100644 --- a/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py +++ b/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py @@ -3,16 +3,17 @@ import numpy as np from torch.utils.data import Dataset import torch + class BaseDataset(Dataset): def __init__(self, opt): self.crop_size = 512 self.debug_mode = opt.debug_mode - self.data_path = opt.data_path # dataset path. e.g., ./data/ - self.camera_name = opt.camera + self.data_path = opt.data_path # dataset path. e.g., ./data/ + self.camera_name = opt.camera self.gamma = opt.gamma def norm_img(self, img, max_value): - img = img / float(max_value) + img = img / float(max_value) return img def pack_raw(self, raw): @@ -20,15 +21,20 @@ class BaseDataset(Dataset): im = np.expand_dims(raw, axis=2) H, W = raw.shape[0], raw.shape[1] # RGBG - out = np.concatenate((im[0:H:2, 0:W:2, :], - im[0:H:2, 1:W:2, :], - im[1:H:2, 1:W:2, :], - im[1:H:2, 0:W:2, :]), axis=2) + out = np.concatenate( + ( + im[0:H:2, 0:W:2, :], + im[0:H:2, 1:W:2, :], + im[1:H:2, 1:W:2, :], + im[1:H:2, 0:W:2, :], + ), + axis=2, + ) return out - + def np2tensor(self, array): - return torch.Tensor(array).permute(2,0,1) - + return torch.Tensor(array).permute(2, 0, 1) + def center_crop(self, img, crop_size=None): H = img.shape[0] W = img.shape[1] @@ -37,44 +43,43 @@ class BaseDataset(Dataset): th, tw = crop_size[0], crop_size[1] else: th, tw = self.crop_size, self.crop_size - x1_img = int(round((W - tw) / 2.)) - y1_img = int(round((H - th) / 2.)) + x1_img = int(round((W - tw) / 2.0)) + y1_img = int(round((H - th) / 2.0)) if img.ndim == 3: - input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw, :] + input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw, :] else: - input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw] + input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw] return input_patch def load(self, is_train=True): # ./data - # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB - # ./data/Canon EOS 5D/RAW, ./data/Canon EOS 5D/RGB - # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt - # ./data/NIKON D700_train.txt: a0016, ... - input_RAWs_WBs = [] - target_RGBs = [] - - data_path = self.data_path # ./data/ + # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB + # ./data/Canon EOS 5D/RAW, ./data/Canon EOS 5D/RGB + # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt + # ./data/NIKON D700_train.txt: a0016, ... + input_RAWs_WBs = [] + target_RGBs = [] + + data_path = self.data_path # ./data/ if is_train: txt_path = data_path + self.camera_name + "_train.txt" else: txt_path = data_path + self.camera_name + "_test.txt" with open(txt_path, "r") as f_read: - # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()] - valid_camera_list = [line.strip() for line in f_read.readlines()] - + # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()] + valid_camera_list = [line.strip() for line in f_read.readlines()] + if self.debug_mode: valid_camera_list = valid_camera_list[:10] - - for i,name in enumerate(valid_camera_list): - full_name = data_path + self.camera_name - input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz") - target_RGBs.append(full_name + "/RGB/" + name + ".jpg") - - return input_RAWs_WBs, target_RGBs + for i, name in enumerate(valid_camera_list): + full_name = data_path + self.camera_name + input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz") + target_RGBs.append(full_name + "/RGB/" + name + ".jpg") + + return input_RAWs_WBs, target_RGBs def __len__(self): return 0 diff --git a/third_party/DarkFeat/datasets/InvISP/model/loss.py b/third_party/DarkFeat/datasets/InvISP/model/loss.py index abe8b599d5402c367bb7c84b7e370964d8273518..62a028ec26a8d7f8ef857e0582ac74800dac212e 100644 --- a/third_party/DarkFeat/datasets/InvISP/model/loss.py +++ b/third_party/DarkFeat/datasets/InvISP/model/loss.py @@ -2,14 +2,15 @@ import torch.nn.functional as F import torch -def l1_loss(output, target_rgb, target_raw, weight=1.): - raw_loss = F.l1_loss(output['reconstruct_raw'], target_raw) - rgb_loss = F.l1_loss(output['reconstruct_rgb'], target_rgb) +def l1_loss(output, target_rgb, target_raw, weight=1.0): + raw_loss = F.l1_loss(output["reconstruct_raw"], target_raw) + rgb_loss = F.l1_loss(output["reconstruct_rgb"], target_rgb) total_loss = raw_loss + weight * rgb_loss return total_loss, raw_loss, rgb_loss -def l2_loss(output, target_rgb, target_raw, weight=1.): - raw_loss = F.mse_loss(output['reconstruct_raw'], target_raw) - rgb_loss = F.mse_loss(output['reconstruct_rgb'], target_rgb) + +def l2_loss(output, target_rgb, target_raw, weight=1.0): + raw_loss = F.mse_loss(output["reconstruct_raw"], target_raw) + rgb_loss = F.mse_loss(output["reconstruct_rgb"], target_rgb) total_loss = raw_loss + weight * rgb_loss - return total_loss, raw_loss, rgb_loss \ No newline at end of file + return total_loss, raw_loss, rgb_loss diff --git a/third_party/DarkFeat/datasets/InvISP/model/model.py b/third_party/DarkFeat/datasets/InvISP/model/model.py index 9dd0e33cee8ebb26d621ece84622bd2611b33a60..52938290b7ca895a7c71173d40f90df5cd51b0d0 100644 --- a/third_party/DarkFeat/datasets/InvISP/model/model.py +++ b/third_party/DarkFeat/datasets/InvISP/model/model.py @@ -14,12 +14,12 @@ def initialize_weights(net_l, scale=1): for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') + init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') + init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() @@ -49,7 +49,7 @@ def initialize_weights_xavier(net_l, scale=1): class DenseBlock(nn.Module): - def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True): + def __init__(self, channel_in, channel_out, init="xavier", gc=32, bias=True): super(DenseBlock, self).__init__() self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) @@ -58,12 +58,14 @@ class DenseBlock(nn.Module): self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - if init == 'xavier': - initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + if init == "xavier": + initialize_weights_xavier( + [self.conv1, self.conv2, self.conv3, self.conv4], 0.1 + ) else: initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) initialize_weights(self.conv5, 0) - + def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) @@ -73,10 +75,11 @@ class DenseBlock(nn.Module): return x5 -def subnet(net_structure, init='xavier'): + +def subnet(net_structure, init="xavier"): def constructor(channel_in, channel_out): - if net_structure == 'DBNet': - if init == 'xavier': + if net_structure == "DBNet": + if init == "xavier": return DenseBlock(channel_in, channel_out, init) else: return DenseBlock(channel_in, channel_out) @@ -93,8 +96,8 @@ class InvBlock(nn.Module): # channel_num: 3 # channel_split_num: 1 - self.split_len1 = channel_split_num # 1 - self.split_len2 = channel_num - channel_split_num # 2 + self.split_len1 = channel_split_num # 1 + self.split_len2 = channel_num - channel_split_num # 2 self.clamp = clamp @@ -102,38 +105,51 @@ class InvBlock(nn.Module): self.G = subnet_constructor(self.split_len1, self.split_len2) self.H = subnet_constructor(self.split_len1, self.split_len2) - in_channels = 3 + in_channels = 3 self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) - + def forward(self, x, rev=False): - if not rev: - # invert1x1conv - x, logdet = self.flow_permutation(x, logdet=0, rev=False) - - # split to 1 channel and 2 channel. - x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) - - y1 = x1 + self.F(x2) # 1 channel + if not rev: + # invert1x1conv + x, logdet = self.flow_permutation(x, logdet=0, rev=False) + + # split to 1 channel and 2 channel. + x1, x2 = ( + x.narrow(1, 0, self.split_len1), + x.narrow(1, self.split_len1, self.split_len2), + ) + + y1 = x1 + self.F(x2) # 1 channel self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) - y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel + y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel out = torch.cat((y1, y2), 1) else: - # split. - x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) + # split. + x1, x2 = ( + x.narrow(1, 0, self.split_len1), + x.narrow(1, self.split_len1, self.split_len2), + ) self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) - y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) - y1 = x1 - self.F(y2) + y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) + y1 = x1 - self.F(y2) - x = torch.cat((y1, y2), 1) + x = torch.cat((y1, y2), 1) - # inv permutation + # inv permutation out, logdet = self.flow_permutation(x, logdet=0, rev=True) return out + class InvISPNet(nn.Module): - def __init__(self, channel_in=3, channel_out=3, subnet_constructor=subnet('DBNet'), block_num=8): + def __init__( + self, + channel_in=3, + channel_out=3, + subnet_constructor=subnet("DBNet"), + block_num=8, + ): super(InvISPNet, self).__init__() operations = [] @@ -141,10 +157,12 @@ class InvISPNet(nn.Module): channel_num = channel_in channel_split_num = 1 - for j in range(block_num): - b = InvBlock(subnet_constructor, channel_num, channel_split_num) # one block is one flow step. + for j in range(block_num): + b = InvBlock( + subnet_constructor, channel_num, channel_split_num + ) # one block is one flow step. operations.append(b) - + self.operations = nn.ModuleList(operations) self.initialize() @@ -153,27 +171,26 @@ class InvISPNet(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight) - m.weight.data *= 1. # for residual block + m.weight.data *= 1.0 # for residual block if m.bias is not None: - m.bias.data.zero_() + m.bias.data.zero_() elif isinstance(m, nn.Linear): init.xavier_normal_(m.weight) - m.weight.data *= 1. + m.weight.data *= 1.0 if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) - + def forward(self, x, rev=False): - out = x # x: [N,3,H,W] - - if not rev: + out = x # x: [N,3,H,W] + + if not rev: for op in self.operations: out = op.forward(out, rev) else: for op in reversed(self.operations): out = op.forward(out, rev) - - return out + return out diff --git a/third_party/DarkFeat/datasets/InvISP/model/modules.py b/third_party/DarkFeat/datasets/InvISP/model/modules.py index 88244c0b211860d97be78ba4f60f4743228171a7..b32c312d13284bc5a4837df756ed58c505b60768 100644 --- a/third_party/DarkFeat/datasets/InvISP/model/modules.py +++ b/third_party/DarkFeat/datasets/InvISP/model/modules.py @@ -47,7 +47,7 @@ def unsqueeze2d(input, factor): if factor == 1: return input - factor2 = factor ** 2 + factor2 = factor**2 B, C, H, W = input.size() diff --git a/third_party/DarkFeat/datasets/InvISP/model/utils.py b/third_party/DarkFeat/datasets/InvISP/model/utils.py index d1bef31afd7d61d4c942ffd895c818b90571b4b7..a1ab33bf1ba26ee027e1c051f63b0a29fefe6706 100644 --- a/third_party/DarkFeat/datasets/InvISP/model/utils.py +++ b/third_party/DarkFeat/datasets/InvISP/model/utils.py @@ -27,7 +27,7 @@ def uniform_binning_correction(x, n_bits=8): objective: Equivalent to -q(x)*log(q(x)). """ b, c, h, w = x.size() - n_bins = 2 ** n_bits + n_bins = 2**n_bits chw = c * h * w x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) @@ -42,11 +42,7 @@ def split_feature(tensor, type="split"): C = tensor.size(1) if type == "split": # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] - return tensor[:, :1, ...], tensor[:,1:, ...] + return tensor[:, :1, ...], tensor[:, 1:, ...] elif type == "cross": # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] - return tensor[:, 0::2, ...], tensor[:, 1::2, ...] - - - - + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] diff --git a/third_party/DarkFeat/datasets/InvISP/test_raw.py b/third_party/DarkFeat/datasets/InvISP/test_raw.py index 37610f8268e4586864e0275236c5bb1932f894df..8c3c30faf6662b04fe34f63de0d729ebcec86517 100644 --- a/third_party/DarkFeat/datasets/InvISP/test_raw.py +++ b/third_party/DarkFeat/datasets/InvISP/test_raw.py @@ -18,101 +18,145 @@ from utils.JPEG import DiffJPEG from utils.commons import denorm, preprocess_test_patch -os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') -os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") +os.environ["CUDA_VISIBLE_DEVICES"] = str( + np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) +) # os.environ['CUDA_VISIBLE_DEVICES'] = '7' -os.system('rm tmp') +os.system("rm tmp") DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() parser = get_arguments() -parser.add_argument("--ckpt", type=str, help="Checkpoint path.") -parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ") -parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ") +parser.add_argument("--ckpt", type=str, help="Checkpoint path.") +parser.add_argument( + "--out_path", type=str, default="./exps/", help="Path to save checkpoint. " +) +parser.add_argument( + "--split_to_patch", + dest="split_to_patch", + action="store_true", + help="Test on patch. ", +) args = parser.parse_args() print("Parsed arguments: {}".format(args)) ckpt_name = args.ckpt.split("/")[-1].split(".")[0] if args.split_to_patch: - os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True) - out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name) + os.makedirs( + args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True + ) + out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name) else: - os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True) - out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name) + os.makedirs( + args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True + ) + out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name) def main(args): # ======================================define the model============================================ net = InvISPNet(channel_in=3, channel_out=3, block_num=8) device = torch.device("cuda:0") - + net.to(device) net.eval() # load the pretrained weight if there exists one if os.path.isfile(args.ckpt): net.load_state_dict(torch.load(args.ckpt), strict=False) print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) - - print("[INFO] Start data load and preprocessing") - RAWDataset = FiveKDatasetTest(opt=args) - dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) - - input_RGBs = sorted(glob(out_path+"pred*jpg")) - input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] - - print("[INFO] Start test...") + + print("[INFO] Start data load and preprocessing") + RAWDataset = FiveKDatasetTest(opt=args) + dataloader = DataLoader( + RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True + ) + + input_RGBs = sorted(glob(out_path + "pred*jpg")) + input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] + + print("[INFO] Start test...") for i_batch, sample_batched in enumerate(tqdm(dataloader)): step_time = time.time() - - input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \ - sample_batched['target_raw'].to(device) - file_name = sample_batched['file_name'][0] + + input, target_rgb, target_raw = ( + sample_batched["input_raw"].to(device), + sample_batched["target_rgb"].to(device), + sample_batched["target_raw"].to(device), + ) + file_name = sample_batched["file_name"][0] if args.split_to_patch: - input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw) + input_list, target_rgb_list, target_raw_list = preprocess_test_patch( + input, target_rgb, target_raw + ) else: - # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution - input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]] - + # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution + input_list, target_rgb_list, target_raw_list = ( + [input[:, :, ::2, ::2]], + [target_rgb[:, :, ::2, ::2]], + [target_raw[:, :, ::2, ::2]], + ) + for i_patch in range(len(input_list)): - file_name_patch = file_name + "_%05d"%i_patch + file_name_patch = file_name + "_%05d" % i_patch idx = input_RGBs_names.index(file_name_patch) input_RGB_path = input_RGBs[idx] - input_RGB = torch.from_numpy(np.array(PILImage.open(input_RGB_path))/255.0).unsqueeze(0).permute(0,3,1,2).float().to(device) - - target_raw_patch = target_raw_list[i_patch] - + input_RGB = ( + torch.from_numpy(np.array(PILImage.open(input_RGB_path)) / 255.0) + .unsqueeze(0) + .permute(0, 3, 1, 2) + .float() + .to(device) + ) + + target_raw_patch = target_raw_list[i_patch] + with torch.no_grad(): reconstruct_raw = net(input_RGB, rev=True) - - pred_raw = reconstruct_raw.detach().permute(0,2,3,1) + + pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) pred_raw = torch.clamp(pred_raw, 0, 1) - - target_raw_patch = target_raw_patch.permute(0,2,3,1) + + target_raw_patch = target_raw_patch.permute(0, 2, 3, 1) pred_raw = denorm(pred_raw, 255) target_raw_patch = denorm(target_raw_patch, 255) pred_raw = pred_raw.cpu().numpy() target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32) - raw_pred = PILImage.fromarray(np.uint8(pred_raw[0,:,:,0])) - raw_tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_raw_patch[0,:,:,0]), np.uint8(pred_raw[0,:,:,0])))) - - raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0,:,:,0])) + raw_pred = PILImage.fromarray(np.uint8(pred_raw[0, :, :, 0])) + raw_tar_pred = PILImage.fromarray( + np.hstack( + ( + np.uint8(target_raw_patch[0, :, :, 0]), + np.uint8(pred_raw[0, :, :, 0]), + ) + ) + ) - raw_pred.save(out_path+"raw_pred_%s_%05d.jpg"%(file_name, i_patch)) - raw_tar.save(out_path+"raw_tar_%s_%05d.jpg"%(file_name, i_patch)) - raw_tar_pred.save(out_path+"raw_gt_pred_%s_%05d.jpg"%(file_name, i_patch)) - - np.save(out_path+"raw_pred_%s_%05d.npy"%(file_name, i_patch), pred_raw[0,:,:,:]/255.0) - np.save(out_path+"raw_tar_%s_%05d.npy"%(file_name, i_patch), target_raw_patch[0,:,:,:]/255.0) + raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0, :, :, 0])) - del reconstruct_raw + raw_pred.save(out_path + "raw_pred_%s_%05d.jpg" % (file_name, i_patch)) + raw_tar.save(out_path + "raw_tar_%s_%05d.jpg" % (file_name, i_patch)) + raw_tar_pred.save( + out_path + "raw_gt_pred_%s_%05d.jpg" % (file_name, i_patch) + ) + np.save( + out_path + "raw_pred_%s_%05d.npy" % (file_name, i_patch), + pred_raw[0, :, :, :] / 255.0, + ) + np.save( + out_path + "raw_tar_%s_%05d.npy" % (file_name, i_patch), + target_raw_patch[0, :, :, :] / 255.0, + ) -if __name__ == '__main__': + del reconstruct_raw + + +if __name__ == "__main__": torch.set_num_threads(4) main(args) - diff --git a/third_party/DarkFeat/datasets/InvISP/test_rgb.py b/third_party/DarkFeat/datasets/InvISP/test_rgb.py index d1e054b899d9142609e3f90f4a12d367a45aeac0..5c1c9f1839acd58e71b4dc244b0ce3132d09b8c7 100644 --- a/third_party/DarkFeat/datasets/InvISP/test_rgb.py +++ b/third_party/DarkFeat/datasets/InvISP/test_rgb.py @@ -16,90 +16,133 @@ from utils.JPEG import DiffJPEG from utils.commons import denorm, preprocess_test_patch from tqdm import tqdm -os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') -os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") +os.environ["CUDA_VISIBLE_DEVICES"] = str( + np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) +) # os.environ['CUDA_VISIBLE_DEVICES'] = '7' -os.system('rm tmp') +os.system("rm tmp") DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() parser = get_arguments() -parser.add_argument("--ckpt", type=str, help="Checkpoint path.") -parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save results. ") -parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ") +parser.add_argument("--ckpt", type=str, help="Checkpoint path.") +parser.add_argument( + "--out_path", type=str, default="./exps/", help="Path to save results. " +) +parser.add_argument( + "--split_to_patch", + dest="split_to_patch", + action="store_true", + help="Test on patch. ", +) args = parser.parse_args() print("Parsed arguments: {}".format(args)) ckpt_name = args.ckpt.split("/")[-1].split(".")[0] if args.split_to_patch: - os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True) - out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name) + os.makedirs( + args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True + ) + out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name) else: - os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True) - out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name) + os.makedirs( + args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True + ) + out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name) def main(args): # ======================================define the model============================================ net = InvISPNet(channel_in=3, channel_out=3, block_num=8) device = torch.device("cuda:0") - + net.to(device) net.eval() # load the pretrained weight if there exists one if os.path.isfile(args.ckpt): net.load_state_dict(torch.load(args.ckpt), strict=False) print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) - - print("[INFO] Start data load and preprocessing") - RAWDataset = FiveKDatasetTest(opt=args) - dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) - - print("[INFO] Start test...") + + print("[INFO] Start data load and preprocessing") + RAWDataset = FiveKDatasetTest(opt=args) + dataloader = DataLoader( + RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True + ) + + print("[INFO] Start test...") for i_batch, sample_batched in enumerate(tqdm(dataloader)): - step_time = time.time() - - input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \ - sample_batched['target_raw'].to(device) - file_name = sample_batched['file_name'][0] - + step_time = time.time() + + input, target_rgb, target_raw = ( + sample_batched["input_raw"].to(device), + sample_batched["target_rgb"].to(device), + sample_batched["target_raw"].to(device), + ) + file_name = sample_batched["file_name"][0] + if args.split_to_patch: - input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw) + input_list, target_rgb_list, target_raw_list = preprocess_test_patch( + input, target_rgb, target_raw + ) else: - # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution - input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]] - + # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution + input_list, target_rgb_list, target_raw_list = ( + [input[:, :, ::2, ::2]], + [target_rgb[:, :, ::2, ::2]], + [target_raw[:, :, ::2, ::2]], + ) + for i_patch in range(len(input_list)): input_patch = input_list[i_patch] target_rgb_patch = target_rgb_list[i_patch] - target_raw_patch = target_raw_list[i_patch] - + target_raw_patch = target_raw_list[i_patch] + with torch.no_grad(): reconstruct_rgb = net(input_patch) reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) - - pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) - target_rgb_patch = target_rgb_patch.permute(0,2,3,1) - + + pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1) + target_rgb_patch = target_rgb_patch.permute(0, 2, 3, 1) + pred_rgb = denorm(pred_rgb, 255) target_rgb_patch = denorm(target_rgb_patch, 255) pred_rgb = pred_rgb.cpu().numpy() target_rgb_patch = target_rgb_patch.cpu().numpy().astype(np.float32) - + # print(type(pred_rgb)) - pred = PILImage.fromarray(np.uint8(pred_rgb[0,:,:,:])) - tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_rgb_patch[0,:,:,:]), np.uint8(pred_rgb[0,:,:,:])))) - - tar = PILImage.fromarray(np.uint8(target_rgb_patch[0,:,:,:])) - - pred.save(out_path+"pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) - tar.save(out_path+"tar_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) - tar_pred.save(out_path+"gt_pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) - + pred = PILImage.fromarray(np.uint8(pred_rgb[0, :, :, :])) + tar_pred = PILImage.fromarray( + np.hstack( + ( + np.uint8(target_rgb_patch[0, :, :, :]), + np.uint8(pred_rgb[0, :, :, :]), + ) + ) + ) + + tar = PILImage.fromarray(np.uint8(target_rgb_patch[0, :, :, :])) + + pred.save( + out_path + "pred_%s_%05d.jpg" % (file_name, i_patch), + quality=90, + subsampling=1, + ) + tar.save( + out_path + "tar_%s_%05d.jpg" % (file_name, i_patch), + quality=90, + subsampling=1, + ) + tar_pred.save( + out_path + "gt_pred_%s_%05d.jpg" % (file_name, i_patch), + quality=90, + subsampling=1, + ) + del reconstruct_rgb -if __name__ == '__main__': + +if __name__ == "__main__": torch.set_num_threads(4) main(args) - diff --git a/third_party/DarkFeat/datasets/InvISP/train.py b/third_party/DarkFeat/datasets/InvISP/train.py index 16186cb38d825ac1299e5c4164799d35bfa79907..4022c4a8f523b97ffeb928263b14a79bd8b54a20 100644 --- a/third_party/DarkFeat/datasets/InvISP/train.py +++ b/third_party/DarkFeat/datasets/InvISP/train.py @@ -14,85 +14,130 @@ from config.config import get_arguments from utils.JPEG import DiffJPEG -os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') -os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") +os.environ["CUDA_VISIBLE_DEVICES"] = str( + np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) +) # os.environ['CUDA_VISIBLE_DEVICES'] = "1" -os.system('rm tmp') +os.system("rm tmp") DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() parser = get_arguments() -parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ") -parser.add_argument("--resume", dest='resume', action='store_true', help="Resume training. ") -parser.add_argument("--loss", type=str, default="L1", choices=["L1", "L2"], help="Choose which loss function to use. ") +parser.add_argument( + "--out_path", type=str, default="./exps/", help="Path to save checkpoint. " +) +parser.add_argument( + "--resume", dest="resume", action="store_true", help="Resume training. " +) +parser.add_argument( + "--loss", + type=str, + default="L1", + choices=["L1", "L2"], + help="Choose which loss function to use. ", +) parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate") -parser.add_argument("--aug", dest='aug', action='store_true', help="Use data augmentation.") +parser.add_argument( + "--aug", dest="aug", action="store_true", help="Use data augmentation." +) args = parser.parse_args() print("Parsed arguments: {}".format(args)) os.makedirs(args.out_path, exist_ok=True) -os.makedirs(args.out_path+"%s"%args.task, exist_ok=True) -os.makedirs(args.out_path+"%s/checkpoint"%args.task, exist_ok=True) +os.makedirs(args.out_path + "%s" % args.task, exist_ok=True) +os.makedirs(args.out_path + "%s/checkpoint" % args.task, exist_ok=True) -with open(args.out_path+"%s/commandline_args.yaml"%args.task , 'w') as f: +with open(args.out_path + "%s/commandline_args.yaml" % args.task, "w") as f: json.dump(args.__dict__, f, indent=2) + def main(args): # ======================================define the model====================================== net = InvISPNet(channel_in=3, channel_out=3, block_num=8) net.cuda() # load the pretrained weight if there exists one if args.resume: - net.load_state_dict(torch.load(args.out_path+"%s/checkpoint/latest.pth"%args.task)) - print("[INFO] loaded " + args.out_path+"%s/checkpoint/latest.pth"%args.task) + net.load_state_dict( + torch.load(args.out_path + "%s/checkpoint/latest.pth" % args.task) + ) + print("[INFO] loaded " + args.out_path + "%s/checkpoint/latest.pth" % args.task) optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) - scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5) - + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5) + print("[INFO] Start data loading and preprocessing") - RAWDataset = FiveKDatasetTrain(opt=args) - dataloader = DataLoader(RAWDataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) + RAWDataset = FiveKDatasetTrain(opt=args) + dataloader = DataLoader( + RAWDataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=0, + drop_last=True, + ) print("[INFO] Start to train") step = 0 for epoch in range(0, 300): - epoch_time = time.time() - + epoch_time = time.time() + for i_batch, sample_batched in enumerate(dataloader): - step_time = time.time() + step_time = time.time() - input, target_rgb, target_raw = sample_batched['input_raw'].cuda(), sample_batched['target_rgb'].cuda(), \ - sample_batched['target_raw'].cuda() - - reconstruct_rgb = net(input) + input, target_rgb, target_raw = ( + sample_batched["input_raw"].cuda(), + sample_batched["target_rgb"].cuda(), + sample_batched["target_raw"].cuda(), + ) + + reconstruct_rgb = net(input) reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) rgb_loss = F.l1_loss(reconstruct_rgb, target_rgb) reconstruct_rgb = DiffJPEG(reconstruct_rgb) reconstruct_raw = net(reconstruct_rgb, rev=True) raw_loss = F.l1_loss(reconstruct_raw, target_raw) - + loss = args.rgb_weight * rgb_loss + raw_loss - + optimizer.zero_grad() loss.backward() optimizer.step() - - print("task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f"%( - args.task, epoch, step, loss.detach().cpu().numpy(), raw_loss.detach().cpu().numpy(), - rgb_loss.detach().cpu().numpy(), optimizer.param_groups[0]['lr'], time.time()-step_time - )) - step += 1 - - torch.save(net.state_dict(), args.out_path+"%s/checkpoint/latest.pth"%args.task) - if (epoch+1) % 10 == 0: + + print( + "task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f" + % ( + args.task, + epoch, + step, + loss.detach().cpu().numpy(), + raw_loss.detach().cpu().numpy(), + rgb_loss.detach().cpu().numpy(), + optimizer.param_groups[0]["lr"], + time.time() - step_time, + ) + ) + step += 1 + + torch.save( + net.state_dict(), args.out_path + "%s/checkpoint/latest.pth" % args.task + ) + if (epoch + 1) % 10 == 0: # os.makedirs(args.out_path+"%s/checkpoint/%04d"%(args.task,epoch), exist_ok=True) - torch.save(net.state_dict(), args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch)) - print("[INFO] Successfully saved "+args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch)) - scheduler.step() - - print("[INFO] Epoch time: ", time.time()-epoch_time, "task: ", args.task) + torch.save( + net.state_dict(), + args.out_path + "%s/checkpoint/%04d.pth" % (args.task, epoch), + ) + print( + "[INFO] Successfully saved " + + args.out_path + + "%s/checkpoint/%04d.pth" % (args.task, epoch) + ) + scheduler.step() + + print("[INFO] Epoch time: ", time.time() - epoch_time, "task: ", args.task) + -if __name__ == '__main__': +if __name__ == "__main__": torch.set_num_threads(4) main(args) diff --git a/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py b/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py index 8997ee98a41668b4737a9b2acc2341032f173bd3..7cdd7fa91ee424250f241ecc7de63d868795aaa7 100644 --- a/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py +++ b/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py @@ -1,5 +1,3 @@ - - import torch import torch.nn as nn @@ -8,16 +6,16 @@ from .compression import compress_jpeg from .decompression import decompress_jpeg -class DiffJPEG(nn.Module): +class DiffJPEG(nn.Module): def __init__(self, differentiable=True, quality=75): - ''' Initialize the DiffJPEG layer + """Initialize the DiffJPEG layer Inputs: height(int): Original image height width(int): Original image width differentiable(bool): If true uses custom differentiable rounding function, if false uses standrard torch.round - quality(float): Quality factor for jpeg compression scheme. - ''' + quality(float): Quality factor for jpeg compression scheme. + """ super(DiffJPEG, self).__init__() if differentiable: rounding = diff_round @@ -31,13 +29,10 @@ class DiffJPEG(nn.Module): self.decompress = decompress_jpeg(rounding=rounding, factor=factor) def forward(self, x): - ''' - ''' + """ """ org_height = x.shape[2] org_width = x.shape[3] y, cb, cr = self.compress(x) recovered = self.decompress(y, cb, cr, org_height, org_width) return recovered - - diff --git a/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py b/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py index e2ebd9bdc184e869ade58eea1c6763baa1d9fc91..4ef225505d21728f63d34cec55e5335a50130e17 100644 --- a/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py +++ b/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py @@ -1,58 +1,65 @@ # Standard libraries import numpy as np + # PyTorch import torch import torch.nn as nn import math y_table = np.array( - [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, - 55], [14, 13, 16, 24, 40, 57, 69, 56], - [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, - 77], [24, 35, 55, 64, 81, 104, 113, 92], - [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], - dtype=np.float32).T + [ + [16, 11, 10, 16, 24, 40, 51, 61], + [12, 12, 14, 19, 26, 58, 60, 55], + [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], + [18, 22, 37, 56, 68, 109, 103, 77], + [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], + [72, 92, 95, 98, 112, 100, 103, 99], + ], + dtype=np.float32, +).T y_table = nn.Parameter(torch.from_numpy(y_table)) # c_table = np.empty((8, 8), dtype=np.float32) c_table.fill(99) -c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], - [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table[:4, :4] = np.array( + [[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]] +).T c_table = nn.Parameter(torch.from_numpy(c_table)) def diff_round_back(x): - """ Differentiable rounding function + """Differentiable rounding function Input: x(tensor) Output: x(tensor) """ - return torch.round(x) + (x - torch.round(x))**3 - + return torch.round(x) + (x - torch.round(x)) ** 3 def diff_round(input_tensor): test = 0 for n in range(1, 10): - test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor) + test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor) final_tensor = input_tensor - 1 / math.pi * test return final_tensor class Quant(torch.autograd.Function): - @staticmethod def forward(ctx, input): input = torch.clamp(input, 0, 1) - output = (input * 255.).round() / 255. + output = (input * 255.0).round() / 255.0 return output @staticmethod def backward(ctx, grad_output): return grad_output + class Quantization(nn.Module): def __init__(self): super(Quantization, self).__init__() @@ -62,14 +69,14 @@ class Quantization(nn.Module): def quality_to_factor(quality): - """ Calculate factor corresponding to quality + """Calculate factor corresponding to quality Input: quality(float): Quality for jpeg compression Output: factor(float): Compression factor """ if quality < 50: - quality = 5000. / quality + quality = 5000.0 / quality else: - quality = 200. - quality*2 - return quality / 100. \ No newline at end of file + quality = 200.0 - quality * 2 + return quality / 100.0 diff --git a/third_party/DarkFeat/datasets/InvISP/utils/commons.py b/third_party/DarkFeat/datasets/InvISP/utils/commons.py index e594e0597bac601edc2015d9cae670799f981495..ea546a3fa517304e97652f00c5cc65a8a2b512d6 100644 --- a/third_party/DarkFeat/datasets/InvISP/utils/commons.py +++ b/third_party/DarkFeat/datasets/InvISP/utils/commons.py @@ -5,6 +5,7 @@ def denorm(img, max_value): img = img * float(max_value) return img + def preprocess_test_patch(input_image, target_image, gt_image): input_patch_list = [] target_patch_list = [] @@ -13,11 +14,26 @@ def preprocess_test_patch(input_image, target_image, gt_image): W = input_image.shape[3] for i in range(3): for j in range(3): - input_patch = input_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] - target_patch = target_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] - gt_patch = gt_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] + input_patch = input_image[ + :, + :, + int(i * H / 3) : int((i + 1) * H / 3), + int(j * W / 3) : int((j + 1) * W / 3), + ] + target_patch = target_image[ + :, + :, + int(i * H / 3) : int((i + 1) * H / 3), + int(j * W / 3) : int((j + 1) * W / 3), + ] + gt_patch = gt_image[ + :, + :, + int(i * H / 3) : int((i + 1) * H / 3), + int(j * W / 3) : int((j + 1) * W / 3), + ] input_patch_list.append(input_patch) target_patch_list.append(target_patch) gt_patch_list.append(gt_patch) - + return input_patch_list, target_patch_list, gt_patch_list diff --git a/third_party/DarkFeat/datasets/InvISP/utils/compression.py b/third_party/DarkFeat/datasets/InvISP/utils/compression.py index 3ae22f8839517bfd7e3c774528943e8fff59dce7..9519bb99cedd1cf64efc3dacc07d59603d9e7508 100644 --- a/third_party/DarkFeat/datasets/InvISP/utils/compression.py +++ b/third_party/DarkFeat/datasets/InvISP/utils/compression.py @@ -1,40 +1,47 @@ # Standard libraries import itertools import numpy as np + # PyTorch import torch import torch.nn as nn + # Local from . import JPEG_utils class rgb_to_ycbcr_jpeg(nn.Module): - """ Converts RGB image to YCbCr + """Converts RGB image to YCbCr Input: image(tensor): batch x 3 x height x width Outpput: result(tensor): batch x height x width x 3 """ + def __init__(self): super(rgb_to_ycbcr_jpeg, self).__init__() matrix = np.array( - [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], - [0.5, -0.418688, -0.081312]], dtype=np.float32).T - self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + [ + [0.299, 0.587, 0.114], + [-0.168736, -0.331264, 0.5], + [0.5, -0.418688, -0.081312], + ], + dtype=np.float32, + ).T + self.shift = nn.Parameter(torch.tensor([0.0, 128.0, 128.0])) # self.matrix = nn.Parameter(torch.from_numpy(matrix)) def forward(self, image): image = image.permute(0, 2, 3, 1) result = torch.tensordot(image, self.matrix, dims=1) + self.shift - # result = torch.from_numpy(result) + # result = torch.from_numpy(result) result.view(image.shape) return result - class chroma_subsampling(nn.Module): - """ Chroma subsampling on CbCv channels + """Chroma subsampling on CbCv channels Input: image(tensor): batch x height x width x 3 Output: @@ -42,27 +49,28 @@ class chroma_subsampling(nn.Module): cb(tensor): batch x height/2 x width/2 cr(tensor): batch x height/2 x width/2 """ + def __init__(self): super(chroma_subsampling, self).__init__() def forward(self, image): image_2 = image.permute(0, 3, 1, 2).clone() - avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), - count_include_pad=False) + avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), count_include_pad=False) cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) cb = cb.permute(0, 2, 3, 1) cr = cr.permute(0, 2, 3, 1) return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) - + class block_splitting(nn.Module): - """ Splitting image into patches + """Splitting image into patches Input: image(tensor): batch x height x width - Output: + Output: patch(tensor): batch x h*w/64 x h x w """ + def __init__(self): super(block_splitting, self).__init__() self.k = 8 @@ -75,26 +83,30 @@ class block_splitting(nn.Module): image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) - + class dct_8x8(nn.Module): - """ Discrete Cosine Transformation + """Discrete Cosine Transformation Input: image(tensor): batch x height x width Output: dcp(tensor): batch x height x width """ + def __init__(self): super(dct_8x8, self).__init__() tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) for x, y, u, v in itertools.product(range(8), repeat=4): tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( - (2 * y + 1) * v * np.pi / 16) - alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + (2 * y + 1) * v * np.pi / 16 + ) + alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) # - self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) - self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() ) - + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter( + torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() + ) + def forward(self, image): image = image - 128 result = self.scale * torch.tensordot(image, self.tensor, dims=2) @@ -103,7 +115,7 @@ class dct_8x8(nn.Module): class y_quantize(nn.Module): - """ JPEG Quantization for Y channel + """JPEG Quantization for Y channel Input: image(tensor): batch x height x width rounding(function): rounding function to use @@ -111,6 +123,7 @@ class y_quantize(nn.Module): Output: image(tensor): batch x height x width """ + def __init__(self, rounding, factor=1): super(y_quantize, self).__init__() self.rounding = rounding @@ -124,7 +137,7 @@ class y_quantize(nn.Module): class c_quantize(nn.Module): - """ JPEG Quantization for CrCb channels + """JPEG Quantization for CrCb channels Input: image(tensor): batch x height x width rounding(function): rounding function to use @@ -132,6 +145,7 @@ class c_quantize(nn.Module): Output: image(tensor): batch x height x width """ + def __init__(self, rounding, factor=1): super(c_quantize, self).__init__() self.rounding = rounding @@ -145,41 +159,39 @@ class c_quantize(nn.Module): class compress_jpeg(nn.Module): - """ Full JPEG compression algortihm + """Full JPEG compression algortihm Input: - imgs(tensor): batch x 3 x height x width + imgs(tensor): batch x 3 x height x width rounding(function): rounding function to use factor(float): Compression factor Ouput: compressed(dict(tensor)): batch x h*w/64 x 8 x 8 """ + def __init__(self, rounding=torch.round, factor=1): super(compress_jpeg, self).__init__() self.l1 = nn.Sequential( rgb_to_ycbcr_jpeg(), - # comment this line if no subsampling - chroma_subsampling() - ) - self.l2 = nn.Sequential( - block_splitting(), - dct_8x8() + # comment this line if no subsampling + chroma_subsampling(), ) + self.l2 = nn.Sequential(block_splitting(), dct_8x8()) self.c_quantize = c_quantize(rounding=rounding, factor=factor) self.y_quantize = y_quantize(rounding=rounding, factor=factor) - + def forward(self, image): - y, cb, cr = self.l1(image*255) # modify + y, cb, cr = self.l1(image * 255) # modify # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2] - components = {'y': y, 'cb': cb, 'cr': cr} + components = {"y": y, "cb": cb, "cr": cr} for k in components.keys(): comp = self.l2(components[k]) # print(comp.shape) - if k in ('cb', 'cr'): + if k in ("cb", "cr"): comp = self.c_quantize(comp) else: comp = self.y_quantize(comp) components[k] = comp - return components['y'], components['cb'], components['cr'] \ No newline at end of file + return components["y"], components["cb"], components["cr"] diff --git a/third_party/DarkFeat/datasets/InvISP/utils/decompression.py b/third_party/DarkFeat/datasets/InvISP/utils/decompression.py index b73ff96d5f6818e1d0464b9c4133f559a3b23fba..8a006442522b8b39261c78be85fcf16e7400fe7e 100644 --- a/third_party/DarkFeat/datasets/InvISP/utils/decompression.py +++ b/third_party/DarkFeat/datasets/InvISP/utils/decompression.py @@ -1,21 +1,24 @@ # Standard libraries import itertools import numpy as np + # PyTorch import torch import torch.nn as nn + # Local from . import JPEG_utils as utils class y_dequantize(nn.Module): - """ Dequantize Y channel + """Dequantize Y channel Inputs: image(tensor): batch x height x width factor(float): compression factor Outputs: image(tensor): batch x height x width """ + def __init__(self, factor=1): super(y_dequantize, self).__init__() self.y_table = utils.y_table @@ -26,13 +29,14 @@ class y_dequantize(nn.Module): class c_dequantize(nn.Module): - """ Dequantize CbCr channel + """Dequantize CbCr channel Inputs: image(tensor): batch x height x width factor(float): compression factor Outputs: image(tensor): batch x height x width """ + def __init__(self, factor=1): super(c_dequantize, self).__init__() self.factor = factor @@ -43,24 +47,26 @@ class c_dequantize(nn.Module): class idct_8x8(nn.Module): - """ Inverse discrete Cosine Transformation + """Inverse discrete Cosine Transformation Input: dcp(tensor): batch x height x width Output: image(tensor): batch x height x width """ + def __init__(self): super(idct_8x8, self).__init__() - alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) for x, y, u, v in itertools.product(range(8), repeat=4): tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( - (2 * v + 1) * y * np.pi / 16) + (2 * v + 1) * y * np.pi / 16 + ) self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) def forward(self, image): - + image = image * self.alpha result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 result.view(image.shape) @@ -68,7 +74,7 @@ class idct_8x8(nn.Module): class block_merging(nn.Module): - """ Merge pathces into image + """Merge pathces into image Inputs: patches(tensor) batch x height*width/64, height x width height(int) @@ -76,30 +82,32 @@ class block_merging(nn.Module): Output: image(tensor): batch x height x width """ + def __init__(self): super(block_merging, self).__init__() - + def forward(self, patches, height, width): k = 8 batch_size = patches.shape[0] - # print(patches.shape) # (1,1024,8,8) - image_reshaped = patches.view(batch_size, height//k, width//k, k, k) + # print(patches.shape) # (1,1024,8,8) + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) return image_transposed.contiguous().view(batch_size, height, width) class chroma_upsampling(nn.Module): - """ Upsample chroma layers - Input: + """Upsample chroma layers + Input: y(tensor): y channel image cb(tensor): cb channel cr(tensor): cr channel Ouput: image(tensor): batch x height x width x 3 """ + def __init__(self): super(chroma_upsampling, self).__init__() - + def forward(self, y, cb, cr): def repeat(x, k=2): height, width = x.shape[1:3] @@ -110,35 +118,37 @@ class chroma_upsampling(nn.Module): cb = repeat(cb) cr = repeat(cr) - + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) class ycbcr_to_rgb_jpeg(nn.Module): - """ Converts YCbCr image to RGB JPEG + """Converts YCbCr image to RGB JPEG Input: image(tensor): batch x height x width x 3 Outpput: result(tensor): batch x 3 x height x width """ + def __init__(self): super(ycbcr_to_rgb_jpeg, self).__init__() matrix = np.array( - [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], - dtype=np.float32).T - self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + [[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], + dtype=np.float32, + ).T + self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0])) self.matrix = nn.Parameter(torch.from_numpy(matrix)) def forward(self, image): result = torch.tensordot(image + self.shift, self.matrix, dims=1) - #result = torch.from_numpy(result) + # result = torch.from_numpy(result) result.view(image.shape) return result.permute(0, 3, 1, 2) class decompress_jpeg(nn.Module): - """ Full JPEG decompression algortihm + """Full JPEG decompression algortihm Input: compressed(dict(tensor)): batch x h*w/64 x 8 x 8 rounding(function): rounding function to use @@ -146,6 +156,7 @@ class decompress_jpeg(nn.Module): Ouput: image(tensor): batch x 3 x height x width """ + # def __init__(self, height, width, rounding=torch.round, factor=1): def __init__(self, rounding=torch.round, factor=1): super(decompress_jpeg, self).__init__() @@ -156,35 +167,35 @@ class decompress_jpeg(nn.Module): # comment this line if no subsampling self.chroma = chroma_upsampling() self.colors = ycbcr_to_rgb_jpeg() - + # self.height, self.width = height, width - + def forward(self, y, cb, cr, height, width): - components = {'y': y, 'cb': cb, 'cr': cr} + components = {"y": y, "cb": cb, "cr": cr} # height = y.shape[0] # width = y.shape[1] self.height = height self.width = width for k in components.keys(): - if k in ('cb', 'cr'): + if k in ("cb", "cr"): comp = self.c_dequantize(components[k]) # comment this line if no subsampling - height, width = int(self.height/2), int(self.width/2) + height, width = int(self.height / 2), int(self.width / 2) # height, width = int(self.height), int(self.width) - + else: - comp = self.y_dequantize(components[k]) - # comment this line if no subsampling - height, width = self.height, self.width - comp = self.idct(comp) - components[k] = self.merging(comp, height, width) - # - # comment this line if no subsampling - image = self.chroma(components['y'], components['cb'], components['cr']) - # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) + comp = self.y_dequantize(components[k]) + # comment this line if no subsampling + height, width = self.height, self.width + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + # comment this line if no subsampling + image = self.chroma(components["y"], components["cb"], components["cr"]) + # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) image = self.colors(image) - image = torch.min(255*torch.ones_like(image), - torch.max(torch.zeros_like(image), image)) - return image/255 - + image = torch.min( + 255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image) + ) + return image / 255 diff --git a/third_party/DarkFeat/datasets/gl3d/io.py b/third_party/DarkFeat/datasets/gl3d/io.py index 9e5b4b0459d6814ef6af17a0a322b59202037d4f..9b48a2be61ba799d567b7df45c9b9b011cbef4be 100644 --- a/third_party/DarkFeat/datasets/gl3d/io.py +++ b/third_party/DarkFeat/datasets/gl3d/io.py @@ -5,42 +5,42 @@ import numpy as np from ..utils.common import Notify + def read_list(list_path): """Read list.""" if list_path is None or not os.path.exists(list_path): - print(Notify.FAIL, 'Not exist', list_path, Notify.ENDC) + print(Notify.FAIL, "Not exist", list_path, Notify.ENDC) exit(-1) content = open(list_path).read().splitlines() return content def load_pfm(pfm_path): - with open(pfm_path, 'rb') as fin: + with open(pfm_path, "rb") as fin: color = None width = None height = None scale = None data_type = None - header = str(fin.readline().decode('UTF-8')).rstrip() + header = str(fin.readline().decode("UTF-8")).rstrip() - if header == 'PF': + if header == "PF": color = True - elif header == 'Pf': + elif header == "Pf": color = False else: - raise Exception('Not a PFM file.') + raise Exception("Not a PFM file.") - dim_match = re.match(r'^(\d+)\s(\d+)\s$', - fin.readline().decode('UTF-8')) + dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8")) if dim_match: width, height = map(int, dim_match.groups()) else: - raise Exception('Malformed PFM header.') - scale = float((fin.readline().decode('UTF-8')).rstrip()) + raise Exception("Malformed PFM header.") + scale = float((fin.readline().decode("UTF-8")).rstrip()) if scale < 0: # little-endian - data_type = ' 0: - img = cv2.resize( - img, (config['resize'], config['resize'])) + if config["resize"] > 0: + img = cv2.resize(img, (config["resize"], config["resize"])) return img def _parse_depth(depth_paths, idx, config): depth = load_pfm(depth_paths[idx]) - if config['resize'] > 0: - target_size = config['resize'] - if config['input_type'] == 'raw': - depth = cv2.resize(depth, (int(target_size/2), int(target_size/2))) + if config["resize"] > 0: + target_size = config["resize"] + if config["input_type"] == "raw": + depth = cv2.resize(depth, (int(target_size / 2), int(target_size / 2))) else: depth = cv2.resize(depth, (target_size, target_size)) return depth def _parse_kpts(kpts_paths, idx, config): - kpts = np.load(kpts_paths[idx])['pts'] + kpts = np.load(kpts_paths[idx])["pts"] # output: [N, 2] (W first H last) return kpts diff --git a/third_party/DarkFeat/datasets/gl3d_dataset.py b/third_party/DarkFeat/datasets/gl3d_dataset.py index db3d2db646ae7fce81424f5f72cdff7e6e34ba60..0dd9ea77f44bcc065a895c05a66cdc843632ddee 100644 --- a/third_party/DarkFeat/datasets/gl3d_dataset.py +++ b/third_party/DarkFeat/datasets/gl3d_dataset.py @@ -15,17 +15,18 @@ class GL3DDataset(Dataset): self.config = config self.is_training = is_training self.data_split = data_split - - self.match_set_list, self.global_img_list, \ - self.global_depth_list = self.prepare_match_sets() - pass + ( + self.match_set_list, + self.global_img_list, + self.global_depth_list, + ) = self.prepare_match_sets() + pass def __len__(self): return len(self.match_set_list) - def __getitem__(self, idx): match_set_path = self.match_set_list[idx] decoded = np.fromfile(match_set_path, dtype=np.float32) @@ -50,26 +51,24 @@ class GL3DDataset(Dataset): img1 = photaug(img1) return { - 'img0': img0 / 255., - 'img1': img1 / 255., - 'depth0': depth0, - 'depth1': depth1, - 'ori_img_size0': ori_img_size0, - 'ori_img_size1': ori_img_size1, - 'K0': K0, - 'K1': K1, - 'rel_pose': rel_pose, - 'inlier_num': inlier_num + "img0": img0 / 255.0, + "img1": img1 / 255.0, + "depth0": depth0, + "depth1": depth1, + "ori_img_size0": ori_img_size0, + "ori_img_size1": ori_img_size1, + "K0": K0, + "K1": K1, + "rel_pose": rel_pose, + "inlier_num": inlier_num, } - def points_to_2D(self, pnts, H, W): labels = np.zeros((H, W)) pnts = pnts.astype(int) labels[pnts[:, 1], pnts[:, 0]] = 1 return labels - def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60): """Get match sets. Args: @@ -81,20 +80,29 @@ class GL3DDataset(Dataset): global_context_feat_list: """ # get necessary lists. - gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split) - global_info = read_list(os.path.join( - gl3d_list_folder, 'image_index_offset.txt')) - global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list( - os.path.join(gl3d_list_folder, 'image_list.txt'))] - global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list( - os.path.join(gl3d_list_folder, 'depth_list.txt'))] - - imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt' - match_set_list = self.get_match_set_list(os.path.join( - gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld) + gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split) + global_info = read_list( + os.path.join(gl3d_list_folder, "image_index_offset.txt") + ) + global_img_list = [ + os.path.join(self.dataset_dir, i) + for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt")) + ] + global_depth_list = [ + os.path.join(self.dataset_dir, i) + for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt")) + ] + + imageset_list_name = ( + "imageset_train.txt" if self.is_training else "imageset_test.txt" + ) + match_set_list = self.get_match_set_list( + os.path.join(gl3d_list_folder, imageset_list_name), + q_diff_thld, + rot_diff_thld, + ) return match_set_list, global_img_list, global_depth_list - def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld): """Get the path list of match sets. Args: @@ -103,25 +111,25 @@ class GL3DDataset(Dataset): Returns: match_set_list: List of match set path. """ - imageset_list = [os.path.join(self.dataset_dir, 'data', i) - for i in read_list(imageset_list_path)] - print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC) + imageset_list = [ + os.path.join(self.dataset_dir, "data", i) + for i in read_list(imageset_list_path) + ] + print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC) match_set_list = [] # discard image pairs whose image simiarity is beyond the threshold. for i in imageset_list: - match_set_folder = os.path.join(i, 'match_sets') + match_set_folder = os.path.join(i, "match_sets") if os.path.exists(match_set_folder): match_set_files = os.listdir(match_set_folder) for val in match_set_files: name, ext = os.path.splitext(val) - if ext == '.match_set': - splits = name.split('_') + if ext == ".match_set": + splits = name.split("_") q_diff = int(splits[2]) rot_diff = int(splits[3]) if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld: - match_set_list.append( - os.path.join(match_set_folder, val)) + match_set_list.append(os.path.join(match_set_folder, val)) - print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC) + print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC) return match_set_list - diff --git a/third_party/DarkFeat/datasets/noise.py b/third_party/DarkFeat/datasets/noise.py index aa68c98183186e9e9185e78e1a3e7335ac8d5bb1..a44c6a902c653f6c829a2536a49e5a3c9790e5de 100644 --- a/third_party/DarkFeat/datasets/noise.py +++ b/third_party/DarkFeat/datasets/noise.py @@ -3,31 +3,49 @@ import random from scipy.stats import tukeylambda camera_params = { - 'Kmin': 0.2181895124454343, - 'Kmax': 3.0, - 'G_shape': np.array([0.15714286, 0.14285714, 0.08571429, 0.08571429, 0.2 , - 0.2 , 0.1 , 0.08571429, 0.05714286, 0.07142857, - 0.02857143, 0.02857143, 0.01428571, 0.02857143, 0.08571429, - 0.07142857, 0.11428571, 0.11428571]), - 'Profile-1': { - 'R_scale': { - 'slope': 0.4712797750747537, - 'bias': -0.8078958947116487, - 'sigma': 0.2436176299944695 + "Kmin": 0.2181895124454343, + "Kmax": 3.0, + "G_shape": np.array( + [ + 0.15714286, + 0.14285714, + 0.08571429, + 0.08571429, + 0.2, + 0.2, + 0.1, + 0.08571429, + 0.05714286, + 0.07142857, + 0.02857143, + 0.02857143, + 0.01428571, + 0.02857143, + 0.08571429, + 0.07142857, + 0.11428571, + 0.11428571, + ] + ), + "Profile-1": { + "R_scale": { + "slope": 0.4712797750747537, + "bias": -0.8078958947116487, + "sigma": 0.2436176299944695, }, - 'g_scale': { - 'slope': 0.6771267783987617, - 'bias': 1.5121876510805845, - 'sigma': 0.24641096601611254 + "g_scale": { + "slope": 0.6771267783987617, + "bias": 1.5121876510805845, + "sigma": 0.24641096601611254, + }, + "G_scale": { + "slope": 0.6558756156508007, + "bias": 1.09268679594838, + "sigma": 0.28604721742277756, }, - 'G_scale': { - 'slope': 0.6558756156508007, - 'bias': 1.09268679594838, - 'sigma': 0.28604721742277756 - } }, - 'black_level': 2048, - 'max_value': 16383 + "black_level": 2048, + "max_value": 16383, } @@ -46,15 +64,18 @@ def addGStarNoise(img, K, G_shape, G_scale_param): rand_num = random.uniform(0, 1) idx = np.sum(np.cumsum(a) < rand_num) - lam = random.uniform(b[idx], b[idx+1]) + lam = random.uniform(b[idx], b[idx + 1]) # calculate scale parameter [G_scale] log_K = np.log(K) - log_G_scale = np.random.standard_normal() * G_scale_param['sigma'] * 1 +\ - G_scale_param['slope'] * log_K + G_scale_param['bias'] + log_G_scale = ( + np.random.standard_normal() * G_scale_param["sigma"] * 1 + + G_scale_param["slope"] * log_K + + G_scale_param["bias"] + ) G_scale = np.exp(log_G_scale) # print(f'G_scale: {G_scale}') - + return img + tukeylambda.rvs(lam, scale=G_scale, size=img.shape).astype(np.float32) @@ -63,11 +84,14 @@ def addGStarNoise(img, K, G_shape, G_scale_param): def addRowNoise(img, K, R_scale_param): # calculate scale parameter [R_scale] log_K = np.log(K) - log_R_scale = np.random.standard_normal() * R_scale_param['sigma'] * 1 +\ - R_scale_param['slope'] * log_K + R_scale_param['bias'] + log_R_scale = ( + np.random.standard_normal() * R_scale_param["sigma"] * 1 + + R_scale_param["slope"] * log_K + + R_scale_param["bias"] + ) R_scale = np.exp(log_R_scale) # print(f'R_scale: {R_scale}') - + row_noise = np.random.randn(img.shape[0], 1).astype(np.float32) * R_scale return img + np.tile(row_noise, (1, img.shape[1])) @@ -75,7 +99,7 @@ def addRowNoise(img, K, R_scale_param): # quantization noise # uniform distribution def addQuantNoise(img, q): - return img + np.random.uniform(low=-0.5*q, high=0.5*q, size=img.shape) + return img + np.random.uniform(low=-0.5 * q, high=0.5 * q, size=img.shape) def sampleK(Kmin, Kmax): diff --git a/third_party/DarkFeat/datasets/noise_simulator.py b/third_party/DarkFeat/datasets/noise_simulator.py index 17e21d3b3443aaa3585ae8460709f60b05835a84..8d7ff4ad00583b1a0879160d725a5de4dade4892 100644 --- a/third_party/DarkFeat/datasets/noise_simulator.py +++ b/third_party/DarkFeat/datasets/noise_simulator.py @@ -14,17 +14,28 @@ import colour_demosaicing from .InvISP.model.model import InvISPNet from .utils.common import Notify -from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK +from datasets.noise import ( + camera_params, + addGStarNoise, + addPStarNoise, + addQuantNoise, + addRowNoise, + sampleK, +) class NoiseSimulator: - def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'): + def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"): self.device = device # load Invertible ISP Network - self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() + self.net = ( + InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() + ) self.net.load_state_dict(torch.load(ckpt_path), strict=False) - print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC) + print( + Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC + ) # white balance parameters self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) @@ -75,11 +86,11 @@ class NoiseSimulator: # input: [H, W] # output: [H, W, 3] def demosaic(self, img): - return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB') + return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB") # load rgb image def path2rgb(self, path): - return torch.from_numpy(np.array(PILImage.open(path))/255.0) + return torch.from_numpy(np.array(PILImage.open(path)) / 255.0) # InvISP # input: rgb image [H, W, 3] @@ -89,21 +100,21 @@ class NoiseSimulator: if not batched: rgb = rgb.unsqueeze(0) - rgb = rgb.permute(0,3,1,2).float().to(self.device) + rgb = rgb.permute(0, 3, 1, 2).float().to(self.device) with torch.no_grad(): reconstruct_raw = self.net(rgb, rev=True) - pred_raw = reconstruct_raw.detach().permute(0,2,3,1) + pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) pred_raw = torch.clamp(pred_raw, 0, 1) if not batched: pred_raw = pred_raw[0, ...] - + pred_raw = pred_raw.cpu().numpy() # 2. -> inv gamma - norm_value = np.power(16383, 1/2.2) - pred_raw *= norm_value + norm_value = np.power(16383, 1 / 2.2) + pred_raw *= norm_value pred_raw = np.power(pred_raw, 2.2) # 3. -> inv white balance @@ -111,7 +122,7 @@ class NoiseSimulator: pred_raw = pred_raw / wb[:-1] # 4. -> add black level - pred_raw += self.camera_params['black_level'] + pred_raw += self.camera_params["black_level"] # 5. -> inv demosaic if not batched: @@ -124,18 +135,24 @@ class NoiseSimulator: return pred_raw - def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): if not batched: ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 raw = raw.copy() / ratio - K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) - q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) + K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) + q = 1 / ( + self.camera_params["max_value"] - self.camera_params["black_level"] + ) raw = addPStarNoise(raw, K) - raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) - raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale']) + raw = addGStarNoise( + raw, + K, + self.camera_params["G_shape"], + self.camera_params["Profile-1"]["G_scale"], + ) + raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"]) raw = addQuantNoise(raw, q) raw *= ratio return raw @@ -146,12 +163,21 @@ class NoiseSimulator: ratio = random.uniform(self.ratio_min, self.ratio_max) raw[i] /= ratio - K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) - q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) + K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) + q = 1 / ( + self.camera_params["max_value"] - self.camera_params["black_level"] + ) raw[i] = addPStarNoise(raw[i], K) - raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) - raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale']) + raw[i] = addGStarNoise( + raw[i], + K, + self.camera_params["G_shape"], + self.camera_params["Profile-1"]["G_scale"], + ) + raw[i] = addRowNoise( + raw[i], K, self.camera_params["Profile-1"]["R_scale"] + ) raw[i] = addQuantNoise(raw[i], q) raw[i] *= ratio return raw @@ -167,29 +193,38 @@ class NoiseSimulator: raw = np.stack(raws, axis=0) # 2. -> substract black level - raw -= self.camera_params['black_level'] - raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) + raw -= self.camera_params["black_level"] + raw = np.clip( + raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] + ) # 3. -> white balance wb = self.wb / self.wb.max() raw = raw * wb[:-1] # 4. -> gamma - norm_value = np.power(16383, 1/2.2) - raw = np.power(raw, 1/2.2) + norm_value = np.power(16383, 1 / 2.2) + raw = np.power(raw, 1 / 2.2) raw /= norm_value # 5. -> ispnet if not batched: - input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...] + input_raw_img = ( + torch.Tensor(raw) + .permute(2, 0, 1) + .float() + .to(self.device)[np.newaxis, ...] + ) else: - input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device) + input_raw_img = ( + torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device) + ) with torch.no_grad(): reconstruct_rgb = self.net(input_raw_img) reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) - pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) + pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1) if not batched: pred_rgb = pred_rgb[0, ...] @@ -197,12 +232,13 @@ class NoiseSimulator: return pred_rgb - def raw2packedRaw(self, raw, batched=False): # 1. -> substract black level - raw -= self.camera_params['black_level'] - raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) - raw /= self.camera_params['max_value'] + raw -= self.camera_params["black_level"] + raw = np.clip( + raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] + ) + raw /= self.camera_params["max_value"] # 2. pack if not batched: @@ -211,20 +247,30 @@ class NoiseSimulator: H = img_shape[0] W = img_shape[1] - out = np.concatenate((im[0:H:2, 0:W:2, :], - im[0:H:2, 1:W:2, :], - im[1:H:2, 1:W:2, :], - im[1:H:2, 0:W:2, :]), axis=2) + out = np.concatenate( + ( + im[0:H:2, 0:W:2, :], + im[0:H:2, 1:W:2, :], + im[1:H:2, 1:W:2, :], + im[1:H:2, 0:W:2, :], + ), + axis=2, + ) else: im = np.expand_dims(raw, axis=3) img_shape = im.shape H = img_shape[1] W = img_shape[2] - out = np.concatenate((im[:, 0:H:2, 0:W:2, :], - im[:, 0:H:2, 1:W:2, :], - im[:, 1:H:2, 1:W:2, :], - im[:, 1:H:2, 0:W:2, :]), axis=3) + out = np.concatenate( + ( + im[:, 0:H:2, 0:W:2, :], + im[:, 0:H:2, 1:W:2, :], + im[:, 1:H:2, 1:W:2, :], + im[:, 1:H:2, 0:W:2, :], + ), + axis=3, + ) return out def raw2demosaicRaw(self, raw, batched=False): @@ -238,7 +284,9 @@ class NoiseSimulator: raw = np.stack(raws, axis=0) # 2. -> substract black level - raw -= self.camera_params['black_level'] - raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) - raw /= self.camera_params['max_value'] + raw -= self.camera_params["black_level"] + raw = np.clip( + raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] + ) + raw /= self.camera_params["max_value"] return raw diff --git a/third_party/DarkFeat/datasets/utils/common.py b/third_party/DarkFeat/datasets/utils/common.py index 6433408a39e53fcedb634901268754ed1ba971b3..aa2007b0b31df0325c51f4112a259ab1e1d7f1aa 100644 --- a/third_party/DarkFeat/datasets/utils/common.py +++ b/third_party/DarkFeat/datasets/utils/common.py @@ -28,31 +28,30 @@ class Notify(object): @ClassProperty def HEADER(cls): - return str(datetime.now()) + ': \033[95m' + return str(datetime.now()) + ": \033[95m" @ClassProperty def INFO(cls): - return str(datetime.now()) + ': \033[92mI' + return str(datetime.now()) + ": \033[92mI" @ClassProperty def OKBLUE(cls): - return str(datetime.now()) + ': \033[94m' + return str(datetime.now()) + ": \033[94m" @ClassProperty def WARNING(cls): - return str(datetime.now()) + ': \033[93mW' + return str(datetime.now()) + ": \033[93mW" @ClassProperty def FAIL(cls): - return str(datetime.now()) + ': \033[91mF' + return str(datetime.now()) + ": \033[91mF" @ClassProperty def BOLD(cls): - return str(datetime.now()) + ': \033[1mB' + return str(datetime.now()) + ": \033[1mB" @ClassProperty def UNDERLINE(cls): - return str(datetime.now()) + ': \033[4mU' - ENDC = '\033[0m' - + return str(datetime.now()) + ": \033[4mU" + ENDC = "\033[0m" diff --git a/third_party/DarkFeat/datasets/utils/photaug.py b/third_party/DarkFeat/datasets/utils/photaug.py index 41f2278c720355470f00a881a1516cf1b71d2c4a..29b9130871f8cb96d714228fe22d8c6f4b6526e3 100644 --- a/third_party/DarkFeat/datasets/utils/photaug.py +++ b/third_party/DarkFeat/datasets/utils/photaug.py @@ -7,41 +7,45 @@ def random_brightness_np(image, max_abs_change=50): delta = random.uniform(-max_abs_change, max_abs_change) return np.clip(image + delta, 0, 255) + def random_contrast_np(image, strength_range=[0.3, 1.5]): delta = random.uniform(*strength_range) mean = image.mean() return np.clip((image - mean) * delta + mean, 0, 255) + def motion_blur_np(img, max_kernel_size=3): # Either vertial, hozirontal or diagonal blur - mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) - ksize = np.random.randint( - 0, (max_kernel_size+1)/2)*2 + 1 # make sure is odd - center = int((ksize-1)/2) + mode = np.random.choice(["h", "v", "diag_down", "diag_up"]) + ksize = np.random.randint(0, (max_kernel_size + 1) / 2) * 2 + 1 # make sure is odd + center = int((ksize - 1) / 2) kernel = np.zeros((ksize, ksize)) - if mode == 'h': - kernel[center, :] = 1. - elif mode == 'v': - kernel[:, center] = 1. - elif mode == 'diag_down': + if mode == "h": + kernel[center, :] = 1.0 + elif mode == "v": + kernel[:, center] = 1.0 + elif mode == "diag_down": kernel = np.eye(ksize) - elif mode == 'diag_up': + elif mode == "diag_up": kernel = np.flip(np.eye(ksize), 0) - var = ksize * ksize / 16. + var = ksize * ksize / 16.0 grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) - gaussian = np.exp(-(np.square(grid-center) + - np.square(grid.T-center))/(2.*var)) + gaussian = np.exp( + -(np.square(grid - center) + np.square(grid.T - center)) / (2.0 * var) + ) kernel *= gaussian kernel /= np.sum(kernel) img = cv2.filter2D(img, -1, kernel) return np.clip(img, 0, 255) + def additive_gaussian_noise(image, stddev_range=[5, 95]): stddev = random.uniform(*stddev_range) noise = np.random.normal(size=image.shape, scale=stddev) noisy_image = np.clip(image + noise, 0, 255) return noisy_image + def photaug(img): img = random_brightness_np(img) img = random_contrast_np(img) diff --git a/third_party/DarkFeat/demo_darkfeat.py b/third_party/DarkFeat/demo_darkfeat.py index ca50ae5b892e7a90e75da7197c33bc0c06e699bf..be9a25c92f7e77da57ca111311dd96d426ba0c36 100644 --- a/third_party/DarkFeat/demo_darkfeat.py +++ b/third_party/DarkFeat/demo_darkfeat.py @@ -5,82 +5,106 @@ import matplotlib.cm as cm import torch import numpy as np from utils.nnmatching import NNMatching -from utils.misc import (AverageTimer, VideoStreamer, make_matching_plot_fast, frame2tensor) +from utils.misc import ( + AverageTimer, + VideoStreamer, + make_matching_plot_fast, + frame2tensor, +) torch.set_grad_enabled(False) def compute_essential(matched_kp1, matched_kp2, K): - pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) - pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts1 = cv2.undistortPoints( + matched_kp1, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) + pts2 = cv2.undistortPoints( + matched_kp2, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) K_1 = np.eye(3) # Estimate the homography between the matches using RANSAC - ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000) - if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_model, ransac_inliers = cv2.findEssentialMat( + pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000 + ) + if ransac_inliers is None or ransac_model.shape != (3, 3): ransac_inliers = np.array([]) ransac_model = None return ransac_model, ransac_inliers, pts1, pts2 sizer = (960, 640) -focallength_x = 4.504986436499113e+03/(6744/sizer[0]) -focallength_y = 4.513311442889859e+03/(4502/sizer[1]) +focallength_x = 4.504986436499113e03 / (6744 / sizer[0]) +focallength_y = 4.513311442889859e03 / (4502 / sizer[1]) K = np.eye(3) -K[0,0] = focallength_x -K[1,1] = focallength_y -K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5 -K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5 +K[0, 0] = focallength_x +K[1, 1] = focallength_y +K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0]) # * 0.5 +K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1]) # * 0.5 -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='DarkFeat demo', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="DarkFeat demo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--input", type=str, help="path to an image directory") parser.add_argument( - '--input', type=str, - help='path to an image directory') - parser.add_argument( - '--output_dir', type=str, default=None, - help='Directory where to write output frames (If None, no output)') + "--output_dir", + type=str, + default=None, + help="Directory where to write output frames (If None, no output)", + ) parser.add_argument( - '--image_glob', type=str, nargs='+', default=['*.ARW'], - help='Glob if a directory of images is specified') + "--image_glob", + type=str, + nargs="+", + default=["*.ARW"], + help="Glob if a directory of images is specified", + ) parser.add_argument( - '--resize', type=int, nargs='+', default=[640, 480], - help='Resize the input image before running inference. If two numbers, ' - 'resize to the exact dimensions, if one number, resize the max ' - 'dimension, if -1, do not resize') + "--resize", + type=int, + nargs="+", + default=[640, 480], + help="Resize the input image before running inference. If two numbers, " + "resize to the exact dimensions, if one number, resize the max " + "dimension, if -1, do not resize", + ) parser.add_argument( - '--force_cpu', action='store_true', - help='Force pytorch to run in CPU mode.') - parser.add_argument('--model_path', type=str, - help='Path to the pretrained model') + "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode." + ) + parser.add_argument("--model_path", type=str, help="Path to the pretrained model") opt = parser.parse_args() print(opt) assert len(opt.resize) == 2 - print('Will resize to {}x{} (WxH)'.format(opt.resize[0], opt.resize[1])) + print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1])) - device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' - print('Running inference on device \"{}\"'.format(device)) + device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu" + print('Running inference on device "{}"'.format(device)) matching = NNMatching(opt.model_path).eval().to(device) - keys = ['keypoints', 'scores', 'descriptors'] + keys = ["keypoints", "scores", "descriptors"] vs = VideoStreamer(opt.input, opt.resize, opt.image_glob) frame, ret = vs.next_frame() - assert ret, 'Error when reading the first frame (try different --input?)' + assert ret, "Error when reading the first frame (try different --input?)" frame_tensor = frame2tensor(frame, device) - last_data = matching.darkfeat({'image': frame_tensor}) - last_data = {k+'0': [last_data[k]] for k in keys} - last_data['image0'] = frame_tensor + last_data = matching.darkfeat({"image": frame_tensor}) + last_data = {k + "0": [last_data[k]] for k in keys} + last_data["image0"] = frame_tensor last_frame = frame last_image_id = 0 if opt.output_dir is not None: - print('==> Will write outputs to {}'.format(opt.output_dir)) + print("==> Will write outputs to {}".format(opt.output_dir)) Path(opt.output_dir).mkdir(exist_ok=True) timer = AverageTimer() @@ -88,37 +112,43 @@ if __name__ == '__main__': while True: frame, ret = vs.next_frame() if not ret: - print('Finished demo_darkfeat.py') + print("Finished demo_darkfeat.py") break - timer.update('data') + timer.update("data") stem0, stem1 = last_image_id, vs.i - 1 frame_tensor = frame2tensor(frame, device) - pred = matching({**last_data, 'image1': frame_tensor}) - kpts0 = last_data['keypoints0'][0].cpu().numpy() - kpts1 = pred['keypoints1'][0].cpu().numpy() - matches = pred['matches0'][0].cpu().numpy() - confidence = pred['matching_scores0'][0].cpu().numpy() - timer.update('forward') + pred = matching({**last_data, "image1": frame_tensor}) + kpts0 = last_data["keypoints0"][0].cpu().numpy() + kpts1 = pred["keypoints1"][0].cpu().numpy() + matches = pred["matches0"][0].cpu().numpy() + confidence = pred["matching_scores0"][0].cpu().numpy() + timer.update("forward") valid = matches > -1 mkpts0 = kpts0[valid] mkpts1 = kpts1[matches[valid]] E, inliers, pts1, pts2 = compute_essential(mkpts0, mkpts1, K) - color = cm.jet(np.clip(confidence[valid][inliers[:, 0].astype('bool')] * 2 - 1, -1, 1)) + color = cm.jet( + np.clip(confidence[valid][inliers[:, 0].astype("bool")] * 2 - 1, -1, 1) + ) - text = [ - 'DarkFeat', - 'Matches: {}'.format(inliers.sum()) - ] + text = ["DarkFeat", "Matches: {}".format(inliers.sum())] out = make_matching_plot_fast( - last_frame, frame, mkpts0[inliers[:, 0].astype('bool')], mkpts1[inliers[:, 0].astype('bool')], color, text, - path=None, small_text=' ') + last_frame, + frame, + mkpts0[inliers[:, 0].astype("bool")], + mkpts1[inliers[:, 0].astype("bool")], + color, + text, + path=None, + small_text=" ", + ) if opt.output_dir is not None: - stem = 'matches_{:06}_{:06}'.format(stem0, stem1) - out_file = str(Path(opt.output_dir, stem + '.png')) - print('Writing image to {}'.format(out_file)) + stem = "matches_{:06}_{:06}".format(stem0, stem1) + out_file = str(Path(opt.output_dir, stem + ".png")) + print("Writing image to {}".format(out_file)) cv2.imwrite(out_file, out) diff --git a/third_party/DarkFeat/export_features.py b/third_party/DarkFeat/export_features.py index c7caea5e57890948728f84cbb7e68e59d455e171..da54e3dc0a1fed98e832b9cc5d6961e713087b8b 100644 --- a/third_party/DarkFeat/export_features.py +++ b/third_party/DarkFeat/export_features.py @@ -11,6 +11,7 @@ import cv2 from darkfeat import DarkFeat from utils import matching + def darkfeat_pre(img, cuda): H, W = img.shape[0], img.shape[1] inp = img.copy() @@ -21,24 +22,25 @@ def darkfeat_pre(img, cuda): inp = inp.cuda() return inp -if __name__ == '__main__': + +if __name__ == "__main__": # Parse command line arguments. parser = argparse.ArgumentParser() - parser.add_argument('--H', type=int, default=int(640)) - parser.add_argument('--W', type=int, default=int(960)) - parser.add_argument('--histeq', action='store_true') - parser.add_argument('--model_path', type=str) - parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + parser.add_argument("--H", type=int, default=int(640)) + parser.add_argument("--W", type=int, default=int(960)) + parser.add_argument("--histeq", action="store_true") + parser.add_argument("--model_path", type=str) + parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/") opt = parser.parse_args() sizer = (opt.W, opt.H) - focallength_x = 4.504986436499113e+03/(6744/sizer[0]) - focallength_y = 4.513311442889859e+03/(4502/sizer[1]) + focallength_x = 4.504986436499113e03 / (6744 / sizer[0]) + focallength_y = 4.513311442889859e03 / (4502 / sizer[1]) K = np.eye(3) - K[0,0] = focallength_x - K[1,1] = focallength_y - K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5 - K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5 + K[0, 0] = focallength_x + K[1, 1] = focallength_y + K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0]) # * 0.5 + K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1]) # * 0.5 Kinv = np.linalg.inv(K) Kinvt = np.transpose(Kinv) @@ -46,83 +48,111 @@ if __name__ == '__main__': if cuda: darkfeat = DarkFeat(opt.model_path).cuda().eval() - for scene in ['Indoor', 'Outdoor']: - base_save = './result/' + scene + '/' - dir_base = opt.dataset_dir + '/' + scene + '/' + for scene in ["Indoor", "Outdoor"]: + base_save = "./result/" + scene + "/" + dir_base = opt.dataset_dir + "/" + scene + "/" pair_list = sorted(os.listdir(dir_base)) for pair in tqdm.tqdm(pair_list): opention = 1 - if scene == 'Outdoor': + if scene == "Outdoor": pass else: if int(pair[4::]) <= 17: opention = 0 else: pass - name=[] - files = sorted(os.listdir(dir_base+pair)) + name = [] + files = sorted(os.listdir(dir_base + pair)) for file_ in files: - if file_.endswith('.cr2'): + if file_.endswith(".cr2"): name.append(file_[0:9]) - ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800'] + ISO = [ + "00100", + "00200", + "00400", + "00800", + "01600", + "03200", + "06400", + "12800", + ] if opention == 1: - Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5'] + Shutter_speed = ["0.005", "0.01", "0.025", "0.05", "0.17", "0.5"] else: - Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1'] + Shutter_speed = ["0.01", "0.02", "0.05", "0.1", "0.3", "1"] - E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy') - F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv) - R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy') - t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy') + E_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "E_estimated.npy") + F_GT = np.dot(np.dot(Kinvt, E_GT), Kinv) + R_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "R_GT.npy") + t_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "T_GT.npy") - id0, id1 = sorted([ int(i.split('/')[-1]) for i in glob.glob(f'{dir_base+pair}/?????') ]) + id0, id1 = sorted( + [int(i.split("/")[-1]) for i in glob.glob(f"{dir_base+pair}/?????")] + ) cnt = 0 for iso in ISO: for ex in Shutter_speed: - dark_name1 = name[0] + iso+'_'+ex+'_'+scene+'.npy' - dark_name2 = name[1] + iso+'_'+ex+'_'+scene+'.npy' + dark_name1 = name[0] + iso + "_" + ex + "_" + scene + ".npy" + dark_name2 = name[1] + iso + "_" + ex + "_" + scene + ".npy" if not opt.histeq: - dst_T1_None = f'{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}' - dst_T2_None = f'{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}' + dst_T1_None = ( + f"{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}" + ) + dst_T2_None = ( + f"{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}" + ) img1_orig_None = np.load(dst_T1_None) img2_orig_None = np.load(dst_T2_None) - dir_save = base_save + pair + '/None/' + dir_save = base_save + pair + "/None/" - img_input1 = darkfeat_pre(img1_orig_None.astype('float32')/255.0, cuda) - img_input2 = darkfeat_pre(img2_orig_None.astype('float32')/255.0, cuda) + img_input1 = darkfeat_pre( + img1_orig_None.astype("float32") / 255.0, cuda + ) + img_input2 = darkfeat_pre( + img2_orig_None.astype("float32") / 255.0, cuda + ) else: - dst_T1_histeq = f'{dir_base}{pair}/{id0:05d}-npy/{dark_name1}' - dst_T2_histeq = f'{dir_base}{pair}/{id1:05d}-npy/{dark_name2}' + dst_T1_histeq = f"{dir_base}{pair}/{id0:05d}-npy/{dark_name1}" + dst_T2_histeq = f"{dir_base}{pair}/{id1:05d}-npy/{dark_name2}" img1_orig_histeq = np.load(dst_T1_histeq) img2_orig_histeq = np.load(dst_T2_histeq) - dir_save = base_save + pair + '/HistEQ/' + dir_save = base_save + pair + "/HistEQ/" - img_input1 = darkfeat_pre(img1_orig_histeq.astype('float32')/255.0, cuda) - img_input2 = darkfeat_pre(img2_orig_histeq.astype('float32')/255.0, cuda) + img_input1 = darkfeat_pre( + img1_orig_histeq.astype("float32") / 255.0, cuda + ) + img_input2 = darkfeat_pre( + img2_orig_histeq.astype("float32") / 255.0, cuda + ) - result1 = darkfeat({'image': img_input1}) - result2 = darkfeat({'image': img_input2}) + result1 = darkfeat({"image": img_input1}) + result2 = darkfeat({"image": img_input2}) mkpts0, mkpts1, _ = matching.match_descriptors( - cv2.KeyPoint_convert(result1['keypoints'].detach().cpu().float().numpy()), result1['descriptors'].detach().cpu().numpy(), - cv2.KeyPoint_convert(result2['keypoints'].detach().cpu().float().numpy()), result2['descriptors'].detach().cpu().numpy(), - ORB=False + cv2.KeyPoint_convert( + result1["keypoints"].detach().cpu().float().numpy() + ), + result1["descriptors"].detach().cpu().numpy(), + cv2.KeyPoint_convert( + result2["keypoints"].detach().cpu().float().numpy() + ), + result2["descriptors"].detach().cpu().numpy(), + ORB=False, ) - POINT_1_dir = dir_save+f'DarkFeat/POINT_1/' - POINT_2_dir = dir_save+f'DarkFeat/POINT_2/' - - subprocess.check_output(['mkdir', '-p', POINT_1_dir]) - subprocess.check_output(['mkdir', '-p', POINT_2_dir]) - np.save(POINT_1_dir+dark_name1[0:-3]+'npy',mkpts0) - np.save(POINT_2_dir+dark_name2[0:-3]+'npy',mkpts1) + POINT_1_dir = dir_save + f"DarkFeat/POINT_1/" + POINT_2_dir = dir_save + f"DarkFeat/POINT_2/" + subprocess.check_output(["mkdir", "-p", POINT_1_dir]) + subprocess.check_output(["mkdir", "-p", POINT_2_dir]) + np.save(POINT_1_dir + dark_name1[0:-3] + "npy", mkpts0) + np.save(POINT_2_dir + dark_name2[0:-3] + "npy", mkpts1) diff --git a/third_party/DarkFeat/nets/geom.py b/third_party/DarkFeat/nets/geom.py index 043ca6e8f5917c56defd6aa17c1ff236a431f8c0..d711ffdbf57aa023caa048adb3e7c8519aef7a3f 100644 --- a/third_party/DarkFeat/nets/geom.py +++ b/third_party/DarkFeat/nets/geom.py @@ -14,23 +14,25 @@ def rnd_sample(inputs, n_sample): def _grid_positions(h, w, bs): x_rng = torch.arange(0, w.int()) y_rng = torch.arange(0, h.int()) - xv, yv = torch.meshgrid(x_rng, y_rng, indexing='xy') - return torch.reshape( - torch.stack((yv, xv), axis=-1), - (1, -1, 2) - ).repeat(bs, 1, 1).float() + xv, yv = torch.meshgrid(x_rng, y_rng, indexing="xy") + return ( + torch.reshape(torch.stack((yv, xv), axis=-1), (1, -1, 2)) + .repeat(bs, 1, 1) + .float() + ) def getK(ori_img_size, cur_feat_size, K): # WARNING: cur_feat_size's order is [h, w] r = ori_img_size / cur_feat_size[[1, 0]] - r_K0 = torch.stack([K[:, 0] / r[:, 0][..., None], K[:, 1] / - r[:, 1][..., None], K[:, 2]], axis=1) + r_K0 = torch.stack( + [K[:, 0] / r[:, 0][..., None], K[:, 1] / r[:, 1][..., None], K[:, 2]], axis=1 + ) return r_K0 def gather_nd(params, indices): - """ The same as tf.gather_nd but batched gather is not supported yet. + """The same as tf.gather_nd but batched gather is not supported yet. indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params: output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]] @@ -40,7 +42,7 @@ def gather_nd(params, indices): indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n. Returns: gathered Tensor. - shape [y_0,y_2,...y_{k-2}] + params.shape[m:] + shape [y_0,y_2,...y_{k-2}] + params.shape[m:] """ orig_shape = list(indices.shape) @@ -52,13 +54,14 @@ def gather_nd(params, indices): out_shape = orig_shape[:-1] + list(params.shape)[m:] else: raise ValueError( - f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}' + f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}" ) indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() - output = params[indices] # (num_samples, ...) + output = params[indices] # (num_samples, ...) return output.reshape(out_shape).contiguous() + # input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W] # output: [kpt_n, 128] / [kpt_n] def interpolate(pos, inputs, nd=True): @@ -94,17 +97,21 @@ def interpolate(pos, inputs, nd=True): w_bottom_right = w_bottom_right[..., None] interpolated_val = ( - w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + - w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + - w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + - w_bottom_right * - gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right + * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left + * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right + * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) ) return interpolated_val -def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=None, nd=False): +def validate_and_interpolate( + pos, inputs, validate_corner=True, validate_val=None, nd=False +): if nd: h, w, c = inputs.shape else: @@ -135,7 +142,7 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non valid_corner = torch.logical_and( torch.logical_and(valid_top_left, valid_top_right), - torch.logical_and(valid_bottom_left, valid_bottom_right) + torch.logical_and(valid_bottom_left, valid_bottom_right), ) i_top_left = i_top_left[valid_corner] @@ -157,12 +164,16 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non valid_depth = torch.logical_and( torch.logical_and( gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0, - gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0 + gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0, ), torch.logical_and( - gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) > 0, - gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) > 0 - ) + gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + > 0, + gather_nd( + inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1) + ) + > 0, + ), ) i_top_left = i_top_left[valid_depth] @@ -196,10 +207,13 @@ def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=Non w_bottom_right = w_bottom_right[..., None] interpolated_val = ( - w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + - w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + - w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + - w_bottom_right * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right + * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left + * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right + * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) ) pos = torch.stack([i, j], axis=1) @@ -218,10 +232,21 @@ def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs): for i in range(bs): z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) - uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) + uv0_homo = torch.cat( + [ + swap_axis(new_pos0), + torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device), + ], + axis=-1, + ) xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) - xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, - torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) + xyz0_homo = torch.cat( + [ + torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, new_pos0.shape[0])).to(z0.device), + ], + axis=0, + ) xyz1 = torch.matmul(rel_pose[i], xyz0_homo) xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) @@ -229,7 +254,8 @@ def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs): new_pos1 = swap_axis(uv1) annotated_depth, new_pos1, new_ids = validate_and_interpolate( - new_pos1, depth1[i], validate_val=0) + new_pos1, depth1[i], validate_val=0 + ) ids = ids[new_ids] new_pos0 = new_pos0[new_ids] @@ -256,10 +282,21 @@ def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs): for i in range(bs): z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) - uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) + uv0_homo = torch.cat( + [ + swap_axis(new_pos0), + torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device), + ], + axis=-1, + ) xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) - xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, - torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) + xyz0_homo = torch.cat( + [ + torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, new_pos0.shape[0])).to(z0.device), + ], + axis=0, + ) xyz1 = torch.matmul(rel_pose[i], xyz0_homo) xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) @@ -267,7 +304,8 @@ def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs): new_pos1 = swap_axis(uv1) _, new_pos1, new_ids = validate_and_interpolate( - new_pos1, depth1[i], validate_val=0) + new_pos1, depth1[i], validate_val=0 + ) ids = ids[new_ids] new_pos0 = new_pos0[new_ids] @@ -287,10 +325,17 @@ def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1): z0 = interpolate(pos0, depth0, nd=False) - uv0_homo = torch.cat([swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1) + uv0_homo = torch.cat( + [swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1 + ) xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t()) - xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, - torch.ones((1, pos0.shape[0])).to(z0.device)], axis=0) + xyz0_homo = torch.cat( + [ + torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, pos0.shape[0])).to(z0.device), + ], + axis=0, + ) xyz1 = torch.matmul(rel_pose, xyz0_homo) xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) @@ -301,22 +346,18 @@ def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1): return new_pos1 - def get_dist_mat(feat1, feat2, dist_type): eps = 1e-6 cos_dist_mat = torch.matmul(feat1, feat2.t()) - if dist_type == 'cosine_dist': + if dist_type == "cosine_dist": dist_mat = torch.clamp(cos_dist_mat, -1, 1) - elif dist_type == 'euclidean_dist': + elif dist_type == "euclidean_dist": dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps)) - elif dist_type == 'euclidean_dist_no_norm': + elif dist_type == "euclidean_dist_no_norm": norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True) norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True) dist_mat = torch.sqrt( - torch.clamp( - norm1 - 2 * cos_dist_mat + norm2.t(), - min=0. - ) + eps + torch.clamp(norm1 - 2 * cos_dist_mat + norm2.t(), min=0.0) + eps ) else: raise NotImplementedError() diff --git a/third_party/DarkFeat/nets/l2net.py b/third_party/DarkFeat/nets/l2net.py index e1ddfe8919bd4d5fe75215d253525123e1402952..b51dc0e9e983c7795924f75b2a814bea85fd08a0 100644 --- a/third_party/DarkFeat/nets/l2net.py +++ b/third_party/DarkFeat/nets/l2net.py @@ -7,9 +7,10 @@ from .score import peakiness_score class BaseNet(nn.Module): - """ Helper class to construct a fully-convolutional network that - extract a l2-normalized patch descriptor. + """Helper class to construct a fully-convolutional network that + extract a l2-normalized patch descriptor. """ + def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): super(BaseNet, self).__init__() self.inchan = inchan @@ -22,27 +23,42 @@ class BaseNet(nn.Module): def _make_bn(self, outd): return nn.BatchNorm2d(outd, affine=self.bn_affine) - def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False): + def _add_conv( + self, + outd, + k=3, + stride=1, + dilation=1, + bn=True, + relu=True, + k_pool=1, + pool_type="max", + bias=False, + ): # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer d = self.dilation * dilation - # if self.dilated: + # if self.dilated: # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1) # self.dilation *= stride # else: # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride) - conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias) + conv_params = dict( + padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias + ) ops = nn.ModuleList([]) - ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) - if bn and self.bn: ops.append( self._make_bn(outd) ) - if relu: ops.append( nn.ReLU(inplace=True) ) + ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) + if bn and self.bn: + ops.append(self._make_bn(outd)) + if relu: + ops.append(nn.ReLU(inplace=True)) self.curchan = outd - + if k_pool > 1: - if pool_type == 'avg': + if pool_type == "avg": ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) - elif pool_type == 'max': + elif pool_type == "max": ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) else: print(f"Error, unknown pooling type {pool_type}...") @@ -51,29 +67,31 @@ class BaseNet(nn.Module): class Quad_L2Net(BaseNet): - """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs. - """ + """Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.""" + def __init__(self, dim=128, mchan=4, relu22=False, **kw): BaseNet.__init__(self, **kw) - self.conv0 = self._add_conv( 8*mchan) - self.conv1 = self._add_conv( 8*mchan, bn=False) - self.bn1 = self._make_bn(8*mchan) - self.conv2 = self._add_conv( 16*mchan, stride=2) - self.conv3 = self._add_conv( 16*mchan, bn=False) - self.bn3 = self._make_bn(16*mchan) - self.conv4 = self._add_conv( 32*mchan, stride=2) - self.conv5 = self._add_conv( 32*mchan) + self.conv0 = self._add_conv(8 * mchan) + self.conv1 = self._add_conv(8 * mchan, bn=False) + self.bn1 = self._make_bn(8 * mchan) + self.conv2 = self._add_conv(16 * mchan, stride=2) + self.conv3 = self._add_conv(16 * mchan, bn=False) + self.bn3 = self._make_bn(16 * mchan) + self.conv4 = self._add_conv(32 * mchan, stride=2) + self.conv5 = self._add_conv(32 * mchan) # replace last 8x8 convolution with 3 3x3 convolutions - self.conv6_0 = self._add_conv( 32*mchan) - self.conv6_1 = self._add_conv( 32*mchan) + self.conv6_0 = self._add_conv(32 * mchan) + self.conv6_1 = self._add_conv(32 * mchan) self.conv6_2 = self._add_conv(dim, bn=False, relu=False) self.out_dim = dim - self.moving_avg_params = nn.ParameterList([ - Parameter(torch.tensor(1.), requires_grad=False), - Parameter(torch.tensor(1.), requires_grad=False), - Parameter(torch.tensor(1.), requires_grad=False) - ]) + self.moving_avg_params = nn.ParameterList( + [ + Parameter(torch.tensor(1.0), requires_grad=False), + Parameter(torch.tensor(1.0), requires_grad=False), + Parameter(torch.tensor(1.0), requires_grad=False), + ] + ) def forward(self, x): # x: [N, C, H, W] @@ -90,7 +108,7 @@ class Quad_L2Net(BaseNet): x6_2 = self.conv6_2(x6_1) # calculate score map - comb_weights = torch.tensor([1., 2., 3.], device=x.device) + comb_weights = torch.tensor([1.0, 2.0, 3.0], device=x.device) comb_weights /= torch.sum(comb_weights) ksize = [3, 2, 1] det_score_maps = [] @@ -98,15 +116,21 @@ class Quad_L2Net(BaseNet): for idx, xx in enumerate([x1, x3, x6_2]): if self.training: instance_max = torch.max(xx) - self.moving_avg_params[idx].data = self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01 + self.moving_avg_params[idx].data = ( + self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01 + ) else: pass - alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]) + alpha, beta = peakiness_score( + xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx] + ) score_vol = alpha * beta det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] - det_score_map = F.interpolate(det_score_map, size=x.shape[2:], mode='bilinear', align_corners=True) + det_score_map = F.interpolate( + det_score_map, size=x.shape[2:], mode="bilinear", align_corners=True + ) det_score_map = comb_weights[idx] * det_score_map det_score_maps.append(det_score_map) diff --git a/third_party/DarkFeat/nets/loss.py b/third_party/DarkFeat/nets/loss.py index 0dd42b4214d021137ddfe72771ccad0264d2321f..1440ef46f43108db0053cf48369e4014c348f98c 100644 --- a/third_party/DarkFeat/nets/loss.py +++ b/third_party/DarkFeat/nets/loss.py @@ -4,10 +4,20 @@ import torch.nn.functional as F from .geom import rnd_sample, interpolate, get_dist_mat -def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1, - score_map0, score_map1, batch_size, num_corr, loss_type, config): - joint_loss = 0. - accuracy = 0. +def make_detector_loss( + pos0, + pos1, + dense_feat_map0, + dense_feat_map1, + score_map0, + score_map1, + batch_size, + num_corr, + loss_type, + config, +): + joint_loss = 0.0 + accuracy = 0.0 all_valid_pos0 = [] all_valid_pos1 = [] all_valid_match = [] @@ -22,36 +32,54 @@ def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1, valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1) - valid_score0 = interpolate(valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False) - valid_score1 = interpolate(valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False) - - if config['network']['det']['corr_weight']: + valid_score0 = interpolate( + valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False + ) + valid_score1 = interpolate( + valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False + ) + + if config["network"]["det"]["corr_weight"]: corr_weight = valid_score0 * valid_score1 else: corr_weight = None - safe_radius = config['network']['det']['safe_radius'] + safe_radius = config["network"]["det"]["safe_radius"] if safe_radius > 0: radius_mask_row = get_dist_mat( - valid_pos1, valid_pos1, "euclidean_dist_no_norm") + valid_pos1, valid_pos1, "euclidean_dist_no_norm" + ) radius_mask_row = torch.le(radius_mask_row, safe_radius) radius_mask_col = get_dist_mat( - valid_pos0, valid_pos0, "euclidean_dist_no_norm") + valid_pos0, valid_pos0, "euclidean_dist_no_norm" + ) radius_mask_col = torch.le(radius_mask_col, safe_radius) - radius_mask_row = radius_mask_row.float() - torch.eye(valid_num, device=radius_mask_row.device) - radius_mask_col = radius_mask_col.float() - torch.eye(valid_num, device=radius_mask_col.device) + radius_mask_row = radius_mask_row.float() - torch.eye( + valid_num, device=radius_mask_row.device + ) + radius_mask_col = radius_mask_col.float() - torch.eye( + valid_num, device=radius_mask_col.device + ) else: radius_mask_row = None radius_mask_col = None if valid_num < 32: - si_loss, si_accuracy, matched_mask = 0., 1., torch.zeros((1, valid_num)).bool() + si_loss, si_accuracy, matched_mask = ( + 0.0, + 1.0, + torch.zeros((1, valid_num)).bool(), + ) else: si_loss, si_accuracy, matched_mask = make_structured_loss( - torch.unsqueeze(valid_feat0, 0), torch.unsqueeze(valid_feat1, 0), + torch.unsqueeze(valid_feat0, 0), + torch.unsqueeze(valid_feat1, 0), loss_type=loss_type, - radius_mask_row=radius_mask_row, radius_mask_col=radius_mask_col, - corr_weight=torch.unsqueeze(corr_weight, 0) if corr_weight is not None else None + radius_mask_row=radius_mask_row, + radius_mask_col=radius_mask_col, + corr_weight=torch.unsqueeze(corr_weight, 0) + if corr_weight is not None + else None, ) joint_loss += si_loss / batch_size @@ -63,10 +91,16 @@ def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1, return joint_loss, accuracy -def make_structured_loss(feat_anc, feat_pos, - loss_type='RATIO', inlier_mask=None, - radius_mask_row=None, radius_mask_col=None, - corr_weight=None, dist_mat=None): +def make_structured_loss( + feat_anc, + feat_pos, + loss_type="RATIO", + inlier_mask=None, + radius_mask_row=None, + radius_mask_col=None, + corr_weight=None, + dist_mat=None, +): """ Structured loss construction. Args: @@ -82,23 +116,26 @@ def make_structured_loss(feat_anc, feat_pos, inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool() inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1) - if loss_type == 'L2NET' or loss_type == 'CIRCLE': - dist_type = 'cosine_dist' - elif loss_type.find('HARD') >= 0: - dist_type = 'euclidean_dist' + if loss_type == "L2NET" or loss_type == "CIRCLE": + dist_type = "cosine_dist" + elif loss_type.find("HARD") >= 0: + dist_type = "euclidean_dist" else: raise NotImplementedError() if dist_mat is None: - dist_mat = get_dist_mat(feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type).unsqueeze(0) + dist_mat = get_dist_mat( + feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type + ).unsqueeze(0) pos_vec = dist_mat[0].diag().unsqueeze(0) - if loss_type.find('HARD') >= 0: + if loss_type.find("HARD") >= 0: neg_margin = 1 - dist_mat_without_min_on_diag = dist_mat + \ - 10 * torch.unsqueeze(torch.eye(num_corr, device=dist_mat.device), dim=0) + dist_mat_without_min_on_diag = dist_mat + 10 * torch.unsqueeze( + torch.eye(num_corr, device=dist_mat.device), dim=0 + ) mask = torch.le(dist_mat_without_min_on_diag, 0.008).float() - dist_mat_without_min_on_diag += mask*10 + dist_mat_without_min_on_diag += mask * 10 if radius_mask_row is not None: hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row @@ -112,18 +149,18 @@ def make_structured_loss(feat_anc, feat_pos, hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0] hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0] - if loss_type == 'HARD_TRIPLET': + if loss_type == "HARD_TRIPLET": loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0) loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0) - elif loss_type == 'HARD_CONTRASTIVE': + elif loss_type == "HARD_CONTRASTIVE": pos_margin = 0.2 pos_loss = torch.clamp(pos_vec - pos_margin, min=0) loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0) loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0) else: raise NotImplementedError() - - elif loss_type == 'CIRCLE': + + elif loss_type == "CIRCLE": log_scale = 512 m = 0.1 neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) @@ -141,14 +178,26 @@ def make_structured_loss(feat_anc, feat_pos, neg_mat_row = dist_mat - 128 * neg_mask_row neg_mat_col = dist_mat - 128 * neg_mask_col - lse_positive = torch.logsumexp(-log_scale * (pos_vec[..., None] - pos_margin) * \ - torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), dim=-1) - - lse_negative_row = torch.logsumexp(log_scale * (neg_mat_row - neg_margin) * \ - torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), dim=-1) - - lse_negative_col = torch.logsumexp(log_scale * (neg_mat_col - neg_margin) * \ - torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), dim=-2) + lse_positive = torch.logsumexp( + -log_scale + * (pos_vec[..., None] - pos_margin) + * torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), + dim=-1, + ) + + lse_negative_row = torch.logsumexp( + log_scale + * (neg_mat_row - neg_margin) + * torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), + dim=-1, + ) + + lse_negative_col = torch.logsumexp( + log_scale + * (neg_mat_col - neg_margin) + * torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), + dim=-2, + ) loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale @@ -156,10 +205,10 @@ def make_structured_loss(feat_anc, feat_pos, else: raise NotImplementedError() - if dist_type == 'cosine_dist': + if dist_type == "cosine_dist": err_row = dist_mat - torch.unsqueeze(pos_vec, -1) err_col = dist_mat - torch.unsqueeze(pos_vec, -2) - elif dist_type == 'euclidean_dist' or dist_type == 'euclidean_dist_no_norm': + elif dist_type == "euclidean_dist" or dist_type == "euclidean_dist_no_norm": err_row = torch.unsqueeze(pos_vec, -1) - dist_mat err_col = torch.unsqueeze(pos_vec, -2) - dist_mat else: @@ -180,17 +229,18 @@ def make_structured_loss(feat_anc, feat_pos, for i in range(batch_size): if corr_weight is not None: - loss += torch.sum(tot_loss[i][inlier_mask[i]]) / \ - (torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6) + loss += torch.sum(tot_loss[i][inlier_mask[i]]) / ( + torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6 + ) else: loss += torch.mean(tot_loss[i][inlier_mask[i]]) cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float() cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float() tot_err = cnt_err_row + cnt_err_col if inlier_num[i] != 0: - accuracy += 1. - tot_err / inlier_num[i] / batch_size / 2. + accuracy += 1.0 - tot_err / inlier_num[i] / batch_size / 2.0 else: - accuracy += 1. + accuracy += 1.0 matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0)) matched_mask = torch.logical_and(matched_mask, inlier_mask) @@ -205,11 +255,13 @@ def make_structured_loss(feat_anc, feat_pos, # for the rest, the noise image's score should less than normal image # input: score_map [batch_size, H, W, 1]; indices [2, k, 2] # output: loss [scalar] -def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, thld=0.): +def make_noise_score_map_loss( + score_map, noise_score_map, indices, batch_size, thld=0.0 +): H, W = score_map.shape[1:3] loss = 0 for i in range(batch_size): - kpts_coords = indices[i].T # (2, num_kpts) + kpts_coords = indices[i].T # (2, num_kpts) mask = torch.zeros([H, W], device=score_map.device) mask[kpts_coords.cpu().numpy()] = 1 @@ -217,8 +269,13 @@ def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, t kernel = torch.ones([1, 1, 3, 3], device=score_map.device) mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0 - loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) - loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) + loss1 = torch.sum( + torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask + ) / torch.sum(mask) + loss2 = torch.sum( + torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() + * torch.logical_not(mask) + ) / (H * W - torch.sum(mask)) loss += loss1 loss += loss2 @@ -229,16 +286,28 @@ def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, t return loss, first_mask -def make_noise_score_map_loss_labelmap(score_map, noise_score_map, labelmap, batch_size, thld=0.): +def make_noise_score_map_loss_labelmap( + score_map, noise_score_map, labelmap, batch_size, thld=0.0 +): H, W = score_map.shape[1:3] loss = 0 for i in range(batch_size): # using 3x3 kernel to put kpts' neightborhood area into the mask kernel = torch.ones([1, 1, 3, 3], device=score_map.device) - mask = F.conv2d(labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1)[0, 0] > 0 - - loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) - loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) + mask = ( + F.conv2d( + labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1 + )[0, 0] + > 0 + ) + + loss1 = torch.sum( + torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask + ) / torch.sum(mask) + loss2 = torch.sum( + torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() + * torch.logical_not(mask) + ) / (H * W - torch.sum(mask)) loss += loss1 loss += loss2 diff --git a/third_party/DarkFeat/nets/multi_sampler.py b/third_party/DarkFeat/nets/multi_sampler.py index dc400fb2afeb50575cd81d3c01b605bea6db1121..862a6e9e785f826853021c27d5c0fc2cfa2c2f51 100644 --- a/third_party/DarkFeat/nets/multi_sampler.py +++ b/third_party/DarkFeat/nets/multi_sampler.py @@ -5,17 +5,28 @@ import numpy as np from .geom import rnd_sample, interpolate -class MultiSampler (nn.Module): - """ Similar to NghSampler, but doesnt warp the 2nd image. + +class MultiSampler(nn.Module): + """Similar to NghSampler, but doesnt warp the 2nd image. Distance to GT => 0 ... pos_d ... neg_d ... ngh Pixel label => + + + + + + 0 0 - - - - - - - - + Subsample on query side: if > 0, regular grid - < 0, random points + < 0, random points In both cases, the number of query points is = W*H/subq**2 """ - def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, - maxpool_pos=True, subd_neg=0): + + def __init__( + self, + ngh, + subq=1, + subd=1, + pos_d=0, + neg_d=2, + border=None, + maxpool_pos=True, + subd_neg=0, + ): nn.Module.__init__(self) assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) self.ngh = ngh @@ -26,8 +37,9 @@ class MultiSampler (nn.Module): self.sub_q = subq self.sub_d = subd self.sub_d_neg = subd_neg - if border is None: border = ngh - assert border >= ngh, 'border has to be larger than ngh' + if border is None: + border = ngh + assert border >= ngh, "border has to be larger than ngh" self.border = border self.maxpool_pos = maxpool_pos self.precompute_offsets() @@ -36,22 +48,37 @@ class MultiSampler (nn.Module): pos_d2 = self.pos_d**2 neg_d2 = self.neg_d**2 rad2 = self.ngh**2 - rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple pos = [] neg = [] - for j in range(-rad, rad+1, self.sub_d): - for i in range(-rad, rad+1, self.sub_d): - d2 = i*i + j*j - if d2 <= pos_d2: - pos.append( (i,j) ) - elif neg_d2 <= d2 <= rad2: - neg.append( (i,j) ) - - self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) - self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) - - - def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=2500): + for j in range(-rad, rad + 1, self.sub_d): + for i in range(-rad, rad + 1, self.sub_d): + d2 = i * i + j * j + if d2 <= pos_d2: + pos.append((i, j)) + elif neg_d2 <= d2 <= rad2: + neg.append((i, j)) + + self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t()) + self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t()) + + def forward( + self, + feat0, + feat1, + noise_feat0, + noise_feat1, + conf0, + conf1, + noise_conf0, + noise_conf1, + pos0, + pos1, + B, + H, + W, + N=2500, + ): pscores_ls, nscores_ls, distractors_ls = [], [], [] valid_feat0_ls = [] noise_pscores_ls, noise_nscores_ls, noise_distractors_ls = [], [], [] @@ -62,58 +89,103 @@ class MultiSampler (nn.Module): mask_ls = [] for i in range(B): - tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \ - * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border) + tmp_mask = ( + (pos0[i][:, 1] >= self.border) + * (pos0[i][:, 1] < W - self.border) + * (pos0[i][:, 0] >= self.border) + * (pos0[i][:, 0] < H - self.border) + ) selected_pos0 = pos0[i][tmp_mask] selected_pos1 = pos1[i][tmp_mask] valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N) # sample features from first image - valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] - valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] + valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] + valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] qconf = interpolate(valid_pos0 / 4, conf0[i]) - valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i]) # [N, 128] - valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1) # [N, 128] + valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i]) # [N, 128] + valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1) # [N, 128] noise_qconf = interpolate(valid_pos0 / 4, noise_conf0[i]) # sample GT from second image - mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \ - * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H) + mask = ( + (valid_pos1[:, 1] >= 0) + * (valid_pos1[:, 1] < W) + * (valid_pos1[:, 0] >= 0) + * (valid_pos1[:, 0] < H) + ) def clamp(xy): xy = xy - torch.clamp(xy[0], 0, H-1, out=xy[0]) - torch.clamp(xy[1], 0, W-1, out=xy[1]) + torch.clamp(xy[0], 0, H - 1, out=xy[0]) + torch.clamp(xy[1], 0, W - 1, out=xy[1]) return xy # compute positive scores - valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] - valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] - valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] - valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_noise_feat1p = F.normalize(valid_noise_feat1p, p=2, dim=-1) # [29, N, 128] - - pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29] + valid_pos1p = clamp( + valid_pos1.t()[:, None, :] + + self.pos_offsets[:, :, None].to(valid_pos1.device) + ) # [2, 29, N] + valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape( + -1, 2 + ) # [29, N, 2] -> [29*N, 2] + valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape( + self.pos_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] + valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape( + self.pos_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_noise_feat1p = F.normalize( + valid_noise_feat1p, p=2, dim=-1 + ) # [29, N, 128] + + pscores = ( + (valid_feat0[None, :, :] * valid_feat1p).sum(dim=-1).t() + ) # [N, 29] pscores, pos = pscores.max(dim=1, keepdim=True) - sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device)) - qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2 - noise_pscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1p).sum(dim=-1).t() # [N, 29] + sel = clamp( + valid_pos1.t() + self.pos_offsets[:, pos.view(-1)].to(valid_pos1.device) + ) + qconf = (qconf + interpolate(sel.t() / 4, conf1[i])) / 2 + noise_pscores = ( + (valid_noise_feat0[None, :, :] * valid_noise_feat1p).sum(dim=-1).t() + ) # [N, 29] noise_pscores, noise_pos = noise_pscores.max(dim=1, keepdim=True) - noise_sel = clamp(valid_pos1.t() + self.pos_offsets[:,noise_pos.view(-1)].to(valid_pos1.device)) - noise_qconf = (noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i]))/2 + noise_sel = clamp( + valid_pos1.t() + + self.pos_offsets[:, noise_pos.view(-1)].to(valid_pos1.device) + ) + noise_qconf = ( + noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i]) + ) / 2 # compute negative scores - valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] - valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] - valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] - nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29] - valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_noise_feat1n = F.normalize(valid_noise_feat1n, p=2, dim=-1) # [29, N, 128] - noise_nscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1n).sum(dim=-1).t() # [N, 29] + valid_pos1n = clamp( + valid_pos1.t()[:, None, :] + + self.neg_offsets[:, :, None].to(valid_pos1.device) + ) # [2, 29, N] + valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape( + -1, 2 + ) # [29, N, 2] -> [29*N, 2] + valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape( + self.neg_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] + nscores = ( + (valid_feat0[None, :, :] * valid_feat1n).sum(dim=-1).t() + ) # [N, 29] + valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape( + self.neg_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_noise_feat1n = F.normalize( + valid_noise_feat1n, p=2, dim=-1 + ) # [29, N, 128] + noise_nscores = ( + (valid_noise_feat0[None, :, :] * valid_noise_feat1n).sum(dim=-1).t() + ) # [N, 29] if self.sub_d_neg: valid_pos2 = rnd_sample([selected_pos1], N)[0] @@ -158,15 +230,17 @@ class MultiSampler (nn.Module): dscores = torch.matmul(valid_feat0, distractors.t()) noise_dscores = torch.matmul(valid_noise_feat0, noise_distractors.t()) - dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2 - b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1) - dis2 += (b != b[:,None]).long() * self.neg_d**2 + dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:, None]) ** 2 + ( + valid_pos2[:, 0] - valid_pos1[:, 0][:, None] + ) ** 2 + b = torch.arange(B, device=dscores.device)[:, None].expand(B, N).reshape(-1) + dis2 += (b != b[:, None]).long() * self.neg_d**2 dscores[dis2 < self.neg_d**2] = 0 noise_dscores[dis2 < self.neg_d**2] = 0 scores = torch.cat((pscores, nscores, dscores), dim=1) noise_scores = torch.cat((noise_pscores, noise_nscores, noise_dscores), dim=1) gt = scores.new_zeros(scores.shape, dtype=torch.uint8) - gt[:, :pscores.shape[1]] = 1 + gt[:, : pscores.shape[1]] = 1 return scores, noise_scores, gt, mask, qconf, noise_qconf diff --git a/third_party/DarkFeat/nets/noise_reliability_loss.py b/third_party/DarkFeat/nets/noise_reliability_loss.py index 9efddae149653c225ee7f2c1eb5fed5f92cef15c..cbd69bba727e38efc3ac356168b4041b30c48e05 100644 --- a/third_party/DarkFeat/nets/noise_reliability_loss.py +++ b/third_party/DarkFeat/nets/noise_reliability_loss.py @@ -3,14 +3,15 @@ import torch.nn as nn from .reliability_loss import APLoss -class MultiPixelAPLoss (nn.Module): - """ Computes the pixel-wise AP loss: - Given two images and ground-truth optical flow, computes the AP per pixel. - - feat1: (B, C, H, W) pixel-wise features extracted from img1 - feat2: (B, C, H, W) pixel-wise features extracted from img2 - aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 +class MultiPixelAPLoss(nn.Module): + """Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 """ + def __init__(self, sampler, nq=20): nn.Module.__init__(self) self.aploss = APLoss(nq, min=0, max=1, euc=False) @@ -20,21 +21,54 @@ class MultiPixelAPLoss (nn.Module): def loss_from_ap(self, ap, rel, noise_ap, noise_rel): dec_ap = torch.clamp(ap - noise_ap, min=0, max=1) - return (1 - ap*noise_rel - (1-noise_rel)*self.base), (1. - dec_ap*(1-noise_rel) - noise_rel*self.dec_base) + return (1 - ap * noise_rel - (1 - noise_rel) * self.base), ( + 1.0 - dec_ap * (1 - noise_rel) - noise_rel * self.dec_base + ) - def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500): + def forward( + self, + feat0, + feat1, + noise_feat0, + noise_feat1, + conf0, + conf1, + noise_conf0, + noise_conf1, + pos0, + pos1, + B, + H, + W, + N=1500, + ): # subsample things - scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler(feat0, feat1, noise_feat0, noise_feat1, \ - conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500) - + scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler( + feat0, + feat1, + noise_feat0, + noise_feat1, + conf0, + conf1, + noise_conf0, + noise_conf1, + pos0, + pos1, + B, + H, + W, + N=1500, + ) + # compute pixel-wise AP n = qconf.numel() - if n == 0: return 0, 0 - scores, noise_scores, gt = scores.view(n,-1), noise_scores, gt.view(n,-1) + if n == 0: + return 0, 0 + scores, noise_scores, gt = scores.view(n, -1), noise_scores, gt.view(n, -1) ap = self.aploss(scores, gt).view(msk.shape) noise_ap = self.aploss(noise_scores, gt).view(msk.shape) pixel_loss = self.loss_from_ap(ap, qconf, noise_ap, noise_qconf) - + loss = pixel_loss[0][msk].mean(), pixel_loss[1][msk].mean() - return loss \ No newline at end of file + return loss diff --git a/third_party/DarkFeat/nets/reliability_loss.py b/third_party/DarkFeat/nets/reliability_loss.py index 527f9886a2d4785680bac52ff2fa20033b8d8920..bdb3b73f472d915c9fd4c4542cdcab162298de5e 100644 --- a/third_party/DarkFeat/nets/reliability_loss.py +++ b/third_party/DarkFeat/nets/reliability_loss.py @@ -3,15 +3,16 @@ import torch.nn as nn import numpy as np -class APLoss (nn.Module): - """ differentiable AP loss, through quantization. - - Input: (N, M) values in [min, max] - label: (N, M) values in {0, 1} - - Returns: list of query AP (for each n in {1..N}) - Note: typically, you want to minimize 1 - mean(AP) +class APLoss(nn.Module): + """differentiable AP loss, through quantization. + + Input: (N, M) values in [min, max] + label: (N, M) values in {0, 1} + + Returns: list of query AP (for each n in {1..N}) + Note: typically, you want to minimize 1 - mean(AP) """ + def __init__(self, nq=25, min=0, max=1, euc=False): nn.Module.__init__(self) assert isinstance(nq, int) and 2 <= nq <= 100 @@ -21,16 +22,20 @@ class APLoss (nn.Module): self.euc = euc gap = max - min assert gap > 0 - + # init quantizer = non-learnable (fixed) convolution - self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True) - a = (nq-1) / gap - #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) + self.quantizer = q = nn.Conv1d(1, 2 * nq, kernel_size=1, bias=True) + a = (nq - 1) / gap + # 1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) q.weight.data[:nq] = -a - q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x) - #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.bias.data[:nq] = torch.from_numpy( + a * min + np.arange(nq, 0, -1) + ) # b = 1 + a*(min+x) + # 2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) q.weight.data[nq:] = a - q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x) + q.bias.data[nq:] = torch.from_numpy( + np.arange(2 - nq, 2, 1) - a * min + ) # b = 1 - a*(min+x) # first and last one are special: just horizontal straight line q.weight.data[0] = q.weight.data[-1] = 0 q.bias.data[0] = q.bias.data[-1] = 1 @@ -39,37 +44,42 @@ class APLoss (nn.Module): N, M = x.shape # print(x.shape, label.shape) if self.euc: # euclidean distance in same range than similarities - x = 1 - torch.sqrt(2.001 - 2*x) + x = 1 - torch.sqrt(2.001 - 2 * x) # quantize all predictions q = self.quantizer(x.unsqueeze(1)) - q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M [1600, 20, 1681] - - nbs = q.sum(dim=-1) # number of samples N x Q = c - rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q - prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision - rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] - - ap = (prec * rec).sum(dim=-1) # per-image AP + q = torch.min(q[:, : self.nq], q[:, self.nq :]).clamp( + min=0 + ) # N x Q x M [1600, 20, 1681] + + nbs = q.sum(dim=-1) # number of samples N x Q = c + rec = (q * label.view(N, 1, M).float()).sum( + dim=-1 + ) # nb of correct samples = c+ N x Q + prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision + rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] + + ap = (prec * rec).sum(dim=-1) # per-image AP return ap def forward(self, x, label): - assert x.shape == label.shape # N x M + assert x.shape == label.shape # N x M return self.compute_AP(x, label) -class PixelAPLoss (nn.Module): - """ Computes the pixel-wise AP loss: - Given two images and ground-truth optical flow, computes the AP per pixel. - - feat1: (B, C, H, W) pixel-wise features extracted from img1 - feat2: (B, C, H, W) pixel-wise features extracted from img2 - aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 +class PixelAPLoss(nn.Module): + """Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 """ + def __init__(self, sampler, nq=20): nn.Module.__init__(self) self.aploss = APLoss(nq, min=0, max=1, euc=False) - self.name = 'pixAP' + self.name = "pixAP" self.sampler = sampler def loss_from_ap(self, ap, rel): @@ -77,29 +87,32 @@ class PixelAPLoss (nn.Module): def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200): # subsample things - scores, gt, msk, qconf = self.sampler(feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200) - + scores, gt, msk, qconf = self.sampler( + feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200 + ) + # compute pixel-wise AP n = qconf.numel() - if n == 0: return 0 - scores, gt = scores.view(n,-1), gt.view(n,-1) + if n == 0: + return 0 + scores, gt = scores.view(n, -1), gt.view(n, -1) ap = self.aploss(scores, gt).view(msk.shape) pixel_loss = self.loss_from_ap(ap, qconf) - + loss = pixel_loss[msk].mean() return loss -class ReliabilityLoss (PixelAPLoss): - """ same than PixelAPLoss, but also train a pixel-wise confidence - that this pixel is going to have a good AP. +class ReliabilityLoss(PixelAPLoss): + """same than PixelAPLoss, but also train a pixel-wise confidence + that this pixel is going to have a good AP. """ + def __init__(self, sampler, base=0.5, **kw): PixelAPLoss.__init__(self, sampler, **kw) assert 0 <= base < 1 self.base = base def loss_from_ap(self, ap, rel): - return 1 - ap*rel - (1-rel)*self.base - + return 1 - ap * rel - (1 - rel) * self.base diff --git a/third_party/DarkFeat/nets/sampler.py b/third_party/DarkFeat/nets/sampler.py index b732a3671872d5675be9826f76b0818d3b99d466..7686b24d78eb92b90ee3cafb95ad48966ee0f00f 100644 --- a/third_party/DarkFeat/nets/sampler.py +++ b/third_party/DarkFeat/nets/sampler.py @@ -5,17 +5,28 @@ import numpy as np from .geom import rnd_sample, interpolate -class NghSampler2 (nn.Module): - """ Similar to NghSampler, but doesnt warp the 2nd image. + +class NghSampler2(nn.Module): + """Similar to NghSampler, but doesnt warp the 2nd image. Distance to GT => 0 ... pos_d ... neg_d ... ngh Pixel label => + + + + + + 0 0 - - - - - - - - + Subsample on query side: if > 0, regular grid - < 0, random points + < 0, random points In both cases, the number of query points is = W*H/subq**2 """ - def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, - maxpool_pos=True, subd_neg=0): + + def __init__( + self, + ngh, + subq=1, + subd=1, + pos_d=0, + neg_d=2, + border=None, + maxpool_pos=True, + subd_neg=0, + ): nn.Module.__init__(self) assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) self.ngh = ngh @@ -26,8 +37,9 @@ class NghSampler2 (nn.Module): self.sub_q = subq self.sub_d = subd self.sub_d_neg = subd_neg - if border is None: border = ngh - assert border >= ngh, 'border has to be larger than ngh' + if border is None: + border = ngh + assert border >= ngh, "border has to be larger than ngh" self.border = border self.maxpool_pos = maxpool_pos self.precompute_offsets() @@ -36,39 +48,39 @@ class NghSampler2 (nn.Module): pos_d2 = self.pos_d**2 neg_d2 = self.neg_d**2 rad2 = self.ngh**2 - rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple pos = [] neg = [] - for j in range(-rad, rad+1, self.sub_d): - for i in range(-rad, rad+1, self.sub_d): - d2 = i*i + j*j - if d2 <= pos_d2: - pos.append( (i,j) ) - elif neg_d2 <= d2 <= rad2: - neg.append( (i,j) ) + for j in range(-rad, rad + 1, self.sub_d): + for i in range(-rad, rad + 1, self.sub_d): + d2 = i * i + j * j + if d2 <= pos_d2: + pos.append((i, j)) + elif neg_d2 <= d2 <= rad2: + neg.append((i, j)) - self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) - self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) + self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t()) + self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t()) def gen_grid(self, step, B, H, W, dev): b1 = torch.arange(B, device=dev) if step > 0: # regular grid - x1 = torch.arange(self.border, W-self.border, step, device=dev) - y1 = torch.arange(self.border, H-self.border, step, device=dev) + x1 = torch.arange(self.border, W - self.border, step, device=dev) + y1 = torch.arange(self.border, H - self.border, step, device=dev) H1, W1 = len(y1), len(x1) - x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1) - y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1) - b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1) + x1 = x1[None, None, :].expand(B, H1, W1).reshape(-1) + y1 = y1[None, :, None].expand(B, H1, W1).reshape(-1) + b1 = b1[:, None, None].expand(B, H1, W1).reshape(-1) shape = (B, H1, W1) else: # randomly spread - n = (H - 2*self.border) * (W - 2*self.border) // step**2 - x1 = torch.randint(self.border, W-self.border, (n,), device=dev) - y1 = torch.randint(self.border, H-self.border, (n,), device=dev) - x1 = x1[None,:].expand(B,n).reshape(-1) - y1 = y1[None,:].expand(B,n).reshape(-1) - b1 = b1[:,None].expand(B,n).reshape(-1) + n = (H - 2 * self.border) * (W - 2 * self.border) // step**2 + x1 = torch.randint(self.border, W - self.border, (n,), device=dev) + y1 = torch.randint(self.border, H - self.border, (n,), device=dev) + x1 = x1[None, :].expand(B, n).reshape(-1) + y1 = y1[None, :].expand(B, n).reshape(-1) + b1 = b1[:, None].expand(B, n).reshape(-1) shape = (B, n) return b1, y1, x1, shape @@ -81,45 +93,73 @@ class NghSampler2 (nn.Module): for i in range(B): # positions in the first image - tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \ - * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border) + tmp_mask = ( + (pos0[i][:, 1] >= self.border) + * (pos0[i][:, 1] < W - self.border) + * (pos0[i][:, 0] >= self.border) + * (pos0[i][:, 0] < H - self.border) + ) selected_pos0 = pos0[i][tmp_mask] selected_pos1 = pos1[i][tmp_mask] valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N) # sample features from first image - valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] - valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] + valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] + valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] qconf = interpolate(valid_pos0 / 4, conf0[i]) # sample GT from second image - mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \ - * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H) + mask = ( + (valid_pos1[:, 1] >= 0) + * (valid_pos1[:, 1] < W) + * (valid_pos1[:, 0] >= 0) + * (valid_pos1[:, 0] < H) + ) def clamp(xy): xy = xy - torch.clamp(xy[0], 0, H-1, out=xy[0]) - torch.clamp(xy[1], 0, W-1, out=xy[1]) + torch.clamp(xy[0], 0, H - 1, out=xy[0]) + torch.clamp(xy[1], 0, W - 1, out=xy[1]) return xy # compute positive scores - valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] - valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] - valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] - - pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29] + valid_pos1p = clamp( + valid_pos1.t()[:, None, :] + + self.pos_offsets[:, :, None].to(valid_pos1.device) + ) # [2, 29, N] + valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape( + -1, 2 + ) # [29, N, 2] -> [29*N, 2] + valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape( + self.pos_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] + + pscores = ( + (valid_feat0[None, :, :] * valid_feat1p).sum(dim=-1).t() + ) # [N, 29] pscores, pos = pscores.max(dim=1, keepdim=True) - sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device)) - qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2 + sel = clamp( + valid_pos1.t() + self.pos_offsets[:, pos.view(-1)].to(valid_pos1.device) + ) + qconf = (qconf + interpolate(sel.t() / 4, conf1[i])) / 2 # compute negative scores - valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] - valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] - valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] - valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] - nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29] + valid_pos1n = clamp( + valid_pos1.t()[:, None, :] + + self.neg_offsets[:, :, None].to(valid_pos1.device) + ) # [2, 29, N] + valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape( + -1, 2 + ) # [29, N, 2] -> [29*N, 2] + valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape( + self.neg_offsets.shape[-1], -1, 128 + ) # [29, N, 128] + valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] + nscores = ( + (valid_feat0[None, :, :] * valid_feat1n).sum(dim=-1).t() + ) # [N, 29] if self.sub_d_neg: valid_pos2 = rnd_sample([selected_pos1], N)[0] @@ -148,13 +188,15 @@ class NghSampler2 (nn.Module): valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0) dscores = torch.matmul(valid_feat0, distractors.t()) - dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2 - b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1) - dis2 += (b != b[:,None]).long() * self.neg_d**2 + dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:, None]) ** 2 + ( + valid_pos2[:, 0] - valid_pos1[:, 0][:, None] + ) ** 2 + b = torch.arange(B, device=dscores.device)[:, None].expand(B, N).reshape(-1) + dis2 += (b != b[:, None]).long() * self.neg_d**2 dscores[dis2 < self.neg_d**2] = 0 scores = torch.cat((pscores, nscores, dscores), dim=1) - + gt = scores.new_zeros(scores.shape, dtype=torch.uint8) - gt[:, :pscores.shape[1]] = 1 + gt[:, : pscores.shape[1]] = 1 return scores, gt, mask, qconf diff --git a/third_party/DarkFeat/nets/score.py b/third_party/DarkFeat/nets/score.py index a78cf1c893bc338c12803697d55e121a75171f2c..60b255b6d2c9572323460500efd89fb414dee29e 100644 --- a/third_party/DarkFeat/nets/score.py +++ b/third_party/DarkFeat/nets/score.py @@ -8,23 +8,20 @@ from .geom import gather_nd # output: [batch_size, C, H, W], [batch_size, C, H, W] def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): inputs = inputs / moving_instance_max - + batch_size, C, H, W = inputs.shape pad_size = ksize // 2 + (dilation - 1) kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize) - - pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect') + + pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect") avg_spatial_inputs = F.conv2d( - pad_inputs, - kernel, - stride=1, - dilation=dilation, - padding=0, - groups=C + pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C ) - avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1 + avg_channel_inputs = torch.mean( + inputs, axis=1, keepdim=True + ) # channel dimension is 1 alpha = F.softplus(inputs - avg_spatial_inputs) beta = F.softplus(inputs - avg_channel_inputs) @@ -40,11 +37,17 @@ def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_si mask = score_map > score_thld if nms_size > 0: - nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2) + nms_mask = F.max_pool2d( + score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2 + ) nms_mask = torch.eq(score_map, nms_mask) mask = torch.logical_and(nms_mask, mask) if eof_size > 0: - eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device) + eof_mask = torch.ones( + (1, 1, h - 2 * eof_size, w - 2 * eof_size), + dtype=torch.float32, + device=score_map.device, + ) eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0) eof_mask = eof_mask.bool() mask = torch.logical_and(eof_mask, mask) @@ -86,24 +89,29 @@ def edge_mask(inputs, n_channel, dilation=1, edge_thld=5): b, c, h, w = inputs.size() device = inputs.device - dii_filter = torch.tensor( - [[0, 1., 0], [0, -2., 0], [0, 1., 0]] - ).view(1, 1, 3, 3) + dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3) dij_filter = 0.25 * torch.tensor( - [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] - ).view(1, 1, 3, 3) - djj_filter = torch.tensor( - [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] + [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]] ).view(1, 1, 3, 3) + djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3) dii = F.conv2d( - inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + dii_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) dij = F.conv2d( - inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + dij_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) djj = F.conv2d( - inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation + inputs.view(-1, 1, h, w), + djj_filter.to(device), + padding=dilation, + dilation=dilation, ).view(b, c, h, w) det = dii * djj - dij * dij diff --git a/third_party/DarkFeat/pose_estimation.py b/third_party/DarkFeat/pose_estimation.py index c87877191e7e31c3bc0a362d7d481dfd5d4b5757..d4ebe66700f895f0d1fac1b21d502b3a7de02325 100644 --- a/third_party/DarkFeat/pose_estimation.py +++ b/third_party/DarkFeat/pose_estimation.py @@ -8,18 +8,28 @@ from tqdm import tqdm def compute_essential(matched_kp1, matched_kp2, K): - pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) - pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts1 = cv2.undistortPoints( + matched_kp1, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) + pts2 = cv2.undistortPoints( + matched_kp2, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) K_1 = np.eye(3) # Estimate the homography between the matches using RANSAC - ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000) - if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_model, ransac_inliers = cv2.findEssentialMat( + pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000 + ) + if ransac_inliers is None or ransac_model.shape != (3, 3): ransac_inliers = np.array([]) ransac_model = None return ransac_model, ransac_inliers, pts1, pts2 -def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): +def compute_error(R_GT, t_GT, E, pts1_norm, pts2_norm, inliers): """Compute the angular error between two rotation matrices and two translation vectors. Keyword arguments: R -- 2D numpy array containing an estimated rotation @@ -30,14 +40,14 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): inliers = inliers.ravel() R = np.eye(3) - t = np.zeros((3,1)) + t = np.zeros((3, 1)) sst = True try: _, R, t, _ = cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), inliers) except: sst = False # calculate angle between provided rotations - # + # if sst: dR = np.matmul(R, np.transpose(R_GT)) dR = cv2.Rodrigues(dR)[0] @@ -48,10 +58,10 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): dT /= float(np.linalg.norm(t_GT)) if dT > 1 or dT < -1: - print("Domain warning! dT:",dT) - dT = max(-1,min(1,dT)) + print("Domain warning! dT:", dT) + dT = max(-1, min(1, dT)) dT = math.acos(dT) * 180 / math.pi - dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation + dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation else: dR, dT = 180.0, 180.0 return dR, dT @@ -59,8 +69,8 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT): try: - m_kp1 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_1/'+dark_name1) - m_kp2 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_2/'+dark_name2) + m_kp1 = np.load(result_base_dir + enhancer + "/DarkFeat/POINT_1/" + dark_name1) + m_kp2 = np.load(result_base_dir + enhancer + "/DarkFeat/POINT_2/" + dark_name2) except: return 180.0, 180.0 try: @@ -71,37 +81,37 @@ def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, return dR, dT -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--histeq', action='store_true') - parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + parser.add_argument("--histeq", action="store_true") + parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/") opt = parser.parse_args() - + sizer = (960, 640) - focallength_x = 4.504986436499113e+03/(6744/sizer[0]) - focallength_y = 4.513311442889859e+03/(4502/sizer[1]) + focallength_x = 4.504986436499113e03 / (6744 / sizer[0]) + focallength_y = 4.513311442889859e03 / (4502 / sizer[1]) K = np.eye(3) - K[0,0] = focallength_x - K[1,1] = focallength_y - K[0,2] = 3.363322177533149e+03/(6744/sizer[0]) - K[1,2] = 2.291824660547715e+03/(4502/sizer[1]) + K[0, 0] = focallength_x + K[1, 1] = focallength_y + K[0, 2] = 3.363322177533149e03 / (6744 / sizer[0]) + K[1, 2] = 2.291824660547715e03 / (4502 / sizer[1]) Kinv = np.linalg.inv(K) Kinvt = np.transpose(Kinv) PE_MT = np.zeros((6, 8)) - enhancer = 'None' if not opt.histeq else 'HistEQ' + enhancer = "None" if not opt.histeq else "HistEQ" - for scene in ['Indoor', 'Outdoor']: - dir_base = opt.dataset_dir + '/' + scene + '/' - base_save = 'result_errors/' + scene + '/' + for scene in ["Indoor", "Outdoor"]: + dir_base = opt.dataset_dir + "/" + scene + "/" + base_save = "result_errors/" + scene + "/" pair_list = sorted(os.listdir(dir_base)) os.makedirs(base_save, exist_ok=True) for pair in tqdm(pair_list): opention = 1 - if scene == 'Outdoor': + if scene == "Outdoor": pass else: if int(pair[4::]) <= 17: @@ -109,29 +119,43 @@ if __name__ == '__main__': else: pass name = [] - files = sorted(os.listdir(dir_base+pair)) + files = sorted(os.listdir(dir_base + pair)) for file_ in files: - if file_.endswith('.cr2'): + if file_.endswith(".cr2"): name.append(file_[0:9]) - ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800'] + ISO = [ + "00100", + "00200", + "00400", + "00800", + "01600", + "03200", + "06400", + "12800", + ] if opention == 1: - Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5'] + Shutter_speed = ["0.005", "0.01", "0.025", "0.05", "0.17", "0.5"] else: - Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1'] + Shutter_speed = ["0.01", "0.02", "0.05", "0.1", "0.3", "1"] - E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy') - F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv) - R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy') - t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy') - result_base_dir ='result/' +scene+'/'+pair+'/' + E_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "E_estimated.npy") + F_GT = np.dot(np.dot(Kinvt, E_GT), Kinv) + R_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "R_GT.npy") + t_GT = np.load(dir_base + pair + "/GT_Correspondence/" + "T_GT.npy") + result_base_dir = "result/" + scene + "/" + pair + "/" for iso in ISO: for ex in Shutter_speed: - dark_name1 = name[0]+iso+'_'+ex+'_'+scene+'.npy' - dark_name2 = name[1]+iso+'_'+ex+'_'+scene+'.npy' - - dr, dt = pose_evaluation(result_base_dir,dark_name1,dark_name2,enhancer,K,R_GT,t_GT) - PE_MT[Shutter_speed.index(ex),ISO.index(iso)] = max(dr, dt) - - subprocess.check_output(['mkdir', '-p', base_save + pair + f'/{enhancer}/']) - np.save(base_save + pair + f'/{enhancer}/Pose_error_DarkFeat.npy', PE_MT) - \ No newline at end of file + dark_name1 = name[0] + iso + "_" + ex + "_" + scene + ".npy" + dark_name2 = name[1] + iso + "_" + ex + "_" + scene + ".npy" + + dr, dt = pose_evaluation( + result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT + ) + PE_MT[Shutter_speed.index(ex), ISO.index(iso)] = max(dr, dt) + + subprocess.check_output( + ["mkdir", "-p", base_save + pair + f"/{enhancer}/"] + ) + np.save( + base_save + pair + f"/{enhancer}/Pose_error_DarkFeat.npy", PE_MT + ) diff --git a/third_party/DarkFeat/raw_preprocess.py b/third_party/DarkFeat/raw_preprocess.py index 226155a84e97f15782d3650f4ef6b3fa1880e07b..6f51bef8ae45114160214fbc22b1c5cc832c7d42 100644 --- a/third_party/DarkFeat/raw_preprocess.py +++ b/third_party/DarkFeat/raw_preprocess.py @@ -9,54 +9,78 @@ from tqdm import tqdm def process_raw(args, path, w_new, h_new): raw = rawpy.imread(str(path)).raw_image_visible - if '_00200_' in str(path) or '_00100_' in str(path): - raw = np.clip(raw.astype('float32') - 512, 0, 65535) + if "_00200_" in str(path) or "_00100_" in str(path): + raw = np.clip(raw.astype("float32") - 512, 0, 65535) else: - raw = np.clip(raw.astype('float32') - 2048, 0, 65535) - img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32') + raw = np.clip(raw.astype("float32") - 2048, 0, 65535) + img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, "RGGB").astype( + "float32" + ) img = np.clip(img, 0, 16383) # HistEQ start if args.histeq: img2 = np.zeros_like(img) for i in range(3): - hist,bins = np.histogram(img[..., i].flatten(),16384,[0,16384]) + hist, bins = np.histogram(img[..., i].flatten(), 16384, [0, 16384]) cdf = hist.cumsum() cdf_normalized = cdf * float(hist.max()) / cdf.max() - cdf_m = np.ma.masked_equal(cdf,0) - cdf_m = (cdf_m - cdf_m.min())*16383/(cdf_m.max()-cdf_m.min()) - cdf = np.ma.filled(cdf_m,0).astype('uint16') - img2[..., i] = cdf[img[..., i].astype('int16')] - img[..., i] = img2[..., i].astype('float32') + cdf_m = np.ma.masked_equal(cdf, 0) + cdf_m = (cdf_m - cdf_m.min()) * 16383 / (cdf_m.max() - cdf_m.min()) + cdf = np.ma.filled(cdf_m, 0).astype("uint16") + img2[..., i] = cdf[img[..., i].astype("int16")] + img[..., i] = img2[..., i].astype("float32") # HistEQ end m = img.mean() d = np.abs(img - img.mean()).mean() - img = (img - m + 2*d) / 4/d * 255 + img = (img - m + 2 * d) / 4 / d * 255 image = np.clip(img, 0, 255) - image = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA) + image = cv2.resize( + image.astype("float32"), (w_new, h_new), interpolation=cv2.INTER_AREA + ) if args.histeq: - path=str(path) - os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']), exist_ok=True) - np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']+[path.split('/')[-1].replace('cr2','npy')]), image) + path = str(path) + os.makedirs( + "/".join(path.split("/")[:-2] + [path.split("/")[-2] + "-npy"]), + exist_ok=True, + ) + np.save( + "/".join( + path.split("/")[:-2] + + [path.split("/")[-2] + "-npy"] + + [path.split("/")[-1].replace("cr2", "npy")] + ), + image, + ) else: - path=str(path) - os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']), exist_ok=True) - np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']+[path.split('/')[-1].replace('cr2','npy')]), image) + path = str(path) + os.makedirs( + "/".join(path.split("/")[:-2] + [path.split("/")[-2] + "-npy-nohisteq"]), + exist_ok=True, + ) + np.save( + "/".join( + path.split("/")[:-2] + + [path.split("/")[-2] + "-npy-nohisteq"] + + [path.split("/")[-1].replace("cr2", "npy")] + ), + image, + ) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument('--H', type=int, default=int(640)) - parser.add_argument('--W', type=int, default=int(960)) - parser.add_argument('--histeq', action='store_true') - parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + parser.add_argument("--H", type=int, default=int(640)) + parser.add_argument("--W", type=int, default=int(960)) + parser.add_argument("--histeq", action="store_true") + parser.add_argument("--dataset_dir", type=str, default="/data/hyz/MID/") args = parser.parse_args() - path_ls = glob.glob(args.dataset_dir + '/*/pair*/?????/*') + path_ls = glob.glob(args.dataset_dir + "/*/pair*/?????/*") for path in tqdm(path_ls): process_raw(args, path, args.W, args.H) - diff --git a/third_party/DarkFeat/read_error.py b/third_party/DarkFeat/read_error.py index 406b92dbd3877a11e51aebc3a705cd8d8d17e173..9015dfd2954b21115458fa25a2fd278c7cd69596 100644 --- a/third_party/DarkFeat/read_error.py +++ b/third_party/DarkFeat/read_error.py @@ -1,56 +1,80 @@ -import os +import os import numpy as np import subprocess # def ratio(losses, thresholds=[1,2,3,4,5,6,7,8,9,10]): -def ratio(losses, thresholds=[5,10]): - return [ - '{:.3f}'.format(np.mean(losses < threshold)) - for threshold in thresholds - ] +def ratio(losses, thresholds=[5, 10]): + return ["{:.3f}".format(np.mean(losses < threshold)) for threshold in thresholds] -if __name__ == '__main__': - scene = 'Indoor' - dir_base = 'result_errors/Indoor/' - save_pt = 'resultfinal_errors/Indoor/' - subprocess.check_output(['mkdir', '-p', save_pt]) +if __name__ == "__main__": + scene = "Indoor" + dir_base = "result_errors/Indoor/" + save_pt = "resultfinal_errors/Indoor/" - with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f: - f.write('5deg 10deg'+'\n') + subprocess.check_output(["mkdir", "-p", save_pt]) + + with open(save_pt + "ratio_methods_" + scene + ".txt", "w") as f: + f.write("5deg 10deg" + "\n") pair_list = os.listdir(dir_base) - enhancer = os.listdir(dir_base+'/pair9/') + enhancer = os.listdir(dir_base + "/pair9/") for method in enhancer: - pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method)) + pose_error_list = sorted(os.listdir(dir_base + "/pair9/" + method)) for pose_error in pose_error_list: - error_array = np.expand_dims(np.zeros((6, 8)),axis=2) + error_array = np.expand_dims(np.zeros((6, 8)), axis=2) for pair in pair_list: try: - error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2) + error = np.expand_dims( + np.load( + dir_base + "/" + pair + "/" + method + "/" + pose_error + ), + axis=2, + ) except: - print('error in', dir_base+'/'+pair+'/'+method+'/'+pose_error) + print( + "error in", + dir_base + "/" + pair + "/" + method + "/" + pose_error, + ) continue - error_array = np.concatenate((error_array,error),axis=2) - ratio_result = ratio(error_array[:,:,1::].flatten()) - f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n") + error_array = np.concatenate((error_array, error), axis=2) + ratio_result = ratio(error_array[:, :, 1::].flatten()) + f.write( + method + + "_" + + pose_error[11:-4] + + " " + + " ".join([str(i) for i in ratio_result]) + + "\n" + ) - - scene = 'Outdoor' - dir_base = 'result_errors/Outdoor/' - save_pt = 'resultfinal_errors/Outdoor/' + scene = "Outdoor" + dir_base = "result_errors/Outdoor/" + save_pt = "resultfinal_errors/Outdoor/" - subprocess.check_output(['mkdir', '-p', save_pt]) + subprocess.check_output(["mkdir", "-p", save_pt]) - with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f: - f.write('5deg 10deg'+'\n') + with open(save_pt + "ratio_methods_" + scene + ".txt", "w") as f: + f.write("5deg 10deg" + "\n") pair_list = os.listdir(dir_base) - enhancer = os.listdir(dir_base+'/pair9/') + enhancer = os.listdir(dir_base + "/pair9/") for method in enhancer: - pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method)) + pose_error_list = sorted(os.listdir(dir_base + "/pair9/" + method)) for pose_error in pose_error_list: - error_array = np.expand_dims(np.zeros((6, 8)),axis=2) + error_array = np.expand_dims(np.zeros((6, 8)), axis=2) for pair in pair_list: - error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2) - error_array = np.concatenate((error_array,error),axis=2) - ratio_result = ratio(error_array[:,:,1::].flatten()) - f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n") + error = np.expand_dims( + np.load( + dir_base + "/" + pair + "/" + method + "/" + pose_error + ), + axis=2, + ) + error_array = np.concatenate((error_array, error), axis=2) + ratio_result = ratio(error_array[:, :, 1::].flatten()) + f.write( + method + + "_" + + pose_error[11:-4] + + " " + + " ".join([str(i) for i in ratio_result]) + + "\n" + ) diff --git a/third_party/DarkFeat/run.py b/third_party/DarkFeat/run.py index 0e4c87053d2970fc927d8991aa0dab208f3c4917..1cf463d4e0218d66dff0c3637346a12d327d9fda 100644 --- a/third_party/DarkFeat/run.py +++ b/third_party/DarkFeat/run.py @@ -10,39 +10,45 @@ from trainer_single_norel import SingleTrainerNoRel from trainer_single import SingleTrainer -if __name__ == '__main__': +if __name__ == "__main__": # add argument parser parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, default='./configs/config.yaml') - parser.add_argument('--dataset_dir', type=str, default='/mnt/nvme2n1/hyz/data/GL3D') - parser.add_argument('--data_split', type=str, default='comb') - parser.add_argument('--is_training', type=bool, default=True) - parser.add_argument('--job_name', type=str, default='') - parser.add_argument('--gpu', type=str, default='0') - parser.add_argument('--start_cnt', type=int, default=0) - parser.add_argument('--stage', type=int, default=1) + parser.add_argument("--config", type=str, default="./configs/config.yaml") + parser.add_argument("--dataset_dir", type=str, default="/mnt/nvme2n1/hyz/data/GL3D") + parser.add_argument("--data_split", type=str, default="comb") + parser.add_argument("--is_training", type=bool, default=True) + parser.add_argument("--job_name", type=str, default="") + parser.add_argument("--gpu", type=str, default="0") + parser.add_argument("--start_cnt", type=int, default=0) + parser.add_argument("--stage", type=int, default=1) args = parser.parse_args() # load global config - with open(args.config, 'r') as f: + with open(args.config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) # setup dataloader - dataset = GL3DDataset(args.dataset_dir, config['network'], args.data_split, is_training=args.is_training) + dataset = GL3DDataset( + args.dataset_dir, + config["network"], + args.data_split, + is_training=args.is_training, + ) data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4) - os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu - + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.stage == 1: - trainer = SingleTrainerNoRel(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + trainer = SingleTrainerNoRel( + config, f"cuda:0", data_loader, args.job_name, args.start_cnt + ) elif args.stage == 2: - trainer = SingleTrainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + trainer = SingleTrainer( + config, f"cuda:0", data_loader, args.job_name, args.start_cnt + ) elif args.stage == 3: - trainer = Trainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + trainer = Trainer(config, f"cuda:0", data_loader, args.job_name, args.start_cnt) else: raise NotImplementedError() - - trainer.train() - \ No newline at end of file + trainer.train() diff --git a/third_party/DarkFeat/trainer.py b/third_party/DarkFeat/trainer.py index e6ff2af9608e934b6899058d756bb2ab7d0fee2d..1f3bed348f16adf81d3f48ef23563442c7d35fdc 100644 --- a/third_party/DarkFeat/trainer.py +++ b/third_party/DarkFeat/trainer.py @@ -23,23 +23,26 @@ class Trainer: self.config = config self.device = device self.loader = loader - + # tensorboard writer construction - os.makedirs('./runs/', exist_ok=True) - if job_name != '': - self.log_dir = f'runs/{job_name}' + os.makedirs("./runs/", exist_ok=True) + if job_name != "": + self.log_dir = f"runs/{job_name}" else: self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' self.writer = SummaryWriter(self.log_dir) - with open(f'{self.log_dir}/config.yaml', 'w') as f: + with open(f"{self.log_dir}/config.yaml", "w") as f: yaml.dump(config, f) - if config['network']['input_type'] == 'gray': + if config["network"]["input_type"] == "gray": self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) - elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + elif ( + config["network"]["input_type"] == "rgb" + or config["network"]["input_type"] == "raw-demosaic" + ): self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) - elif config['network']['input_type'] == 'raw': + elif config["network"]["input_type"] == "raw": self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) else: raise NotImplementedError() @@ -49,80 +52,104 @@ class Trainer: # reliability map conv self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda() - + # load model self.cnt = 0 if start_cnt != 0: - self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth', map_location=device)) + self.model.load_state_dict( + torch.load( + f"{self.log_dir}/model_{start_cnt:06d}.pth", map_location=device + ) + ) self.cnt = start_cnt + 1 # sampler - sampler = MultiSampler(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, - subd_neg=-8,maxpool_pos=True).to(device) + sampler = MultiSampler( + ngh=7, + subq=-8, + subd=1, + pos_d=3, + neg_d=5, + border=16, + subd_neg=-8, + maxpool_pos=True, + ).to(device) self.reliability_relitive_loss = MultiPixelAPLoss(sampler, nq=20).to(device) - # optimizer and scheduler - if self.config['training']['optimizer'] == 'SGD': + if self.config["training"]["optimizer"] == "SGD": self.optimizer = torch.optim.SGD( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - momentum=self.config['training']['momentum'], - weight_decay=self.config['training']['weight_decay'], + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + momentum=self.config["training"]["momentum"], + weight_decay=self.config["training"]["weight_decay"], ) - elif self.config['training']['optimizer'] == 'Adam': + elif self.config["training"]["optimizer"] == "Adam": self.optimizer = torch.optim.Adam( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - weight_decay=self.config['training']['weight_decay'] + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + weight_decay=self.config["training"]["weight_decay"], ) else: raise NotImplementedError() self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, - step_size=self.config['training']['lr_step'], - gamma=self.config['training']['lr_gamma'], - last_epoch=start_cnt + step_size=self.config["training"]["lr_step"], + gamma=self.config["training"]["lr_gamma"], + last_epoch=start_cnt, ) for param_tensor in self.model.state_dict(): print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) - def save(self, iter_num): - torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth") def load(self, path): self.model.load_state_dict(torch.load(path)) def train(self): self.model.train() - + for epoch in range(2): for batch_idx, inputs in enumerate(self.loader): self.optimizer.zero_grad() t = time.time() # preprocess and add noise - img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) - img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + img0_ori, noise_img0_ori = self.preprocess_noise_pair( + inputs["img0"], self.cnt + ) + img1_ori, noise_img1_ori = self.preprocess_noise_pair( + inputs["img1"], self.cnt + ) img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) noise_img0 = noise_img0_ori.permute(0, 3, 1, 2).float().to(self.device) noise_img1 = noise_img1_ori.permute(0, 3, 1, 2).float().to(self.device) - if self.config['network']['input_type'] == 'rgb': + if self.config["network"]["input_type"] == "rgb": # 3-channel rgb RGB_mean = [0.485, 0.456, 0.406] - RGB_std = [0.229, 0.224, 0.225] + RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) img0 = norm_RGB(img0) img1 = norm_RGB(img1) noise_img0 = norm_RGB(noise_img0) noise_img1 = norm_RGB(noise_img1) - elif self.config['network']['input_type'] == 'gray': + elif self.config["network"]["input_type"] == "gray": # 1-channel img0 = torch.mean(img0, dim=1, keepdim=True) img1 = torch.mean(img1, dim=1, keepdim=True) @@ -135,11 +162,11 @@ class Trainer: noise_img0 = norm_gray0(noise_img0) noise_img1 = norm_gray1(noise_img1) - elif self.config['network']['input_type'] == 'raw': + elif self.config["network"]["input_type"] == "raw": # 4-channel pass - elif self.config['network']['input_type'] == 'raw-demosaic': + elif self.config["network"]["input_type"] == "raw-demosaic": # 3-channel pass @@ -149,14 +176,26 @@ class Trainer: desc0, score_map0, _, _ = self.model(img0) desc1, score_map1, _, _ = self.model(img1) - conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2] - conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2] + conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[ + :, 1:2 + ] + conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[ + :, 1:2 + ] - noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(noise_img0) - noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(noise_img1) + noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model( + noise_img0 + ) + noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model( + noise_img1 + ) - noise_conf0 = F.softmax(self.model.clf(torch.abs(noise_desc0)**2.0), dim=1)[:,1:2] - noise_conf1 = F.softmax(self.model.clf(torch.abs(noise_desc1)**2.0), dim=1)[:,1:2] + noise_conf0 = F.softmax( + self.model.clf(torch.abs(noise_desc0) ** 2.0), dim=1 + )[:, 1:2] + noise_conf1 = F.softmax( + self.model.clf(torch.abs(noise_desc1) ** 2.0), dim=1 + )[:, 1:2] cur_feat_size0 = torch.tensor(score_map0.shape[2:]) cur_feat_size1 = torch.tensor(score_map1.shape[2:]) @@ -174,71 +213,128 @@ class Trainer: noise_conf0 = noise_conf0.permute(0, 2, 3, 1) noise_conf1 = noise_conf1.permute(0, 2, 3, 1) - r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) - r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) - + r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to( + self.device + ) + r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to( + self.device + ) + pos0 = _grid_positions( - cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0] + ).to(self.device) pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate( - pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), - r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + pos0, + inputs["rel_pose"].to(self.device), + inputs["depth0"].to(self.device), + r_K0, + inputs["depth1"].to(self.device), + r_K1, + img0.shape[0], + ) pos0, pos1, _ = getWarp( - pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), - r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + pos0, + inputs["rel_pose"].to(self.device), + inputs["depth0"].to(self.device), + r_K0, + inputs["depth1"].to(self.device), + r_K1, + img0.shape[0], + ) - reliab_loss_relative = self.reliability_relitive_loss(desc0, desc1, noise_desc0, noise_desc1, conf0, conf1, noise_conf0, noise_conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3]) + reliab_loss_relative = self.reliability_relitive_loss( + desc0, + desc1, + noise_desc0, + noise_desc1, + conf0, + conf1, + noise_conf0, + noise_conf1, + pos0_for_rel, + pos1_for_rel, + img0.shape[0], + img0.shape[2], + img0.shape[3], + ) det_structured_loss, det_accuracy = make_detector_loss( - pos0, pos1, desc0, desc1, - score_map0, score_map1, img0.shape[0], - self.config['network']['use_corr_n'], - self.config['network']['loss_type'], - self.config + pos0, + pos1, + desc0, + desc1, + score_map0, + score_map1, + img0.shape[0], + self.config["network"]["use_corr_n"], + self.config["network"]["loss_type"], + self.config, ) det_structured_loss_noise, det_accuracy_noise = make_detector_loss( - pos0, pos1, noise_desc0, noise_desc1, - noise_score_map0, noise_score_map1, img0.shape[0], - self.config['network']['use_corr_n'], - self.config['network']['loss_type'], - self.config + pos0, + pos1, + noise_desc0, + noise_desc1, + noise_score_map0, + noise_score_map1, + img0.shape[0], + self.config["network"]["use_corr_n"], + self.config["network"]["loss_type"], + self.config, ) indices0, scores0 = extract_kpts( score_map0.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) indices1, scores1 = extract_kpts( score_map1.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) - noise_score_loss0, mask0 = make_noise_score_map_loss(score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1) - noise_score_loss1, mask1 = make_noise_score_map_loss(score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1) + noise_score_loss0, mask0 = make_noise_score_map_loss( + score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1 + ) + noise_score_loss1, mask1 = make_noise_score_map_loss( + score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1 + ) total_loss = det_structured_loss + det_structured_loss_noise - total_loss += noise_score_loss0 / 2. * 1. - total_loss += noise_score_loss1 / 2. * 1. - total_loss += reliab_loss_relative[0] / 2. * 0.5 - total_loss += reliab_loss_relative[1] / 2. * 0.5 - + total_loss += noise_score_loss0 / 2.0 * 1.0 + total_loss += noise_score_loss1 / 2.0 * 1.0 + total_loss += reliab_loss_relative[0] / 2.0 * 0.5 + total_loss += reliab_loss_relative[1] / 2.0 * 0.5 + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) self.writer.add_scalar("acc/noise_acc", det_accuracy_noise, self.cnt) self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) - self.writer.add_scalar("loss/noise_score_loss", (noise_score_loss0 + noise_score_loss1) / 2., self.cnt) - self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) - self.writer.add_scalar("loss/det_loss_noise", det_structured_loss_noise, self.cnt) - print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + self.writer.add_scalar( + "loss/noise_score_loss", + (noise_score_loss0 + noise_score_loss1) / 2.0, + self.cnt, + ) + self.writer.add_scalar( + "loss/det_loss_normal", det_structured_loss, self.cnt + ) + self.writer.add_scalar( + "loss/det_loss_noise", det_structured_loss_noise, self.cnt + ) + print( + "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format( + self.cnt, total_loss, det_accuracy, time.time() - t + ) + ) # print(f'normal_loss: {det_structured_loss}, noise_loss: {det_structured_loss_noise}, reliab_loss: {reliab_loss_relative[0]}, {reliab_loss_relative[1]}') if det_structured_loss != 0: @@ -249,100 +345,162 @@ class Trainer: if self.cnt % 100 == 0: noise_indices0, noise_scores0 = extract_kpts( noise_score_map0.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) noise_indices1, noise_scores1 = extract_kpts( noise_score_map1.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) - if self.config['network']['input_type'] == 'raw': - kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) - noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0][..., :3] * 255., noise_indices0[0]) - noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0][..., :3] * 255., noise_indices1[0]) + if self.config["network"]["input_type"] == "raw": + kpt_img0 = self.showKeyPoints( + img0_ori[0][..., :3] * 255.0, indices0[0] + ) + kpt_img1 = self.showKeyPoints( + img1_ori[0][..., :3] * 255.0, indices1[0] + ) + noise_kpt_img0 = self.showKeyPoints( + noise_img0_ori[0][..., :3] * 255.0, noise_indices0[0] + ) + noise_kpt_img1 = self.showKeyPoints( + noise_img1_ori[0][..., :3] * 255.0, noise_indices1[0] + ) else: - kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) - noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0] * 255., noise_indices0[0]) - noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0] * 255., noise_indices1[0]) - - self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') - self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') - self.writer.add_image('img0/noise_kpts', noise_kpt_img0, self.cnt, dataformats='HWC') - self.writer.add_image('img1/noise_kpts', noise_kpt_img1, self.cnt, dataformats='HWC') - self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') - self.writer.add_image('img0/noise_score_map', noise_score_map0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/noise_score_map', noise_score_map1[0], self.cnt, dataformats='HWC') - self.writer.add_image('img0/kpt_mask', mask0.unsqueeze(2), self.cnt, dataformats='HWC') - self.writer.add_image('img1/kpt_mask', mask1.unsqueeze(2), self.cnt, dataformats='HWC') - self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC') - self.writer.add_image('img0/noise_conf', noise_conf0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/noise_conf', noise_conf1[0], self.cnt, dataformats='HWC') + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0]) + noise_kpt_img0 = self.showKeyPoints( + noise_img0_ori[0] * 255.0, noise_indices0[0] + ) + noise_kpt_img1 = self.showKeyPoints( + noise_img1_ori[0] * 255.0, noise_indices1[0] + ) + + self.writer.add_image( + "img0/kpts", kpt_img0, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/kpts", kpt_img1, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/noise_kpts", noise_kpt_img0, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/noise_kpts", noise_kpt_img1, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/score_map", score_map0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/score_map", score_map1[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/noise_score_map", + noise_score_map0[0], + self.cnt, + dataformats="HWC", + ) + self.writer.add_image( + "img1/noise_score_map", + noise_score_map1[0], + self.cnt, + dataformats="HWC", + ) + self.writer.add_image( + "img0/kpt_mask", mask0.unsqueeze(2), self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/kpt_mask", mask1.unsqueeze(2), self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/conf", conf0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/conf", conf1[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/noise_conf", noise_conf0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/noise_conf", noise_conf1[0], self.cnt, dataformats="HWC" + ) if self.cnt % 5000 == 0: self.save(self.cnt) - - self.cnt += 1 + self.cnt += 1 def showKeyPoints(self, img, indices): key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) - img = img.numpy().astype('uint8') + img = img.numpy().astype("uint8") img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) return img - def preprocess(self, img, iter_idx): - if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + if ( + not self.config["network"]["noise"] + and "raw" not in self.config["network"]["input_type"] + ): return img raw = self.noise_maker.rgb2raw(img, batched=True) - if self.config['network']['noise']: - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + if self.config["network"]["noise"]: + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) - if self.config['network']['input_type'] == 'raw': + if self.config["network"]["input_type"] == "raw": return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': + if self.config["network"]["input_type"] == "raw-demosaic": return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) rgb = self.noise_maker.raw2rgb(raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return torch.tensor(rgb) raise NotImplementedError() - def preprocess_noise_pair(self, img, iter_idx): - assert self.config['network']['noise'] + assert self.config["network"]["noise"] raw = self.noise_maker.rgb2raw(img, batched=True) - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] - noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) + noise_raw = self.noise_maker.raw2noisyRaw( + raw, ratio_dec=ratio_dec, batched=True + ) - if self.config['network']['input_type'] == 'raw': - return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw": + return torch.tensor( + self.noise_maker.raw2packedRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': - return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw-demosaic": + return torch.tensor( + self.noise_maker.raw2demosaicRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return img, torch.tensor(noise_rgb) raise NotImplementedError() diff --git a/third_party/DarkFeat/trainer_single.py b/third_party/DarkFeat/trainer_single.py index 65566e7e27cfd605eba000d308b6d3610f29e746..0b079d1fc376b3dbd45297902c4d1e195c267156 100644 --- a/third_party/DarkFeat/trainer_single.py +++ b/third_party/DarkFeat/trainer_single.py @@ -24,23 +24,29 @@ class SingleTrainer: self.config = config self.device = device self.loader = loader - + # tensorboard writer construction - os.makedirs('./runs/', exist_ok=True) - if job_name != '': - self.log_dir = f'runs/{job_name}' + os.makedirs("./runs/", exist_ok=True) + if job_name != "": + self.log_dir = f"runs/{job_name}" else: self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' self.writer = SummaryWriter(self.log_dir) - with open(f'{self.log_dir}/config.yaml', 'w') as f: + with open(f"{self.log_dir}/config.yaml", "w") as f: yaml.dump(config, f) - if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray': + if ( + config["network"]["input_type"] == "gray" + or config["network"]["input_type"] == "raw-gray" + ): self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) - elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + elif ( + config["network"]["input_type"] == "rgb" + or config["network"]["input_type"] == "raw-demosaic" + ): self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) - elif config['network']['input_type'] == 'raw': + elif config["network"]["input_type"] == "raw": self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) else: raise NotImplementedError() @@ -51,75 +57,98 @@ class SingleTrainer: # load model self.cnt = 0 if start_cnt != 0: - self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth')) + self.model.load_state_dict( + torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth") + ) self.cnt = start_cnt + 1 # sampler - sampler = NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, - subd_neg=-8,maxpool_pos=True).to(device) + sampler = NghSampler2( + ngh=7, + subq=-8, + subd=1, + pos_d=3, + neg_d=5, + border=16, + subd_neg=-8, + maxpool_pos=True, + ).to(device) self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device) # reliability map conv self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda() # optimizer and scheduler - if self.config['training']['optimizer'] == 'SGD': + if self.config["training"]["optimizer"] == "SGD": self.optimizer = torch.optim.SGD( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - momentum=self.config['training']['momentum'], - weight_decay=self.config['training']['weight_decay'], + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + momentum=self.config["training"]["momentum"], + weight_decay=self.config["training"]["weight_decay"], ) - elif self.config['training']['optimizer'] == 'Adam': + elif self.config["training"]["optimizer"] == "Adam": self.optimizer = torch.optim.Adam( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - weight_decay=self.config['training']['weight_decay'] + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + weight_decay=self.config["training"]["weight_decay"], ) else: raise NotImplementedError() self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, - step_size=self.config['training']['lr_step'], - gamma=self.config['training']['lr_gamma'], - last_epoch=start_cnt + step_size=self.config["training"]["lr_step"], + gamma=self.config["training"]["lr_gamma"], + last_epoch=start_cnt, ) for param_tensor in self.model.state_dict(): print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) - def save(self, iter_num): - torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth") def load(self, path): self.model.load_state_dict(torch.load(path)) def train(self): self.model.train() - + for epoch in range(2): for batch_idx, inputs in enumerate(self.loader): self.optimizer.zero_grad() t = time.time() # preprocess and add noise - img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) - img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + img0_ori, noise_img0_ori = self.preprocess_noise_pair( + inputs["img0"], self.cnt + ) + img1_ori, noise_img1_ori = self.preprocess_noise_pair( + inputs["img1"], self.cnt + ) img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) - if self.config['network']['input_type'] == 'rgb': + if self.config["network"]["input_type"] == "rgb": # 3-channel rgb RGB_mean = [0.485, 0.456, 0.406] - RGB_std = [0.229, 0.224, 0.225] + RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) img0 = norm_RGB(img0) img1 = norm_RGB(img1) noise_img0 = norm_RGB(noise_img0) noise_img1 = norm_RGB(noise_img1) - elif self.config['network']['input_type'] == 'gray': + elif self.config["network"]["input_type"] == "gray": # 1-channel img0 = torch.mean(img0, dim=1, keepdim=True) img1 = torch.mean(img1, dim=1, keepdim=True) @@ -132,11 +161,11 @@ class SingleTrainer: noise_img0 = norm_gray0(noise_img0) noise_img1 = norm_gray1(noise_img1) - elif self.config['network']['input_type'] == 'raw': + elif self.config["network"]["input_type"] == "raw": # 4-channel pass - elif self.config['network']['input_type'] == 'raw-demosaic': + elif self.config["network"]["input_type"] == "raw-demosaic": # 3-channel pass @@ -149,8 +178,12 @@ class SingleTrainer: cur_feat_size0 = torch.tensor(score_map0.shape[2:]) cur_feat_size1 = torch.tensor(score_map1.shape[2:]) - conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2] - conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2] + conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[ + :, 1:2 + ] + conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[ + :, 1:2 + ] desc0 = desc0.permute(0, 2, 3, 1) desc1 = desc1.permute(0, 2, 3, 1) @@ -159,39 +192,77 @@ class SingleTrainer: conf0 = conf0.permute(0, 2, 3, 1) conf1 = conf1.permute(0, 2, 3, 1) - r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) - r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) - + r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to( + self.device + ) + r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to( + self.device + ) + pos0 = _grid_positions( - cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0] + ).to(self.device) pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate( - pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), - r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + pos0, + inputs["rel_pose"].to(self.device), + inputs["depth0"].to(self.device), + r_K0, + inputs["depth1"].to(self.device), + r_K1, + img0.shape[0], + ) pos0, pos1, _ = getWarp( - pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), - r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + pos0, + inputs["rel_pose"].to(self.device), + inputs["depth0"].to(self.device), + r_K0, + inputs["depth1"].to(self.device), + r_K1, + img0.shape[0], + ) - reliab_loss = self.reliability_loss(desc0, desc1, conf0, conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3]) + reliab_loss = self.reliability_loss( + desc0, + desc1, + conf0, + conf1, + pos0_for_rel, + pos1_for_rel, + img0.shape[0], + img0.shape[2], + img0.shape[3], + ) det_structured_loss, det_accuracy = make_detector_loss( - pos0, pos1, desc0, desc1, - score_map0, score_map1, img0.shape[0], - self.config['network']['use_corr_n'], - self.config['network']['loss_type'], - self.config + pos0, + pos1, + desc0, + desc1, + score_map0, + score_map1, + img0.shape[0], + self.config["network"]["use_corr_n"], + self.config["network"]["loss_type"], + self.config, ) total_loss = det_structured_loss - self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) - + self.writer.add_scalar( + "loss/det_loss_normal", det_structured_loss, self.cnt + ) + total_loss += reliab_loss - + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt) - print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + print( + "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format( + self.cnt, total_loss, det_accuracy, time.time() - t + ) + ) if det_structured_loss != 0: total_loss.backward() @@ -201,94 +272,133 @@ class SingleTrainer: if self.cnt % 100 == 0: indices0, scores0 = extract_kpts( score_map0.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) indices1, scores1 = extract_kpts( score_map1.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) - if self.config['network']['input_type'] == 'raw': - kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) + if self.config["network"]["input_type"] == "raw": + kpt_img0 = self.showKeyPoints( + img0_ori[0][..., :3] * 255.0, indices0[0] + ) + kpt_img1 = self.showKeyPoints( + img1_ori[0][..., :3] * 255.0, indices1[0] + ) else: - kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0]) - self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') - self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') - self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') - self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC') + self.writer.add_image( + "img0/kpts", kpt_img0, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/kpts", kpt_img1, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/score_map", score_map0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/score_map", score_map1[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/conf", conf0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/conf", conf1[0], self.cnt, dataformats="HWC" + ) if self.cnt % 10000 == 0: self.save(self.cnt) - - self.cnt += 1 + self.cnt += 1 def showKeyPoints(self, img, indices): key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) - img = img.numpy().astype('uint8') + img = img.numpy().astype("uint8") img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) return img - def preprocess(self, img, iter_idx): - if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + if ( + not self.config["network"]["noise"] + and "raw" not in self.config["network"]["input_type"] + ): return img raw = self.noise_maker.rgb2raw(img, batched=True) - if self.config['network']['noise']: - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + if self.config["network"]["noise"]: + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) - if self.config['network']['input_type'] == 'raw': + if self.config["network"]["input_type"] == "raw": return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': + if self.config["network"]["input_type"] == "raw-demosaic": return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) rgb = self.noise_maker.raw2rgb(raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return torch.tensor(rgb) raise NotImplementedError() - def preprocess_noise_pair(self, img, iter_idx): - assert self.config['network']['noise'] + assert self.config["network"]["noise"] raw = self.noise_maker.rgb2raw(img, batched=True) - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] - noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) + noise_raw = self.noise_maker.raw2noisyRaw( + raw, ratio_dec=ratio_dec, batched=True + ) - if self.config['network']['input_type'] == 'raw': - return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw": + return torch.tensor( + self.noise_maker.raw2packedRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': - return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw-demosaic": + return torch.tensor( + self.noise_maker.raw2demosaicRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) - if self.config['network']['input_type'] == 'raw-gray': + if self.config["network"]["input_type"] == "raw-gray": factor = torch.tensor([0.299, 0.587, 0.114]).double() - return torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), factor).unsqueeze(-1), \ - torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), factor).unsqueeze(-1) + return torch.matmul( + torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), + factor, + ).unsqueeze(-1), torch.matmul( + torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), + factor, + ).unsqueeze( + -1 + ) noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return img, torch.tensor(noise_rgb) raise NotImplementedError() diff --git a/third_party/DarkFeat/trainer_single_norel.py b/third_party/DarkFeat/trainer_single_norel.py index a572e9c599adc30e5753e11e668d121cd378672a..5447a37dabba339183f4e50ef44381ebc7a34998 100644 --- a/third_party/DarkFeat/trainer_single_norel.py +++ b/third_party/DarkFeat/trainer_single_norel.py @@ -23,23 +23,29 @@ class SingleTrainerNoRel: self.config = config self.device = device self.loader = loader - + # tensorboard writer construction - os.makedirs('./runs/', exist_ok=True) - if job_name != '': - self.log_dir = f'runs/{job_name}' + os.makedirs("./runs/", exist_ok=True) + if job_name != "": + self.log_dir = f"runs/{job_name}" else: self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' self.writer = SummaryWriter(self.log_dir) - with open(f'{self.log_dir}/config.yaml', 'w') as f: + with open(f"{self.log_dir}/config.yaml", "w") as f: yaml.dump(config, f) - if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray': + if ( + config["network"]["input_type"] == "gray" + or config["network"]["input_type"] == "raw-gray" + ): self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) - elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + elif ( + config["network"]["input_type"] == "rgb" + or config["network"]["input_type"] == "raw-demosaic" + ): self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) - elif config['network']['input_type'] == 'raw': + elif config["network"]["input_type"] == "raw": self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) else: raise NotImplementedError() @@ -50,68 +56,83 @@ class SingleTrainerNoRel: # load model self.cnt = 0 if start_cnt != 0: - self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth')) + self.model.load_state_dict( + torch.load(f"{self.log_dir}/model_{start_cnt:06d}.pth") + ) self.cnt = start_cnt + 1 # optimizer and scheduler - if self.config['training']['optimizer'] == 'SGD': + if self.config["training"]["optimizer"] == "SGD": self.optimizer = torch.optim.SGD( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - momentum=self.config['training']['momentum'], - weight_decay=self.config['training']['weight_decay'], + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + momentum=self.config["training"]["momentum"], + weight_decay=self.config["training"]["weight_decay"], ) - elif self.config['training']['optimizer'] == 'Adam': + elif self.config["training"]["optimizer"] == "Adam": self.optimizer = torch.optim.Adam( - [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], - lr=self.config['training']['lr'], - weight_decay=self.config['training']['weight_decay'] + [ + { + "params": self.model.parameters(), + "initial_lr": self.config["training"]["lr"], + } + ], + lr=self.config["training"]["lr"], + weight_decay=self.config["training"]["weight_decay"], ) else: raise NotImplementedError() self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, - step_size=self.config['training']['lr_step'], - gamma=self.config['training']['lr_gamma'], - last_epoch=start_cnt + step_size=self.config["training"]["lr_step"], + gamma=self.config["training"]["lr_gamma"], + last_epoch=start_cnt, ) for param_tensor in self.model.state_dict(): print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) - def save(self, iter_num): - torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth") def load(self, path): self.model.load_state_dict(torch.load(path)) def train(self): self.model.train() - + for epoch in range(2): for batch_idx, inputs in enumerate(self.loader): self.optimizer.zero_grad() t = time.time() # preprocess and add noise - img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) - img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + img0_ori, noise_img0_ori = self.preprocess_noise_pair( + inputs["img0"], self.cnt + ) + img1_ori, noise_img1_ori = self.preprocess_noise_pair( + inputs["img1"], self.cnt + ) img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) - if self.config['network']['input_type'] == 'rgb': + if self.config["network"]["input_type"] == "rgb": # 3-channel rgb RGB_mean = [0.485, 0.456, 0.406] - RGB_std = [0.229, 0.224, 0.225] + RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) img0 = norm_RGB(img0) img1 = norm_RGB(img1) noise_img0 = norm_RGB(noise_img0) noise_img1 = norm_RGB(noise_img1) - elif self.config['network']['input_type'] == 'gray': + elif self.config["network"]["input_type"] == "gray": # 1-channel img0 = torch.mean(img0, dim=1, keepdim=True) img1 = torch.mean(img1, dim=1, keepdim=True) @@ -124,11 +145,11 @@ class SingleTrainerNoRel: noise_img0 = norm_gray0(noise_img0) noise_img1 = norm_gray1(noise_img1) - elif self.config['network']['input_type'] == 'raw': + elif self.config["network"]["input_type"] == "raw": # 4-channel pass - elif self.config['network']['input_type'] == 'raw-demosaic': + elif self.config["network"]["input_type"] == "raw-demosaic": # 3-channel pass @@ -146,30 +167,52 @@ class SingleTrainerNoRel: score_map0 = score_map0.permute(0, 2, 3, 1) score_map1 = score_map1.permute(0, 2, 3, 1) - r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) - r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) - + r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to( + self.device + ) + r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to( + self.device + ) + pos0 = _grid_positions( - cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0] + ).to(self.device) pos0, pos1, _ = getWarp( - pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), - r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + pos0, + inputs["rel_pose"].to(self.device), + inputs["depth0"].to(self.device), + r_K0, + inputs["depth1"].to(self.device), + r_K1, + img0.shape[0], + ) det_structured_loss, det_accuracy = make_detector_loss( - pos0, pos1, desc0, desc1, - score_map0, score_map1, img0.shape[0], - self.config['network']['use_corr_n'], - self.config['network']['loss_type'], - self.config + pos0, + pos1, + desc0, + desc1, + score_map0, + score_map1, + img0.shape[0], + self.config["network"]["use_corr_n"], + self.config["network"]["loss_type"], + self.config, ) total_loss = det_structured_loss - + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) - self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) - print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + self.writer.add_scalar( + "loss/det_loss_normal", det_structured_loss, self.cnt + ) + print( + "iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format( + self.cnt, total_loss, det_accuracy, time.time() - t + ) + ) if det_structured_loss != 0: total_loss.backward() @@ -179,87 +222,115 @@ class SingleTrainerNoRel: if self.cnt % 100 == 0: indices0, scores0 = extract_kpts( score_map0.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) indices1, scores1 = extract_kpts( score_map1.permute(0, 3, 1, 2), - k=self.config['network']['det']['kpt_n'], - score_thld=self.config['network']['det']['score_thld'], - nms_size=self.config['network']['det']['nms_size'], - eof_size=self.config['network']['det']['eof_size'], - edge_thld=self.config['network']['det']['edge_thld'] + k=self.config["network"]["det"]["kpt_n"], + score_thld=self.config["network"]["det"]["score_thld"], + nms_size=self.config["network"]["det"]["nms_size"], + eof_size=self.config["network"]["det"]["eof_size"], + edge_thld=self.config["network"]["det"]["edge_thld"], ) - if self.config['network']['input_type'] == 'raw': - kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) + if self.config["network"]["input_type"] == "raw": + kpt_img0 = self.showKeyPoints( + img0_ori[0][..., :3] * 255.0, indices0[0] + ) + kpt_img1 = self.showKeyPoints( + img1_ori[0][..., :3] * 255.0, indices1[0] + ) else: - kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) - kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0]) - self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') - self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') - self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') - self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') + self.writer.add_image( + "img0/kpts", kpt_img0, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/kpts", kpt_img1, self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img0/score_map", score_map0[0], self.cnt, dataformats="HWC" + ) + self.writer.add_image( + "img1/score_map", score_map1[0], self.cnt, dataformats="HWC" + ) if self.cnt % 10000 == 0: self.save(self.cnt) - - self.cnt += 1 + self.cnt += 1 def showKeyPoints(self, img, indices): key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) - img = img.numpy().astype('uint8') + img = img.numpy().astype("uint8") img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) return img - def preprocess(self, img, iter_idx): - if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + if ( + not self.config["network"]["noise"] + and "raw" not in self.config["network"]["input_type"] + ): return img raw = self.noise_maker.rgb2raw(img, batched=True) - if self.config['network']['noise']: - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + if self.config["network"]["noise"]: + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) - if self.config['network']['input_type'] == 'raw': + if self.config["network"]["input_type"] == "raw": return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': + if self.config["network"]["input_type"] == "raw-demosaic": return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) rgb = self.noise_maker.raw2rgb(raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return torch.tensor(rgb) raise NotImplementedError() - def preprocess_noise_pair(self, img, iter_idx): - assert self.config['network']['noise'] + assert self.config["network"]["noise"] raw = self.noise_maker.rgb2raw(img, batched=True) - ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] - noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + ratio_dec = ( + min(self.config["network"]["noise_maxstep"], iter_idx) + / self.config["network"]["noise_maxstep"] + ) + noise_raw = self.noise_maker.raw2noisyRaw( + raw, ratio_dec=ratio_dec, batched=True + ) - if self.config['network']['input_type'] == 'raw': - return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw": + return torch.tensor( + self.noise_maker.raw2packedRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) - if self.config['network']['input_type'] == 'raw-demosaic': - return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ - torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + if self.config["network"]["input_type"] == "raw-demosaic": + return torch.tensor( + self.noise_maker.raw2demosaicRaw(raw, batched=True) + ), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) - if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + if ( + self.config["network"]["input_type"] == "rgb" + or self.config["network"]["input_type"] == "gray" + ): return img, torch.tensor(noise_rgb) raise NotImplementedError() diff --git a/third_party/DarkFeat/utils/matching.py b/third_party/DarkFeat/utils/matching.py index ca091f418bb4dc4d278611e5126a930aa51e7f3f..78c2415cf54ec3942c94ded3afec381ba63b358a 100644 --- a/third_party/DarkFeat/utils/matching.py +++ b/third_party/DarkFeat/utils/matching.py @@ -2,24 +2,26 @@ import math import numpy as np import cv2 + def extract_ORB_keypoints_and_descriptors(img): # gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) detector = cv2.ORB_create(nfeatures=1000) kp, desc = detector.detectAndCompute(img, None) return kp, desc + def match_descriptors_NG(kp1, desc1, kp2, desc2): bf = cv2.BFMatcher() try: - matches = bf.knnMatch(desc1, desc2,k=2) + matches = bf.knnMatch(desc1, desc2, k=2) except: matches = [] - good_matches=[] + good_matches = [] image1_kp = [] image2_kp = [] ratios = [] try: - for (m1,m2) in matches: + for (m1, m2) in matches: if m1.distance < 0.8 * m2.distance: good_matches.append(m1) image2_kp.append(kp2[m1.trainIdx].pt) @@ -33,41 +35,42 @@ def match_descriptors_NG(kp1, desc1, kp2, desc2): ratios = np.expand_dims(ratios, 2) return image1_kp, image2_kp, good_matches, ratios + def match_descriptors(kp1, desc1, kp2, desc2, ORB): if ORB: bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) try: - matches = bf.match(desc1,desc2) - matches = sorted(matches, key = lambda x:x.distance) + matches = bf.match(desc1, desc2) + matches = sorted(matches, key=lambda x: x.distance) except: matches = [] - good_matches=[] + good_matches = [] image1_kp = [] image2_kp = [] count = 0 try: for m in matches: - count+=1 + count += 1 if count < 1000: good_matches.append(m) image2_kp.append(kp2[m.trainIdx].pt) - image1_kp.append(kp1[m.queryIdx].pt) + image1_kp.append(kp1[m.queryIdx].pt) except: pass else: # Match the keypoints with the warped_keypoints with nearest neighbor search bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) try: - matches = bf.match(desc1.transpose(1,0), desc2.transpose(1,0)) - matches = sorted(matches, key = lambda x:x.distance) + matches = bf.match(desc1.transpose(1, 0), desc2.transpose(1, 0)) + matches = sorted(matches, key=lambda x: x.distance) except: matches = [] - good_matches=[] + good_matches = [] image1_kp = [] image2_kp = [] try: for m in matches: - good_matches.append(m) + good_matches.append(m) image2_kp.append(kp2[m.trainIdx].pt) image1_kp.append(kp1[m.queryIdx].pt) except: @@ -79,18 +82,28 @@ def match_descriptors(kp1, desc1, kp2, desc2, ORB): def compute_essential(matched_kp1, matched_kp2, K): - pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) - pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts1 = cv2.undistortPoints( + matched_kp1, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) + pts2 = cv2.undistortPoints( + matched_kp2, + cameraMatrix=K, + distCoeffs=(-0.117918271740560, 0.075246403574314, 0, 0), + ) K_1 = np.eye(3) # Estimate the homography between the matches using RANSAC - ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001) - if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_model, ransac_inliers = cv2.findEssentialMat( + pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001 + ) + if ransac_inliers is None or ransac_model.shape != (3, 3): ransac_inliers = np.array([]) ransac_model = None return ransac_model, ransac_inliers, pts1, pts2 -def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): +def compute_error(R_GT, t_GT, E, pts1_norm, pts2_norm, inliers): """Compute the angular error between two rotation matrices and two translation vectors. Keyword arguments: R -- 2D numpy array containing an estimated rotation @@ -101,14 +114,14 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): inliers = inliers.ravel() R = np.eye(3) - t = np.zeros((3,1)) + t = np.zeros((3, 1)) sst = True try: cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), R, t, inliers) except: sst = False # calculate angle between provided rotations - # + # if sst: dR = np.matmul(R, np.transpose(R_GT)) dR = cv2.Rodrigues(dR)[0] @@ -119,10 +132,10 @@ def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): dT /= float(np.linalg.norm(t_GT)) if dT > 1 or dT < -1: - print("Domain warning! dT:",dT) - dT = max(-1,min(1,dT)) + print("Domain warning! dT:", dT) + dT = max(-1, min(1, dT)) dT = math.acos(dT) * 180 / math.pi - dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation + dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation else: - dR,dT = 180.0, 180.0 + dR, dT = 180.0, 180.0 return dR, dT diff --git a/third_party/DarkFeat/utils/misc.py b/third_party/DarkFeat/utils/misc.py index 1df6fdec97121486dbb94e0b32a2f66c85c48f7d..7d5ac3c8be8f8aacaaf4ec59f19b3278b963f572 100644 --- a/third_party/DarkFeat/utils/misc.py +++ b/third_party/DarkFeat/utils/misc.py @@ -9,7 +9,7 @@ import colour_demosaicing class AverageTimer: - """ Class to help manage printing simple timing of code execution. """ + """Class to help manage printing simple timing of code execution.""" def __init__(self, smoothing=0.3, newline=False): self.smoothing = smoothing @@ -25,7 +25,7 @@ class AverageTimer: for name in self.will_print: self.will_print[name] = False - def update(self, name='default'): + def update(self, name="default"): now = time.time() dt = now - self.last_time if name in self.times: @@ -34,19 +34,19 @@ class AverageTimer: self.will_print[name] = True self.last_time = now - def print(self, text='Timer'): - total = 0. - print('[{}]'.format(text), end=' ') + def print(self, text="Timer"): + total = 0.0 + print("[{}]".format(text), end=" ") for key in self.times: val = self.times[key] if self.will_print[key]: - print('%s=%.3f' % (key, val), end=' ') + print("%s=%.3f" % (key, val), end=" ") total += val - print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') + print("total=%.3f sec {%.1f FPS}" % (total, 1.0 / total), end=" ") if self.newline: print(flush=True) else: - print(end='\r', flush=True) + print(end="\r", flush=True) self.reset() @@ -56,32 +56,36 @@ class VideoStreamer: self.resize = resize self.i = 0 if Path(basedir).is_dir(): - print('==> Processing image directory input: {}'.format(basedir)) + print("==> Processing image directory input: {}".format(basedir)) self.listing = list(Path(basedir).glob(image_glob[0])) for j in range(1, len(image_glob)): image_path = list(Path(basedir).glob(image_glob[j])) self.listing = self.listing + image_path self.listing.sort() if len(self.listing) == 0: - raise IOError('No images found (maybe bad \'image_glob\' ?)') + raise IOError("No images found (maybe bad 'image_glob' ?)") self.max_length = len(self.listing) else: - raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + raise ValueError('VideoStreamer input "{}" not recognized.'.format(basedir)) def load_image(self, impath): raw = rawpy.imread(str(impath)).raw_image_visible - raw = np.clip(raw.astype('float32') - 512, 0, 65535) - img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32') + raw = np.clip(raw.astype("float32") - 512, 0, 65535) + img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, "RGGB").astype( + "float32" + ) img = np.clip(img, 0, 16383) m = img.mean() d = np.abs(img - img.mean()).mean() - img = (img - m + 2*d) / 4/d * 255 + img = (img - m + 2 * d) / 4 / d * 255 image = np.clip(img, 0, 255) w_new, h_new = self.resize[0], self.resize[1] - im = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA) + im = cv2.resize( + image.astype("float32"), (w_new, h_new), interpolation=cv2.INTER_AREA + ) return im def next_frame(self): @@ -95,57 +99,103 @@ class VideoStreamer: def frame2tensor(frame, device): if len(frame.shape) == 2: - return torch.from_numpy(frame/255.).float()[None, None].to(device) + return torch.from_numpy(frame / 255.0).float()[None, None].to(device) else: - return torch.from_numpy(frame/255.).float().permute(2, 0, 1)[None].to(device) - - -def make_matching_plot_fast(image0, image1, mkpts0, mkpts1, - color, text, path=None, margin=10, - opencv_display=False, opencv_title='', - small_text=[]): + return torch.from_numpy(frame / 255.0).float().permute(2, 0, 1)[None].to(device) + + +def make_matching_plot_fast( + image0, + image1, + mkpts0, + mkpts1, + color, + text, + path=None, + margin=10, + opencv_display=False, + opencv_title="", + small_text=[], +): H0, W0 = image0.shape[:2] H1, W1 = image1.shape[:2] H, W = max(H0, H1), W0 + W1 + margin - out = 255*np.ones((H, W, 3), np.uint8) + out = 255 * np.ones((H, W, 3), np.uint8) out[:H0, :W0, :] = image0 - out[:H1, W0+margin:, :] = image1 + out[:H1, W0 + margin :, :] = image1 # Scale factor for consistent visualization across scales. - sc = min(H / 640., 2.0) + sc = min(H / 640.0, 2.0) # Big text. Ht = int(30 * sc) # text height txt_color_fg = (255, 255, 255) txt_color_bg = (0, 0, 0) - + for i, t in enumerate(text): - cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, - 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) - cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, - 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) + cv2.putText( + out, + t, + (int(8 * sc), Ht * (i + 1)), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_bg, + 2, + cv2.LINE_AA, + ) + cv2.putText( + out, + t, + (int(8 * sc), Ht * (i + 1)), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_fg, + 1, + cv2.LINE_AA, + ) out_backup = out.copy() mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) - color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] + color = (np.array(color[:, :3]) * 255).astype(int)[:, ::-1] for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): c = c.tolist() - cv2.line(out, (x0, y0), (x1 + margin + W0, y1), - color=c, thickness=1, lineType=cv2.LINE_AA) + cv2.line( + out, + (x0, y0), + (x1 + margin + W0, y1), + color=c, + thickness=1, + lineType=cv2.LINE_AA, + ) # display line end-points as circles cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) - cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, - lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, lineType=cv2.LINE_AA) # Small text. Ht = int(18 * sc) # text height for i, t in enumerate(reversed(small_text)): - cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, - 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) - cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, - 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) + cv2.putText( + out, + t, + (int(8 * sc), int(H - Ht * (i + 0.6))), + cv2.FONT_HERSHEY_DUPLEX, + 0.5 * sc, + txt_color_bg, + 2, + cv2.LINE_AA, + ) + cv2.putText( + out, + t, + (int(8 * sc), int(H - Ht * (i + 0.6))), + cv2.FONT_HERSHEY_DUPLEX, + 0.5 * sc, + txt_color_fg, + 1, + cv2.LINE_AA, + ) if path is not None: cv2.imwrite(str(path), out) @@ -153,6 +203,5 @@ def make_matching_plot_fast(image0, image1, mkpts0, mkpts1, if opencv_display: cv2.imshow(opencv_title, out) cv2.waitKey(1) - - return out / 2 + out_backup / 2 + return out / 2 + out_backup / 2 diff --git a/third_party/DarkFeat/utils/nn.py b/third_party/DarkFeat/utils/nn.py index 8a80631d6e12d848cceee3b636baf49deaa7647a..956256aeae1b83700044f8f2df18f8913348ebe7 100644 --- a/third_party/DarkFeat/utils/nn.py +++ b/third_party/DarkFeat/utils/nn.py @@ -7,8 +7,8 @@ class NN2(nn.Module): super().__init__() def forward(self, data): - desc1, desc2 = data['descriptors0'].cuda(), data['descriptors1'].cuda() - kpts1, kpts2 = data['keypoints0'].cuda(), data['keypoints1'].cuda() + desc1, desc2 = data["descriptors0"].cuda(), data["descriptors1"].cuda() + kpts1, kpts2 = data["keypoints0"].cuda(), data["keypoints1"].cuda() # torch.cuda.synchronize() # t = time.time() @@ -16,10 +16,10 @@ class NN2(nn.Module): if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1: # no keypoints shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1] return { - 'matches0': kpts1.new_full(shape0, -1, dtype=torch.int), - 'matches1': kpts2.new_full(shape1, -1, dtype=torch.int), - 'matching_scores0': kpts1.new_zeros(shape0), - 'matching_scores1': kpts2.new_zeros(shape1), + "matches0": kpts1.new_full(shape0, -1, dtype=torch.int), + "matches1": kpts2.new_full(shape1, -1, dtype=torch.int), + "matching_scores0": kpts1.new_zeros(shape0), + "matching_scores1": kpts2.new_zeros(shape1), } sim = torch.matmul(desc1.squeeze().T, desc2.squeeze()) @@ -28,14 +28,16 @@ class NN2(nn.Module): nn21 = torch.argmax(sim, dim=0) mask = torch.eq(ids1, nn21[nn12]) - matches = torch.stack([torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)]) + matches = torch.stack( + [torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)] + ) # matches = torch.stack([ids1, nn12]) indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1 mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1 # torch.cuda.synchronize() # print(time.time() - t) - + matches_0 = matches[0].cpu().int().numpy() matches_1 = matches[1].cpu().int() for i in range(matches.shape[-1]): @@ -43,8 +45,8 @@ class NN2(nn.Module): mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]] return { - 'matches0': indices0, # use -1 for invalid match - 'matches1': indices0, # use -1 for invalid match - 'matching_scores0': mscores0, - 'matching_scores1': mscores0, + "matches0": indices0, # use -1 for invalid match + "matches1": indices0, # use -1 for invalid match + "matching_scores0": mscores0, + "matching_scores1": mscores0, } diff --git a/third_party/DarkFeat/utils/nnmatching.py b/third_party/DarkFeat/utils/nnmatching.py index 7be6f98c050fc2e416ef48e25ca0f293106c1082..6289623c28989dc73dfbeb1763228f301c62831b 100644 --- a/third_party/DarkFeat/utils/nnmatching.py +++ b/third_party/DarkFeat/utils/nnmatching.py @@ -3,28 +3,28 @@ import torch from .nn import NN2 from darkfeat import DarkFeat + class NNMatching(torch.nn.Module): - def __init__(self, model_path=''): + def __init__(self, model_path=""): super().__init__() self.nn = NN2().eval() self.darkfeat = DarkFeat(model_path).eval() def forward(self, data): - """ Run DarkFeat and nearest neighborhood matching + """Run DarkFeat and nearest neighborhood matching Args: data: dictionary with minimal keys: ['image0', 'image1'] """ pred = {} # Extract DarkFeat (keypoints, scores, descriptors) - if 'keypoints0' not in data: - pred0 = self.darkfeat({'image': data['image0']}) + if "keypoints0" not in data: + pred0 = self.darkfeat({"image": data["image0"]}) # print({k+'0': v[0].shape for k, v in pred0.items()}) - pred = {**pred, **{k+'0': [v] for k, v in pred0.items()}} - if 'keypoints1' not in data: - pred1 = self.darkfeat({'image': data['image1']}) - pred = {**pred, **{k+'1': [v] for k, v in pred1.items()}} - + pred = {**pred, **{k + "0": [v] for k, v in pred0.items()}} + if "keypoints1" not in data: + pred1 = self.darkfeat({"image": data["image1"]}) + pred = {**pred, **{k + "1": [v] for k, v in pred1.items()}} # Batch all features # We should either have i) one image per batch, or diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py b/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py index 52113027f2e7ddc144453df9f012f84d3b4ba95b..f428121d175af9f9786cfa9cf9c340b94a170521 100644 --- a/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py +++ b/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py @@ -1,3 +1,3 @@ from .num_inliers import NumInliersBenchmark from .mega_pose_est import MegaDepthPoseEstimationBenchmark -from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark \ No newline at end of file +from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py index 2104284b54d5fe339d6f12d9ae14dcdd3c0fb564..66292fe5a6efbdf328e5f27d806479616455cff7 100644 --- a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py +++ b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py @@ -5,8 +5,9 @@ from PIL import Image from tqdm import tqdm import torch.nn.functional as F + class MegaDepthPoseEstimationBenchmark: - def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + def __init__(self, data_root="data/megadepth", scene_names=None) -> None: if scene_names is None: self.scene_names = [ "0015_0.1_0.3.npz", @@ -23,14 +24,23 @@ class MegaDepthPoseEstimationBenchmark: ] self.data_root = data_root - def benchmark(self, keypoint_model, matching_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): - H,W = matching_model.get_output_resolution() + def benchmark( + self, + keypoint_model, + matching_model, + model_name=None, + resolution=None, + scale_intrinsics=True, + calibrated=True, + ): + H, W = matching_model.get_output_resolution() with torch.no_grad(): data_root = self.data_root tot_e_t, tot_e_R, tot_e_pose = [], [], [] thresholds = [5, 10, 20] for scene_ind in range(len(self.scenes)): import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] scene = self.scenes[scene_ind] pairs = scene["pair_infos"] @@ -47,14 +57,20 @@ class MegaDepthPoseEstimationBenchmark: T2 = poses[idx2].copy() R2, t2 = T2[:3, :3], T2[:3, 3] R, t = compute_relative_pose(R1, t1, R2, t2) - T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + T1_to_2 = np.concatenate((R, t[:, None]), axis=-1) im_A_path = f"{data_root}/{im_paths[idx1]}" im_B_path = f"{data_root}/{im_paths[idx2]}" - - keypoints_A = keypoint_model.detect_from_path(im_A_path, num_keypoints = 20_000)["keypoints"][0] - keypoints_B = keypoint_model.detect_from_path(im_B_path, num_keypoints = 20_000)["keypoints"][0] + + keypoints_A = keypoint_model.detect_from_path( + im_A_path, num_keypoints=20_000 + )["keypoints"][0] + keypoints_B = keypoint_model.detect_from_path( + im_B_path, num_keypoints=20_000 + )["keypoints"][0] warp, certainty = matching_model.match(im_A_path, im_B_path) - matches = matching_model.match_keypoints(keypoints_A, keypoints_B, warp, certainty, return_tuple = False) + matches = matching_model.match_keypoints( + keypoints_A, keypoints_B, warp, certainty, return_tuple=False + ) im_A = Image.open(im_A_path) w1, h1 = im_A.size im_B = Image.open(im_B_path) @@ -67,15 +83,20 @@ class MegaDepthPoseEstimationBenchmark: K1, K2 = K1.copy(), K2.copy() K1[:2] = K1[:2] * scale1 K2[:2] = K2[:2] * scale2 - kpts1, kpts2 = matching_model.to_pixel_coordinates(matches, h1, w1, h2, w2) + kpts1, kpts2 = matching_model.to_pixel_coordinates( + matches, h1, w1, h2, w2 + ) for _ in range(1): shuffling = np.random.permutation(np.arange(len(kpts1))) kpts1 = kpts1[shuffling] kpts2 = kpts2[shuffling] try: - threshold = 0.5 + threshold = 0.5 if calibrated: - norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + norm_threshold = threshold / ( + np.mean(np.abs(K1[:2, :2])) + + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1.cpu().numpy(), kpts2.cpu().numpy(), @@ -111,4 +132,4 @@ class MegaDepthPoseEstimationBenchmark: "map_5": map_5, "map_10": map_10, "map_20": map_20, - } \ No newline at end of file + } diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py index 15f4cdea05c601173fab765b92d5379e8f0bb349..e979bddfb09ff8760d83442b284662376a074998 100644 --- a/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py +++ b/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py @@ -5,8 +5,9 @@ from PIL import Image from tqdm import tqdm import torch.nn.functional as F + class MegaDepthPoseMNNBenchmark: - def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + def __init__(self, data_root="data/megadepth", scene_names=None) -> None: if scene_names is None: self.scene_names = [ "0015_0.1_0.3.npz", @@ -23,13 +24,23 @@ class MegaDepthPoseMNNBenchmark: ] self.data_root = data_root - def benchmark(self, detector_model, descriptor_model, matcher_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): + def benchmark( + self, + detector_model, + descriptor_model, + matcher_model, + model_name=None, + resolution=None, + scale_intrinsics=True, + calibrated=True, + ): with torch.no_grad(): data_root = self.data_root tot_e_t, tot_e_R, tot_e_pose = [], [], [] thresholds = [5, 10, 20] for scene_ind in range(len(self.scenes)): import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] scene = self.scenes[scene_ind] pairs = scene["pair_infos"] @@ -46,19 +57,36 @@ class MegaDepthPoseMNNBenchmark: T2 = poses[idx2].copy() R2, t2 = T2[:3, :3], T2[:3, 3] R, t = compute_relative_pose(R1, t1, R2, t2) - T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + T1_to_2 = np.concatenate((R, t[:, None]), axis=-1) im_A_path = f"{data_root}/{im_paths[idx1]}" im_B_path = f"{data_root}/{im_paths[idx2]}" detections_A = detector_model.detect_from_path(im_A_path) - keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + keypoints_A, P_A = ( + detections_A["keypoints"], + detections_A["confidence"], + ) detections_B = detector_model.detect_from_path(im_B_path) - keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"] - description_A = descriptor_model.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"] - description_B = descriptor_model.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"] - matches_A, matches_B, batch_ids = matcher_model.match(keypoints_A, description_A, - keypoints_B, description_B, - P_A = P_A, P_B = P_B, - normalize = True, inv_temp=20, threshold = 0.01) + keypoints_B, P_B = ( + detections_B["keypoints"], + detections_B["confidence"], + ) + description_A = descriptor_model.describe_keypoints_from_path( + im_A_path, keypoints_A + )["descriptions"] + description_B = descriptor_model.describe_keypoints_from_path( + im_B_path, keypoints_B + )["descriptions"] + matches_A, matches_B, batch_ids = matcher_model.match( + keypoints_A, + description_A, + keypoints_B, + description_B, + P_A=P_A, + P_B=P_B, + normalize=True, + inv_temp=20, + threshold=0.01, + ) im_A = Image.open(im_A_path) w1, h1 = im_A.size @@ -72,15 +100,20 @@ class MegaDepthPoseMNNBenchmark: K1, K2 = K1.copy(), K2.copy() K1[:2] = K1[:2] * scale1 K2[:2] = K2[:2] * scale2 - kpts1, kpts2 = matcher_model.to_pixel_coords(matches_A, matches_B, h1, w1, h2, w2) + kpts1, kpts2 = matcher_model.to_pixel_coords( + matches_A, matches_B, h1, w1, h2, w2 + ) for _ in range(1): shuffling = np.random.permutation(np.arange(len(kpts1))) kpts1 = kpts1[shuffling] kpts2 = kpts2[shuffling] try: - threshold = 0.5 + threshold = 0.5 if calibrated: - norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + norm_threshold = threshold / ( + np.mean(np.abs(K1[:2, :2])) + + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1.cpu().numpy(), kpts2.cpu().numpy(), @@ -116,4 +149,4 @@ class MegaDepthPoseMNNBenchmark: "map_5": map_5, "map_10": map_10, "map_20": map_20, - } \ No newline at end of file + } diff --git a/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py b/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py index 24be32b2bc54f1d650836e5ab2f540e80fd3d5c0..f2b36f6a2b97b9c7010ef2455352531ffe3e4405 100644 --- a/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py +++ b/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py @@ -3,39 +3,56 @@ import torch.nn as nn from DeDoDe.utils import * import DeDoDe + class NumInliersBenchmark(nn.Module): - - def __init__(self, dataset, num_samples = 1000, batch_size = 8, num_keypoints = 10_000, device = "cuda") -> None: + def __init__( + self, + dataset, + num_samples=1000, + batch_size=8, + num_keypoints=10_000, + device="cuda", + ) -> None: super().__init__() sampler = torch.utils.data.WeightedRandomSampler( - torch.ones(len(dataset)), replacement=False, num_samples=num_samples - ) + torch.ones(len(dataset)), replacement=False, num_samples=num_samples + ) dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler - ) + dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler + ) self.dataloader = dataloader self.tracked_metrics = {} self.batch_size = batch_size self.N = len(dataloader) self.num_keypoints = num_keypoints - - def compute_batch_metrics(self, outputs, batch, device = "cuda"): + + def compute_batch_metrics(self, outputs, batch, device="cuda"): kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"] B, K, H, W = batch["im_A"].shape - gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( - batch["im_A_depth"], - batch["im_B_depth"], - batch["T_1to2"], - batch["K1"], - batch["K2"], - H=H, - W=W, - ) - kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], - align_corners=False, mode = 'bilinear')[...,0].mT - legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], - align_corners=False, mode = 'bilinear')[...,0,:,0] - dists = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.]).float() + gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=H, + W=W, + ) + kpts_A_to_B = F.grid_sample( + gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2), + kpts_A[..., None, :], + align_corners=False, + mode="bilinear", + )[..., 0].mT + legit_A_to_B = F.grid_sample( + valid_mask_A_to_B.reshape(B, 1, H, W), + kpts_A[..., None, :], + align_corners=False, + mode="bilinear", + )[..., 0, :, 0] + dists = ( + torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.0] + ).float() if legit_A_to_B.sum() == 0: return percent_inliers_at_1 = (dists < 0.02).float().mean() @@ -44,33 +61,65 @@ class NumInliersBenchmark(nn.Module): percent_inliers_at_01 = (dists < 0.002).float().mean() percent_inliers_at_005 = (dists < 0.001).float().mean() - inlier_bins = torch.linspace(0, 0.002, steps = 100, device = device)[None] - inlier_counts = (dists[...,None] < inlier_bins).float().mean(dim=0) - self.tracked_metrics["inlier_counts"] = self.tracked_metrics.get("inlier_counts", 0) + 1/self.N * inlier_counts - self.tracked_metrics["percent_inliers_at_1"] = self.tracked_metrics.get("percent_inliers_at_1", 0) + 1/self.N * percent_inliers_at_1 - self.tracked_metrics["percent_inliers_at_05"] = self.tracked_metrics.get("percent_inliers_at_05", 0) + 1/self.N * percent_inliers_at_05 - self.tracked_metrics["percent_inliers_at_025"] = self.tracked_metrics.get("percent_inliers_at_025", 0) + 1/self.N * percent_inliers_at_025 - self.tracked_metrics["percent_inliers_at_01"] = self.tracked_metrics.get("percent_inliers_at_01", 0) + 1/self.N * percent_inliers_at_01 - self.tracked_metrics["percent_inliers_at_005"] = self.tracked_metrics.get("percent_inliers_at_005", 0) + 1/self.N * percent_inliers_at_005 + inlier_bins = torch.linspace(0, 0.002, steps=100, device=device)[None] + inlier_counts = (dists[..., None] < inlier_bins).float().mean(dim=0) + self.tracked_metrics["inlier_counts"] = ( + self.tracked_metrics.get("inlier_counts", 0) + 1 / self.N * inlier_counts + ) + self.tracked_metrics["percent_inliers_at_1"] = ( + self.tracked_metrics.get("percent_inliers_at_1", 0) + + 1 / self.N * percent_inliers_at_1 + ) + self.tracked_metrics["percent_inliers_at_05"] = ( + self.tracked_metrics.get("percent_inliers_at_05", 0) + + 1 / self.N * percent_inliers_at_05 + ) + self.tracked_metrics["percent_inliers_at_025"] = ( + self.tracked_metrics.get("percent_inliers_at_025", 0) + + 1 / self.N * percent_inliers_at_025 + ) + self.tracked_metrics["percent_inliers_at_01"] = ( + self.tracked_metrics.get("percent_inliers_at_01", 0) + + 1 / self.N * percent_inliers_at_01 + ) + self.tracked_metrics["percent_inliers_at_005"] = ( + self.tracked_metrics.get("percent_inliers_at_005", 0) + + 1 / self.N * percent_inliers_at_005 + ) def benchmark(self, detector): self.tracked_metrics = {} from tqdm import tqdm + print("Evaluating percent inliers...") - for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.): + for idx, batch in tqdm(enumerate(self.dataloader), mininterval=10.0): batch = to_cuda(batch) - outputs = detector.detect(batch, num_keypoints = self.num_keypoints) - keypoints_A, keypoints_B = outputs["keypoints"][:self.batch_size], outputs["keypoints"][self.batch_size:] + outputs = detector.detect(batch, num_keypoints=self.num_keypoints) + keypoints_A, keypoints_B = ( + outputs["keypoints"][: self.batch_size], + outputs["keypoints"][self.batch_size :], + ) if isinstance(outputs["keypoints"], (tuple, list)): - keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack(keypoints_B) + keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack( + keypoints_B + ) outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B} self.compute_batch_metrics(outputs, batch) import matplotlib.pyplot as plt - plt.plot(torch.linspace(0, 0.002, steps = 100), self.tracked_metrics["inlier_counts"].cpu()) + + plt.plot( + torch.linspace(0, 0.002, steps=100), + self.tracked_metrics["inlier_counts"].cpu(), + ) import numpy as np - x = np.linspace(0,0.002, 100) + + x = np.linspace(0, 0.002, 100) sigma = 0.52 * 2 / 512 - F = 1 - np.exp(-x**2 / (2*sigma**2)) + F = 1 - np.exp(-(x**2) / (2 * sigma**2)) plt.plot(x, F) plt.savefig("vis/inlier_counts") - [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items() if "percent" in name] \ No newline at end of file + [ + print(name, metric.item() * self.N / (idx + 1)) + for name, metric in self.tracked_metrics.items() + if "percent" in name + ] diff --git a/third_party/DeDoDe/DeDoDe/checkpoint.py b/third_party/DeDoDe/DeDoDe/checkpoint.py index 07d6f80ae09acf5702475504a8e8d61f40c21cd3..6429ca8b6999a133455bb9e271618f50be4a0ed8 100644 --- a/third_party/DeDoDe/DeDoDe/checkpoint.py +++ b/third_party/DeDoDe/DeDoDe/checkpoint.py @@ -6,6 +6,7 @@ import gc import DeDoDe + class CheckPoint: def __init__(self, dir=None, name="tmp"): self.name = name @@ -18,7 +19,7 @@ class CheckPoint: optimizer, lr_scheduler, n, - ): + ): if DeDoDe.RANK == 0: assert model is not None if isinstance(model, (DataParallel, DistributedDataParallel)): @@ -31,14 +32,14 @@ class CheckPoint: } torch.save(states, self.dir + self.name + f"_latest.pth") print(f"Saved states {list(states.keys())}, at step {n}") - + def load( self, model, optimizer, lr_scheduler, n, - ): + ): if os.path.exists(self.dir + self.name + f"_latest.pth") and DeDoDe.RANK == 0: states = torch.load(self.dir + self.name + f"_latest.pth") if "model" in states: @@ -56,4 +57,4 @@ class CheckPoint: del states gc.collect() torch.cuda.empty_cache() - return model, optimizer, lr_scheduler, n \ No newline at end of file + return model, optimizer, lr_scheduler, n diff --git a/third_party/DeDoDe/DeDoDe/datasets/megadepth.py b/third_party/DeDoDe/DeDoDe/datasets/megadepth.py index 7de9d9a8e270fb74a6591944878c0e5e70ddf650..70d76d471c0d0bd5b8545e28ea06a7d178a1abf6 100644 --- a/third_party/DeDoDe/DeDoDe/datasets/megadepth.py +++ b/third_party/DeDoDe/DeDoDe/datasets/megadepth.py @@ -10,6 +10,7 @@ from DeDoDe.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops import DeDoDe from DeDoDe.utils import * + class MegadepthScene: def __init__( self, @@ -23,14 +24,16 @@ class MegadepthScene: scene_info_detections=None, scene_info_detections3D=None, normalize=True, - max_num_pairs = 100_000, - scene_name = None, - use_horizontal_flip_aug = False, - grayscale = False, - clahe = False, + max_num_pairs=100_000, + scene_name=None, + use_horizontal_flip_aug=False, + grayscale=False, + clahe=False, ) -> None: self.data_root = data_root - self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.scene_name = ( + os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}" + ) self.image_paths = scene_info["image_paths"] self.depth_paths = scene_info["depth_paths"] self.intrinsics = scene_info["intrinsics"] @@ -49,7 +52,9 @@ class MegadepthScene: self.pairs = self.pairs[pairinds] self.overlaps = self.overlaps[pairinds] self.im_transform_ops = get_tuple_transform_ops( - resize=(ht, wt), normalize=normalize, clahe = clahe, + resize=(ht, wt), + normalize=normalize, + clahe=clahe, ) self.depth_transform_ops = get_depth_tuple_transform_ops( resize=(ht, wt), normalize=False @@ -62,17 +67,19 @@ class MegadepthScene: def load_im(self, im_B, crop=None): im = Image.open(im_B) return im - - def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): im_A = im_A.flip(-1) im_B = im_B.flip(-1) - depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) - flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) - K_A = flip_mat@K_A - K_B = flip_mat@K_B - + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to( + K_A.device + ) + K_A = flip_mat @ K_A + K_B = flip_mat @ K_B + return im_A, im_B, depth_A, depth_B, K_A, K_B - + def load_depth(self, depth_ref, crop=None): depth = np.array(h5py.File(depth_ref, "r")["depth"]) return torch.from_numpy(depth) @@ -87,8 +94,8 @@ class MegadepthScene: def scale_detections(self, detections, wi, hi): sx, sy = self.wt / wi, self.ht / hi - return detections * torch.tensor([[sx,sy]]) - + return detections * torch.tensor([[sx, sy]]) + def rand_shake(self, *things): t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2)) return [ @@ -99,18 +106,27 @@ class MegadepthScene: def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W): tracks3D = tracks3D.double() intrinsics = intrinsics.double() - bearing_vectors = pose[...,:3,:3] @ tracks3D.mT + pose[...,:3,3:] + bearing_vectors = pose[..., :3, :3] @ tracks3D.mT + pose[..., :3, 3:] hom_pixel_coords = (intrinsics @ bearing_vectors).mT - pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12) - legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1) + pixel_coords = hom_pixel_coords[..., :2] / (hom_pixel_coords[..., 2:] + 1e-12) + legit_detections = ( + (pixel_coords > 0).prod(dim=-1) + * (pixel_coords[..., 0] < W - 1) + * (pixel_coords[..., 1] < H - 1) + * (tracks3D != 0).prod(dim=-1) + ) return pixel_coords.float(), legit_detections.bool() def __getitem__(self, pair_idx): try: # read intrinsics of original size idx1, idx2 = self.pairs[pair_idx] - K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) - K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape( + 3, 3 + ) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape( + 3, 3 + ) # read and compute relative poses T1 = self.poses[idx1] @@ -138,19 +154,23 @@ class MegadepthScene: detections2D_A = self.detections[idx1] detections2D_B = self.detections[idx2] - + K = 10000 - tracks3D_A = torch.zeros(K,3) - tracks3D_B = torch.zeros(K,3) - tracks3D_A[:len(detections2D_A)] = torch.tensor(self.tracks3D[detections2D_A[:K,-1].astype(np.int32)]) - tracks3D_B[:len(detections2D_B)] = torch.tensor(self.tracks3D[detections2D_B[:K,-1].astype(np.int32)]) - - #projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A) - #tracks3D_B = torch.zeros(K,2) + tracks3D_A = torch.zeros(K, 3) + tracks3D_B = torch.zeros(K, 3) + tracks3D_A[: len(detections2D_A)] = torch.tensor( + self.tracks3D[detections2D_A[:K, -1].astype(np.int32)] + ) + tracks3D_B[: len(detections2D_B)] = torch.tensor( + self.tracks3D[detections2D_B[:K, -1].astype(np.int32)] + ) + + # projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A) + # tracks3D_B = torch.zeros(K,2) K1 = self.scale_intrinsic(K1, W_A, H_A) K2 = self.scale_intrinsic(K2, W_B, H_B) - + # Process images im_A, im_B = self.im_transform_ops((im_A, im_B)) depth_A, depth_B = self.depth_transform_ops( @@ -159,34 +179,43 @@ class MegadepthScene: [im_A, depth_A], t_A = self.rand_shake(im_A, depth_A) [im_B, depth_B], t_B = self.rand_shake(im_B, depth_B) - detections_A = -torch.ones(K,2) - detections_B = -torch.ones(K,2) - detections_A[:len(self.detections[idx1])] = self.scale_detections(torch.tensor(detections2D_A[:K,:2]), W_A, H_A) + t_A - detections_B[:len(self.detections[idx2])] = self.scale_detections(torch.tensor(detections2D_B[:K,:2]), W_B, H_B) + t_B + detections_A = -torch.ones(K, 2) + detections_B = -torch.ones(K, 2) + detections_A[: len(self.detections[idx1])] = ( + self.scale_detections(torch.tensor(detections2D_A[:K, :2]), W_A, H_A) + + t_A + ) + detections_B[: len(self.detections[idx2])] = ( + self.scale_detections(torch.tensor(detections2D_B[:K, :2]), W_B, H_B) + + t_B + ) - K1[:2, 2] += t_A K2[:2, 2] += t_B - + if self.use_horizontal_flip_aug: if np.random.rand() > 0.5: - im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) - detections_A[:,0] = W-detections_A - detections_B[:,0] = W-detections_B - + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip( + im_A, im_B, depth_A, depth_B, K1, K2 + ) + detections_A[:, 0] = W - detections_A + detections_B[:, 0] = W - detections_B + if DeDoDe.DEBUG_MODE: - tensor_to_pil(im_A[0], unnormalize=True).save( - f"vis/im_A.jpg") - tensor_to_pil(im_B[0], unnormalize=True).save( - f"vis/im_B.jpg") + tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg") if self.grayscale: - im_A = im_A.mean(dim=-3,keepdim=True) - im_B = im_B.mean(dim=-3,keepdim=True) + im_A = im_A.mean(dim=-3, keepdim=True) + im_B = im_B.mean(dim=-3, keepdim=True) data_dict = { "im_A": im_A, - "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "im_A_identifier": self.image_paths[idx1] + .split("/")[-1] + .split(".jpg")[0], "im_B": im_B, - "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0], + "im_B_identifier": self.image_paths[idx2] + .split("/")[-1] + .split(".jpg")[0], "im_A_depth": depth_A[0, 0], "im_B_depth": depth_B[0, 0], "pose_A": T1, @@ -211,19 +240,48 @@ class MegadepthScene: class MegadepthBuilder: - def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + def __init__( + self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True + ) -> None: self.data_root = data_root self.scene_info_root = os.path.join(data_root, "prep_scene_info") self.all_scenes = os.listdir(self.scene_info_root) self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those - self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) - self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.loftr_ignore_scenes = set( + [ + "0121.npy", + "0133.npy", + "0168.npy", + "0178.npy", + "0229.npy", + "0349.npy", + "0412.npy", + "0430.npy", + "0443.npy", + "1001.npy", + "5014.npy", + "5015.npy", + "5016.npy", + ] + ) + self.imc21_scenes = set( + [ + "0008.npy", + "0019.npy", + "0021.npy", + "0024.npy", + "0025.npy", + "0032.npy", + "0063.npy", + "1589.npy", + ] + ) self.test_scenes_loftr = ["0015.npy", "0022.npy"] self.loftr_ignore = loftr_ignore self.imc21_ignore = imc21_ignore - def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs): if split == "train": scene_names = set(self.all_scenes) - set(self.test_scenes) elif split == "train_loftr": @@ -248,15 +306,27 @@ class MegadepthBuilder: os.path.join(self.scene_info_root, scene_name), allow_pickle=True ).item() scene_info_detections = np.load( - os.path.join(self.scene_info_root, "detections", f"detections_{scene_name}"), allow_pickle=True + os.path.join( + self.scene_info_root, "detections", f"detections_{scene_name}" + ), + allow_pickle=True, ).item() scene_info_detections3D = np.load( - os.path.join(self.scene_info_root, "detections3D", f"detections3D_{scene_name}"), allow_pickle=True + os.path.join( + self.scene_info_root, "detections3D", f"detections3D_{scene_name}" + ), + allow_pickle=True, ) scenes.append( MegadepthScene( - self.data_root, scene_info, scene_info_detections = scene_info_detections, scene_info_detections3D = scene_info_detections3D, min_overlap=min_overlap,scene_name = scene_name, **kwargs + self.data_root, + scene_info, + scene_info_detections=scene_info_detections, + scene_info_detections3D=scene_info_detections3D, + min_overlap=min_overlap, + scene_name=scene_name, + **kwargs, ) ) return scenes diff --git a/third_party/DeDoDe/DeDoDe/decoder.py b/third_party/DeDoDe/DeDoDe/decoder.py index 4e1b58fcc588e6ee12c591b5f446829a914bc611..76f6c3b86e309e9f18e5525e132128c2de08c747 100644 --- a/third_party/DeDoDe/DeDoDe/decoder.py +++ b/third_party/DeDoDe/DeDoDe/decoder.py @@ -4,19 +4,26 @@ import torchvision.models as tvm class Decoder(nn.Module): - def __init__(self, layers, *args, super_resolution = False, num_prototypes = 1, **kwargs) -> None: + def __init__( + self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.layers = layers self.scales = self.layers.keys() self.super_resolution = super_resolution self.num_prototypes = num_prototypes - def forward(self, features, context = None, scale = None): + + def forward(self, features, context=None, scale=None): if context is not None: - features = torch.cat((features, context), dim = 1) + features = torch.cat((features, context), dim=1) stuff = self.layers[scale](features) - logits, context = stuff[:,:self.num_prototypes], stuff[:,self.num_prototypes:] + logits, context = ( + stuff[:, : self.num_prototypes], + stuff[:, self.num_prototypes :], + ) return logits, context + class ConvRefiner(nn.Module): def __init__( self, @@ -26,13 +33,16 @@ class ConvRefiner(nn.Module): dw=True, kernel_size=5, hidden_blocks=5, - amp = True, - residual = False, - amp_dtype = torch.float16, + amp=True, + residual=False, + amp_dtype=torch.float16, ): super().__init__() self.block1 = self.create_block( - in_dim, hidden_dim, dw=False, kernel_size=1, + in_dim, + hidden_dim, + dw=False, + kernel_size=1, ) self.hidden_blocks = nn.Sequential( *[ @@ -50,15 +60,15 @@ class ConvRefiner(nn.Module): self.amp = amp self.amp_dtype = amp_dtype self.residual = residual - + def create_block( self, in_dim, out_dim, dw=True, kernel_size=5, - bias = True, - norm_type = nn.BatchNorm2d, + bias=True, + norm_type=nn.BatchNorm2d, ): num_groups = 1 if not dw else in_dim if dw: @@ -74,17 +84,21 @@ class ConvRefiner(nn.Module): groups=num_groups, bias=bias, ) - norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + norm = ( + norm_type(out_dim) + if norm_type is nn.BatchNorm2d + else norm_type(num_channels=out_dim) + ) relu = nn.ReLU(inplace=True) conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) return nn.Sequential(conv1, norm, relu, conv2) - + def forward(self, feats): - b,c,hs,ws = feats.shape - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + b, c, hs, ws = feats.shape + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): x0 = self.block1(feats) x = self.hidden_blocks(x0) if self.residual: - x = (x + x0)/1.4 + x = (x + x0) / 1.4 x = self.out_conv(x) return x diff --git a/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py b/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py index 6d949a1b8ed2a58140af49e8167eda4e4099d022..0f98368f1ee812275726e306f356fdfbefa1663b 100644 --- a/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py +++ b/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py @@ -5,14 +5,18 @@ import torchvision.models as tvm import torch.nn.functional as F import numpy as np + class DeDoDeDescriptor(nn.Module): def __init__(self, encoder, decoder, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.encoder = encoder self.decoder = decoder import torchvision.transforms as transforms - self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - + + self.normalizer = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + def forward( self, batch, @@ -26,24 +30,43 @@ class DeDoDeDescriptor(nn.Module): context = None scales = self.decoder.scales for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)): - delta_descriptor, context = self.decoder(feature_map, scale = scale, context = context) + delta_descriptor, context = self.decoder( + feature_map, scale=scale, context=context + ) descriptor = descriptor + delta_descriptor if idx < len(scales) - 1: - size = sizes[-(idx+2)] - descriptor = F.interpolate(descriptor, size = size, mode = "bilinear", align_corners = False) - context = F.interpolate(context, size = size, mode = "bilinear", align_corners = False) - return {"description_grid" : descriptor} - + size = sizes[-(idx + 2)] + descriptor = F.interpolate( + descriptor, size=size, mode="bilinear", align_corners=False + ) + context = F.interpolate( + context, size=size, mode="bilinear", align_corners=False + ) + return {"description_grid": descriptor} + @torch.inference_mode() def describe_keypoints(self, batch, keypoints): self.train(False) description_grid = self.forward(batch)["description_grid"] - described_keypoints = F.grid_sample(description_grid.float(), keypoints[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT + described_keypoints = F.grid_sample( + description_grid.float(), + keypoints[:, None], + mode="bilinear", + align_corners=False, + )[:, :, 0].mT return {"descriptions": described_keypoints} - - def read_image(self, im_path, H = 560, W = 560): - return self.normalizer(torch.from_numpy(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)).cuda().float()[None] - def describe_keypoints_from_path(self, im_path, keypoints, H = 768, W = 768): - batch = {"image": self.read_image(im_path, H = H, W = W)} - return self.describe_keypoints(batch, keypoints) \ No newline at end of file + def read_image(self, im_path, H=560, W=560): + return ( + self.normalizer( + torch.from_numpy( + np.array(Image.open(im_path).resize((W, H))) / 255.0 + ).permute(2, 0, 1) + ) + .cuda() + .float()[None] + ) + + def describe_keypoints_from_path(self, im_path, keypoints, H=768, W=768): + batch = {"image": self.read_image(im_path, H=H, W=W)} + return self.describe_keypoints(batch, keypoints) diff --git a/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py b/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py index 494d39ca124941e7a9f870b427c9d1317c01dafc..343ef0cde0fbccdf981634bbdbd2c6b8948d0ee7 100644 --- a/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py +++ b/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py @@ -6,70 +6,107 @@ import torch.nn.functional as F from DeDoDe.utils import * import DeDoDe + class DescriptorLoss(nn.Module): - - def __init__(self, detector, num_keypoints = 5000, normalize_descriptions = False, inv_temp = 1, device = "cuda") -> None: + def __init__( + self, + detector, + num_keypoints=5000, + normalize_descriptions=False, + inv_temp=1, + device="cuda", + ) -> None: super().__init__() self.detector = detector self.tracked_metrics = {} self.num_keypoints = num_keypoints self.normalize_descriptions = normalize_descriptions self.inv_temp = inv_temp - + def warp_from_depth(self, batch, kpts_A, kpts_B): - mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A, - batch["im_A_depth"], - batch["im_B_depth"], - batch["T_1to2"], - batch["K1"], - batch["K2"],) - mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B, - batch["im_B_depth"], - batch["im_A_depth"], - batch["T_1to2"].inverse(), - batch["K2"], - batch["K1"],) + mask_A_to_B, kpts_A_to_B = warp_kpts( + kpts_A, + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + ) + mask_B_to_A, kpts_B_to_A = warp_kpts( + kpts_B, + batch["im_B_depth"], + batch["im_A_depth"], + batch["T_1to2"].inverse(), + batch["K2"], + batch["K1"], + ) return (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) - + def warp_from_homog(self, batch, kpts_A, kpts_B): kpts_A_to_B = homog_transform(batch["Homog_A_to_B"], kpts_A) kpts_B_to_A = homog_transform(batch["Homog_A_to_B"].inverse(), kpts_B) return (None, kpts_A_to_B), (None, kpts_B_to_A) def supervised_loss(self, outputs, batch): - kpts_A, kpts_B = self.detector.detect(batch, num_keypoints = self.num_keypoints)['keypoints'].clone().chunk(2) + kpts_A, kpts_B = ( + self.detector.detect(batch, num_keypoints=self.num_keypoints)["keypoints"] + .clone() + .chunk(2) + ) desc_grid_A, desc_grid_B = outputs["description_grid"].chunk(2) - desc_A = F.grid_sample(desc_grid_A.float(), kpts_A[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT - desc_B = F.grid_sample(desc_grid_B.float(), kpts_B[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT + desc_A = F.grid_sample( + desc_grid_A.float(), kpts_A[:, None], mode="bilinear", align_corners=False + )[:, :, 0].mT + desc_B = F.grid_sample( + desc_grid_B.float(), kpts_B[:, None], mode="bilinear", align_corners=False + )[:, :, 0].mT if "im_A_depth" in batch: - (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_depth(batch, kpts_A, kpts_B) + (mask_A_to_B, kpts_A_to_B), ( + mask_B_to_A, + kpts_B_to_A, + ) = self.warp_from_depth(batch, kpts_A, kpts_B) elif "Homog_A_to_B" in batch: - (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_homog(batch, kpts_A, kpts_B) - + (mask_A_to_B, kpts_A_to_B), ( + mask_B_to_A, + kpts_B_to_A, + ) = self.warp_from_homog(batch, kpts_A, kpts_B) + with torch.no_grad(): D_B = torch.cdist(kpts_A_to_B, kpts_B) D_A = torch.cdist(kpts_A, kpts_B_to_A) - inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) - * (D_A == D_A.min(dim=-2, keepdim = True).values) - * (D_B < 0.01) - * (D_A < 0.01)) - - logP_A_B = dual_log_softmax_matcher(desc_A, desc_B, - normalize = self.normalize_descriptions, - inv_temperature = self.inv_temp) - neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean() + inds = torch.nonzero( + (D_B == D_B.min(dim=-1, keepdim=True).values) + * (D_A == D_A.min(dim=-2, keepdim=True).values) + * (D_B < 0.01) + * (D_A < 0.01) + ) + + logP_A_B = dual_log_softmax_matcher( + desc_A, + desc_B, + normalize=self.normalize_descriptions, + inv_temperature=self.inv_temp, + ) + neg_log_likelihood = -logP_A_B[inds[:, 0], inds[:, 1], inds[:, 2]].mean() if False: import matplotlib.pyplot as plt - inds0 = inds[inds[:,0] == 0] - mnn_A = kpts_A[0,inds0[:,1]].detach().cpu() - mnn_B = kpts_B[0,inds0[:,2]].detach().cpu() - plt.scatter(mnn_A[:,0], -mnn_A[:,1], s = 0.5) + + inds0 = inds[inds[:, 0] == 0] + mnn_A = kpts_A[0, inds0[:, 1]].detach().cpu() + mnn_B = kpts_B[0, inds0[:, 2]].detach().cpu() + plt.scatter(mnn_A[:, 0], -mnn_A[:, 1], s=0.5) plt.savefig("vis/mnn_A.jpg") - self.tracked_metrics["neg_log_likelihood"] = (0.99 * self.tracked_metrics.get("neg_log_likelihood", neg_log_likelihood.detach().item()) + 0.01 * neg_log_likelihood.detach().item()) + self.tracked_metrics["neg_log_likelihood"] = ( + 0.99 + * self.tracked_metrics.get( + "neg_log_likelihood", neg_log_likelihood.detach().item() + ) + + 0.01 * neg_log_likelihood.detach().item() + ) if np.random.rand() > 0.99: print(self.tracked_metrics["neg_log_likelihood"]) return neg_log_likelihood - + def forward(self, outputs, batch): losses = self.supervised_loss(outputs, batch) - return losses \ No newline at end of file + return losses diff --git a/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py b/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py index a482d6ddfb8d44de4d00e815b3002f523700390e..dd68212099a2417ca89a562623f670f9f8526b04 100644 --- a/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py +++ b/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py @@ -8,15 +8,17 @@ import numpy as np from DeDoDe.utils import sample_keypoints, to_pixel_coords, to_normalized_coords - class DeDoDeDetector(nn.Module): def __init__(self, encoder, decoder, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.encoder = encoder self.decoder = decoder import torchvision.transforms as transforms - self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - + + self.normalizer = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + def forward( self, batch, @@ -30,24 +32,43 @@ class DeDoDeDetector(nn.Module): context = None scales = ["8", "4", "2", "1"] for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)): - delta_logits, context = self.decoder(feature_map, context = context, scale = scale) - logits = logits + delta_logits.float() # ensure float (need bf16 doesnt have f.interpolate) + delta_logits, context = self.decoder( + feature_map, context=context, scale=scale + ) + logits = ( + logits + delta_logits.float() + ) # ensure float (need bf16 doesnt have f.interpolate) if idx < len(scales) - 1: - size = sizes[-(idx+2)] - logits = F.interpolate(logits, size = size, mode = "bicubic", align_corners = False) - context = F.interpolate(context.float(), size = size, mode = "bilinear", align_corners = False) - return {"keypoint_logits" : logits.float()} - + size = sizes[-(idx + 2)] + logits = F.interpolate( + logits, size=size, mode="bicubic", align_corners=False + ) + context = F.interpolate( + context.float(), size=size, mode="bilinear", align_corners=False + ) + return {"keypoint_logits": logits.float()} + @torch.inference_mode() - def detect(self, batch, num_keypoints = 10_000): + def detect(self, batch, num_keypoints=10_000): self.train(False) keypoint_logits = self.forward(batch)["keypoint_logits"] - B,K,H,W = keypoint_logits.shape - keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1) - keypoints, confidence = sample_keypoints(keypoint_p.reshape(B,H,W), - use_nms = False, sample_topk = True, num_samples = num_keypoints, - return_scoremap=True, sharpen = False, upsample = False, - increase_coverage=True) + B, K, H, W = keypoint_logits.shape + keypoint_p = ( + keypoint_logits.reshape(B, K * H * W) + .softmax(dim=-1) + .reshape(B, K, H * W) + .sum(dim=1) + ) + keypoints, confidence = sample_keypoints( + keypoint_p.reshape(B, H, W), + use_nms=False, + sample_topk=True, + num_samples=num_keypoints, + return_scoremap=True, + sharpen=False, + upsample=False, + increase_coverage=True, + ) return {"keypoints": keypoints, "confidence": confidence} @torch.inference_mode() @@ -56,20 +77,26 @@ class DeDoDeDetector(nn.Module): keypoint_logits = self.forward(batch)["keypoint_logits"] return {"dense_keypoint_logits": keypoint_logits} - def read_image(self, im_path, H = 560, W = 560): + def read_image(self, im_path, H=560, W=560): pil_im = Image.open(im_path).resize((W, H)) - standard_im = np.array(pil_im)/255. - return self.normalizer(torch.from_numpy(standard_im).permute(2,0,1)).cuda().float()[None] + standard_im = np.array(pil_im) / 255.0 + return ( + self.normalizer(torch.from_numpy(standard_im).permute(2, 0, 1)) + .cuda() + .float()[None] + ) - def detect_from_path(self, im_path, num_keypoints = 30_000, H = 768, W = 768, dense = False): - batch = {"image": self.read_image(im_path, H = H, W = W)} + def detect_from_path( + self, im_path, num_keypoints=30_000, H=768, W=768, dense=False + ): + batch = {"image": self.read_image(im_path, H=H, W=W)} if dense: return self.detect_dense(batch) else: - return self.detect(batch, num_keypoints = num_keypoints) + return self.detect(batch, num_keypoints=num_keypoints) def to_pixel_coords(self, x, H, W): return to_pixel_coords(x, H, W) - + def to_normalized_coords(self, x, H, W): - return to_normalized_coords(x, H, W) \ No newline at end of file + return to_normalized_coords(x, H, W) diff --git a/third_party/DeDoDe/DeDoDe/detectors/loss.py b/third_party/DeDoDe/DeDoDe/detectors/loss.py index 74d47058c82714729a05ea8a3b8433f352af2f4a..924bb896a66034ef45b11420ca6d48a462092ed1 100644 --- a/third_party/DeDoDe/DeDoDe/detectors/loss.py +++ b/third_party/DeDoDe/DeDoDe/detectors/loss.py @@ -5,27 +5,34 @@ import math from DeDoDe.utils import * import DeDoDe + class KeyPointLoss(nn.Module): - - def __init__(self, smoothing_size = 1, use_max_logit = False, entropy_target = 80, - num_matches = 1024, jacobian_density_adjustment = False, - matchability_weight = 1, device = "cuda") -> None: + def __init__( + self, + smoothing_size=1, + use_max_logit=False, + entropy_target=80, + num_matches=1024, + jacobian_density_adjustment=False, + matchability_weight=1, + device="cuda", + ) -> None: super().__init__() - X = torch.linspace(-1,1,smoothing_size, device = device) - G = (-X**2 / (2 *1/2**2)).exp() - G = G/G.sum() + X = torch.linspace(-1, 1, smoothing_size, device=device) + G = (-(X**2) / (2 * 1 / 2**2)).exp() + G = G / G.sum() self.use_max_logit = use_max_logit self.entropy_target = entropy_target - self.smoothing_kernel = G[None, None, None,:] + self.smoothing_kernel = G[None, None, None, :] self.smoothing_size = smoothing_size self.tracked_metrics = {} self.center = None self.num_matches = num_matches self.jacobian_density_adjustment = jacobian_density_adjustment self.matchability_weight = matchability_weight - - def compute_consistency(self, logits_A, logits_B_to_A, mask = None): - + + def compute_consistency(self, logits_A, logits_B_to_A, mask=None): + masked_logits_A = torch.full_like(logits_A, -torch.inf) masked_logits_A[mask] = logits_A[mask] @@ -36,129 +43,186 @@ class KeyPointLoss(nn.Module): log_p_B_to_A = masked_logits_B_to_A.log_softmax(dim=-1)[mask] return self.compute_jensen_shannon_div(log_p_A, log_p_B_to_A) - - def compute_joint_neg_log_likelihood(self, logits_A, logits_B_to_A, detections_A = None, detections_B_to_A = None, mask = None, device = "cuda", dtype = torch.float32, num_matches = None): + + def compute_joint_neg_log_likelihood( + self, + logits_A, + logits_B_to_A, + detections_A=None, + detections_B_to_A=None, + mask=None, + device="cuda", + dtype=torch.float32, + num_matches=None, + ): B, K, HW = logits_A.shape logits_A, logits_B_to_A = logits_A.to(dtype), logits_B_to_A.to(dtype) - mask = mask[:,None].expand(B, K, HW).reshape(B, K*HW) - log_p_B_to_A = self.masked_log_softmax(logits_B_to_A.reshape(B,K*HW), mask = mask) - log_p_A = self.masked_log_softmax(logits_A.reshape(B,K*HW), mask = mask) + mask = mask[:, None].expand(B, K, HW).reshape(B, K * HW) + log_p_B_to_A = self.masked_log_softmax( + logits_B_to_A.reshape(B, K * HW), mask=mask + ) + log_p_A = self.masked_log_softmax(logits_A.reshape(B, K * HW), mask=mask) log_p = log_p_A + log_p_B_to_A if detections_A is None: detections_A = torch.zeros_like(log_p_A) if detections_B_to_A is None: detections_B_to_A = torch.zeros_like(log_p_B_to_A) detections_A = detections_A.reshape(B, HW) - detections_A[~mask] = 0 + detections_A[~mask] = 0 detections_B_to_A = detections_B_to_A.reshape(B, HW) detections_B_to_A[~mask] = 0 - log_p_target = log_p.detach() + 50*detections_A + 50*detections_B_to_A + log_p_target = log_p.detach() + 50 * detections_A + 50 * detections_B_to_A num_matches = self.num_matches if num_matches is None else num_matches - best_k = -(-log_p_target).flatten().kthvalue(k = B * num_matches, dim=-1).values - p_target = (log_p_target > best_k[..., None]).float().reshape(B,K*HW)/num_matches - return self.compute_cross_entropy(log_p_A[mask], p_target[mask]) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask]) - + best_k = -(-log_p_target).flatten().kthvalue(k=B * num_matches, dim=-1).values + p_target = (log_p_target > best_k[..., None]).float().reshape( + B, K * HW + ) / num_matches + return self.compute_cross_entropy( + log_p_A[mask], p_target[mask] + ) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask]) + def compute_jensen_shannon_div(self, log_p, log_q): - return 1/2 * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p)) - + return ( + 1 + / 2 + * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p)) + ) + def compute_kl_div(self, log_p, log_q): - return (log_p.exp()*(log_p-log_q)).sum(dim=-1) - + return (log_p.exp() * (log_p - log_q)).sum(dim=-1) + def masked_log_softmax(self, logits, mask): masked_logits = torch.full_like(logits, -torch.inf) masked_logits[mask] = logits[mask] log_p = masked_logits.log_softmax(dim=-1) return log_p - + def masked_softmax(self, logits, mask): masked_logits = torch.full_like(logits, -torch.inf) masked_logits[mask] = logits[mask] log_p = masked_logits.softmax(dim=-1) return log_p - - def compute_entropy(self, logits, mask = None): + + def compute_entropy(self, logits, mask=None): p = self.masked_softmax(logits, mask)[mask] log_p = self.masked_log_softmax(logits, mask)[mask] - return -(log_p * p).sum(dim=-1) + return -(log_p * p).sum(dim=-1) - def compute_detection_img(self, detections, mask, B, H, W, device = "cuda"): + def compute_detection_img(self, detections, mask, B, H, W, device="cuda"): kernel_size = 5 - X = torch.linspace(-2,2,kernel_size, device = device) - G = (-X**2 / (2 * (1/2)**2)).exp() # half pixel std - G = G/G.sum() - det_smoothing_kernel = G[None, None, None,:] - det_img = torch.zeros((B,1,H,W), device = device) # add small epsilon for later logstuff + X = torch.linspace(-2, 2, kernel_size, device=device) + G = (-(X**2) / (2 * (1 / 2) ** 2)).exp() # half pixel std + G = G / G.sum() + det_smoothing_kernel = G[None, None, None, :] + det_img = torch.zeros( + (B, 1, H, W), device=device + ) # add small epsilon for later logstuff for b in range(B): valid_detections = (detections[b][mask[b]]).int() - det_img[b,0][valid_detections[:,1], valid_detections[:,0]] = 1 - det_img = F.conv2d(det_img, weight = det_smoothing_kernel, padding = (kernel_size//2, 0)) - det_img = F.conv2d(det_img, weight = det_smoothing_kernel.mT, padding = (0, kernel_size//2)) + det_img[b, 0][valid_detections[:, 1], valid_detections[:, 0]] = 1 + det_img = F.conv2d( + det_img, weight=det_smoothing_kernel, padding=(kernel_size // 2, 0) + ) + det_img = F.conv2d( + det_img, weight=det_smoothing_kernel.mT, padding=(0, kernel_size // 2) + ) return det_img def compute_cross_entropy(self, log_p_hat, p): return -(log_p_hat * p).sum(dim=-1) - def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device = "cuda"): - smooth_keypoint_p = F.conv2d(keypoint_p.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (self.smoothing_size//2,0)) - smooth_keypoint_p = F.conv2d(smooth_keypoint_p, weight = self.smoothing_kernel.mT, padding = (0,self.smoothing_size//2)) - log_p_hat = (smooth_keypoint_p+1e-8).log().reshape(B,H*W).log_softmax(dim=-1) - smooth_has_depth = F.conv2d(has_depth.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (0,self.smoothing_size//2)) - smooth_has_depth = F.conv2d(smooth_has_depth, weight = self.smoothing_kernel.mT, padding = (self.smoothing_size//2,0)).reshape(B,H*W) - p = smooth_has_depth/smooth_has_depth.sum(dim=-1,keepdim=True) - return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy((p+1e-12).log(), p) + def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device="cuda"): + smooth_keypoint_p = F.conv2d( + keypoint_p.reshape(B, 1, H, W), + weight=self.smoothing_kernel, + padding=(self.smoothing_size // 2, 0), + ) + smooth_keypoint_p = F.conv2d( + smooth_keypoint_p, + weight=self.smoothing_kernel.mT, + padding=(0, self.smoothing_size // 2), + ) + log_p_hat = ( + (smooth_keypoint_p + 1e-8).log().reshape(B, H * W).log_softmax(dim=-1) + ) + smooth_has_depth = F.conv2d( + has_depth.reshape(B, 1, H, W), + weight=self.smoothing_kernel, + padding=(0, self.smoothing_size // 2), + ) + smooth_has_depth = F.conv2d( + smooth_has_depth, + weight=self.smoothing_kernel.mT, + padding=(self.smoothing_size // 2, 0), + ).reshape(B, H * W) + p = smooth_has_depth / smooth_has_depth.sum(dim=-1, keepdim=True) + return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy( + (p + 1e-12).log(), p + ) def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W): tracks3D = tracks3D.double() intrinsics = intrinsics.double() - bearing_vectors = pose[:,:3,:3] @ tracks3D.mT + pose[:,:3,3:] + bearing_vectors = pose[:, :3, :3] @ tracks3D.mT + pose[:, :3, 3:] hom_pixel_coords = (intrinsics @ bearing_vectors).mT - pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12) - legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1) + pixel_coords = hom_pixel_coords[..., :2] / (hom_pixel_coords[..., 2:] + 1e-12) + legit_detections = ( + (pixel_coords > 0).prod(dim=-1) + * (pixel_coords[..., 0] < W - 1) + * (pixel_coords[..., 1] < H - 1) + * (tracks3D != 0).prod(dim=-1) + ) return pixel_coords.float(), legit_detections.bool() - + def self_supervised_loss(self, outputs, batch): keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2) B, K, H, W = keypoint_logits_A.shape - keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W) - keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W) + keypoint_logits_A = keypoint_logits_A.reshape(B, K, H * W) + keypoint_logits_B = keypoint_logits_B.reshape(B, K, H * W) keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B)) - warp_A_to_B, mask_A_to_B = get_homog_warp( - batch["Homog_A_to_B"], H, W - ) + warp_A_to_B, mask_A_to_B = get_homog_warp(batch["Homog_A_to_B"], H, W) warp_B_to_A, mask_B_to_A = get_homog_warp( torch.linalg.inv(batch["Homog_A_to_B"]), H, W ) - B = 2*B - - warp = torch.cat((warp_A_to_B, warp_B_to_A)).reshape(B, H*W, 4) - mask = torch.cat((mask_A_to_B, mask_B_to_A)).reshape(B,H*W) - - keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W), - warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic") - - keypoint_logits_backwarped = keypoint_logits_backwarped.reshape(B,K,H*W) - joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped, - mask = mask.bool(), num_matches = 5_000).mean() + B = 2 * B + + warp = torch.cat((warp_A_to_B, warp_B_to_A)).reshape(B, H * W, 4) + mask = torch.cat((mask_A_to_B, mask_B_to_A)).reshape(B, H * W) + + keypoint_logits_backwarped = F.grid_sample( + torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B, K, H, W), + warp[..., -2:].reshape(B, H, W, 2).float(), + align_corners=False, + mode="bicubic", + ) + + keypoint_logits_backwarped = keypoint_logits_backwarped.reshape(B, K, H * W) + joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood( + keypoint_logits, + keypoint_logits_backwarped, + mask=mask.bool(), + num_matches=5_000, + ).mean() return joint_log_likelihood_loss - + def supervised_loss(self, outputs, batch): keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2) B, K, H, W = keypoint_logits_A.shape detections_A, detections_B = batch["detections_A"], batch["detections_B"] - + tracks3D_A, tracks3D_B = batch["tracks3D_A"], batch["tracks3D_B"] - gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( - batch["im_A_depth"], - batch["im_B_depth"], - batch["T_1to2"], - batch["K1"], - batch["K2"], - H=H, - W=W, - ) - gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp( + gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=H, + W=W, + ) + gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp( batch["im_B_depth"], batch["im_A_depth"], batch["T_1to2"].inverse(), @@ -167,103 +231,216 @@ class KeyPointLoss(nn.Module): H=H, W=W, ) - keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W) - keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W) + keypoint_logits_A = keypoint_logits_A.reshape(B, K, H * W) + keypoint_logits_B = keypoint_logits_B.reshape(B, K, H * W) keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B)) - B = 2*B + B = 2 * B gt_warp = torch.cat((gt_warp_A_to_B, gt_warp_B_to_A)) valid_mask = torch.cat((valid_mask_A_to_B, valid_mask_B_to_A)) - valid_mask = valid_mask.reshape(B,H*W) + valid_mask = valid_mask.reshape(B, H * W) binary_mask = valid_mask == 1 if self.jacobian_density_adjustment: - j_logdet = jacobi_determinant(gt_warp.reshape(B,H,W,4), valid_mask.reshape(B,H,W).float())[:,None] + j_logdet = jacobi_determinant( + gt_warp.reshape(B, H, W, 4), valid_mask.reshape(B, H, W).float() + )[:, None] else: j_logdet = 0 tracks3D = torch.cat((tracks3D_A, tracks3D_B)) - - #detections, legit_detections = self.tracks_to_detections(tracks3D, torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W) - #detections_backwarped, legit_backwarped_detections = self.tracks_to_detections(torch.cat((tracks3D_B, tracks3D_A)), torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W) + + # detections, legit_detections = self.tracks_to_detections(tracks3D, torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W) + # detections_backwarped, legit_backwarped_detections = self.tracks_to_detections(torch.cat((tracks3D_B, tracks3D_A)), torch.cat((batch["pose_A"],batch["pose_B"])), torch.cat((batch["K1"],batch["K2"])), H, W) detections = torch.cat((detections_A, detections_B)) - legit_detections = ((detections > 0).prod(dim = -1) * (detections[...,0] < W) * (detections[...,1] < H)).bool() - det_imgs_A, det_imgs_B = self.compute_detection_img(detections, legit_detections, B, H, W).chunk(2) + legit_detections = ( + (detections > 0).prod(dim=-1) + * (detections[..., 0] < W) + * (detections[..., 1] < H) + ).bool() + det_imgs_A, det_imgs_B = self.compute_detection_img( + detections, legit_detections, B, H, W + ).chunk(2) det_imgs = torch.cat((det_imgs_A, det_imgs_B)) - #det_imgs_backwarped = self.compute_detection_img(detections_backwarped, legit_backwarped_detections, B, H, W) - det_imgs_backwarped = F.grid_sample(torch.cat((det_imgs_B, det_imgs_A)).reshape(B,1,H,W), - gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic") + # det_imgs_backwarped = self.compute_detection_img(detections_backwarped, legit_backwarped_detections, B, H, W) + det_imgs_backwarped = F.grid_sample( + torch.cat((det_imgs_B, det_imgs_A)).reshape(B, 1, H, W), + gt_warp[..., -2:].reshape(B, H, W, 2).float(), + align_corners=False, + mode="bicubic", + ) + + keypoint_logits_backwarped = F.grid_sample( + torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B, K, H, W), + gt_warp[..., -2:].reshape(B, H, W, 2).float(), + align_corners=False, + mode="bicubic", + ) - keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W), - gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic") - # Note: Below step should be taken, but seems difficult to get it to work well. - #keypoint_logits_B_to_A = keypoint_logits_B_to_A + j_logdet_A_to_B # adjust for the viewpoint by log jacobian of warp - keypoint_logits_backwarped = (keypoint_logits_backwarped + j_logdet).reshape(B,K,H*W) - - - depth = F.interpolate(torch.cat((batch["im_A_depth"][:,None],batch["im_B_depth"][:,None]),dim=0), size = (H,W), mode = "bilinear", align_corners=False) - has_depth = (depth > 0).float().reshape(B,H*W) - - joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped, - mask = binary_mask, detections_A = det_imgs, - detections_B_to_A = det_imgs_backwarped).mean() - keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1) - matchability_loss = self.compute_matchability(keypoint_p, has_depth, B, K, H, W).mean() - - #peakiness_loss = self.compute_negative_peakiness(keypoint_logits.reshape(B,H,W), mask = binary_mask) - #mnn_loss = self.compute_mnn_loss(keypoint_logits_A, keypoint_logits_B, gt_warp_A_to_B, valid_mask_A_to_B, B, H, W) - B = B//2 + # keypoint_logits_B_to_A = keypoint_logits_B_to_A + j_logdet_A_to_B # adjust for the viewpoint by log jacobian of warp + keypoint_logits_backwarped = (keypoint_logits_backwarped + j_logdet).reshape( + B, K, H * W + ) + + depth = F.interpolate( + torch.cat( + (batch["im_A_depth"][:, None], batch["im_B_depth"][:, None]), dim=0 + ), + size=(H, W), + mode="bilinear", + align_corners=False, + ) + has_depth = (depth > 0).float().reshape(B, H * W) + + joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood( + keypoint_logits, + keypoint_logits_backwarped, + mask=binary_mask, + detections_A=det_imgs, + detections_B_to_A=det_imgs_backwarped, + ).mean() + keypoint_p = ( + keypoint_logits.reshape(B, K * H * W) + .softmax(dim=-1) + .reshape(B, K, H * W) + .sum(dim=1) + ) + matchability_loss = self.compute_matchability( + keypoint_p, has_depth, B, K, H, W + ).mean() + + # peakiness_loss = self.compute_negative_peakiness(keypoint_logits.reshape(B,H,W), mask = binary_mask) + # mnn_loss = self.compute_mnn_loss(keypoint_logits_A, keypoint_logits_B, gt_warp_A_to_B, valid_mask_A_to_B, B, H, W) + B = B // 2 import matplotlib.pyplot as plt - kpts_A = sample_keypoints(keypoint_p[:B].reshape(B,H,W), - use_nms = False, sample_topk = True, num_samples = 4*2048) - kpts_B = sample_keypoints(keypoint_p[B:].reshape(B,H,W), - use_nms = False, sample_topk = True, num_samples = 4*2048) - kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], - align_corners=False, mode = 'bilinear')[...,0].mT - legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], - align_corners=False, mode = 'bilinear')[...,0,:,0] - percent_inliers = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0] < 0.01).float().mean() - self.tracked_metrics["mega_percent_inliers"] = (0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers) + 0.1 * percent_inliers) + + kpts_A = sample_keypoints( + keypoint_p[:B].reshape(B, H, W), + use_nms=False, + sample_topk=True, + num_samples=4 * 2048, + ) + kpts_B = sample_keypoints( + keypoint_p[B:].reshape(B, H, W), + use_nms=False, + sample_topk=True, + num_samples=4 * 2048, + ) + kpts_A_to_B = F.grid_sample( + gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2), + kpts_A[..., None, :], + align_corners=False, + mode="bilinear", + )[..., 0].mT + legit_A_to_B = F.grid_sample( + valid_mask_A_to_B.reshape(B, 1, H, W), + kpts_A[..., None, :], + align_corners=False, + mode="bilinear", + )[..., 0, :, 0] + percent_inliers = ( + ( + torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0] + < 0.01 + ) + .float() + .mean() + ) + self.tracked_metrics["mega_percent_inliers"] = ( + 0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers) + + 0.1 * percent_inliers + ) if torch.rand(1) > 0.995: keypoint_logits_A_to_B = keypoint_logits_backwarped[:B] import matplotlib.pyplot as plt import os - os.makedirs("vis",exist_ok = True) + + os.makedirs("vis", exist_ok=True) for b in range(0, B, 2): - #import cv2 - plt.scatter(kpts_A_to_B[b,:,0].cpu(),-kpts_A_to_B[b,:,1].cpu(), s = 1) - plt.scatter(kpts_B[b,:,0].cpu(),-kpts_B[b,:,1].cpu(), s = 1) - plt.xlim(-1,1) - plt.ylim(-1,1) + # import cv2 + plt.scatter( + kpts_A_to_B[b, :, 0].cpu(), -kpts_A_to_B[b, :, 1].cpu(), s=1 + ) + plt.scatter(kpts_B[b, :, 0].cpu(), -kpts_B[b, :, 1].cpu(), s=1) + plt.xlim(-1, 1) + plt.ylim(-1, 1) plt.savefig(f"vis/keypoints_A_to_B_vs_B_{b}.png") plt.close() - tensor_to_pil(keypoint_logits_A[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), - autoscale = True).save(f"vis/logits_A_{b}.png") - tensor_to_pil(keypoint_logits_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), - autoscale = True).save(f"vis/logits_B_{b}.png") - tensor_to_pil(keypoint_logits_A_to_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), - autoscale = True).save(f"vis/logits_A_to_B{b}.png") - tensor_to_pil(keypoint_logits_A[b].softmax(dim=-1).reshape(1,H,W).expand(3,H,W).detach().cpu(), - autoscale = True).save(f"vis/keypoint_p_A_{b}.png") - tensor_to_pil(keypoint_logits_B[b].softmax(dim=-1).reshape(1,H,W).expand(3,H,W).detach().cpu(), - autoscale = True).save(f"vis/keypoint_p_B_{b}.png") - tensor_to_pil(has_depth[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), autoscale=True).save(f"vis/has_depth_A_{b}.png") - tensor_to_pil(valid_mask_A_to_B[b].reshape(1,H,W).expand(3,H,W).detach().cpu(), autoscale=True).save(f"vis/valid_mask_A_to_B_{b}.png") - tensor_to_pil(batch['im_A'][b], unnormalize=True).save( - f"vis/im_A_{b}.jpg") - tensor_to_pil(batch['im_B'][b], unnormalize=True).save( - f"vis/im_B_{b}.jpg") + tensor_to_pil( + keypoint_logits_A[b] + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/logits_A_{b}.png") + tensor_to_pil( + keypoint_logits_B[b] + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/logits_B_{b}.png") + tensor_to_pil( + keypoint_logits_A_to_B[b] + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/logits_A_to_B{b}.png") + tensor_to_pil( + keypoint_logits_A[b] + .softmax(dim=-1) + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/keypoint_p_A_{b}.png") + tensor_to_pil( + keypoint_logits_B[b] + .softmax(dim=-1) + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/keypoint_p_B_{b}.png") + tensor_to_pil( + has_depth[b].reshape(1, H, W).expand(3, H, W).detach().cpu(), + autoscale=True, + ).save(f"vis/has_depth_A_{b}.png") + tensor_to_pil( + valid_mask_A_to_B[b] + .reshape(1, H, W) + .expand(3, H, W) + .detach() + .cpu(), + autoscale=True, + ).save(f"vis/valid_mask_A_to_B_{b}.png") + tensor_to_pil(batch["im_A"][b], unnormalize=True).save( + f"vis/im_A_{b}.jpg" + ) + tensor_to_pil(batch["im_B"][b], unnormalize=True).save( + f"vis/im_B_{b}.jpg" + ) plt.close() - tot_loss = joint_log_likelihood_loss + self.matchability_weight * matchability_loss# - #tot_loss = tot_loss + (-2*consistency_loss).detach().exp()*compression_loss + tot_loss = ( + joint_log_likelihood_loss + self.matchability_weight * matchability_loss + ) # + # tot_loss = tot_loss + (-2*consistency_loss).detach().exp()*compression_loss if torch.rand(1) > 1: - print(f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}") + print( + f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}" + ) print(f"{joint_log_likelihood_loss=} {matchability_loss=}") print(f"Total Loss: {tot_loss.item()}") - return tot_loss - + return tot_loss + def forward(self, outputs, batch): - + if not isinstance(outputs, list): outputs = [outputs] losses = 0 @@ -272,4 +449,4 @@ class KeyPointLoss(nn.Module): losses = losses + self.self_supervised_loss(output, batch) else: losses = losses + self.supervised_loss(output, batch) - return losses \ No newline at end of file + return losses diff --git a/third_party/DeDoDe/DeDoDe/encoder.py b/third_party/DeDoDe/DeDoDe/encoder.py index faf56c4b6629ce7147b46272ae1f4715e4d10740..2aebb1c5ac890c77d01774ab74caed460c2ff028 100644 --- a/third_party/DeDoDe/DeDoDe/encoder.py +++ b/third_party/DeDoDe/DeDoDe/encoder.py @@ -4,7 +4,7 @@ import torchvision.models as tvm class VGG19(nn.Module): - def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + def __init__(self, pretrained=False, amp=False, amp_dtype=torch.float16) -> None: super().__init__() self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) # Maxpool layers: 6, 13, 26, 39 @@ -12,7 +12,7 @@ class VGG19(nn.Module): self.amp_dtype = amp_dtype def forward(self, x, **kwargs): - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): feats = [] sizes = [] for layer in self.layers: @@ -22,21 +22,30 @@ class VGG19(nn.Module): x = layer(x) return feats, sizes + class VGG(nn.Module): - def __init__(self, size = "19", pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + def __init__( + self, size="19", pretrained=False, amp=False, amp_dtype=torch.float16 + ) -> None: super().__init__() if size == "11": - self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22]) - elif size == "13": - self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28]) - elif size == "19": - self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.layers = nn.ModuleList( + tvm.vgg11_bn(pretrained=pretrained).features[:22] + ) + elif size == "13": + self.layers = nn.ModuleList( + tvm.vgg13_bn(pretrained=pretrained).features[:28] + ) + elif size == "19": + self.layers = nn.ModuleList( + tvm.vgg19_bn(pretrained=pretrained).features[:40] + ) # Maxpool layers: 6, 13, 26, 39 self.amp = amp self.amp_dtype = amp_dtype def forward(self, x, **kwargs): - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): feats = [] sizes = [] for layer in self.layers: diff --git a/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py b/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py index 5cc76cad77ee403d7d5ab729c786982a47fbe6e9..5927cff63be726b842e74647f2beae081d803dca 100644 --- a/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py +++ b/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py @@ -6,33 +6,59 @@ import torch.nn.functional as F import numpy as np from DeDoDe.utils import dual_softmax_matcher, to_pixel_coords, to_normalized_coords -class DualSoftMaxMatcher(nn.Module): + +class DualSoftMaxMatcher(nn.Module): @torch.inference_mode() - def match(self, keypoints_A, descriptions_A, - keypoints_B, descriptions_B, P_A = None, P_B = None, - normalize = False, inv_temp = 1, threshold = 0.0): + def match( + self, + keypoints_A, + descriptions_A, + keypoints_B, + descriptions_B, + P_A=None, + P_B=None, + normalize=False, + inv_temp=1, + threshold=0.0, + ): if isinstance(descriptions_A, list): - matches = [self.match(k_A[None], d_A[None], k_B[None], d_B[None], normalize = normalize, - inv_temp = inv_temp, threshold = threshold) - for k_A,d_A,k_B,d_B in - zip(keypoints_A, descriptions_A, keypoints_B, descriptions_B)] + matches = [ + self.match( + k_A[None], + d_A[None], + k_B[None], + d_B[None], + normalize=normalize, + inv_temp=inv_temp, + threshold=threshold, + ) + for k_A, d_A, k_B, d_B in zip( + keypoints_A, descriptions_A, keypoints_B, descriptions_B + ) + ] matches_A = torch.cat([m[0] for m in matches]) matches_B = torch.cat([m[1] for m in matches]) inds = torch.cat([m[2] + b for b, m in enumerate(matches)]) return matches_A, matches_B, inds - - P = dual_softmax_matcher(descriptions_A, descriptions_B, - normalize = normalize, inv_temperature=inv_temp, - ) - inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) - * (P == P.max(dim=-2, keepdim = True).values) * (P > threshold)) - batch_inds = inds[:,0] - matches_A = keypoints_A[batch_inds, inds[:,1]] - matches_B = keypoints_B[batch_inds, inds[:,2]] + + P = dual_softmax_matcher( + descriptions_A, + descriptions_B, + normalize=normalize, + inv_temperature=inv_temp, + ) + inds = torch.nonzero( + (P == P.max(dim=-1, keepdim=True).values) + * (P == P.max(dim=-2, keepdim=True).values) + * (P > threshold) + ) + batch_inds = inds[:, 0] + matches_A = keypoints_A[batch_inds, inds[:, 1]] + matches_B = keypoints_B[batch_inds, inds[:, 2]] return matches_A, matches_B, batch_inds def to_pixel_coords(self, x_A, x_B, H_A, W_A, H_B, W_B): return to_pixel_coords(x_A, H_A, W_A), to_pixel_coords(x_B, H_B, W_B) - + def to_normalized_coords(self, x_A, x_B, H_A, W_A, H_B, W_B): - return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B) \ No newline at end of file + return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B) diff --git a/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py b/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py index b500da585ffd0216e2e434a2179f3045f485dbfb..6296a2833d1dd18c9d52ba45dc6649ff383dfb6f 100644 --- a/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py +++ b/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py @@ -1,3 +1 @@ from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B - - \ No newline at end of file diff --git a/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py b/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py index f43dd22f0d59dabd18eef4beae4a3637dcd8912b..8c6d93d4b6d3a7c0daaf767fa53cd021f248dacd 100644 --- a/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py +++ b/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py @@ -7,8 +7,7 @@ from DeDoDe.decoder import ConvRefiner, Decoder from DeDoDe.encoder import VGG19, VGG - -def dedode_detector_B(device = "cuda", weights = None): +def dedode_detector_B(device="cuda", weights=None): residual = True hidden_blocks = 5 amp_dtype = torch.float16 @@ -20,55 +19,55 @@ def dedode_detector_B(device = "cuda", weights = None): 512, 512, 256 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "4": ConvRefiner( - 256+256, + 256 + 256, 256, 128 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "2": ConvRefiner( - 128+128, + 128 + 128, 64, 32 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "1": ConvRefiner( 64 + 32, 32, 1 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), } ) - encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype) + encoder = VGG19(pretrained=False, amp=amp, amp_dtype=amp_dtype) decoder = Decoder(conv_refiner) - model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device) + model = DeDoDeDetector(encoder=encoder, decoder=decoder).to(device) if weights is not None: model.load_state_dict(weights) return model -def dedode_detector_L(device = "cuda", weights = None): +def dedode_detector_L(device="cuda", weights=None): NUM_PROTOTYPES = 1 residual = True hidden_blocks = 8 - amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp_dtype = ( + torch.float16 + ) # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 amp = True conv_refiner = nn.ModuleDict( { @@ -76,56 +75,55 @@ def dedode_detector_L(device = "cuda", weights = None): 512, 512, 256 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "4": ConvRefiner( - 256+256, + 256 + 256, 256, 128 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "2": ConvRefiner( - 128+128, + 128 + 128, 128, 64 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "1": ConvRefiner( 64 + 64, 64, 1 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), } ) - encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype) + encoder = VGG19(pretrained=False, amp=amp, amp_dtype=amp_dtype) decoder = Decoder(conv_refiner) - model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device) + model = DeDoDeDetector(encoder=encoder, decoder=decoder).to(device) if weights is not None: model.load_state_dict(weights) return model - -def dedode_descriptor_B(device = "cuda", weights = None): - NUM_PROTOTYPES = 256 # == descriptor size +def dedode_descriptor_B(device="cuda", weights=None): + NUM_PROTOTYPES = 256 # == descriptor size residual = True hidden_blocks = 5 - amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp_dtype = ( + torch.float16 + ) # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 amp = True conv_refiner = nn.ModuleDict( { @@ -133,45 +131,43 @@ def dedode_descriptor_B(device = "cuda", weights = None): 512, 512, 256 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "4": ConvRefiner( - 256+256, + 256 + 256, 256, 128 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "2": ConvRefiner( - 128+128, + 128 + 128, 64, 32 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, - + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), "1": ConvRefiner( 64 + 32, 32, 1 + NUM_PROTOTYPES, - hidden_blocks = hidden_blocks, - residual = residual, - amp = amp, - amp_dtype = amp_dtype, + hidden_blocks=hidden_blocks, + residual=residual, + amp=amp, + amp_dtype=amp_dtype, ), } ) - encoder = VGG(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype) + encoder = VGG(size="19", pretrained=False, amp=amp, amp_dtype=amp_dtype) decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES) - model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device) + model = DeDoDeDescriptor(encoder=encoder, decoder=decoder).to(device) if weights is not None: model.load_state_dict(weights) return model diff --git a/third_party/DeDoDe/DeDoDe/train.py b/third_party/DeDoDe/DeDoDe/train.py index 348f268d6f7752bdf2ad45ba1851ec13a57825a0..2572e3a726d16ffef1bb734feeba0a7a19f4d354 100644 --- a/third_party/DeDoDe/DeDoDe/train.py +++ b/third_party/DeDoDe/DeDoDe/train.py @@ -3,7 +3,7 @@ from tqdm import tqdm from DeDoDe.utils import to_cuda -def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kwargs): +def train_step(train_batch, model, objective, optimizer, grad_scaler=None, **kwargs): optimizer.zero_grad() out = model(train_batch) l = objective(out, train_batch) @@ -20,9 +20,17 @@ def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kw def train_k_steps( - n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler = None, progress_bar=True + n_0, + k, + dataloader, + model, + objective, + optimizer, + lr_scheduler, + grad_scaler=None, + progress_bar=True, ): - for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval = 10.): + for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval=10.0): batch = next(dataloader) model.train(True) batch = to_cuda(batch) @@ -33,7 +41,7 @@ def train_k_steps( optimizer=optimizer, lr_scheduler=lr_scheduler, n=n, - grad_scaler = grad_scaler, + grad_scaler=grad_scaler, ) lr_scheduler.step() diff --git a/third_party/DeDoDe/DeDoDe/utils.py b/third_party/DeDoDe/DeDoDe/utils.py index 183c35f5606301720adffa2b7b25e7996404e1a1..1076a06b98ac5ce74f847e75fff86d2a913f9348 100644 --- a/third_party/DeDoDe/DeDoDe/utils.py +++ b/third_party/DeDoDe/DeDoDe/utils.py @@ -11,13 +11,14 @@ from einops import rearrange import torch from time import perf_counter + def recover_pose(E, kpts0, kpts1, K0, K1, mask): best_num_inliers = 0 - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) @@ -27,17 +28,16 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask): return ret - # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py # --- GEOMETRY --- def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf ) @@ -54,150 +54,213 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): return ret -def get_grid(B,H,W, device = "cuda"): +def get_grid(B, H, W, device="cuda"): x1_n = torch.meshgrid( - *[ - torch.linspace( - -1 + 1 / n, 1 - 1 / n, n, device=device - ) - for n in (B, H, W) - ] + *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)] ) x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) return x1_n + @torch.no_grad() -def finite_diff_hessian(f: tuple(["B", "H", "W"]), device = "cuda"): - dxx = torch.tensor([[0,0,0],[1,-2,1],[0,0,0]], device = device)[None,None]/2 - dxy = torch.tensor([[1,0,-1],[0,0,0],[-1,0,1]], device = device)[None,None]/4 +def finite_diff_hessian(f: tuple(["B", "H", "W"]), device="cuda"): + dxx = ( + torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], device=device)[None, None] / 2 + ) + dxy = ( + torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], device=device)[None, None] / 4 + ) dyy = dxx.mT - Hxx = F.conv2d(f[:,None], dxx, padding = 1)[:,0] - Hxy = F.conv2d(f[:,None], dxy, padding = 1)[:,0] - Hyy = F.conv2d(f[:,None], dyy, padding = 1)[:,0] - H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim = -1).reshape(*f.shape,2,2) + Hxx = F.conv2d(f[:, None], dxx, padding=1)[:, 0] + Hxy = F.conv2d(f[:, None], dxy, padding=1)[:, 0] + Hyy = F.conv2d(f[:, None], dyy, padding=1)[:, 0] + H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim=-1).reshape(*f.shape, 2, 2) return H -def finite_diff_grad(f: tuple(["B", "H", "W"]), device = "cuda"): - dx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]],device = device)[None,None]/2 + +def finite_diff_grad(f: tuple(["B", "H", "W"]), device="cuda"): + dx = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], device=device)[None, None] / 2 dy = dx.mT - gx = F.conv2d(f[:,None], dx, padding = 1) - gy = F.conv2d(f[:,None], dy, padding = 1) - g = torch.cat((gx, gy), dim = 1) + gx = F.conv2d(f[:, None], dx, padding=1) + gy = F.conv2d(f[:, None], dy, padding=1) + g = torch.cat((gx, gy), dim=1) return g -def fast_inv_2x2(matrix: tuple[...,2,2], eps = 1e-10): - return 1/(torch.linalg.det(matrix)[...,None,None]+eps) * torch.stack((matrix[...,1,1],-matrix[...,0,1], - -matrix[...,1,0],matrix[...,0,0]),dim=-1).reshape(*matrix.shape) -def newton_step(f:tuple["B","H","W"], inds, device = "cuda"): - B,H,W = f.shape - Hess = finite_diff_hessian(f).reshape(B,H*W,2,2) - Hess = torch.gather(Hess, dim = 1, index = inds[...,None].expand(B,-1,2,2)) - grad = finite_diff_grad(f).reshape(B,H*W,2) - grad = torch.gather(grad, dim = 1, index = inds) - Hessinv = fast_inv_2x2(Hess-torch.eye(2, device = device)[None,None]) - step = (Hessinv @ grad[...,None]) - return step[...,0] +def fast_inv_2x2(matrix: tuple[..., 2, 2], eps=1e-10): + return ( + 1 + / (torch.linalg.det(matrix)[..., None, None] + eps) + * torch.stack( + ( + matrix[..., 1, 1], + -matrix[..., 0, 1], + -matrix[..., 1, 0], + matrix[..., 0, 0], + ), + dim=-1, + ).reshape(*matrix.shape) + ) + + +def newton_step(f: tuple["B", "H", "W"], inds, device="cuda"): + B, H, W = f.shape + Hess = finite_diff_hessian(f).reshape(B, H * W, 2, 2) + Hess = torch.gather(Hess, dim=1, index=inds[..., None].expand(B, -1, 2, 2)) + grad = finite_diff_grad(f).reshape(B, H * W, 2) + grad = torch.gather(grad, dim=1, index=inds) + Hessinv = fast_inv_2x2(Hess - torch.eye(2, device=device)[None, None]) + step = Hessinv @ grad[..., None] + return step[..., 0] + @torch.no_grad() -def sample_keypoints(scoremap, num_samples = 8192, device = "cuda", use_nms = True, - sample_topk = False, return_scoremap = False, sharpen = False, upsample = False, - increase_coverage = False,): - #scoremap = scoremap**2 - log_scoremap = (scoremap+1e-10).log() +def sample_keypoints( + scoremap, + num_samples=8192, + device="cuda", + use_nms=True, + sample_topk=False, + return_scoremap=False, + sharpen=False, + upsample=False, + increase_coverage=False, +): + # scoremap = scoremap**2 + log_scoremap = (scoremap + 1e-10).log() if upsample: - log_scoremap = F.interpolate(log_scoremap[:,None], scale_factor = 3, mode = "bicubic", align_corners = False)[:,0]#.clamp(min = 0) + log_scoremap = F.interpolate( + log_scoremap[:, None], scale_factor=3, mode="bicubic", align_corners=False + )[ + :, 0 + ] # .clamp(min = 0) scoremap = log_scoremap.exp() - B,H,W = scoremap.shape + B, H, W = scoremap.shape if increase_coverage: - weights = (-torch.linspace(-2, 2, steps = 51, device = device)**2).exp()[None,None] + weights = (-torch.linspace(-2, 2, steps=51, device=device) ** 2).exp()[ + None, None + ] # 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway - local_density_x = F.conv2d((scoremap[:,None]+1e-6)*10000,weights[...,None,:], padding = (0,51//2)) - local_density = F.conv2d(local_density_x, weights[...,None], padding = (51//2,0))[:,0] - scoremap = scoremap * (local_density+1e-8)**(-1/2) - grid = get_grid(B,H,W, device=device).reshape(B,H*W,2) + local_density_x = F.conv2d( + (scoremap[:, None] + 1e-6) * 10000, + weights[..., None, :], + padding=(0, 51 // 2), + ) + local_density = F.conv2d( + local_density_x, weights[..., None], padding=(51 // 2, 0) + )[:, 0] + scoremap = scoremap * (local_density + 1e-8) ** (-1 / 2) + grid = get_grid(B, H, W, device=device).reshape(B, H * W, 2) if sharpen: - laplace_operator = torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], device = device)/4 - scoremap = scoremap[:,None] - 0.5 * F.conv2d(scoremap[:,None], weight = laplace_operator, padding = 1) - scoremap = scoremap[:,0].clamp(min = 0) + laplace_operator = ( + torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], device=device) / 4 + ) + scoremap = scoremap[:, None] - 0.5 * F.conv2d( + scoremap[:, None], weight=laplace_operator, padding=1 + ) + scoremap = scoremap[:, 0].clamp(min=0) if use_nms: - scoremap = scoremap * (scoremap == F.max_pool2d(scoremap, (3, 3), stride = 1, padding = 1)) + scoremap = scoremap * ( + scoremap == F.max_pool2d(scoremap, (3, 3), stride=1, padding=1) + ) if sample_topk: - inds = torch.topk(scoremap.reshape(B,H*W), k = num_samples).indices + inds = torch.topk(scoremap.reshape(B, H * W), k=num_samples).indices else: - inds = torch.multinomial(scoremap.reshape(B,H*W), num_samples = num_samples, replacement=False) - kps = torch.gather(grid, dim = 1, index = inds[...,None].expand(B,num_samples,2)) + inds = torch.multinomial( + scoremap.reshape(B, H * W), num_samples=num_samples, replacement=False + ) + kps = torch.gather(grid, dim=1, index=inds[..., None].expand(B, num_samples, 2)) if return_scoremap: - return kps, torch.gather(scoremap.reshape(B,H*W), dim = 1, index = inds) + return kps, torch.gather(scoremap.reshape(B, H * W), dim=1, index=inds) return kps + @torch.no_grad() -def jacobi_determinant(warp, certainty, R = 3, device = "cuda", dtype = torch.float32): +def jacobi_determinant(warp, certainty, R=3, device="cuda", dtype=torch.float32): t = perf_counter() *dims, _ = warp.shape warp = warp.to(dtype) certainty = certainty.to(dtype) - + dtype = warp.dtype - match_regions = torch.zeros((*dims, 4, R, R), device = device).to(dtype) - match_regions[:,1:-1, 1:-1] = warp.unfold(1,R,1).unfold(2,R,1) - match_regions = rearrange(match_regions,"B H W D R1 R2 -> B H W (R1 R2) D") - warp[...,None,:] - - match_regions_cert = torch.zeros((*dims, R, R), device = device).to(dtype) - match_regions_cert[:,1:-1, 1:-1] = certainty.unfold(1,R,1).unfold(2,R,1) - match_regions_cert = rearrange(match_regions_cert,"B H W R1 R2 -> B H W (R1 R2)")[..., None] - - #print("Time for unfold", perf_counter()-t) - #t = perf_counter() + match_regions = torch.zeros((*dims, 4, R, R), device=device).to(dtype) + match_regions[:, 1:-1, 1:-1] = warp.unfold(1, R, 1).unfold(2, R, 1) + match_regions = ( + rearrange(match_regions, "B H W D R1 R2 -> B H W (R1 R2) D") + - warp[..., None, :] + ) + + match_regions_cert = torch.zeros((*dims, R, R), device=device).to(dtype) + match_regions_cert[:, 1:-1, 1:-1] = certainty.unfold(1, R, 1).unfold(2, R, 1) + match_regions_cert = rearrange(match_regions_cert, "B H W R1 R2 -> B H W (R1 R2)")[ + ..., None + ] + + # print("Time for unfold", perf_counter()-t) + # t = perf_counter() *dims, N, D = match_regions.shape # standardize: - mu, sigma = match_regions.mean(dim=(-2,-1), keepdim = True), match_regions.std(dim=(-2,-1),keepdim=True) - match_regions = (match_regions-mu)/(sigma+1e-6) - x_a, x_b = match_regions.chunk(2,-1) - + mu, sigma = match_regions.mean(dim=(-2, -1), keepdim=True), match_regions.std( + dim=(-2, -1), keepdim=True + ) + match_regions = (match_regions - mu) / (sigma + 1e-6) + x_a, x_b = match_regions.chunk(2, -1) - A = torch.zeros((*dims,2*x_a.shape[-2],4), device = device).to(dtype) - A[...,::2,:2] = x_a * match_regions_cert - A[...,1::2,2:] = x_a * match_regions_cert + A = torch.zeros((*dims, 2 * x_a.shape[-2], 4), device=device).to(dtype) + A[..., ::2, :2] = x_a * match_regions_cert + A[..., 1::2, 2:] = x_a * match_regions_cert - a_block = A[...,::2,:2] + a_block = A[..., ::2, :2] ata = a_block.mT @ a_block - #print("Time for ata", perf_counter()-t) - #t = perf_counter() + # print("Time for ata", perf_counter()-t) + # t = perf_counter() - #atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype)) + # atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype)) atainv = fast_inv_2x2(ata) - ATA_inv = torch.zeros((*dims, 4, 4), device = device, dtype = dtype) - ATA_inv[...,:2,:2] = atainv - ATA_inv[...,2:,2:] = atainv - atb = A.mT @ (match_regions_cert*x_b).reshape(*dims,N*2,1) - theta = ATA_inv @ atb - #print("Time for theta", perf_counter()-t) - #t = perf_counter() + ATA_inv = torch.zeros((*dims, 4, 4), device=device, dtype=dtype) + ATA_inv[..., :2, :2] = atainv + ATA_inv[..., 2:, 2:] = atainv + atb = A.mT @ (match_regions_cert * x_b).reshape(*dims, N * 2, 1) + theta = ATA_inv @ atb + # print("Time for theta", perf_counter()-t) + # t = perf_counter() J = theta.reshape(*dims, 2, 2) - abs_J_det = torch.linalg.det(J+1e-8*torch.eye(2,2,device = device).expand(*dims,2,2)).abs() # Note: This should always be positive for correct warps, but still taking abs here - abs_J_logdet = (abs_J_det+1e-12).log() + abs_J_det = torch.linalg.det( + J + 1e-8 * torch.eye(2, 2, device=device).expand(*dims, 2, 2) + ).abs() # Note: This should always be positive for correct warps, but still taking abs here + abs_J_logdet = (abs_J_det + 1e-12).log() B = certainty.shape[0] # Handle outliers - robust_abs_J_logdet = abs_J_logdet.clamp(-3, 3) # Shouldn't be more that exp(3) \approx 8 times zoom - #print("Time for logdet", perf_counter()-t) - #t = perf_counter() + robust_abs_J_logdet = abs_J_logdet.clamp( + -3, 3 + ) # Shouldn't be more that exp(3) \approx 8 times zoom + # print("Time for logdet", perf_counter()-t) + # t = perf_counter() return robust_abs_J_logdet -def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): - + +def get_gt_warp( + depth1, + depth2, + T_1to2, + K1, + K2, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=0.05, + H=None, + W=None, +): + if H is None: - B,H,W = depth1.shape + B, H, W = depth1.shape else: B = depth1.shape[0] with torch.no_grad(): x1_n = torch.meshgrid( *[ - torch.linspace( - -1 + 1 / n, 1 - 1 / n, n, device=depth1.device - ) + torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device) for n in (B, H, W) ] ) @@ -209,20 +272,21 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili T_1to2.double(), K1.double(), K2.double(), - depth_interpolation_mode = depth_interpolation_mode, - relative_depth_error_threshold = relative_depth_error_threshold, + depth_interpolation_mode=depth_interpolation_mode, + relative_depth_error_threshold=relative_depth_error_threshold, ) prob = mask.float().reshape(B, H, W) x2 = x2.reshape(B, H, W, 2) - return torch.cat((x1_n.reshape(B,H,W,2),x2),dim=-1), prob + return torch.cat((x1_n.reshape(B, H, W, 2), x2), dim=-1), prob + def recover_pose(E, kpts0, kpts1, K0, K1, mask): best_num_inliers = 0 - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) @@ -232,17 +296,23 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask): return ret - # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py # --- GEOMETRY --- -def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999, ): +def estimate_pose( + kpts0, + kpts1, + K0, + K1, + norm_thresh, + conf=0.99999, +): if len(kpts0) < 5: return None - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T method = cv2.USAC_ACCURATE E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=method @@ -259,31 +329,40 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999, ): ret = (R, t, mask.ravel() > 0) return ret + def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None method = cv2.USAC_ACCURATE F, mask = cv2.findFundamentalMat( - kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 + kpts0, + kpts1, + ransacReprojThreshold=norm_thresh, + confidence=conf, + method=method, + maxIters=10000, ) - E = K1.T@F@K0 + E = K1.T @ F @ K0 ret = None if E is not None: best_num_inliers = 0 - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) + + kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T - kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T - for _E in np.split(E, len(E) / 3): - n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + n, R, t, _ = cv2.recoverPose( + _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask + ) if n > best_num_inliers: best_num_inliers = n ret = (R, t, mask.ravel() > 0) return ret -def unnormalize_coords(x_n,h,w): + +def unnormalize_coords(x_n, h, w): x = torch.stack( (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] @@ -316,6 +395,7 @@ def scale_intrinsics(K, scales): scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) return np.dot(scales, K) + def angle_error_mat(R1, R2): cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds @@ -355,14 +435,16 @@ def pose_auc(errors, thresholds): def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): ops = [] if resize: - ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias = False)) + ops.append( + TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias=False) + ) return TupleCompose(ops) -def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False): +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe=False): ops = [] if resize: - ops.append(TupleResize(resize, antialias = True)) + ops.append(TupleResize(resize, antialias=True)) if clahe: ops.append(TupleClahe()) if normalize: @@ -377,22 +459,27 @@ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = ops.append(TupleToTensorScaled()) return TupleCompose(ops) + class Clahe: - def __init__(self, cliplimit = 2, blocksize = 8) -> None: - self.clahe = cv2.createCLAHE(cliplimit,(blocksize,blocksize)) + def __init__(self, cliplimit=2, blocksize=8) -> None: + self.clahe = cv2.createCLAHE(cliplimit, (blocksize, blocksize)) + def __call__(self, im): - im_hsv = cv2.cvtColor(np.array(im),cv2.COLOR_RGB2HSV) - im_v = self.clahe.apply(im_hsv[:,:,2]) - im_hsv[...,2] = im_v - im_clahe = cv2.cvtColor(im_hsv,cv2.COLOR_HSV2RGB) + im_hsv = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2HSV) + im_v = self.clahe.apply(im_hsv[:, :, 2]) + im_hsv[..., 2] = im_v + im_clahe = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB) return Image.fromarray(im_clahe) + class TupleClahe: - def __init__(self, cliplimit = 8, blocksize = 8) -> None: - self.clahe = Clahe(cliplimit,blocksize) + def __init__(self, cliplimit=8, blocksize=8) -> None: + self.clahe = Clahe(cliplimit, blocksize) + def __call__(self, ims): return [self.clahe(im) for im in ims] + class ToTensorScaled(object): """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" @@ -443,9 +530,9 @@ class TupleToTensorUnscaled(object): class TupleResize(object): - def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias = None): + def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias=None): self.size = size - self.resize = transforms.Resize(size, mode, antialias = antialias) + self.resize = transforms.Resize(size, mode, antialias=antialias) def __call__(self, im_tuple): return [self.resize(im) for im in im_tuple] @@ -453,11 +540,12 @@ class TupleResize(object): def __repr__(self): return "TupleResize(size={})".format(self.size) + class Normalize: - def __call__(self,im): - mean = im.mean(dim=(1,2), keepdims=True) - std = im.std(dim=(1,2), keepdims=True) - return (im-mean)/std + def __call__(self, im): + mean = im.mean(dim=(1, 2), keepdims=True) + std = im.std(dim=(1, 2), keepdims=True) + return (im - mean) / std class TupleNormalize(object): @@ -467,7 +555,7 @@ class TupleNormalize(object): self.normalize = transforms.Normalize(mean=mean, std=std) def __call__(self, im_tuple): - c,h,w = im_tuple[0].shape + c, h, w = im_tuple[0].shape if c > 3: warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb") return [self.normalize(im[:3]) for im in im_tuple] @@ -495,7 +583,18 @@ class TupleCompose(object): @torch.no_grad() -def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): +def warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=False, + return_relative_depth_error=False, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=0.05, +): """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). @@ -520,26 +619,44 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation if smooth_mask: raise NotImplementedError("Combined bilinear and NN warp not implemented") - valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, - smooth_mask = smooth_mask, - return_relative_depth_error = return_relative_depth_error, - depth_interpolation_mode = "bilinear", - relative_depth_error_threshold = relative_depth_error_threshold) - valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, - smooth_mask = smooth_mask, - return_relative_depth_error = return_relative_depth_error, - depth_interpolation_mode = "nearest-exact", - relative_depth_error_threshold = relative_depth_error_threshold) - nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + valid_bilinear, warp_bilinear = warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=smooth_mask, + return_relative_depth_error=return_relative_depth_error, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=relative_depth_error_threshold, + ) + valid_nearest, warp_nearest = warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=smooth_mask, + return_relative_depth_error=return_relative_depth_error, + depth_interpolation_mode="nearest-exact", + relative_depth_error_threshold=relative_depth_error_threshold, + ) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) warp = warp_bilinear.clone() - warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + warp[nearest_valid_bilinear_invalid] = warp_nearest[ + nearest_valid_bilinear_invalid + ] valid = valid_bilinear | valid_nearest return valid, warp - - - kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ - :, 0, :, 0 - ] + + kpts0_depth = F.grid_sample( + depth0[:, None], + kpts0[:, :, None], + mode=depth_interpolation_mode, + align_corners=False, + )[:, 0, :, 0] kpts0 = torch.stack( (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] @@ -578,22 +695,26 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return # w_kpts0[~covisible_mask, :] = -5 # xd w_kpts0_depth = F.grid_sample( - depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + depth1[:, None], + w_kpts0[:, :, None], + mode=depth_interpolation_mode, + align_corners=False, )[:, 0, :, 0] - + relative_depth_error = ( (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth ).abs() if not smooth_mask: consistent_mask = relative_depth_error < relative_depth_error_threshold else: - consistent_mask = (-relative_depth_error/smooth_mask).exp() + consistent_mask = (-relative_depth_error / smooth_mask).exp() valid_mask = nonzero_mask * covisible_mask * consistent_mask if return_relative_depth_error: return relative_depth_error, w_kpts0 else: return valid_mask, w_kpts0 + imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) imagenet_std = torch.tensor([0.229, 0.224, 0.225]) @@ -611,15 +732,17 @@ def numpy_to_pil(x: np.ndarray): return Image.fromarray(x) -def tensor_to_pil(x, unnormalize=False, autoscale = False): +def tensor_to_pil(x, unnormalize=False, autoscale=False): if unnormalize: - x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + x = x * (imagenet_std[:, None, None].to(x.device)) + ( + imagenet_mean[:, None, None].to(x.device) + ) if autoscale: if x.max() == x.min(): warnings.warn("x max == x min, cant autoscale") else: - x = (x-x.min())/(x.max()-x.min()) - + x = (x - x.min()) / (x.max() - x.min()) + x = x.detach().permute(1, 2, 0).cpu().numpy() x = np.clip(x, 0.0, 1.0) return numpy_to_pil(x) @@ -649,61 +772,57 @@ def compute_relative_pose(R1, t1, R2, t2): trans = -rots @ t1 + t2 return rots, trans + def to_pixel_coords(flow, h1, w1): - flow = ( - torch.stack( - ( - w1 * (flow[..., 0] + 1) / 2, - h1 * (flow[..., 1] + 1) / 2, - ), - axis=-1, - ) + flow = torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, ) return flow + def to_normalized_coords(flow, h1, w1): - flow = ( - torch.stack( - ( - 2 * (flow[..., 0]) / w1 - 1, - 2 * (flow[..., 1]) / h1 - 1, - ), - axis=-1, - ) + flow = torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, ) return flow def warp_to_pixel_coords(warp, h1, w1, h2, w2): warp1 = warp[..., :2] - warp1 = ( - torch.stack( - ( - w1 * (warp1[..., 0] + 1) / 2, - h1 * (warp1[..., 1] + 1) / 2, - ), - axis=-1, - ) + warp1 = torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, ) warp2 = warp[..., 2:] - warp2 = ( - torch.stack( - ( - w2 * (warp2[..., 0] + 1) / 2, - h2 * (warp2[..., 1] + 1) / 2, - ), - axis=-1, - ) + warp2 = torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, ) - return torch.cat((warp1,warp2), dim=-1) + return torch.cat((warp1, warp2), dim=-1) def to_homogeneous(x): - ones = torch.ones_like(x[...,-1:]) - return torch.cat((x, ones), dim = -1) + ones = torch.ones_like(x[..., -1:]) + return torch.cat((x, ones), dim=-1) + + +def from_homogeneous(xh, eps=1e-12): + return xh[..., :-1] / (xh[..., -1:] + eps) -def from_homogeneous(xh, eps = 1e-12): - return xh[...,:-1] / (xh[...,-1:]+eps) def homog_transform(Homog, x): xh = to_homogeneous(x) @@ -711,49 +830,71 @@ def homog_transform(Homog, x): y = from_homogeneous(yh) return y -def get_homog_warp(Homog, H, W, device = "cuda"): - grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device)) - - x_A = torch.stack((grid[1], grid[0]), dim = -1)[None] + +def get_homog_warp(Homog, H, W, device="cuda"): + grid = torch.meshgrid( + torch.linspace(-1 + 1 / H, 1 - 1 / H, H, device=device), + torch.linspace(-1 + 1 / W, 1 - 1 / W, W, device=device), + ) + + x_A = torch.stack((grid[1], grid[0]), dim=-1)[None] x_A_to_B = homog_transform(Homog, x_A) mask = ((x_A_to_B > -1) * (x_A_to_B < 1)).prod(dim=-1).float() - return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B),dim=-1), mask + return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B), dim=-1), mask -def dual_log_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + +def dual_log_softmax_matcher( + desc_A: tuple["B", "N", "C"], + desc_B: tuple["B", "M", "C"], + inv_temperature=1, + normalize=False, +): B, N, C = desc_A.shape if normalize: - desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) - desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True) + desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True) corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature else: corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature - logP = corr.log_softmax(dim = -2) + corr.log_softmax(dim= -1) + logP = corr.log_softmax(dim=-2) + corr.log_softmax(dim=-1) return logP -def dual_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + +def dual_softmax_matcher( + desc_A: tuple["B", "N", "C"], + desc_B: tuple["B", "M", "C"], + inv_temperature=1, + normalize=False, +): if len(desc_A.shape) < 3: desc_A, desc_B = desc_A[None], desc_B[None] B, N, C = desc_A.shape if normalize: - desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) - desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True) + desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True) corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature else: corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature - P = corr.softmax(dim = -2) * corr.softmax(dim= -1) + P = corr.softmax(dim=-2) * corr.softmax(dim=-1) return P -def conditional_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + +def conditional_softmax_matcher( + desc_A: tuple["B", "N", "C"], + desc_B: tuple["B", "M", "C"], + inv_temperature=1, + normalize=False, +): if len(desc_A.shape) < 3: desc_A, desc_B = desc_A[None], desc_B[None] B, N, C = desc_A.shape if normalize: - desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) - desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True) + desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True) corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature else: corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature - P_B_cond_A = corr.softmax(dim = -1) - P_A_cond_B = corr.softmax(dim = -2) - - return P_A_cond_B, P_B_cond_A \ No newline at end of file + P_B_cond_A = corr.softmax(dim=-1) + P_A_cond_B = corr.softmax(dim=-2) + + return P_A_cond_B, P_B_cond_A diff --git a/third_party/DeDoDe/data_prep/prep_keypoints.py b/third_party/DeDoDe/data_prep/prep_keypoints.py index 25713ed7573babadc3a42daa544d85052fc37421..616f91b875879f726218efdfe4bb6dc95297b33a 100644 --- a/third_party/DeDoDe/data_prep/prep_keypoints.py +++ b/third_party/DeDoDe/data_prep/prep_keypoints.py @@ -9,70 +9,64 @@ import os base_path = "data/megadepth" # Remove the trailing / if need be. -if base_path[-1] in ['/', '\\']: - base_path = base_path[: - 1] +if base_path[-1] in ["/", "\\"]: + base_path = base_path[:-1] -base_depth_path = os.path.join( - base_path, 'phoenix/S6/zl548/MegaDepth_v1' -) -base_undistorted_sfm_path = os.path.join( - base_path, 'Undistorted_SfM' -) +base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1") +base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM") scene_ids = os.listdir(base_undistorted_sfm_path) for scene_id in scene_ids: - if os.path.exists(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy"): + if os.path.exists( + f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy" + ): print(f"skipping {scene_id} as it exists") continue undistorted_sparse_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'sparse-txt' + base_undistorted_sfm_path, scene_id, "sparse-txt" ) if not os.path.exists(undistorted_sparse_path): print("sparse path doesnt exist") continue - depths_path = os.path.join( - base_depth_path, scene_id, 'dense0', 'depths' - ) + depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths") if not os.path.exists(depths_path): print("depths doesnt exist") - + continue - images_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'images' - ) + images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images") if not os.path.exists(images_path): print("images path doesnt exist") continue # Process cameras.txt - if not os.path.exists(os.path.join(undistorted_sparse_path, 'cameras.txt')): + if not os.path.exists(os.path.join(undistorted_sparse_path, "cameras.txt")): print("no cameras") continue - with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header + with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header camera_intrinsics = {} for camera in raw: - camera = camera.split(' ') - camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + camera = camera.split(" ") + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]] # Process points3D.txt - with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header + with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header points3D = {} for point3D in raw: - point3D = point3D.split(' ') - points3D[int(point3D[0])] = np.array([ - float(point3D[1]), float(point3D[2]), float(point3D[3]) - ]) - + point3D = point3D.split(" ") + points3D[int(point3D[0])] = np.array( + [float(point3D[1]), float(point3D[2]), float(point3D[3])] + ) + # Process images.txt - with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: - raw = f.readlines()[4 :] # skip the header + with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f: + raw = f.readlines()[4:] # skip the header image_id_to_idx = {} image_names = [] @@ -81,20 +75,22 @@ for scene_id in scene_ids: points3D_id_to_2D = [] n_points3D = [] id_to_detections = {} - for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): - image = image.split(' ') - points = points.split(' ') + for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])): + image = image.split(" ") + points = points.split(" ") image_id_to_idx[int(image[0])] = idx - image_name = image[-1].strip('\n') + image_name = image[-1].strip("\n") image_names.append(image_name) - raw_pose.append([float(elem) for elem in image[1 : -2]]) + raw_pose.append([float(elem) for elem in image[1:-2]]) camera.append(int(image[-2])) - points_np = np.array(points).astype(np.float32).reshape(len(points)//3, 3) - visible_points = points_np[points_np[:,2] != -1] + points_np = np.array(points).astype(np.float32).reshape(len(points) // 3, 3) + visible_points = points_np[points_np[:, 2] != -1] id_to_detections[idx] = visible_points - np.save(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy", - id_to_detections) - print(f"{scene_id} done") \ No newline at end of file + np.save( + f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy", + id_to_detections, + ) + print(f"{scene_id} done") diff --git a/third_party/DeDoDe/demo/demo_kpts.py b/third_party/DeDoDe/demo/demo_kpts.py index 270a23b12e2148ce7a438a68ab3ef1135a93a9e6..f0ae36aa4bbe3439e96d7b45bfa809c48b6ebf45 100644 --- a/third_party/DeDoDe/demo/demo_kpts.py +++ b/third_party/DeDoDe/demo/demo_kpts.py @@ -4,17 +4,19 @@ import numpy as np from PIL import Image from DeDoDe import dedode_detector_L -def draw_kpts(im, kpts): - kpts = [cv2.KeyPoint(x,y,1.) for x,y in kpts.cpu().numpy()] + +def draw_kpts(im, kpts): + kpts = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts.cpu().numpy()] im = np.array(im) ret = cv2.drawKeypoints(im, kpts, None) return ret -detector = dedode_detector_L(weights = torch.load("dedode_detector_l.pth")) + +detector = dedode_detector_L(weights=torch.load("dedode_detector_l.pth")) im_path = "assets/im_A.jpg" im = Image.open(im_path) -out = detector.detect_from_path(im_path, num_keypoints = 10_000) -W,H = im.size +out = detector.detect_from_path(im_path, num_keypoints=10_000) +W, H = im.size kps = out["keypoints"] kps = detector.to_pixel_coords(kps, H, W) -Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png") \ No newline at end of file +Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png") diff --git a/third_party/DeDoDe/demo/demo_match.py b/third_party/DeDoDe/demo/demo_match.py index 6492392d07a49fcdb7e287b619b404df84521ca8..2ddecc453e1e3d0beb5e832819833209ad431048 100644 --- a/third_party/DeDoDe/demo/demo_match.py +++ b/third_party/DeDoDe/demo/demo_match.py @@ -5,17 +5,18 @@ from DeDoDe.utils import * from PIL import Image import cv2 -def draw_matches(im_A, kpts_A, im_B, kpts_B): - kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()] - kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()] - matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))] + +def draw_matches(im_A, kpts_A, im_B, kpts_B): + kpts_A = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts_A.cpu().numpy()] + kpts_B = [cv2.KeyPoint(x, y, 1.0) for x, y in kpts_B.cpu().numpy()] + matches_A_to_B = [cv2.DMatch(idx, idx, 0.0) for idx in range(len(kpts_A))] im_A, im_B = np.array(im_A), np.array(im_B) - ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, - matches_A_to_B, None) + ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, matches_A_to_B, None) return ret -detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth")) -descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth")) + +detector = dedode_detector_L(weights=torch.load("dedode_detector_L.pth")) +descriptor = dedode_descriptor_B(weights=torch.load("dedode_descriptor_B.pth")) matcher = DualSoftMaxMatcher() im_A_path = "assets/im_A.jpg" @@ -26,20 +27,33 @@ W_A, H_A = im_A.size W_B, H_B = im_B.size -detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000) +detections_A = detector.detect_from_path(im_A_path, num_keypoints=10_000) keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] -detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000) +detections_B = detector.detect_from_path(im_B_path, num_keypoints=10_000) keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"] -description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"] -description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"] -matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A, - keypoints_B, description_B, - P_A = P_A, P_B = P_B, - normalize = True, inv_temp=20, threshold = 0.1)#Increasing threshold -> fewer matches, fewer outliers +description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)[ + "descriptions" +] +description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)[ + "descriptions" +] +matches_A, matches_B, batch_ids = matcher.match( + keypoints_A, + description_A, + keypoints_B, + description_B, + P_A=P_A, + P_B=P_B, + normalize=True, + inv_temp=20, + threshold=0.1, +) # Increasing threshold -> fewer matches, fewer outliers matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B) import cv2 import numpy as np -Image.fromarray(draw_matches(im_A, matches_A[::5], im_B, matches_B[::5])).save("demo/matches.png") \ No newline at end of file +Image.fromarray(draw_matches(im_A, matches_A[::5], im_B, matches_B[::5])).save( + "demo/matches.png" +) diff --git a/third_party/DeDoDe/demo/demo_scoremap.py b/third_party/DeDoDe/demo/demo_scoremap.py index 68af499dbb58e275e227bbdc979b4d1923902df0..1a0a2b2470783c69753960725aee1b689b0cb2cc 100644 --- a/third_party/DeDoDe/demo/demo_scoremap.py +++ b/third_party/DeDoDe/demo/demo_scoremap.py @@ -5,16 +5,20 @@ import numpy as np from DeDoDe import dedode_detector_L from DeDoDe.utils import tensor_to_pil -detector = dedode_detector_L(weights = torch.load("dedode_detector_l.pth")) +detector = dedode_detector_L(weights=torch.load("dedode_detector_l.pth")) H, W = 768, 768 im_path = "assets/im_A.jpg" -out = detector.detect_from_path(im_path, dense = True, H = H, W = W) +out = detector.detect_from_path(im_path, dense=True, H=H, W=W) logit_map = out["dense_keypoint_logits"].clone() min = logit_map.max() - 3 logit_map[logit_map < min] = min -logit_map = (logit_map-min)/(logit_map.max()-min) -logit_map = logit_map.cpu()[0].expand(3,H,W) -im_A = torch.tensor(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1) -tensor_to_pil(logit_map * logit_map + 0.15 * (1-logit_map) * im_A).save("demo/dense_logits.png") +logit_map = (logit_map - min) / (logit_map.max() - min) +logit_map = logit_map.cpu()[0].expand(3, H, W) +im_A = torch.tensor(np.array(Image.open(im_path).resize((W, H))) / 255.0).permute( + 2, 0, 1 +) +tensor_to_pil(logit_map * logit_map + 0.15 * (1 - logit_map) * im_A).save( + "demo/dense_logits.png" +) diff --git a/third_party/DeDoDe/setup.py b/third_party/DeDoDe/setup.py index 18a43e0b69131d3f91229f5a59e9b1d48411d890..94d1fd8ed2e5ac769222afce4f084ac19029a2a4 100644 --- a/third_party/DeDoDe/setup.py +++ b/third_party/DeDoDe/setup.py @@ -3,8 +3,8 @@ from setuptools import setup, find_packages setup( name="DeDoDe", - packages=find_packages(include= ["DeDoDe*"]), + packages=find_packages(include=["DeDoDe*"]), install_requires=open("requirements.txt", "r").read().split("\n"), version="0.0.1", author="Johan Edstedt", -) \ No newline at end of file +) diff --git a/third_party/GlueStick/gluestick/__init__.py b/third_party/GlueStick/gluestick/__init__.py index d3051821ecfb2e18f4b9b4dfb50f35064106eb57..4eaf01e90440afeb485a4743f181dac348ede63d 100644 --- a/third_party/GlueStick/gluestick/__init__.py +++ b/third_party/GlueStick/gluestick/__init__.py @@ -8,11 +8,12 @@ GLUESTICK_ROOT = Path(__file__).parent.parent def get_class(mod_name, base_path, BaseClass): """Get the class object which inherits from BaseClass and is defined in - the module named mod_name, child of base_path. + the module named mod_name, child of base_path. """ import inspect - mod_path = '{}.{}'.format(base_path, mod_name) - mod = __import__(mod_path, fromlist=['']) + + mod_path = "{}.{}".format(base_path, mod_name) + mod = __import__(mod_path, fromlist=[""]) classes = inspect.getmembers(mod, inspect.isclass) # Filter classes defined in the module classes = [c for c in classes if c[1].__module__ == mod_path] @@ -24,7 +25,8 @@ def get_class(mod_name, base_path, BaseClass): def get_model(name): from .models.base_model import BaseModel - return get_class('models.' + name, __name__, BaseModel) + + return get_class("models." + name, __name__, BaseModel) def numpy_image_to_torch(image): @@ -34,8 +36,8 @@ def numpy_image_to_torch(image): elif image.ndim == 2: image = image[None] # add channel axis else: - raise ValueError(f'Not an image: {image.shape}') - return torch.from_numpy(image / 255.).float() + raise ValueError(f"Not an image: {image.shape}") + return torch.from_numpy(image / 255.0).float() def map_tensor(input_, func): diff --git a/third_party/GlueStick/gluestick/drawing.py b/third_party/GlueStick/gluestick/drawing.py index 8e6d24b6bfedc93449142647410057d978d733ef..8365b7e1f91adedcd190c49b2a38cbcd817d84c2 100644 --- a/third_party/GlueStick/gluestick/drawing.py +++ b/third_party/GlueStick/gluestick/drawing.py @@ -4,8 +4,7 @@ import numpy as np import seaborn as sns -def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, - adaptive=True): +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): """Plot a set of images horizontally. Args: imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). @@ -23,7 +22,8 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, ratios = [4 / 3] * n figsize = [sum(ratios) * 4.5, 4.5] fig, ax = plt.subplots( - 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) if n == 1: ax = [ax] for i in range(n): @@ -39,7 +39,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, return ax -def plot_keypoints(kpts, colors='lime', ps=4, alpha=1): +def plot_keypoints(kpts, colors="lime", ps=4, alpha=1): """Plot keypoints for existing images. Args: kpts: list of ndarrays of size (N, 2). @@ -53,7 +53,7 @@ def plot_keypoints(kpts, colors='lime', ps=4, alpha=1): a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0) -def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): """Plot matches for a pair of existing images. Args: kpts0, kpts1: corresponding keypoints of size (N, 2). @@ -80,11 +80,18 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) - fig.lines += [matplotlib.lines.Line2D( - (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), - zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, - alpha=a) - for i in range(len(kpts0))] + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=color[i], + linewidth=lw, + alpha=a, + ) + for i in range(len(kpts0)) + ] # freeze the axes to prevent the transform to change ax0.autoscale(enable=False) @@ -95,9 +102,16 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) -def plot_lines(lines, line_colors='orange', point_colors='cyan', - ps=4, lw=2, alpha=1., indices=(0, 1)): - """ Plot lines and endpoints for existing images. +def plot_lines( + lines, + line_colors="orange", + point_colors="cyan", + ps=4, + lw=2, + alpha=1.0, + indices=(0, 1), +): + """Plot lines and endpoints for existing images. Args: lines: list of ndarrays of size (N, 2, 2). colors: string, or list of list of tuples (one for each keypoints). @@ -120,18 +134,20 @@ def plot_lines(lines, line_colors='orange', point_colors='cyan', # Plot the lines and junctions for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): for i in range(len(l)): - line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), - (l[i, 0, 1], l[i, 1, 1]), - zorder=1, c=lc, linewidth=lw, - alpha=alpha) + line = matplotlib.lines.Line2D( + (l[i, 0, 0], l[i, 1, 0]), + (l[i, 0, 1], l[i, 1, 1]), + zorder=1, + c=lc, + linewidth=lw, + alpha=alpha, + ) a.add_line(line) pts = l.reshape(-1, 2) - a.scatter(pts[:, 0], pts[:, 1], - c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) + a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) -def plot_color_line_matches(lines, correct_matches=None, - lw=2, indices=(0, 1)): +def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)): """Plot line matches for existing images with multiple colors. Args: lines: list of ndarrays of size (N, 2, 2). @@ -140,7 +156,7 @@ def plot_color_line_matches(lines, correct_matches=None, indices: indices of the images to draw the matches on. """ n_lines = len(lines[0]) - colors = sns.color_palette('husl', n_colors=n_lines) + colors = sns.color_palette("husl", n_colors=n_lines) np.random.shuffle(colors) alphas = np.ones(n_lines) # If correct_matches is not None, display wrong matches with a low alpha @@ -159,8 +175,15 @@ def plot_color_line_matches(lines, correct_matches=None, transFigure = fig.transFigure.inverted() endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) - fig.lines += [matplotlib.lines.Line2D( - (endpoint0[i, 0], endpoint1[i, 0]), - (endpoint0[i, 1], endpoint1[i, 1]), - zorder=1, transform=fig.transFigure, c=colors[i], - alpha=alphas[i], linewidth=lw) for i in range(n_lines)] + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=colors[i], + alpha=alphas[i], + linewidth=lw, + ) + for i in range(n_lines) + ] diff --git a/third_party/GlueStick/gluestick/geometry.py b/third_party/GlueStick/gluestick/geometry.py index 97853c4807d319eb9ea0377db7385e9a72fb400b..0cdd232e74aeda84e1683dcb8e51385cc2497c37 100644 --- a/third_party/GlueStick/gluestick/geometry.py +++ b/third_party/GlueStick/gluestick/geometry.py @@ -21,7 +21,7 @@ def to_homogeneous(points): raise ValueError -def from_homogeneous(points, eps=0.): +def from_homogeneous(points, eps=0.0): """Remove the homogeneous dimension of N-dimensional points. Args: points: torch.Tensor or numpy.ndarray with size (..., N+1). @@ -32,14 +32,22 @@ def from_homogeneous(points, eps=0.): def skew_symmetric(v): - """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). - """ + """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).""" z = torch.zeros_like(v[..., 0]) - M = torch.stack([ - z, -v[..., 2], v[..., 1], - v[..., 2], z, -v[..., 0], - -v[..., 1], v[..., 0], z, - ], dim=-1).reshape(v.shape[:-1] + (3, 3)) + M = torch.stack( + [ + z, + -v[..., 2], + v[..., 1], + v[..., 2], + z, + -v[..., 0], + -v[..., 1], + v[..., 0], + z, + ], + dim=-1, + ).reshape(v.shape[:-1] + (3, 3)) return M @@ -67,7 +75,7 @@ def warp_points_torch(points, H, inverse=True): H_mat = torch.cat([H, torch.ones_like(H[..., :1])], axis=-1).reshape(out_shape) if inverse: H_mat = torch.inverse(H_mat) - warped_points = torch.einsum('...nj,...ji->...ni', points, H_mat.transpose(-2, -1)) + warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat.transpose(-2, -1)) warped_points = from_homogeneous(warped_points, eps=1e-5) @@ -76,18 +84,27 @@ def warp_points_torch(points, H, inverse=True): def seg_equation(segs): # calculate list of start, end and midpoints points from both lists - start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(segs[..., 1, :]) + start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous( + segs[..., 1, :] + ) # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1 lines = torch.cross(start_points, end_points, dim=-1) - lines_norm = (torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]) - assert torch.all(lines_norm > 0), 'Error: trying to compute the equation of a line with a single point' + lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None] + assert torch.all( + lines_norm > 0 + ), "Error: trying to compute the equation of a line with a single point" lines = lines / lines_norm return lines def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]): h, w = img_shape - return (pts >= 0).all(dim=-1) & (pts[..., 0] < w) & (pts[..., 1] < h) & (~torch.isinf(pts).any(dim=-1)) + return ( + (pts >= 0).all(dim=-1) + & (pts[..., 0] < w) + & (pts[..., 1] < h) + & (~torch.isinf(pts).any(dim=-1)) + ) def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor: @@ -102,7 +119,9 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch. # Project the segments to the reference image segs = segs.clone() eqs = seg_equation(segs) - x0, y0 = torch.tensor([1., 0, 0.], device=device), torch.tensor([0., 1, 0], device=device) + x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor( + [0.0, 1, 0], device=device + ) x0 = x0.repeat(eqs.shape[:-1] + (1,)) y0 = y0.repeat(eqs.shape[:-1] + (1,)) pt_x0s = torch.cross(eqs, x0, dim=-1) @@ -112,7 +131,9 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch. pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1] pt_y0s_valid = is_inside_img(pt_y0s, img_shape) - xW, yH = torch.tensor([1., 0, EPS - w], device=device), torch.tensor([0., 1, EPS - h], device=device) + xW, yH = torch.tensor([1.0, 0, EPS - w], device=device), torch.tensor( + [0.0, 1, EPS - h], device=device + ) xW = xW.repeat(eqs.shape[:-1] + (1,)) yH = yH.repeat(eqs.shape[:-1] + (1,)) pt_xWs = torch.cross(eqs, xW, dim=-1) @@ -143,11 +164,17 @@ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch. mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid segs[mask, 1, :] = pt_yHs[mask] - assert torch.all(segs >= 0) and torch.all(segs[..., 0] < w) and torch.all(segs[..., 1] < h) + assert ( + torch.all(segs >= 0) + and torch.all(segs[..., 0] < w) + and torch.all(segs[..., 1] < h) + ) return segs -def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) -> Tuple[torch.Tensor, torch.Tensor]: +def warp_lines_torch( + lines, H, inverse=True, dst_shape: Tuple[int, int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """ :param lines: A tensor of shape (B, N, 2, 2) where B is the batch size, N the number of lines. :param H: The homography used to convert the lines. batched or not (shapes (B, 8) and (8,) respectively). @@ -156,12 +183,16 @@ def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) """ device = lines.device batch_size, n = lines.shape[:2] - lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(lines.shape) + lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape( + lines.shape + ) if dst_shape is None: return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device) - out_img = torch.any((lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1) + out_img = torch.any( + (lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1 + ) valid = ~out_img.all(-1) any_out_of_img = out_img.any(-1) lines_to_trim = valid & any_out_of_img diff --git a/third_party/GlueStick/gluestick/models/base_model.py b/third_party/GlueStick/gluestick/models/base_model.py index 30ca991655a28ca88074b42312c33b360f655fab..ef326bbb9e7deb78ee59d7cf9b2a76a5234106b4 100644 --- a/third_party/GlueStick/gluestick/models/base_model.py +++ b/third_party/GlueStick/gluestick/models/base_model.py @@ -13,7 +13,7 @@ class MetaModel(ABCMeta): def __prepare__(name, bases, **kwds): total_conf = OmegaConf.create() for base in bases: - for key in ('base_default_conf', 'default_conf'): + for key in ("base_default_conf", "default_conf"): update = getattr(base, key, {}) if isinstance(update, dict): update = OmegaConf.create(update) @@ -49,10 +49,11 @@ class BaseModel(nn.Module, metaclass=MetaModel): metrics(self, pred, data): method that returns a dictionary of metrics, each as a batch of scalars. """ + default_conf = { - 'name': None, - 'trainable': True, # if false: do not optimize this model parameters - 'freeze_batch_normalization': False, # use test-time statistics + "name": None, + "trainable": True, # if false: do not optimize this model parameters + "freeze_batch_normalization": False, # use test-time statistics } required_data_keys = [] strict_conf = True @@ -61,15 +62,16 @@ class BaseModel(nn.Module, metaclass=MetaModel): """Perform some logic and call the _init method of the child model.""" super().__init__() default_conf = OmegaConf.merge( - self.base_default_conf, OmegaConf.create(self.default_conf)) + self.base_default_conf, OmegaConf.create(self.default_conf) + ) if self.strict_conf: OmegaConf.set_struct(default_conf, True) # fixme: backward compatibility - if 'pad' in conf and 'pad' not in default_conf: # backward compat. + if "pad" in conf and "pad" not in default_conf: # backward compat. with omegaconf.read_write(conf): with omegaconf.open_dict(conf): - conf['interpolation'] = {'pad': conf.pop('pad')} + conf["interpolation"] = {"pad": conf.pop("pad")} if isinstance(conf, dict): conf = OmegaConf.create(conf) @@ -89,6 +91,7 @@ class BaseModel(nn.Module, metaclass=MetaModel): def freeze_bn(module): if isinstance(module, nn.modules.batchnorm._BatchNorm): module.eval() + if self.conf.freeze_batch_normalization: self.apply(freeze_bn) @@ -96,9 +99,10 @@ class BaseModel(nn.Module, metaclass=MetaModel): def forward(self, data): """Check the data and call the _forward method of the child model.""" + def recursive_key_check(expected, given): for key in expected: - assert key in given, f'Missing key {key} in data' + assert key in given, f"Missing key {key} in data" if isinstance(expected, dict): recursive_key_check(expected[key], given[key]) diff --git a/third_party/GlueStick/gluestick/models/gluestick.py b/third_party/GlueStick/gluestick/models/gluestick.py index c2a6c477eebecc2c43feea007f99c2115aa7c216..8179f8ff779401f535260b930a3f5e4d957af614 100644 --- a/third_party/GlueStick/gluestick/models/gluestick.py +++ b/third_party/GlueStick/gluestick/models/gluestick.py @@ -12,139 +12,178 @@ ETH_EPS = 1e-8 class GlueStick(BaseModel): default_conf = { - 'input_dim': 256, - 'descriptor_dim': 256, - 'bottleneck_dim': None, - 'weights': None, - 'keypoint_encoder': [32, 64, 128, 256], - 'GNN_layers': ['self', 'cross'] * 9, - 'num_line_iterations': 1, - 'line_attention': False, - 'filter_threshold': 0.2, - 'checkpointed': False, - 'skip_init': False, - 'inter_supervision': None, - 'loss': { - 'nll_weight': 1., - 'nll_balancing': 0.5, - 'reward_weight': 0., - 'bottleneck_l2_weight': 0., - 'dense_nll_weight': 0., - 'inter_supervision': [0.3, 0.6], + "input_dim": 256, + "descriptor_dim": 256, + "bottleneck_dim": None, + "weights": None, + "keypoint_encoder": [32, 64, 128, 256], + "GNN_layers": ["self", "cross"] * 9, + "num_line_iterations": 1, + "line_attention": False, + "filter_threshold": 0.2, + "checkpointed": False, + "skip_init": False, + "inter_supervision": None, + "loss": { + "nll_weight": 1.0, + "nll_balancing": 0.5, + "reward_weight": 0.0, + "bottleneck_l2_weight": 0.0, + "dense_nll_weight": 0.0, + "inter_supervision": [0.3, 0.6], }, } required_data_keys = [ - 'keypoints0', 'keypoints1', - 'descriptors0', 'descriptors1', - 'keypoint_scores0', 'keypoint_scores1'] - - DEFAULT_LOSS_CONF = {'nll_weight': 1., 'nll_balancing': 0.5, 'reward_weight': 0., 'bottleneck_l2_weight': 0.} + "keypoints0", + "keypoints1", + "descriptors0", + "descriptors1", + "keypoint_scores0", + "keypoint_scores1", + ] + + DEFAULT_LOSS_CONF = { + "nll_weight": 1.0, + "nll_balancing": 0.5, + "reward_weight": 0.0, + "bottleneck_l2_weight": 0.0, + } def _init(self, conf): if conf.bottleneck_dim is not None: self.bottleneck_down = nn.Conv1d( - conf.input_dim, conf.bottleneck_dim, kernel_size=1) + conf.input_dim, conf.bottleneck_dim, kernel_size=1 + ) self.bottleneck_up = nn.Conv1d( - conf.bottleneck_dim, conf.input_dim, kernel_size=1) + conf.bottleneck_dim, conf.input_dim, kernel_size=1 + ) nn.init.constant_(self.bottleneck_down.bias, 0.0) nn.init.constant_(self.bottleneck_up.bias, 0.0) if conf.input_dim != conf.descriptor_dim: self.input_proj = nn.Conv1d( - conf.input_dim, conf.descriptor_dim, kernel_size=1) + conf.input_dim, conf.descriptor_dim, kernel_size=1 + ) nn.init.constant_(self.input_proj.bias, 0.0) - self.kenc = KeypointEncoder(conf.descriptor_dim, - conf.keypoint_encoder) + self.kenc = KeypointEncoder(conf.descriptor_dim, conf.keypoint_encoder) self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder) - self.gnn = AttentionalGNN(conf.descriptor_dim, conf.GNN_layers, - checkpointed=conf.checkpointed, - inter_supervision=conf.inter_supervision, - num_line_iterations=conf.num_line_iterations, - line_attention=conf.line_attention) - self.final_proj = nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, - kernel_size=1) + self.gnn = AttentionalGNN( + conf.descriptor_dim, + conf.GNN_layers, + checkpointed=conf.checkpointed, + inter_supervision=conf.inter_supervision, + num_line_iterations=conf.num_line_iterations, + line_attention=conf.line_attention, + ) + self.final_proj = nn.Conv1d( + conf.descriptor_dim, conf.descriptor_dim, kernel_size=1 + ) nn.init.constant_(self.final_proj.bias, 0.0) nn.init.orthogonal_(self.final_proj.weight, gain=1) self.final_line_proj = nn.Conv1d( - conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) + conf.descriptor_dim, conf.descriptor_dim, kernel_size=1 + ) nn.init.constant_(self.final_line_proj.bias, 0.0) nn.init.orthogonal_(self.final_line_proj.weight, gain=1) if conf.inter_supervision is not None: self.inter_line_proj = nn.ModuleList( - [nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) - for _ in conf.inter_supervision]) + [ + nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) + for _ in conf.inter_supervision + ] + ) self.layer2idx = {} for i, l in enumerate(conf.inter_supervision): nn.init.constant_(self.inter_line_proj[i].bias, 0.0) nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1) self.layer2idx[l] = i - bin_score = torch.nn.Parameter(torch.tensor(1.)) - self.register_parameter('bin_score', bin_score) - line_bin_score = torch.nn.Parameter(torch.tensor(1.)) - self.register_parameter('line_bin_score', line_bin_score) + bin_score = torch.nn.Parameter(torch.tensor(1.0)) + self.register_parameter("bin_score", bin_score) + line_bin_score = torch.nn.Parameter(torch.tensor(1.0)) + self.register_parameter("line_bin_score", line_bin_score) if conf.weights: assert isinstance(conf.weights, str) - state_dict = torch.load(conf.weights, map_location='cpu') - if 'model' in state_dict: - state_dict = {k.replace('matcher.', ''): v for k, v in state_dict['model'].items() if 'matcher.' in k} - state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + state_dict = torch.load(conf.weights, map_location="cpu") + if "model" in state_dict: + state_dict = { + k.replace("matcher.", ""): v + for k, v in state_dict["model"].items() + if "matcher." in k + } + state_dict = { + k.replace("module.", ""): v for k, v in state_dict.items() + } self.load_state_dict(state_dict) def _forward(self, data): - device = data['keypoints0'].device - b_size = len(data['keypoints0']) - image_size0 = (data['image_size0'] if 'image_size0' in data - else data['image0'].shape) - image_size1 = (data['image_size1'] if 'image_size1' in data - else data['image1'].shape) + device = data["keypoints0"].device + b_size = len(data["keypoints0"]) + image_size0 = ( + data["image_size0"] if "image_size0" in data else data["image0"].shape + ) + image_size1 = ( + data["image_size1"] if "image_size1" in data else data["image1"].shape + ) pred = {} - desc0, desc1 = data['descriptors0'], data['descriptors1'] - kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + desc0, desc1 = data["descriptors0"], data["descriptors1"] + kpts0, kpts1 = data["keypoints0"], data["keypoints1"] n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1] - n_lines0, n_lines1 = data['lines0'].shape[1], data['lines1'].shape[1] + n_lines0, n_lines1 = data["lines0"].shape[1], data["lines1"].shape[1] if n_kpts0 == 0 or n_kpts1 == 0: # No detected keypoints nor lines - pred['log_assignment'] = torch.zeros( - b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device) - pred['matches0'] = torch.full( - (b_size, n_kpts0), -1, device=device, dtype=torch.int64) - pred['matches1'] = torch.full( - (b_size, n_kpts1), -1, device=device, dtype=torch.int64) - pred['match_scores0'] = torch.zeros( - (b_size, n_kpts0), device=device, dtype=torch.float32) - pred['match_scores1'] = torch.zeros( - (b_size, n_kpts1), device=device, dtype=torch.float32) - pred['line_log_assignment'] = torch.zeros(b_size, n_lines0, n_lines1, - dtype=torch.float, device=device) - pred['line_matches0'] = torch.full((b_size, n_lines0), -1, - device=device, dtype=torch.int64) - pred['line_matches1'] = torch.full((b_size, n_lines1), -1, - device=device, dtype=torch.int64) - pred['line_match_scores0'] = torch.zeros( - (b_size, n_lines0), device=device, dtype=torch.float32) - pred['line_match_scores1'] = torch.zeros( - (b_size, n_kpts1), device=device, dtype=torch.float32) + pred["log_assignment"] = torch.zeros( + b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device + ) + pred["matches0"] = torch.full( + (b_size, n_kpts0), -1, device=device, dtype=torch.int64 + ) + pred["matches1"] = torch.full( + (b_size, n_kpts1), -1, device=device, dtype=torch.int64 + ) + pred["match_scores0"] = torch.zeros( + (b_size, n_kpts0), device=device, dtype=torch.float32 + ) + pred["match_scores1"] = torch.zeros( + (b_size, n_kpts1), device=device, dtype=torch.float32 + ) + pred["line_log_assignment"] = torch.zeros( + b_size, n_lines0, n_lines1, dtype=torch.float, device=device + ) + pred["line_matches0"] = torch.full( + (b_size, n_lines0), -1, device=device, dtype=torch.int64 + ) + pred["line_matches1"] = torch.full( + (b_size, n_lines1), -1, device=device, dtype=torch.int64 + ) + pred["line_match_scores0"] = torch.zeros( + (b_size, n_lines0), device=device, dtype=torch.float32 + ) + pred["line_match_scores1"] = torch.zeros( + (b_size, n_kpts1), device=device, dtype=torch.float32 + ) return pred - lines0 = data['lines0'].flatten(1, 2) - lines1 = data['lines1'].flatten(1, 2) - lines_junc_idx0 = data['lines_junc_idx0'].flatten(1, 2) # [b_size, num_lines * 2] - lines_junc_idx1 = data['lines_junc_idx1'].flatten(1, 2) + lines0 = data["lines0"].flatten(1, 2) + lines1 = data["lines1"].flatten(1, 2) + lines_junc_idx0 = data["lines_junc_idx0"].flatten( + 1, 2 + ) # [b_size, num_lines * 2] + lines_junc_idx1 = data["lines_junc_idx1"].flatten(1, 2) if self.conf.bottleneck_dim is not None: - pred['down_descriptors0'] = desc0 = self.bottleneck_down(desc0) - pred['down_descriptors1'] = desc1 = self.bottleneck_down(desc1) + pred["down_descriptors0"] = desc0 = self.bottleneck_down(desc0) + pred["down_descriptors1"] = desc1 = self.bottleneck_down(desc1) desc0 = self.bottleneck_up(desc0) desc1 = self.bottleneck_up(desc1) desc0 = nn.functional.normalize(desc0, p=2, dim=1) desc1 = nn.functional.normalize(desc1, p=2, dim=1) - pred['bottleneck_descriptors0'] = desc0 - pred['bottleneck_descriptors1'] = desc1 + pred["bottleneck_descriptors0"] = desc0 + pred["bottleneck_descriptors1"] = desc1 if self.conf.loss.nll_weight == 0: desc0 = desc0.detach() desc1 = desc1.detach() @@ -158,79 +197,113 @@ class GlueStick(BaseModel): assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) - desc0 = desc0 + self.kenc(kpts0, data['keypoint_scores0']) - desc1 = desc1 + self.kenc(kpts1, data['keypoint_scores1']) + desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"]) + desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"]) if n_lines0 != 0 and n_lines1 != 0: # Pre-compute the line encodings lines0 = normalize_keypoints(lines0, image_size0).reshape( - b_size, n_lines0, 2, 2) + b_size, n_lines0, 2, 2 + ) lines1 = normalize_keypoints(lines1, image_size1).reshape( - b_size, n_lines1, 2, 2) - line_enc0 = self.lenc(lines0, data['line_scores0']) - line_enc1 = self.lenc(lines1, data['line_scores1']) + b_size, n_lines1, 2, 2 + ) + line_enc0 = self.lenc(lines0, data["line_scores0"]) + line_enc1 = self.lenc(lines1, data["line_scores1"]) else: line_enc0 = torch.zeros( - b_size, self.conf.descriptor_dim, n_lines0 * 2, - dtype=torch.float, device=device) + b_size, + self.conf.descriptor_dim, + n_lines0 * 2, + dtype=torch.float, + device=device, + ) line_enc1 = torch.zeros( - b_size, self.conf.descriptor_dim, n_lines1 * 2, - dtype=torch.float, device=device) + b_size, + self.conf.descriptor_dim, + n_lines1 * 2, + dtype=torch.float, + device=device, + ) - desc0, desc1 = self.gnn(desc0, desc1, line_enc0, line_enc1, - lines_junc_idx0, lines_junc_idx1) + desc0, desc1 = self.gnn( + desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1 + ) # Match all points (KP and line junctions) mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) - kp_scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) - kp_scores = kp_scores / self.conf.descriptor_dim ** .5 + kp_scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1) + kp_scores = kp_scores / self.conf.descriptor_dim**0.5 kp_scores = log_double_softmax(kp_scores, self.bin_score) m0, m1, mscores0, mscores1 = self._get_matches(kp_scores) - pred['log_assignment'] = kp_scores - pred['matches0'] = m0 - pred['matches1'] = m1 - pred['match_scores0'] = mscores0 - pred['match_scores1'] = mscores1 + pred["log_assignment"] = kp_scores + pred["matches0"] = m0 + pred["matches1"] = m1 + pred["match_scores0"] = mscores0 + pred["match_scores1"] = mscores1 # Match the lines if n_lines0 > 0 and n_lines1 > 0: - (line_scores, m0_lines, m1_lines, mscores0_lines, - mscores1_lines, raw_line_scores) = self._get_line_matches( - desc0[:, :, :2 * n_lines0], desc1[:, :, :2 * n_lines1], - lines_junc_idx0, lines_junc_idx1, self.final_line_proj) + ( + line_scores, + m0_lines, + m1_lines, + mscores0_lines, + mscores1_lines, + raw_line_scores, + ) = self._get_line_matches( + desc0[:, :, : 2 * n_lines0], + desc1[:, :, : 2 * n_lines1], + lines_junc_idx0, + lines_junc_idx1, + self.final_line_proj, + ) if self.conf.inter_supervision: for l in self.conf.inter_supervision: - (line_scores_i, m0_lines_i, m1_lines_i, mscores0_lines_i, - mscores1_lines_i) = self._get_line_matches( - self.gnn.inter_layers[l][0][:, :, :2 * n_lines0], - self.gnn.inter_layers[l][1][:, :, :2 * n_lines1], - lines_junc_idx0, lines_junc_idx1, - self.inter_line_proj[self.layer2idx[l]]) - pred[f'line_{l}_log_assignment'] = line_scores_i - pred[f'line_{l}_matches0'] = m0_lines_i - pred[f'line_{l}_matches1'] = m1_lines_i - pred[f'line_{l}_match_scores0'] = mscores0_lines_i - pred[f'line_{l}_match_scores1'] = mscores1_lines_i + ( + line_scores_i, + m0_lines_i, + m1_lines_i, + mscores0_lines_i, + mscores1_lines_i, + ) = self._get_line_matches( + self.gnn.inter_layers[l][0][:, :, : 2 * n_lines0], + self.gnn.inter_layers[l][1][:, :, : 2 * n_lines1], + lines_junc_idx0, + lines_junc_idx1, + self.inter_line_proj[self.layer2idx[l]], + ) + pred[f"line_{l}_log_assignment"] = line_scores_i + pred[f"line_{l}_matches0"] = m0_lines_i + pred[f"line_{l}_matches1"] = m1_lines_i + pred[f"line_{l}_match_scores0"] = mscores0_lines_i + pred[f"line_{l}_match_scores1"] = mscores1_lines_i else: - line_scores = torch.zeros(b_size, n_lines0, n_lines1, - dtype=torch.float, device=device) - m0_lines = torch.full((b_size, n_lines0), -1, - device=device, dtype=torch.int64) - m1_lines = torch.full((b_size, n_lines1), -1, - device=device, dtype=torch.int64) + line_scores = torch.zeros( + b_size, n_lines0, n_lines1, dtype=torch.float, device=device + ) + m0_lines = torch.full( + (b_size, n_lines0), -1, device=device, dtype=torch.int64 + ) + m1_lines = torch.full( + (b_size, n_lines1), -1, device=device, dtype=torch.int64 + ) mscores0_lines = torch.zeros( - (b_size, n_lines0), device=device, dtype=torch.float32) + (b_size, n_lines0), device=device, dtype=torch.float32 + ) mscores1_lines = torch.zeros( - (b_size, n_lines1), device=device, dtype=torch.float32) - raw_line_scores = torch.zeros(b_size, n_lines0, n_lines1, - dtype=torch.float, device=device) - pred['line_log_assignment'] = line_scores - pred['line_matches0'] = m0_lines - pred['line_matches1'] = m1_lines - pred['line_match_scores0'] = mscores0_lines - pred['line_match_scores1'] = mscores1_lines - pred['raw_line_scores'] = raw_line_scores + (b_size, n_lines1), device=device, dtype=torch.float32 + ) + raw_line_scores = torch.zeros( + b_size, n_lines0, n_lines1, dtype=torch.float, device=device + ) + pred["line_log_assignment"] = line_scores + pred["line_matches0"] = m0_lines + pred["line_matches1"] = m1_lines + pred["line_match_scores0"] = mscores0_lines + pred["line_match_scores1"] = mscores1_lines + pred["raw_line_scores"] = raw_line_scores return pred @@ -249,35 +322,47 @@ class GlueStick(BaseModel): m1 = torch.where(valid1, m1, m1.new_tensor(-1)) return m0, m1, mscores0, mscores1 - def _get_line_matches(self, ldesc0, ldesc1, lines_junc_idx0, - lines_junc_idx1, final_proj): + def _get_line_matches( + self, ldesc0, ldesc1, lines_junc_idx0, lines_junc_idx1, final_proj + ): mldesc0 = final_proj(ldesc0) mldesc1 = final_proj(ldesc1) - line_scores = torch.einsum('bdn,bdm->bnm', mldesc0, mldesc1) - line_scores = line_scores / self.conf.descriptor_dim ** .5 + line_scores = torch.einsum("bdn,bdm->bnm", mldesc0, mldesc1) + line_scores = line_scores / self.conf.descriptor_dim**0.5 # Get the line representation from the junction descriptors n2_lines0 = lines_junc_idx0.shape[1] n2_lines1 = lines_junc_idx1.shape[1] line_scores = torch.gather( - line_scores, dim=2, - index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1)) + line_scores, + dim=2, + index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1), + ) line_scores = torch.gather( - line_scores, dim=1, - index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1)) - line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, - n2_lines1 // 2, 2)) + line_scores, + dim=1, + index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1), + ) + line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, n2_lines1 // 2, 2)) # Match either in one direction or the other raw_line_scores = 0.5 * torch.maximum( line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1], - line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0]) + line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0], + ) line_scores = log_double_softmax(raw_line_scores, self.line_bin_score) m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches( - line_scores) - return (line_scores, m0_lines, m1_lines, mscores0_lines, - mscores1_lines, raw_line_scores) + line_scores + ) + return ( + line_scores, + m0_lines, + m1_lines, + mscores0_lines, + mscores1_lines, + raw_line_scores, + ) def loss(self, pred, data): raise NotImplementedError() @@ -290,8 +375,7 @@ def MLP(channels, do_bn=True): n = len(channels) layers = [] for i in range(1, n): - layers.append( - nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n - 1): if do_bn: layers.append(nn.BatchNorm1d(channels[i])) @@ -338,17 +422,20 @@ class EndPtEncoder(nn.Module): endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2) endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2) endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2) - inputs = [endpoints.flatten(1, 2).transpose(1, 2), - endpt_offset, scores.repeat(1, 2).unsqueeze(1)] + inputs = [ + endpoints.flatten(1, 2).transpose(1, 2), + endpt_offset, + scores.repeat(1, 2).unsqueeze(1), + ] return self.encoder(torch.cat(inputs, dim=1)) @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def attention(query, key, value): dim = query.shape[1] - scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 + scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5 prob = torch.nn.functional.softmax(scores, dim=-1) - return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob + return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob class MultiHeadedAttention(nn.Module): @@ -363,8 +450,10 @@ class MultiHeadedAttention(nn.Module): def forward(self, query, key, value): b = query.size(0) - query, key, value = [l(x).view(b, self.dim, self.h, -1) - for l, x in zip(self.proj, (query, key, value))] + query, key, value = [ + l(x).view(b, self.dim, self.h, -1) + for l, x in zip(self.proj, (query, key, value)) + ] x, prob = attention(query, key, value) # self.prob.append(prob.mean(dim=1)) return self.merge(x.contiguous().view(b, self.dim * self.h, -1)) @@ -377,9 +466,9 @@ class AttentionalPropagation(nn.Module): self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True) nn.init.constant_(self.mlp[-1].bias, 0.0) if skip_init: - self.register_parameter('scaling', nn.Parameter(torch.tensor(0.))) + self.register_parameter("scaling", nn.Parameter(torch.tensor(0.0))) else: - self.scaling = 1. + self.scaling = 1.0 def forward(self, x, source): message = self.attn(x, source, source) @@ -389,14 +478,14 @@ class AttentionalPropagation(nn.Module): class GNNLayer(nn.Module): def __init__(self, feature_dim, layer_type, skip_init): super().__init__() - assert layer_type in ['cross', 'self'] + assert layer_type in ["cross", "self"] self.type = layer_type self.update = AttentionalPropagation(feature_dim, 4, skip_init) def forward(self, desc0, desc1): - if self.type == 'cross': + if self.type == "cross": src0, src1 = desc1, desc0 - elif self.type == 'self': + elif self.type == "self": src0, src1 = desc0, desc1 else: raise ValueError("Unknown layer type: " + self.type) @@ -422,11 +511,19 @@ class LineLayer(nn.Module): # Create one message per line endpoint b_size = lines_junc_idx.shape[0] line_desc = torch.gather( - ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)) - message = torch.cat([ - line_desc, - line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), - line_enc], dim=1) + ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1) + ) + message = torch.cat( + [ + line_desc, + line_desc.reshape(b_size, self.dim, -1, 2) + .flip([-1]) + .flatten(2, 3) + .clone(), + line_enc, + ], + dim=1, + ) return self.mlp(message) # [b_size, D, n_lines * 2] def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx): @@ -442,22 +539,32 @@ class LineLayer(nn.Module): # Key: combination of neighboring desc and line encodings line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx) - key = self.proj_neigh(torch.cat([ - line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), - line_enc], dim=1)) # [b_size, D, n_lines * 2] + key = self.proj_neigh( + torch.cat( + [ + line_desc.reshape(b_size, self.dim, -1, 2) + .flip([-1]) + .flatten(2, 3) + .clone(), + line_enc, + ], + dim=1, + ) + ) # [b_size, D, n_lines * 2] # Compute the attention weights with a custom softmax per junction - prob = (query * key).sum(dim=1) / self.dim ** .5 # [b_size, n_lines * 2] + prob = (query * key).sum(dim=1) / self.dim**0.5 # [b_size, n_lines * 2] prob = torch.exp(prob - prob.max()) denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_( - dim=1, index=lines_junc_idx, - src=prob, reduce='sum', include_self=False) # [b_size, n_junc] + dim=1, index=lines_junc_idx, src=prob, reduce="sum", include_self=False + ) # [b_size, n_junc] denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2] prob = prob / (denom + ETH_EPS) return prob # [b_size, n_lines * 2] - def forward(self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, - lines_junc_idx1): + def forward( + self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1 + ): # Gather the endpoint updates lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0) lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1) @@ -466,26 +573,40 @@ class LineLayer(nn.Module): dim = ldesc0.shape[1] if self.line_attention: # Compute an attention for each neighbor and do a weighted average - prob0 = self.get_endpoint_attention(ldesc0, line_enc0, - lines_junc_idx0) + prob0 = self.get_endpoint_attention(ldesc0, line_enc0, lines_junc_idx0) lupdate0 = lupdate0 * prob0[:, None] update0 = update0.scatter_reduce_( - dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), - src=lupdate0, reduce='sum', include_self=False) - prob1 = self.get_endpoint_attention(ldesc1, line_enc1, - lines_junc_idx1) + dim=2, + index=lines_junc_idx0[:, None].repeat(1, dim, 1), + src=lupdate0, + reduce="sum", + include_self=False, + ) + prob1 = self.get_endpoint_attention(ldesc1, line_enc1, lines_junc_idx1) lupdate1 = lupdate1 * prob1[:, None] update1 = update1.scatter_reduce_( - dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), - src=lupdate1, reduce='sum', include_self=False) + dim=2, + index=lines_junc_idx1[:, None].repeat(1, dim, 1), + src=lupdate1, + reduce="sum", + include_self=False, + ) else: # Average the updates for each junction (requires torch > 1.12) update0 = update0.scatter_reduce_( - dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), - src=lupdate0, reduce='mean', include_self=False) + dim=2, + index=lines_junc_idx0[:, None].repeat(1, dim, 1), + src=lupdate0, + reduce="mean", + include_self=False, + ) update1 = update1.scatter_reduce_( - dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), - src=lupdate1, reduce='mean', include_self=False) + dim=2, + index=lines_junc_idx1[:, None].repeat(1, dim, 1), + src=lupdate1, + reduce="mean", + include_self=False, + ) # Update ldesc0 = ldesc0 + update0 @@ -495,47 +616,75 @@ class LineLayer(nn.Module): class AttentionalGNN(nn.Module): - def __init__(self, feature_dim, layer_types, checkpointed=False, - skip=False, inter_supervision=None, num_line_iterations=1, - line_attention=False): + def __init__( + self, + feature_dim, + layer_types, + checkpointed=False, + skip=False, + inter_supervision=None, + num_line_iterations=1, + line_attention=False, + ): super().__init__() self.checkpointed = checkpointed self.inter_supervision = inter_supervision self.num_line_iterations = num_line_iterations self.inter_layers = {} - self.layers = nn.ModuleList([ - GNNLayer(feature_dim, layer_type, skip) - for layer_type in layer_types]) + self.layers = nn.ModuleList( + [GNNLayer(feature_dim, layer_type, skip) for layer_type in layer_types] + ) self.line_layers = nn.ModuleList( - [LineLayer(feature_dim, line_attention) - for _ in range(len(layer_types) // 2)]) - - def forward(self, desc0, desc1, line_enc0, line_enc1, - lines_junc_idx0, lines_junc_idx1): + [ + LineLayer(feature_dim, line_attention) + for _ in range(len(layer_types) // 2) + ] + ) + + def forward( + self, desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1 + ): for i, layer in enumerate(self.layers): if self.checkpointed: desc0, desc1 = torch.utils.checkpoint.checkpoint( - layer, desc0, desc1, preserve_rng_state=False) + layer, desc0, desc1, preserve_rng_state=False + ) else: desc0, desc1 = layer(desc0, desc1) - if (layer.type == 'self' and lines_junc_idx0.shape[1] > 0 - and lines_junc_idx1.shape[1] > 0): + if ( + layer.type == "self" + and lines_junc_idx0.shape[1] > 0 + and lines_junc_idx1.shape[1] > 0 + ): # Add line self attention layers after every self layer for _ in range(self.num_line_iterations): if self.checkpointed: desc0, desc1 = torch.utils.checkpoint.checkpoint( - self.line_layers[i // 2], desc0, desc1, line_enc0, - line_enc1, lines_junc_idx0, lines_junc_idx1, - preserve_rng_state=False) + self.line_layers[i // 2], + desc0, + desc1, + line_enc0, + line_enc1, + lines_junc_idx0, + lines_junc_idx1, + preserve_rng_state=False, + ) else: desc0, desc1 = self.line_layers[i // 2]( - desc0, desc1, line_enc0, line_enc1, - lines_junc_idx0, lines_junc_idx1) + desc0, + desc1, + line_enc0, + line_enc1, + lines_junc_idx0, + lines_junc_idx1, + ) # Optionally store the line descriptor at intermediate layers - if (self.inter_supervision is not None - and (i // 2) in self.inter_supervision - and layer.type == 'cross'): + if ( + self.inter_supervision is not None + and (i // 2) in self.inter_supervision + and layer.type == "cross" + ): self.inter_layers[i // 2] = (desc0.clone(), desc1.clone()) return desc0, desc1 diff --git a/third_party/GlueStick/gluestick/models/superpoint.py b/third_party/GlueStick/gluestick/models/superpoint.py index 0e0948a90cf5c858ddd14cc498231479fa10d6e3..19e66cdba41749a765829cce0ead608afb04964c 100644 --- a/third_party/GlueStick/gluestick/models/superpoint.py +++ b/third_party/GlueStick/gluestick/models/superpoint.py @@ -25,7 +25,8 @@ def simple_nms(scores, radius): def max_pool(x): return torch.nn.functional.max_pool2d( - x, kernel_size=radius * 2 + 1, stride=1, padding=radius) + x, kernel_size=radius * 2 + 1, stride=1, padding=radius + ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) @@ -54,33 +55,35 @@ def top_k_keypoints(keypoints, scores, k): def sample_descriptors(keypoints, descriptors, s): b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 - keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], - ).to(keypoints)[None] + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to( + keypoints + )[None] keypoints = keypoints * 2 - 1 # normalize to (-1, 1) - args = {'align_corners': True} if torch.__version__ >= '1.3' else {} + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} descriptors = torch.nn.functional.grid_sample( - descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) descriptors = torch.nn.functional.normalize( - descriptors.reshape(b, c, -1), p=2, dim=1) + descriptors.reshape(b, c, -1), p=2, dim=1 + ) return descriptors class SuperPoint(BaseModel): default_conf = { - 'has_detector': True, - 'has_descriptor': True, - 'descriptor_dim': 256, - + "has_detector": True, + "has_descriptor": True, + "descriptor_dim": 256, # Inference - 'return_all': False, - 'sparse_outputs': True, - 'nms_radius': 4, - 'detection_threshold': 0.005, - 'max_num_keypoints': -1, - 'force_num_keypoints': False, - 'remove_borders': 4, + "return_all": False, + "sparse_outputs": True, + "nms_radius": 4, + "detection_threshold": 0.005, + "max_num_keypoints": -1, + "force_num_keypoints": False, + "remove_borders": 4, } - required_data_keys = ['image'] + required_data_keys = ["image"] def _init(self, conf): self.relu = nn.ReLU(inplace=True) @@ -103,13 +106,14 @@ class SuperPoint(BaseModel): if conf.has_descriptor: self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0) + c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0 + ) - path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth' + path = GLUESTICK_ROOT / "resources" / "weights" / "superpoint_v1.pth" self.load_state_dict(torch.load(str(path)), strict=False) def _forward(self, data): - image = data['image'] + image = data["image"] if image.shape[1] == 3: # RGB scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) image = (image * scale).sum(1, keepdim=True) @@ -136,22 +140,24 @@ class SuperPoint(BaseModel): b, c, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) - pred['keypoint_scores'] = dense_scores = scores + pred["keypoint_scores"] = dense_scores = scores if self.conf.has_descriptor: # Compute the dense descriptors cDa = self.relu(self.convDa(x)) all_desc = self.convDb(cDa) all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1) - pred['descriptors'] = all_desc + pred["descriptors"] = all_desc if self.conf.max_num_keypoints == 0: # Predict dense descriptors only b_size = len(image) device = image.device return { - 'keypoints': torch.empty(b_size, 0, 2, device=device), - 'keypoint_scores': torch.empty(b_size, 0, device=device), - 'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device), - 'all_descriptors': all_desc + "keypoints": torch.empty(b_size, 0, 2, device=device), + "keypoint_scores": torch.empty(b_size, 0, device=device), + "descriptors": torch.empty( + b_size, self.conf.descriptor_dim, 0, device=device + ), + "all_descriptors": all_desc, } if self.conf.sparse_outputs: @@ -161,26 +167,36 @@ class SuperPoint(BaseModel): # Extract keypoints keypoints = [ - torch.nonzero(s > self.conf.detection_threshold) - for s in scores] + torch.nonzero(s > self.conf.detection_threshold) for s in scores + ] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders - keypoints, scores = list(zip(*[ - remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) - for k, s in zip(keypoints, scores)])) + keypoints, scores = list( + zip( + *[ + remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) + for k, s in zip(keypoints, scores) + ] + ) + ) # Keep the k keypoints with highest score if self.conf.max_num_keypoints > 0: - keypoints, scores = list(zip(*[ - top_k_keypoints(k, s, self.conf.max_num_keypoints) - for k, s in zip(keypoints, scores)])) + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, self.conf.max_num_keypoints) + for k, s in zip(keypoints, scores) + ] + ) + ) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] if self.conf.force_num_keypoints: - _, _, h, w = data['image'].shape + _, _, h, w = data["image"].shape assert self.conf.max_num_keypoints > 0 scores = list(scores) for i in range(len(keypoints)): @@ -194,8 +210,10 @@ class SuperPoint(BaseModel): scores[i] = torch.cat([s, new_s], 0) # Extract descriptors - desc = [sample_descriptors(k[None], d[None], 8)[0] - for k, d in zip(keypoints, all_desc)] + desc = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, all_desc) + ] if (len(keypoints) == 1) or self.conf.force_num_keypoints: keypoints = torch.stack(keypoints, 0) @@ -203,14 +221,14 @@ class SuperPoint(BaseModel): desc = torch.stack(desc, 0) pred = { - 'keypoints': keypoints, - 'keypoint_scores': scores, - 'descriptors': desc, + "keypoints": keypoints, + "keypoint_scores": scores, + "descriptors": desc, } if self.conf.return_all: - pred['all_descriptors'] = all_desc - pred['dense_score'] = dense_scores + pred["all_descriptors"] = all_desc + pred["dense_score"] = dense_scores else: del all_desc torch.cuda.empty_cache() diff --git a/third_party/GlueStick/gluestick/models/two_view_pipeline.py b/third_party/GlueStick/gluestick/models/two_view_pipeline.py index e0e21c1f62e2bd4ad573ebb87ea5635742b5032e..07a7bf06ea8c7ad2abba5fac2568ebcaffd497b0 100644 --- a/third_party/GlueStick/gluestick/models/two_view_pipeline.py +++ b/third_party/GlueStick/gluestick/models/two_view_pipeline.py @@ -22,10 +22,12 @@ def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): h2, w2 = h // 2, w // 2 w_x = np.random.choice([0, w2]) w_y = np.random.choice([0, h2]) - valid_mask = ((keypoints[..., 0] >= w_x) - & (keypoints[..., 0] < w_x + w2) - & (keypoints[..., 1] >= w_y) - & (keypoints[..., 1] < w_y + h2)) + valid_mask = ( + (keypoints[..., 0] >= w_x) + & (keypoints[..., 0] < w_x + w2) + & (keypoints[..., 1] >= w_y) + & (keypoints[..., 1] < w_y + h2) + ) keypoints = keypoints[valid_mask][None] scores = scores[valid_mask][None] descs = descs.permute(0, 2, 1)[valid_mask].t()[None] @@ -46,47 +48,44 @@ def keep_best_kp_subset(keypoints, scores, descs, num_selected): """Keep the top num_selected best keypoints.""" sorted_indices = torch.sort(scores, dim=1)[1] selected_kp = sorted_indices[:, -num_selected:] - keypoints = torch.gather(keypoints, 1, - selected_kp[:, :, None].repeat(1, 1, 2)) + keypoints = torch.gather(keypoints, 1, selected_kp[:, :, None].repeat(1, 1, 2)) scores = torch.gather(scores, 1, selected_kp) - descs = torch.gather(descs, 2, - selected_kp[:, None].repeat(1, descs.shape[1], 1)) + descs = torch.gather(descs, 2, selected_kp[:, None].repeat(1, descs.shape[1], 1)) return keypoints, scores, descs class TwoViewPipeline(BaseModel): default_conf = { - 'extractor': { - 'name': 'superpoint', - 'trainable': False, + "extractor": { + "name": "superpoint", + "trainable": False, }, - 'use_lines': False, - 'use_points': True, - 'randomize_num_kp': False, - 'detector': {'name': None}, - 'descriptor': {'name': None}, - 'matcher': {'name': 'nearest_neighbor_matcher'}, - 'filter': {'name': None}, - 'solver': {'name': None}, - 'ground_truth': { - 'from_pose_depth': False, - 'from_homography': False, - 'th_positive': 3, - 'th_negative': 5, - 'reward_positive': 1, - 'reward_negative': -0.25, - 'is_likelihood_soft': True, - 'p_random_occluders': 0, - 'n_line_sampled_pts': 50, - 'line_perp_dist_th': 5, - 'overlap_th': 0.2, - 'min_visibility_th': 0.5 + "use_lines": False, + "use_points": True, + "randomize_num_kp": False, + "detector": {"name": None}, + "descriptor": {"name": None}, + "matcher": {"name": "nearest_neighbor_matcher"}, + "filter": {"name": None}, + "solver": {"name": None}, + "ground_truth": { + "from_pose_depth": False, + "from_homography": False, + "th_positive": 3, + "th_negative": 5, + "reward_positive": 1, + "reward_negative": -0.25, + "is_likelihood_soft": True, + "p_random_occluders": 0, + "n_line_sampled_pts": 50, + "line_perp_dist_th": 5, + "overlap_th": 0.2, + "min_visibility_th": 0.5, }, } - required_data_keys = ['image0', 'image1'] + required_data_keys = ["image0", "image1"] strict_conf = False # need to pass new confs to children models - components = [ - 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] + components = ["extractor", "detector", "descriptor", "matcher", "filter", "solver"] def _init(self, conf): if conf.extractor.name: @@ -95,17 +94,16 @@ class TwoViewPipeline(BaseModel): if self.conf.detector.name: self.detector = get_model(conf.detector.name)(conf.detector) else: - self.required_data_keys += ['keypoints0', 'keypoints1'] + self.required_data_keys += ["keypoints0", "keypoints1"] if self.conf.descriptor.name: - self.descriptor = get_model(conf.descriptor.name)( - conf.descriptor) + self.descriptor = get_model(conf.descriptor.name)(conf.descriptor) else: - self.required_data_keys += ['descriptors0', 'descriptors1'] + self.required_data_keys += ["descriptors0", "descriptors1"] if conf.matcher.name: self.matcher = get_model(conf.matcher.name)(conf.matcher) else: - self.required_data_keys += ['matches0'] + self.required_data_keys += ["matches0"] if conf.filter.name: self.filter = get_model(conf.filter.name)(conf.filter) @@ -114,7 +112,6 @@ class TwoViewPipeline(BaseModel): self.solver = get_model(conf.solver.name)(conf.solver) def _forward(self, data): - def process_siamese(data, i): data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} if self.conf.extractor.name: @@ -124,21 +121,28 @@ class TwoViewPipeline(BaseModel): if self.conf.detector.name: pred_i = self.detector(data_i) else: - for k in ['keypoints', 'keypoint_scores', 'descriptors', - 'lines', 'line_scores', 'line_descriptors', - 'valid_lines']: + for k in [ + "keypoints", + "keypoint_scores", + "descriptors", + "lines", + "line_scores", + "line_descriptors", + "valid_lines", + ]: if k in data_i: pred_i[k] = data_i[k] if self.conf.descriptor.name: - pred_i = { - **pred_i, **self.descriptor({**data_i, **pred_i})} + pred_i = {**pred_i, **self.descriptor({**data_i, **pred_i})} return pred_i - pred0 = process_siamese(data, '0') - pred1 = process_siamese(data, '1') + pred0 = process_siamese(data, "0") + pred1 = process_siamese(data, "1") - pred = {**{k + '0': v for k, v in pred0.items()}, - **{k + '1': v for k, v in pred1.items()}} + pred = { + **{k + "0": v for k, v in pred0.items()}, + **{k + "1": v for k, v in pred1.items()}, + } if self.conf.matcher.name: pred = {**pred, **self.matcher({**data, **pred})} @@ -161,8 +165,8 @@ class TwoViewPipeline(BaseModel): except NotImplementedError: continue losses = {**losses, **losses_} - total = losses_['total'] + total - return {**losses, 'total': total} + total = losses_["total"] + total + return {**losses, "total": total} def metrics(self, pred, data): metrics = {} diff --git a/third_party/GlueStick/gluestick/models/wireframe.py b/third_party/GlueStick/gluestick/models/wireframe.py index 0e3dd9873c6fdb4edcb4c75a103673ee2cb3b3fa..9da539387c6da8a5a8df6c677af69803ccdb54b4 100644 --- a/third_party/GlueStick/gluestick/models/wireframe.py +++ b/third_party/GlueStick/gluestick/models/wireframe.py @@ -9,7 +9,7 @@ from ..geometry import warp_lines_torch def lines_to_wireframe(lines, line_scores, all_descs, conf): - """ Given a set of lines, their score and dense descriptors, + """Given a set of lines, their score and dense descriptors, merge close-by endpoints and compute a wireframe defined by its junctions and connectivity. Returns: @@ -26,29 +26,41 @@ def lines_to_wireframe(lines, line_scores, all_descs, conf): device = lines.device endpoints = lines.reshape(b_size, -1, 2) - (junctions, junc_scores, junc_descs, connectivity, new_lines, - lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], [] + ( + junctions, + junc_scores, + junc_descs, + connectivity, + new_lines, + lines_junc_idx, + num_true_junctions, + ) = ([], [], [], [], [], [], []) for bs in range(b_size): # Cluster the junctions that are close-by - db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit( - endpoints[bs].cpu().numpy()) + db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy()) clusters = db.labels_ n_clusters = len(set(clusters)) num_true_junctions.append(n_clusters) # Compute the average junction and score for each cluster - clusters = torch.tensor(clusters, dtype=torch.long, - device=device) - new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, - device=device) - new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2), - endpoints[bs], reduce='mean', - include_self=False) + clusters = torch.tensor(clusters, dtype=torch.long, device=device) + new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device) + new_junc.scatter_reduce_( + 0, + clusters[:, None].repeat(1, 2), + endpoints[bs], + reduce="mean", + include_self=False, + ) junctions.append(new_junc) new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) new_scores.scatter_reduce_( - 0, clusters, torch.repeat_interleave(line_scores[bs], 2), - reduce='mean', include_self=False) + 0, + clusters, + torch.repeat_interleave(line_scores[bs], 2), + reduce="mean", + include_self=False, + ) junc_scores.append(new_scores) # Compute the new lines @@ -56,50 +68,56 @@ def lines_to_wireframe(lines, line_scores, all_descs, conf): lines_junc_idx.append(clusters.reshape(-1, 2)) # Compute the junction connectivity - junc_connect = torch.eye(n_clusters, dtype=torch.bool, - device=device) + junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device) pairs = clusters.reshape(-1, 2) # these pairs are connected by a line junc_connect[pairs[:, 0], pairs[:, 1]] = True junc_connect[pairs[:, 1], pairs[:, 0]] = True connectivity.append(junc_connect) # Interpolate the new junction descriptors - junc_descs.append(sample_descriptors( - junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0]) + junc_descs.append( + sample_descriptors(junctions[-1][None], all_descs[bs : (bs + 1)], 8)[0] + ) new_lines = torch.stack(new_lines, dim=0) lines_junc_idx = torch.stack(lines_junc_idx, dim=0) - return (junctions, junc_scores, junc_descs, connectivity, - new_lines, lines_junc_idx, num_true_junctions) + return ( + junctions, + junc_scores, + junc_descs, + connectivity, + new_lines, + lines_junc_idx, + num_true_junctions, + ) class SPWireframeDescriptor(BaseModel): default_conf = { - 'sp_params': { - 'has_detector': True, - 'has_descriptor': True, - 'descriptor_dim': 256, - 'trainable': False, - + "sp_params": { + "has_detector": True, + "has_descriptor": True, + "descriptor_dim": 256, + "trainable": False, # Inference - 'return_all': True, - 'sparse_outputs': True, - 'nms_radius': 4, - 'detection_threshold': 0.005, - 'max_num_keypoints': 1000, - 'force_num_keypoints': True, - 'remove_borders': 4, + "return_all": True, + "sparse_outputs": True, + "nms_radius": 4, + "detection_threshold": 0.005, + "max_num_keypoints": 1000, + "force_num_keypoints": True, + "remove_borders": 4, }, - 'wireframe_params': { - 'merge_points': True, - 'merge_line_endpoints': True, - 'nms_radius': 3, - 'max_n_junctions': 500, + "wireframe_params": { + "merge_points": True, + "merge_line_endpoints": True, + "nms_radius": 3, + "max_n_junctions": 500, }, - 'max_n_lines': 250, - 'min_length': 15, + "max_n_lines": 250, + "min_length": 15, } - required_data_keys = ['image'] + required_data_keys = ["image"] def _init(self, conf): self.conf = conf @@ -139,78 +157,108 @@ class SPWireframeDescriptor(BaseModel): return lines, scores, valid_lines def _forward(self, data): - b_size, _, h, w = data['image'].shape - device = data['image'].device + b_size, _, h, w = data["image"].shape + device = data["image"].device if not self.conf.sp_params.force_num_keypoints: assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" # Line detection - if 'lines' not in data or 'line_scores' not in data: - if 'original_img' in data: + if "lines" not in data or "line_scores" not in data: + if "original_img" in data: # Detect more lines, because when projecting them to the image most of them will be discarded lines, line_scores, valid_lines = self.detect_lsd_lines( - data['original_img'], self.conf.max_n_lines * 3) + data["original_img"], self.conf.max_n_lines * 3 + ) # Apply the same transformation that is applied in homography_adaptation - lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:]) + lines, valid_lines2 = warp_lines_torch( + lines, data["H"], False, data["image"].shape[-2:] + ) valid_lines = valid_lines & valid_lines2 lines[~valid_lines] = -1 line_scores[~valid_lines] = 0 # Re-sort the line segments to pick the ones that are inside the image and have bigger score - sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True) - line_scores = sorted_scores[:, :self.conf.max_n_lines] - sorting_indices = sorting_indices[:, :self.conf.max_n_lines] + sorted_scores, sorting_indices = torch.sort( + line_scores, dim=-1, descending=True + ) + line_scores = sorted_scores[:, : self.conf.max_n_lines] + sorting_indices = sorting_indices[:, : self.conf.max_n_lines] lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) else: - lines, line_scores, valid_lines = self.detect_lsd_lines(data['image']) + lines, line_scores, valid_lines = self.detect_lsd_lines(data["image"]) else: - lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines'] + lines, line_scores, valid_lines = ( + data["lines"], + data["line_scores"], + data["valid_lines"], + ) if line_scores.shape[-1] != 0: - line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]) + line_scores /= ( + line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None] + ) # SuperPoint prediction pred = self.sp(data) # Remove keypoints that are too close to line endpoints if self.conf.wireframe_params.merge_points: - kp = pred['keypoints'] + kp = pred["keypoints"] line_endpts = lines.reshape(b_size, -1, 2) - dist_pt_lines = torch.norm( - kp[:, :, None] - line_endpts[:, None], dim=-1) + dist_pt_lines = torch.norm(kp[:, :, None] - line_endpts[:, None], dim=-1) # For each keypoint, mark it as valid or to remove pts_to_remove = torch.any( - dist_pt_lines < self.conf.sp_params.nms_radius, dim=2) + dist_pt_lines < self.conf.sp_params.nms_radius, dim=2 + ) # Simply remove them (we assume batch_size = 1 here) assert len(kp) == 1 - pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None] - pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None] - pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None] + pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None] + pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][ + None + ] + pred["descriptors"] = pred["descriptors"][0].T[~pts_to_remove[0]].T[None] # Connect the lines together to form a wireframe orig_lines = lines.clone() if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: # Merge first close-by endpoints to connect lines - (line_points, line_pts_scores, line_descs, line_association, - lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe( - lines, line_scores, pred['all_descriptors'], - conf=self.conf.wireframe_params) + ( + line_points, + line_pts_scores, + line_descs, + line_association, + lines, + lines_junc_idx, + num_true_junctions, + ) = lines_to_wireframe( + lines, + line_scores, + pred["all_descriptors"], + conf=self.conf.wireframe_params, + ) # Add the keypoints to the junctions and fill the rest with random keypoints - (all_points, all_scores, all_descs, - pl_associativity) = [], [], [], [] + (all_points, all_scores, all_descs, pl_associativity) = [], [], [], [] for bs in range(b_size): - all_points.append(torch.cat( - [line_points[bs], pred['keypoints'][bs]], dim=0)) - all_scores.append(torch.cat( - [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0)) - all_descs.append(torch.cat( - [line_descs[bs], pred['descriptors'][bs]], dim=1)) - - associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device) - associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \ - line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]] + all_points.append( + torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0) + ) + all_scores.append( + torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0) + ) + all_descs.append( + torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=1) + ) + + associativity = torch.eye( + len(all_points[-1]), dtype=torch.bool, device=device + ) + associativity[ + : num_true_junctions[bs], : num_true_junctions[bs] + ] = line_association[bs][ + : num_true_junctions[bs], : num_true_junctions[bs] + ] pl_associativity.append(associativity) all_points = torch.stack(all_points, dim=0) @@ -219,38 +267,55 @@ class SPWireframeDescriptor(BaseModel): pl_associativity = torch.stack(pl_associativity, dim=0) else: # Lines are independent - all_points = torch.cat([lines.reshape(b_size, -1, 2), - pred['keypoints']], dim=1) + all_points = torch.cat( + [lines.reshape(b_size, -1, 2), pred["keypoints"]], dim=1 + ) n_pts = all_points.shape[1] num_lines = lines.shape[1] num_true_junctions = [num_lines * 2] * b_size - all_scores = torch.cat([ - torch.repeat_interleave(line_scores, 2, dim=1), - pred['keypoint_scores']], dim=1) - pred['line_descriptors'] = self.endpoints_pooling( - lines, pred['all_descriptors'], (h, w)) - all_descs = torch.cat([ - pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1), - pred['descriptors']], dim=2) - pl_associativity = torch.eye( - n_pts, dtype=torch.bool, - device=device)[None].repeat(b_size, 1, 1) - lines_junc_idx = torch.arange( - num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1) - - del pred['all_descriptors'] # Remove dense descriptors to save memory + all_scores = torch.cat( + [ + torch.repeat_interleave(line_scores, 2, dim=1), + pred["keypoint_scores"], + ], + dim=1, + ) + pred["line_descriptors"] = self.endpoints_pooling( + lines, pred["all_descriptors"], (h, w) + ) + all_descs = torch.cat( + [ + pred["line_descriptors"].reshape( + b_size, self.conf.sp_params.descriptor_dim, -1 + ), + pred["descriptors"], + ], + dim=2, + ) + pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[ + None + ].repeat(b_size, 1, 1) + lines_junc_idx = ( + torch.arange(num_lines * 2, device=device) + .reshape(1, -1, 2) + .repeat(b_size, 1, 1) + ) + + del pred["all_descriptors"] # Remove dense descriptors to save memory torch.cuda.empty_cache() - return {'keypoints': all_points, - 'keypoint_scores': all_scores, - 'descriptors': all_descs, - 'pl_associativity': pl_associativity, - 'num_junctions': torch.tensor(num_true_junctions), - 'lines': lines, - 'orig_lines': orig_lines, - 'lines_junc_idx': lines_junc_idx, - 'line_scores': line_scores, - 'valid_lines': valid_lines} + return { + "keypoints": all_points, + "keypoint_scores": all_scores, + "descriptors": all_descs, + "pl_associativity": pl_associativity, + "num_junctions": torch.tensor(num_true_junctions), + "lines": lines, + "orig_lines": orig_lines, + "lines_junc_idx": lines_junc_idx, + "line_scores": line_scores, + "valid_lines": valid_lines, + } @staticmethod def endpoints_pooling(segs, all_descriptors, img_shape): @@ -259,11 +324,21 @@ class SPWireframeDescriptor(BaseModel): scale_x = filter_shape[1] / img_shape[1] scale_y = filter_shape[0] / img_shape[0] - scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long() + scaled_segs = torch.round( + segs * torch.tensor([scale_x, scale_y]).to(segs) + ).long() scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) - line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])] - for b, b_segs in enumerate(scaled_segs)] + line_descriptors = [ + all_descriptors[ + None, + b, + ..., + torch.squeeze(b_segs[..., 1]), + torch.squeeze(b_segs[..., 0]), + ] + for b, b_segs in enumerate(scaled_segs) + ] line_descriptors = torch.cat(line_descriptors) return line_descriptors # Shape (1, 256, 308, 2) diff --git a/third_party/GlueStick/gluestick/run.py b/third_party/GlueStick/gluestick/run.py index 6baa88834f0b4dfde769ebe6c671e4ec49d4ed10..89569b878cca84fc48ef0b772f71b07befeb45a6 100644 --- a/third_party/GlueStick/gluestick/run.py +++ b/third_party/GlueStick/gluestick/run.py @@ -7,49 +7,58 @@ import torch from matplotlib import pyplot as plt from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT -from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches +from .drawing import ( + plot_images, + plot_lines, + plot_color_line_matches, + plot_keypoints, + plot_matches, +) from .models.two_view_pipeline import TwoViewPipeline def main(): # Parse input parameters parser = argparse.ArgumentParser( - prog='GlueStick Demo', - description='Demo app to show the point and line matches obtained by GlueStick') - parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg')) - parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg')) - parser.add_argument('--max_pts', type=int, default=1000) - parser.add_argument('--max_lines', type=int, default=300) - parser.add_argument('--skip-imshow', default=False, action='store_true') + prog="GlueStick Demo", + description="Demo app to show the point and line matches obtained by GlueStick", + ) + parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg")) + parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg")) + parser.add_argument("--max_pts", type=int, default=1000) + parser.add_argument("--max_lines", type=int, default=300) + parser.add_argument("--skip-imshow", default=False, action="store_true") args = parser.parse_args() # Evaluation config conf = { - 'name': 'two_view_pipeline', - 'use_lines': True, - 'extractor': { - 'name': 'wireframe', - 'sp_params': { - 'force_num_keypoints': False, - 'max_num_keypoints': args.max_pts, + "name": "two_view_pipeline", + "use_lines": True, + "extractor": { + "name": "wireframe", + "sp_params": { + "force_num_keypoints": False, + "max_num_keypoints": args.max_pts, }, - 'wireframe_params': { - 'merge_points': True, - 'merge_line_endpoints': True, + "wireframe_params": { + "merge_points": True, + "merge_line_endpoints": True, }, - 'max_n_lines': args.max_lines, + "max_n_lines": args.max_lines, }, - 'matcher': { - 'name': 'gluestick', - 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), - 'trainable': False, + "matcher": { + "name": "gluestick", + "weights": str( + GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar" + ), + "trainable": False, + }, + "ground_truth": { + "from_pose_depth": False, }, - 'ground_truth': { - 'from_pose_depth': False, - } } - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" pipeline_model = TwoViewPipeline(conf).to(device).eval() @@ -57,8 +66,11 @@ def main(): gray1 = cv2.imread(args.img2, 0) torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) - torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] - x = {'image0': torch_gray0, 'image1': torch_gray1} + torch_gray0, torch_gray1 = ( + torch_gray0.to(device)[None], + torch_gray1.to(device)[None], + ) + x = {"image0": torch_gray0, "image1": torch_gray1} pred = pipeline_model(x) pred = batch_to_np(pred) @@ -79,29 +91,51 @@ def main(): matched_lines1 = line_seg1[match_indices] # Plot the matches - img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) - plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0) + img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor( + gray1, cv2.COLOR_GRAY2BGR + ) + plot_images( + [img0, img1], + ["Image 1 - detected lines", "Image 2 - detected lines"], + dpi=200, + pad=2.0, + ) plot_lines([line_seg0, line_seg1], ps=4, lw=2) - plt.gcf().canvas.manager.set_window_title('Detected Lines') - plt.savefig('detected_lines.png') - - plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0) - plot_keypoints([kp0, kp1], colors='c') - plt.gcf().canvas.manager.set_window_title('Detected Points') - plt.savefig('detected_points.png') - - plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0) + plt.gcf().canvas.manager.set_window_title("Detected Lines") + plt.savefig("detected_lines.png") + + plot_images( + [img0, img1], + ["Image 1 - detected points", "Image 2 - detected points"], + dpi=200, + pad=2.0, + ) + plot_keypoints([kp0, kp1], colors="c") + plt.gcf().canvas.manager.set_window_title("Detected Points") + plt.savefig("detected_points.png") + + plot_images( + [img0, img1], + ["Image 1 - line matches", "Image 2 - line matches"], + dpi=200, + pad=2.0, + ) plot_color_line_matches([matched_lines0, matched_lines1], lw=2) - plt.gcf().canvas.manager.set_window_title('Line Matches') - plt.savefig('line_matches.png') - - plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0) - plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0) - plt.gcf().canvas.manager.set_window_title('Point Matches') - plt.savefig('detected_points.png') + plt.gcf().canvas.manager.set_window_title("Line Matches") + plt.savefig("line_matches.png") + + plot_images( + [img0, img1], + ["Image 1 - point matches", "Image 2 - point matches"], + dpi=200, + pad=2.0, + ) + plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0) + plt.gcf().canvas.manager.set_window_title("Point Matches") + plt.savefig("detected_points.png") if not args.skip_imshow: plt.show() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/third_party/GlueStick/setup.py b/third_party/GlueStick/setup.py index f0caa063e99cf6d7784fe7d54af08dbb66811627..c1a9df947ac2b788597e3028226f8efbdcd21b94 100644 --- a/third_party/GlueStick/setup.py +++ b/third_party/GlueStick/setup.py @@ -1,3 +1,3 @@ from setuptools import setup -setup(name='gluestick', version="0.0", packages=['gluestick']) +setup(name="gluestick", version="0.0", packages=["gluestick"]) diff --git a/third_party/LightGlue/lightglue/__init__.py b/third_party/LightGlue/lightglue/__init__.py index 97ad123d1dd573770da8ce2c4025386b8c70e1a3..aed9fbee8abe8562a5821893e8a219e2f9a38171 100644 --- a/third_party/LightGlue/lightglue/__init__.py +++ b/third_party/LightGlue/lightglue/__init__.py @@ -1,4 +1,4 @@ from .lightglue import LightGlue from .superpoint import SuperPoint from .disk import DISK -from .utils import match_pair \ No newline at end of file +from .utils import match_pair diff --git a/third_party/LightGlue/lightglue/disk.py b/third_party/LightGlue/lightglue/disk.py index 0fd0dec1049299bb53861f359ef63b12578bc0dd..c3e6e63ba76a018709e3332cdf432d06f4cda081 100644 --- a/third_party/LightGlue/lightglue/disk.py +++ b/third_party/LightGlue/lightglue/disk.py @@ -7,21 +7,21 @@ from .utils import ImagePreprocessor class DISK(nn.Module): default_conf = { - 'weights': 'depth', - 'max_num_keypoints': None, - 'desc_dim': 128, - 'nms_window_size': 5, - 'detection_threshold': 0.0, - 'pad_if_not_divisible': True, + "weights": "depth", + "max_num_keypoints": None, + "desc_dim": 128, + "nms_window_size": 5, + "detection_threshold": 0.0, + "pad_if_not_divisible": True, } preprocess_conf = { **ImagePreprocessor.default_conf, - 'resize': 1024, - 'grayscale': False, + "resize": 1024, + "grayscale": False, } - required_data_keys = ['image'] + required_data_keys = ["image"] def __init__(self, **conf) -> None: super().__init__() @@ -30,16 +30,16 @@ class DISK(nn.Module): self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) def forward(self, data: dict) -> dict: - """ Compute keypoints, scores, descriptors for image """ + """Compute keypoints, scores, descriptors for image""" for key in self.required_data_keys: - assert key in data, f'Missing key {key} in data' - image = data['image'] + assert key in data, f"Missing key {key} in data" + image = data["image"] features = self.model( image, n=self.conf.max_num_keypoints, window_size=self.conf.nms_window_size, score_threshold=self.conf.detection_threshold, - pad_if_not_divisible=self.conf.pad_if_not_divisible + pad_if_not_divisible=self.conf.pad_if_not_divisible, ) keypoints = [f.keypoints for f in features] scores = [f.detection_scores for f in features] @@ -51,20 +51,19 @@ class DISK(nn.Module): descriptors = torch.stack(descriptors, 0) return { - 'keypoints': keypoints.to(image), - 'keypoint_scores': scores.to(image), - 'descriptors': descriptors.to(image), + "keypoints": keypoints.to(image), + "keypoint_scores": scores.to(image), + "descriptors": descriptors.to(image), } def extract(self, img: torch.Tensor, **conf) -> dict: - """ Perform extraction with online resizing""" + """Perform extraction with online resizing""" if img.dim() == 3: img = img[None] # add batch dim assert img.dim() == 4 and img.shape[0] == 1 shape = img.shape[-2:][::-1] - img, scales = ImagePreprocessor( - **{**self.preprocess_conf, **conf})(img) - feats = self.forward({'image': img}) - feats['image_size'] = torch.tensor(shape)[None].to(img).float() - feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5 + img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) + feats = self.forward({"image": img}) + feats["image_size"] = torch.tensor(shape)[None].to(img).float() + feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 return feats diff --git a/third_party/LightGlue/lightglue/lightglue.py b/third_party/LightGlue/lightglue/lightglue.py index 3dc872bdc902bb71f640ae8749c07240924c5540..4b20300bf9068267e7b4d334dc2d3e85114ddd3e 100644 --- a/third_party/LightGlue/lightglue/lightglue.py +++ b/third_party/LightGlue/lightglue/lightglue.py @@ -12,7 +12,7 @@ try: except ModuleNotFoundError: FlashCrossAttention = None -if FlashCrossAttention or hasattr(F, 'scaled_dot_product_attention'): +if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"): FLASH_AVAILABLE = True else: FLASH_AVAILABLE = False @@ -21,9 +21,7 @@ torch.backends.cudnn.deterministic = True @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) -def normalize_keypoints( - kpts: torch.Tensor, - size: torch.Tensor) -> torch.Tensor: +def normalize_keypoints(kpts: torch.Tensor, size: torch.Tensor) -> torch.Tensor: if isinstance(size, torch.Size): size = torch.tensor(size)[None] shift = size.float().to(kpts) / 2 @@ -38,22 +36,20 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) -def apply_cached_rotary_emb( - freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: +def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return (t * freqs[0]) + (rotate_half(t) * freqs[1]) class LearnableFourierPositionalEncoding(nn.Module): - def __init__(self, M: int, dim: int, F_dim: int = None, - gamma: float = 1.0) -> None: + def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: super().__init__() F_dim = F_dim if F_dim is not None else dim self.gamma = gamma self.Wr = nn.Linear(M, F_dim // 2, bias=False) - nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ encode position vector """ + """encode position vector""" projected = self.Wr(x) cosines, sines = torch.cos(projected), torch.sin(projected) emb = torch.stack([cosines, sines], 0).unsqueeze(-3) @@ -63,16 +59,14 @@ class LearnableFourierPositionalEncoding(nn.Module): class TokenConfidence(nn.Module): def __init__(self, dim: int) -> None: super().__init__() - self.token = nn.Sequential( - nn.Linear(dim, 1), - nn.Sigmoid() - ) + self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): - """ get confidence tokens """ + """get confidence tokens""" return ( self.token(desc0.detach().float()).squeeze(-1), - self.token(desc1.detach().float()).squeeze(-1)) + self.token(desc1.detach().float()).squeeze(-1), + ) class Attention(nn.Module): @@ -80,8 +74,8 @@ class Attention(nn.Module): super().__init__() if allow_flash and not FLASH_AVAILABLE: warnings.warn( - 'FlashAttention is not available. For optimal speed, ' - 'consider installing torch >= 2.0 or flash-attn.', + "FlashAttention is not available. For optimal speed, " + "consider installing torch >= 2.0 or flash-attn.", stacklevel=2, ) self.enable_flash = allow_flash and FLASH_AVAILABLE @@ -89,7 +83,7 @@ class Attention(nn.Module): self.flash_ = FlashCrossAttention() def forward(self, q, k, v) -> torch.Tensor: - if self.enable_flash and q.device.type == 'cuda': + if self.enable_flash and q.device.type == "cuda": if FlashCrossAttention: q, k, v = [x.transpose(-2, -3) for x in [q, k, v]] m = self.flash_(q.half(), torch.stack([k, v], 2).half()) @@ -98,35 +92,35 @@ class Attention(nn.Module): args = [x.half().contiguous() for x in [q, k, v]] with torch.backends.cuda.sdp_kernel(enable_flash=True): return F.scaled_dot_product_attention(*args).to(q.dtype) - elif hasattr(F, 'scaled_dot_product_attention'): + elif hasattr(F, "scaled_dot_product_attention"): args = [x.contiguous() for x in [q, k, v]] return F.scaled_dot_product_attention(*args).to(q.dtype) else: s = q.shape[-1] ** -0.5 - attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) - return torch.einsum('...ij,...jd->...id', attn, v) + attn = F.softmax(torch.einsum("...id,...jd->...ij", q, k) * s, -1) + return torch.einsum("...ij,...jd->...id", attn, v) class Transformer(nn.Module): - def __init__(self, embed_dim: int, num_heads: int, - flash: bool = False, bias: bool = True) -> None: + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads assert self.embed_dim % num_heads == 0 self.head_dim = self.embed_dim // num_heads - self.Wqkv = nn.Linear(embed_dim, 3*embed_dim, bias=bias) + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) self.inner_attn = Attention(flash) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.ffn = nn.Sequential( - nn.Linear(2*embed_dim, 2*embed_dim), - nn.LayerNorm(2*embed_dim, elementwise_affine=True), + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), nn.GELU(), - nn.Linear(2*embed_dim, embed_dim) + nn.Linear(2 * embed_dim, embed_dim), ) - def _forward(self, x: torch.Tensor, - encoding: Optional[torch.Tensor] = None): + def _forward(self, x: torch.Tensor, encoding: Optional[torch.Tensor] = None): qkv = self.Wqkv(x) qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] @@ -134,8 +128,7 @@ class Transformer(nn.Module): q = apply_cached_rotary_emb(encoding, q) k = apply_cached_rotary_emb(encoding, k) context = self.inner_attn(q, k, v) - message = self.out_proj( - context.transpose(1, 2).flatten(start_dim=-2)) + message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) return x + self.ffn(torch.cat([x, message], -1)) def forward(self, x0, x1, encoding0=None, encoding1=None): @@ -143,21 +136,22 @@ class Transformer(nn.Module): class CrossTransformer(nn.Module): - def __init__(self, embed_dim: int, num_heads: int, - flash: bool = False, bias: bool = True) -> None: + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: super().__init__() self.heads = num_heads dim_head = embed_dim // num_heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 inner_dim = dim_head * num_heads self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) self.ffn = nn.Sequential( - nn.Linear(2*embed_dim, 2*embed_dim), - nn.LayerNorm(2*embed_dim, elementwise_affine=True), + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), nn.GELU(), - nn.Linear(2*embed_dim, embed_dim) + nn.Linear(2 * embed_dim, embed_dim), ) if flash and FLASH_AVAILABLE: @@ -173,19 +167,19 @@ class CrossTransformer(nn.Module): v0, v1 = self.map_(self.to_v, x0, x1) qk0, qk1, v0, v1 = map( lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), - (qk0, qk1, v0, v1)) + (qk0, qk1, v0, v1), + ) if self.flash is not None: m0 = self.flash(qk0, qk1, v1) m1 = self.flash(qk1, qk0, v0) else: qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 - sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1) + sim = torch.einsum("b h i d, b h j d -> b h i j", qk0, qk1) attn01 = F.softmax(sim, dim=-1) attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) - m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1) - m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0) - m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), - m0, m1) + m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) + m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) m0, m1 = self.map_(self.to_out, m0, m1) x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) @@ -193,15 +187,15 @@ class CrossTransformer(nn.Module): def sigmoid_log_double_softmax( - sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor: - """ create the log assignment matrix from logits and similarity""" + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" b, m, n = sim.shape certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) scores0 = F.log_softmax(sim, 2) - scores1 = F.log_softmax( - sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) - scores = sim.new_full((b, m+1, n+1), 0) - scores[:, :m, :n] = (scores0 + scores1 + certainties) + scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = sim.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = scores0 + scores1 + certainties scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) return scores @@ -215,11 +209,11 @@ class MatchAssignment(nn.Module): self.final_proj = nn.Linear(dim, dim, bias=True) def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): - """ build assignment matrix from descriptors """ + """build assignment matrix from descriptors""" mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) _, _, d = mdesc0.shape - mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25 - sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 + sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) z0 = self.matchability(desc0) z1 = self.matchability(desc1) scores = sigmoid_log_double_softmax(sim, z0, z1) @@ -232,7 +226,7 @@ class MatchAssignment(nn.Module): def filter_matches(scores: torch.Tensor, th: float): - """ obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) m0, m1 = max0.indices, max1.indices mutual0 = torch.arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0) @@ -253,42 +247,39 @@ def filter_matches(scores: torch.Tensor, th: float): class LightGlue(nn.Module): default_conf = { - 'name': 'lightglue', # just for interfacing - 'input_dim': 256, # input descriptor dimension (autoselected from weights) - 'descriptor_dim': 256, - 'n_layers': 9, - 'num_heads': 4, - 'flash': True, # enable FlashAttention if available. - 'mp': False, # enable mixed precision - 'depth_confidence': 0.95, # early stopping, disable with -1 - 'width_confidence': 0.99, # point pruning, disable with -1 - 'filter_threshold': 0.1, # match threshold - 'weights': None, + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 256, + "n_layers": 9, + "num_heads": 4, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": 0.95, # early stopping, disable with -1 + "width_confidence": 0.99, # point pruning, disable with -1 + "filter_threshold": 0.1, # match threshold + "weights": None, } - required_data_keys = [ - 'image0', 'image1'] + required_data_keys = ["image0", "image1"] version = "v0.1_arxiv" url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" features = { - 'superpoint': ('superpoint_lightglue', 256), - 'disk': ('disk_lightglue', 128) + "superpoint": ("superpoint_lightglue", 256), + "disk": ("disk_lightglue", 128), } - def __init__(self, features='superpoint', **conf) -> None: + def __init__(self, features="superpoint", **conf) -> None: super().__init__() self.conf = {**self.default_conf, **conf} if features is not None: - assert (features in list(self.features.keys())) - self.conf['weights'], self.conf['input_dim'] = \ - self.features[features] + assert features in list(self.features.keys()) + self.conf["weights"], self.conf["input_dim"] = self.features[features] self.conf = conf = SimpleNamespace(**self.conf) if conf.input_dim != conf.descriptor_dim: - self.input_proj = nn.Linear( - conf.input_dim, conf.descriptor_dim, bias=True) + self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) else: self.input_proj = nn.Identity() @@ -297,26 +288,29 @@ class LightGlue(nn.Module): h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim self.self_attn = nn.ModuleList( - [Transformer(d, h, conf.flash) for _ in range(n)]) + [Transformer(d, h, conf.flash) for _ in range(n)] + ) self.cross_attn = nn.ModuleList( - [CrossTransformer(d, h, conf.flash) for _ in range(n)]) - self.log_assignment = nn.ModuleList( - [MatchAssignment(d) for _ in range(n)]) - self.token_confidence = nn.ModuleList([ - TokenConfidence(d) for _ in range(n-1)]) + [CrossTransformer(d, h, conf.flash) for _ in range(n)] + ) + self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.token_confidence = nn.ModuleList( + [TokenConfidence(d) for _ in range(n - 1)] + ) if features is not None: - fname = f'{conf.weights}_{self.version}.pth'.replace('.', '-') + fname = f"{conf.weights}_{self.version}.pth".replace(".", "-") state_dict = torch.hub.load_state_dict_from_url( - self.url.format(self.version, features), file_name=fname) + self.url.format(self.version, features), file_name=fname + ) self.load_state_dict(state_dict, strict=False) elif conf.weights is not None: path = Path(__file__).parent - path = path / 'weights/{}.pth'.format(self.conf.weights) - state_dict = torch.load(str(path), map_location='cpu') + path = path / "weights/{}.pth".format(self.conf.weights) + state_dict = torch.load(str(path), map_location="cpu") self.load_state_dict(state_dict, strict=False) - print('Loaded LightGlue model') + print("Loaded LightGlue model") def forward(self, data: dict) -> dict: """ @@ -339,27 +333,27 @@ class LightGlue(nn.Module): matching_scores1: [B x N] matches: List[[Si x 2]], scores: List[[Si]] """ - with torch.autocast(enabled=self.conf.mp, device_type='cuda'): + with torch.autocast(enabled=self.conf.mp, device_type="cuda"): return self._forward(data) def _forward(self, data: dict) -> dict: for key in self.required_data_keys: - assert key in data, f'Missing key {key} in data' - data0, data1 = data['image0'], data['image1'] - kpts0_, kpts1_ = data0['keypoints'], data1['keypoints'] + assert key in data, f"Missing key {key} in data" + data0, data1 = data["image0"], data["image1"] + kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"] b, m, _ = kpts0_.shape b, n, _ = kpts1_.shape - size0, size1 = data0.get('image_size'), data1.get('image_size') - size0 = size0 if size0 is not None else data0['image'].shape[-2:][::-1] - size1 = size1 if size1 is not None else data1['image'].shape[-2:][::-1] + size0, size1 = data0.get("image_size"), data1.get("image_size") + size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1] + size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1] kpts0 = normalize_keypoints(kpts0_, size=size0) kpts1 = normalize_keypoints(kpts1_, size=size1) assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) - desc0 = data0['descriptors'].detach() - desc1 = data1['descriptors'].detach() + desc0 = data0["descriptors"].detach() + desc1 = data1["descriptors"].detach() assert desc0.shape[-1] == self.conf.input_dim assert desc1.shape[-1] == self.conf.input_dim @@ -384,19 +378,18 @@ class LightGlue(nn.Module): token0, token1 = None, None for i in range(self.conf.n_layers): # self+cross attention - desc0, desc1 = self.self_attn[i]( - desc0, desc1, encoding0, encoding1) + desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1) desc0, desc1 = self.cross_attn[i](desc0, desc1) if i == self.conf.n_layers - 1: continue # no early stopping or adaptive width at last layer if dec > 0: # early stopping token0, token1 = self.token_confidence[i](desc0, desc1) - if self.stop(token0, token1, self.conf_th(i), dec, m+n): + if self.stop(token0, token1, self.conf_th(i), dec, m + n): break if wic > 0: # point pruning match0, match1 = self.log_assignment[i].scores(desc0, desc1) - mask0 = self.get_mask(token0, match0, self.conf_th(i), 1-wic) - mask1 = self.get_mask(token1, match1, self.conf_th(i), 1-wic) + mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic) + mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic) ind0, ind1 = ind0[mask0][None], ind1[mask1][None] desc0, desc1 = desc0[mask0][None], desc1[mask1][None] if desc0.shape[-2] == 0 or desc1.shape[-2] == 0: @@ -409,17 +402,16 @@ class LightGlue(nn.Module): if wic > 0: # scatter with indices after pruning scores_, _ = self.log_assignment[i](desc0, desc1) dt, dev = scores_.dtype, scores_.device - scores = torch.zeros(b, m+1, n+1, dtype=dt, device=dev) + scores = torch.zeros(b, m + 1, n + 1, dtype=dt, device=dev) scores[:, :-1, :-1] = -torch.inf scores[:, ind0[0], -1] = scores_[:, :-1, -1] scores[:, -1, ind1[0]] = scores_[:, -1, :-1] - x, y = torch.meshgrid(ind0[0], ind1[0], indexing='ij') + x, y = torch.meshgrid(ind0[0], ind1[0], indexing="ij") scores[:, x, y] = scores_[:, :-1, :-1] else: scores, _ = self.log_assignment[i](desc0, desc1) - m0, m1, mscores0, mscores1 = filter_matches( - scores, self.conf.filter_threshold) + m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) matches, mscores = [], [] for k in range(b): @@ -428,36 +420,48 @@ class LightGlue(nn.Module): mscores.append(mscores0[k][valid]) return { - 'log_assignment': scores, - 'matches0': m0, - 'matches1': m1, - 'matching_scores0': mscores0, - 'matching_scores1': mscores1, - 'stop': i+1, - 'prune0': prune0, - 'prune1': prune1, - 'matches': matches, - 'scores': mscores, + "log_assignment": scores, + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "prune0": prune0, + "prune1": prune1, + "matches": matches, + "scores": mscores, } def conf_th(self, i: int) -> float: - """ scaled confidence threshold """ - return np.clip( - 0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1) - - def get_mask(self, confidence: torch.Tensor, match: torch.Tensor, - conf_th: float, match_th: float) -> torch.Tensor: - """ mask points which should be removed """ + """scaled confidence threshold""" + return np.clip(0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1) + + def get_mask( + self, + confidence: torch.Tensor, + match: torch.Tensor, + conf_th: float, + match_th: float, + ) -> torch.Tensor: + """mask points which should be removed""" if conf_th and confidence is not None: - mask = torch.where(confidence > conf_th, match, - match.new_tensor(1.0)) > match_th + mask = ( + torch.where(confidence > conf_th, match, match.new_tensor(1.0)) + > match_th + ) else: mask = match > match_th return mask - def stop(self, token0: torch.Tensor, token1: torch.Tensor, - conf_th: float, inl_th: float, seql: int) -> torch.Tensor: - """ evaluate stopping condition""" + def stop( + self, + token0: torch.Tensor, + token1: torch.Tensor, + conf_th: float, + inl_th: float, + seql: int, + ) -> torch.Tensor: + """evaluate stopping condition""" tokens = torch.cat([token0, token1], -1) if conf_th: pos = 1.0 - (tokens < conf_th).float().sum() / seql diff --git a/third_party/LightGlue/lightglue/superpoint.py b/third_party/LightGlue/lightglue/superpoint.py index abe7539767c9b2fe788e376e872d2844386b1a4a..1b7ce40f698bda6b2aca34d4ee504bd725933005 100644 --- a/third_party/LightGlue/lightglue/superpoint.py +++ b/third_party/LightGlue/lightglue/superpoint.py @@ -48,12 +48,13 @@ from .utils import ImagePreprocessor def simple_nms(scores, nms_radius: int): - """ Fast Non-maximum suppression to remove nearby points """ - assert (nms_radius >= 0) + """Fast Non-maximum suppression to remove nearby points""" + assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool2d( - x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) @@ -73,17 +74,20 @@ def top_k_keypoints(keypoints, scores, k): def sample_descriptors(keypoints, descriptors, s: int = 8): - """ Interpolate descriptors at keypoint locations """ + """Interpolate descriptors at keypoint locations""" b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 - keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], - ).to(keypoints)[None] - keypoints = keypoints*2 - 1 # normalize to (-1, 1) - args = {'align_corners': True} if torch.__version__ >= '1.3' else {} + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} descriptors = torch.nn.functional.grid_sample( - descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) descriptors = torch.nn.functional.normalize( - descriptors.reshape(b, c, -1), p=2, dim=1) + descriptors.reshape(b, c, -1), p=2, dim=1 + ) return descriptors @@ -95,21 +99,22 @@ class SuperPoint(nn.Module): Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 """ + default_conf = { - 'descriptor_dim': 256, - 'nms_radius': 4, - 'max_num_keypoints': None, - 'detection_threshold': 0.0005, - 'remove_borders': 4, + "descriptor_dim": 256, + "nms_radius": 4, + "max_num_keypoints": None, + "detection_threshold": 0.0005, + "remove_borders": 4, } preprocess_conf = { **ImagePreprocessor.default_conf, - 'resize': 1024, - 'grayscale': True, + "resize": 1024, + "grayscale": True, } - required_data_keys = ['image'] + required_data_keys = ["image"] def __init__(self, **conf): super().__init__() @@ -133,26 +138,26 @@ class SuperPoint(nn.Module): self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, self.conf['descriptor_dim'], - kernel_size=1, stride=1, padding=0) + c5, self.conf["descriptor_dim"], kernel_size=1, stride=1, padding=0 + ) url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" self.load_state_dict(torch.hub.load_state_dict_from_url(url)) - mk = self.conf['max_num_keypoints'] + mk = self.conf["max_num_keypoints"] if mk is not None and mk <= 0: - raise ValueError('max_num_keypoints must be positive or None') + raise ValueError("max_num_keypoints must be positive or None") - print('Loaded SuperPoint model') + print("Loaded SuperPoint model") def forward(self, data: dict) -> dict: - """ Compute keypoints, scores, descriptors for image """ + """Compute keypoints, scores, descriptors for image""" for key in self.required_data_keys: - assert key in data, f'Missing key {key} in data' - image = data['image'] + assert key in data, f"Missing key {key} in data" + image = data["image"] if image.shape[1] == 3: # RGB scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) - image = (image*scale).sum(1, keepdim=True) + image = (image * scale).sum(1, keepdim=True) # Shared Encoder x = self.relu(self.conv1a(image)) x = self.relu(self.conv1b(x)) @@ -172,31 +177,37 @@ class SuperPoint(nn.Module): scores = torch.nn.functional.softmax(scores, 1)[:, :-1] b, _, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) - scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) - scores = simple_nms(scores, self.conf['nms_radius']) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.conf["nms_radius"]) # Discard keypoints near the image borders - if self.conf['remove_borders']: - pad = self.conf['remove_borders'] + if self.conf["remove_borders"]: + pad = self.conf["remove_borders"] scores[:, :pad] = -1 scores[:, :, :pad] = -1 scores[:, -pad:] = -1 scores[:, :, -pad:] = -1 # Extract keypoints - best_kp = torch.where(scores > self.conf['detection_threshold']) + best_kp = torch.where(scores > self.conf["detection_threshold"]) scores = scores[best_kp] # Separate into batches - keypoints = [torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] - for i in range(b)] + keypoints = [ + torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) + ] scores = [scores[best_kp[0] == i] for i in range(b)] # Keep the k keypoints with highest score - if self.conf['max_num_keypoints'] is not None: - keypoints, scores = list(zip(*[ - top_k_keypoints(k, s, self.conf['max_num_keypoints']) - for k, s in zip(keypoints, scores)])) + if self.conf["max_num_keypoints"] is not None: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, self.conf["max_num_keypoints"]) + for k, s in zip(keypoints, scores) + ] + ) + ) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] @@ -207,24 +218,25 @@ class SuperPoint(nn.Module): descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) # Extract descriptors - descriptors = [sample_descriptors(k[None], d[None], 8)[0] - for k, d in zip(keypoints, descriptors)] + descriptors = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors) + ] return { - 'keypoints': torch.stack(keypoints, 0), - 'keypoint_scores': torch.stack(scores, 0), - 'descriptors': torch.stack(descriptors, 0).transpose(-1, -2), + "keypoints": torch.stack(keypoints, 0), + "keypoint_scores": torch.stack(scores, 0), + "descriptors": torch.stack(descriptors, 0).transpose(-1, -2), } def extract(self, img: torch.Tensor, **conf) -> dict: - """ Perform extraction with online resizing""" + """Perform extraction with online resizing""" if img.dim() == 3: img = img[None] # add batch dim assert img.dim() == 4 and img.shape[0] == 1 shape = img.shape[-2:][::-1] - img, scales = ImagePreprocessor( - **{**self.preprocess_conf, **conf})(img) - feats = self.forward({'image': img}) - feats['image_size'] = torch.tensor(shape)[None].to(img).float() - feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5 + img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) + feats = self.forward({"image": img}) + feats["image_size"] = torch.tensor(shape)[None].to(img).float() + feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 return feats diff --git a/third_party/LightGlue/lightglue/utils.py b/third_party/LightGlue/lightglue/utils.py index 3e06184948948670db1425ed22f5cbb86061a332..e8d30803931aad89e16e9b543959f76fda87389e 100644 --- a/third_party/LightGlue/lightglue/utils.py +++ b/third_party/LightGlue/lightglue/utils.py @@ -10,12 +10,12 @@ from types import SimpleNamespace class ImagePreprocessor: default_conf = { - 'resize': None, # target edge length, None for no resizing - 'side': 'long', - 'interpolation': 'bilinear', - 'align_corners': None, - 'antialias': True, - 'grayscale': False, # convert rgb to grayscale + "resize": None, # target edge length, None for no resizing + "side": "long", + "interpolation": "bilinear", + "align_corners": None, + "antialias": True, + "grayscale": False, # convert rgb to grayscale } def __init__(self, **conf) -> None: @@ -28,9 +28,12 @@ class ImagePreprocessor: h, w = img.shape[-2:] if self.conf.resize is not None: img = kornia.geometry.transform.resize( - img, self.conf.resize, side=self.conf.side, + img, + self.conf.resize, + side=self.conf.side, antialias=self.conf.antialias, - align_corners=self.conf.align_corners) + align_corners=self.conf.align_corners, + ) scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) if self.conf.grayscale and img.shape[-3] == 3: img = kornia.color.rgb_to_grayscale(img) @@ -53,28 +56,31 @@ def map_tensor(input_, func: Callable): return input_ -def batch_to_device(batch: dict, device: str = 'cpu', - non_blocking: bool = True): +def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): """Move batch (dict) to device""" + def _func(tensor): return tensor.to(device=device, non_blocking=non_blocking).detach() + return map_tensor(batch, _func) def rbd(data: dict) -> dict: """Remove batch dimension from elements in data""" - return {k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v - for k, v in data.items()} + return { + k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v + for k, v in data.items() + } def read_image(path: Path, grayscale: bool = False) -> np.ndarray: """Read an image from path as RGB or grayscale""" if not Path(path).exists(): - raise FileNotFoundError(f'No image at path {path}.') + raise FileNotFoundError(f"No image at path {path}.") mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR image = cv2.imread(str(path), mode) if image is None: - raise IOError(f'Could not read image at {path}.') + raise IOError(f"Could not read image at {path}.") if not grayscale: image = image[..., ::-1] return image @@ -87,31 +93,35 @@ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: elif image.ndim == 2: image = image[None] # add channel axis else: - raise ValueError(f'Not an image: {image.shape}') - return torch.tensor(image / 255., dtype=torch.float) + raise ValueError(f"Not an image: {image.shape}") + return torch.tensor(image / 255.0, dtype=torch.float) -def resize_image(image: np.ndarray, size: Union[List[int], int], - fn: str = 'max', interp: Optional[str] = 'area', - ) -> np.ndarray: +def resize_image( + image: np.ndarray, + size: Union[List[int], int], + fn: str = "max", + interp: Optional[str] = "area", +) -> np.ndarray: """Resize an image to a fixed size, or according to max or min edge.""" h, w = image.shape[:2] - fn = {'max': max, 'min': min}[fn] + fn = {"max": max, "min": min}[fn] if isinstance(size, int): scale = size / fn(h, w) - h_new, w_new = int(round(h*scale)), int(round(w*scale)) + h_new, w_new = int(round(h * scale)), int(round(w * scale)) scale = (w_new / w, h_new / h) elif isinstance(size, (tuple, list)): h_new, w_new = size scale = (w_new / w, h_new / h) else: - raise ValueError(f'Incorrect new size: {size}') + raise ValueError(f"Incorrect new size: {size}") mode = { - 'linear': cv2.INTER_LINEAR, - 'cubic': cv2.INTER_CUBIC, - 'nearest': cv2.INTER_NEAREST, - 'area': cv2.INTER_AREA}[interp] + "linear": cv2.INTER_LINEAR, + "cubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + }[interp] return cv2.resize(image, (w_new, h_new), interpolation=mode), scale @@ -122,13 +132,18 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: return numpy_image_to_torch(image) -def match_pair(extractor, matcher, - image0: torch.Tensor, image1: torch.Tensor, - device: str = 'cpu', **preprocess): +def match_pair( + extractor, + matcher, + image0: torch.Tensor, + image1: torch.Tensor, + device: str = "cpu", + **preprocess, +): """Match a pair of images (image0, image1) with an extractor and matcher""" feats0 = extractor.extract(image0, **preprocess) feats1 = extractor.extract(image1, **preprocess) - matches01 = matcher({'image0': feats0, 'image1': feats1}) + matches01 = matcher({"image0": feats0, "image1": feats1}) data = [feats0, feats1, matches01] # remove batch dim and move to target device feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] diff --git a/third_party/LightGlue/lightglue/viz2d.py b/third_party/LightGlue/lightglue/viz2d.py index 3b8e65b45c8424a0a1747b6f81f6b1d5bb928471..4999a76fd0001b0b7570ba38639fcf0a30b0c915 100644 --- a/third_party/LightGlue/lightglue/viz2d.py +++ b/third_party/LightGlue/lightglue/viz2d.py @@ -14,33 +14,32 @@ import torch def cm_RdGn(x): """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" - x = np.clip(x, 0, 1)[..., None]*2 - c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]]) + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) return np.clip(c, 0, 1) def cm_BlRdGn(x_): """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" - x = np.clip(x_, 0, 1)[..., None]*2 - c = x*np.array([[0, 1., 0, 1.]]) + (2-x)*np.array([[1., 0, 0, 1.]]) + x = np.clip(x_, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) - xn = -np.clip(x_, -1, 0)[..., None]*2 - cn = xn*np.array([[0, 0.1, 1, 1.]]) + (2-xn)*np.array([[1., 0, 0, 1.]]) + xn = -np.clip(x_, -1, 0)[..., None] * 2 + cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) return out def cm_prune(x_): - """ Custom colormap to visualize pruning """ + """Custom colormap to visualize pruning""" if isinstance(x_, torch.Tensor): x_ = x_.cpu().numpy() max_i = max(x_) - norm_x = np.where(x_ == max_i, -1, (x_-1) / 9) + norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) return cm_BlRdGn(norm_x) -def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, - adaptive=True): +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): """Plot a set of images horizontally. Args: imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). @@ -49,9 +48,12 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, adaptive: whether the figure size should fit the image aspect ratios. """ # conversion to (H, W, 3) for torch.Tensor - imgs = [img.permute(1, 2, 0).cpu().numpy() - if (isinstance(img, torch.Tensor) and img.dim() == 3) else img - for img in imgs] + imgs = [ + img.permute(1, 2, 0).cpu().numpy() + if (isinstance(img, torch.Tensor) and img.dim() == 3) + else img + for img in imgs + ] n = len(imgs) if not isinstance(cmaps, (list, tuple)): @@ -60,10 +62,11 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, if adaptive: ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H else: - ratios = [4/3] * n - figsize = [sum(ratios)*4.5, 4.5] + ratios = [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] fig, ax = plt.subplots( - 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) if n == 1: ax = [ax] for i in range(n): @@ -78,7 +81,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, fig.tight_layout(pad=pad) -def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0): +def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): """Plot keypoints for existing images. Args: kpts: list of ndarrays of size (N, 2). @@ -97,8 +100,7 @@ def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0): ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) -def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None, - axes=None): +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): """Plot matches for a pair of existing images. Args: kpts0, kpts1: corresponding keypoints of size (N, 2). @@ -127,12 +129,20 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None, if lw > 0: for i in range(len(kpts0)): line = matplotlib.patches.ConnectionPatch( - xyA=(kpts0[i, 0], kpts0[i, 1]), xyB=(kpts1[i, 0], kpts1[i, 1]), - coordsA=ax0.transData, coordsB=ax1.transData, - axesA=ax0, axesB=ax1, - zorder=1, color=color[i], linewidth=lw, clip_on=True, - alpha=a, label=None if labels is None else labels[i], - picker=5.0) + xyA=(kpts0[i, 0], kpts0[i, 1]), + xyB=(kpts1[i, 0], kpts1[i, 1]), + coordsA=ax0.transData, + coordsB=ax1.transData, + axesA=ax0, + axesB=ax1, + zorder=1, + color=color[i], + linewidth=lw, + clip_on=True, + alpha=a, + label=None if labels is None else labels[i], + picker=5.0, + ) line.set_annotation_clip(True) fig.add_artist(line) @@ -145,17 +155,30 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None, ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) -def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w', - lcolor='k', lwidth=2, ha='left', va='top'): +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=2, + ha="left", + va="top", +): ax = plt.gcf().axes[idx] - t = ax.text(*pos, text, fontsize=fs, ha=ha, va=va, - color=color, transform=ax.transAxes) + t = ax.text( + *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes + ) if lcolor is not None: - t.set_path_effects([ - path_effects.Stroke(linewidth=lwidth, foreground=lcolor), - path_effects.Normal()]) + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) def save_plot(path, **kw): """Save the current figure without any white margin.""" - plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw) + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) diff --git a/third_party/LightGlue/setup.py b/third_party/LightGlue/setup.py index fc349143002bf0860762a0341ceed47667759a10..2b012e92a208d09e4983317c4eb3c1d8093177e8 100644 --- a/third_party/LightGlue/setup.py +++ b/third_party/LightGlue/setup.py @@ -1,24 +1,24 @@ from pathlib import Path from setuptools import setup -description = ['LightGlue'] +description = ["LightGlue"] -with open(str(Path(__file__).parent / 'README.md'), 'r', encoding='utf-8') as f: +with open(str(Path(__file__).parent / "README.md"), "r", encoding="utf-8") as f: readme = f.read() -with open(str(Path(__file__).parent / 'requirements.txt'), 'r') as f: - dependencies = f.read().split('\n') +with open(str(Path(__file__).parent / "requirements.txt"), "r") as f: + dependencies = f.read().split("\n") setup( - name='lightglue', - version='0.0', - packages=['lightglue'], - python_requires='>=3.6', + name="lightglue", + version="0.0", + packages=["lightglue"], + python_requires=">=3.6", install_requires=dependencies, - author='Philipp Lindenberger, Paul-Edouard Sarlin', + author="Philipp Lindenberger, Paul-Edouard Sarlin", description=description, long_description=readme, long_description_content_type="text/markdown", - url='https://github.com/cvg/LightGlue/', + url="https://github.com/cvg/LightGlue/", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", diff --git a/third_party/Roma/demo/demo_fundamental.py b/third_party/Roma/demo/demo_fundamental.py index 31618d4b06cd56fdd4be9065fb00b826a19e10f9..a71fd5532412fb4c65eb109e8e9f83813c11fd85 100644 --- a/third_party/Roma/demo/demo_fundamental.py +++ b/third_party/Roma/demo/demo_fundamental.py @@ -3,11 +3,12 @@ import torch import cv2 from roma import roma_outdoor -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == "__main__": from argparse import ArgumentParser + parser = ArgumentParser() parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) @@ -19,7 +20,6 @@ if __name__ == "__main__": # Create model roma_model = roma_outdoor(device=device) - W_A, H_A = Image.open(im1_path).size W_B, H_B = Image.open(im2_path).size @@ -27,7 +27,12 @@ if __name__ == "__main__": warp, certainty = roma_model.match(im1_path, im2_path, device=device) # Sample matches for estimation matches, certainty = roma_model.sample(warp, certainty) - kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) F, mask = cv2.findFundamentalMat( - kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 - ) \ No newline at end of file + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + ransacReprojThreshold=0.2, + method=cv2.USAC_MAGSAC, + confidence=0.999999, + maxIters=10000, + ) diff --git a/third_party/Roma/demo/demo_match.py b/third_party/Roma/demo/demo_match.py index 46413bb2b336e2ef2c0bc48315821e4de0fcb982..69eb07ffb0b480db99252bbb03a9858964e8d5f0 100644 --- a/third_party/Roma/demo/demo_match.py +++ b/third_party/Roma/demo/demo_match.py @@ -6,15 +6,18 @@ from roma.utils.utils import tensor_to_pil from roma import roma_indoor -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == "__main__": from argparse import ArgumentParser + parser = ArgumentParser() parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) - parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str) + parser.add_argument( + "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str + ) args, _ = parser.parse_known_args() im1_path = args.im_A_path @@ -36,12 +39,12 @@ if __name__ == "__main__": x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) im2_transfer_rgb = F.grid_sample( - x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False )[0] im1_transfer_rgb = F.grid_sample( - x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False )[0] - warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) - white_im = torch.ones((H,2*W),device=device) + warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2) + white_im = torch.ones((H, 2 * W), device=device) vis_im = certainty * warp_im + (1 - certainty) * white_im - tensor_to_pil(vis_im, unnormalize=False).save(save_path) \ No newline at end of file + tensor_to_pil(vis_im, unnormalize=False).save(save_path) diff --git a/third_party/Roma/roma/__init__.py b/third_party/Roma/roma/__init__.py index a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9..a3c12d5247b93a83882edfb45bd127db794e791f 100644 --- a/third_party/Roma/roma/__init__.py +++ b/third_party/Roma/roma/__init__.py @@ -2,7 +2,7 @@ import os from .models import roma_outdoor, roma_indoor DEBUG_MODE = False -RANK = int(os.environ.get('RANK', default = 0)) +RANK = int(os.environ.get("RANK", default=0)) GLOBAL_STEP = 0 STEP_SIZE = 1 -LOCAL_RANK = -1 \ No newline at end of file +LOCAL_RANK = -1 diff --git a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py index 2154a471c73d9e883c3ba8ed1b90d708f4950a63..6417d4d54798360a027a0d11d50fc65cdfae015a 100644 --- a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py +++ b/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py @@ -53,7 +53,7 @@ class HpatchesHomogBenchmark: ) return im_A_coords, im_A_to_im_B - def benchmark(self, model, model_name = None): + def benchmark(self, model, model_name=None): n_matches = [] homog_dists = [] for seq_idx, seq_name in tqdm( @@ -69,9 +69,7 @@ class HpatchesHomogBenchmark: H = np.loadtxt( os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) ) - dense_matches, dense_certainty = model.match( - im_A_path, im_B_path - ) + dense_matches, dense_certainty = model.match(im_A_path, im_B_path) good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) pos_a, pos_b = self.convert_coordinates( good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 @@ -80,9 +78,9 @@ class HpatchesHomogBenchmark: H_pred, inliers = cv2.findHomography( pos_a, pos_b, - method = cv2.RANSAC, - confidence = 0.99999, - ransacReprojThreshold = 3 * min(w2, h2) / 480, + method=cv2.RANSAC, + confidence=0.99999, + ransacReprojThreshold=3 * min(w2, h2) / 480, ) except: H_pred = None diff --git a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py index 0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9..f51a77e15510572b8f594dbc7713a0f348a33fd8 100644 --- a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py +++ b/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py @@ -6,8 +6,11 @@ from roma.utils import warp_kpts from torch.utils.data import ConcatDataset import roma + class MegadepthDenseBenchmark: - def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None: + def __init__( + self, data_root="data/megadepth", h=384, w=512, num_samples=2000 + ) -> None: mega = MegadepthBuilder(data_root=data_root) self.dataset = ConcatDataset( mega.build_scenes(split="test_loftr", ht=h, wt=w) @@ -49,13 +52,15 @@ class MegadepthDenseBenchmark: pck_3_tot = 0.0 pck_5_tot = 0.0 sampler = torch.utils.data.WeightedRandomSampler( - torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + torch.ones(len(self.dataset)), + replacement=False, + num_samples=self.num_samples, ) B = batch_size dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler ) - for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0): + for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0): im_A, im_B, depth1, depth2, T_1to2, K1, K2 = ( data["im_A"], data["im_B"], @@ -72,25 +77,36 @@ class MegadepthDenseBenchmark: if roma.DEBUG_MODE: from roma.utils.utils import tensor_to_pil import torch.nn.functional as F + path = "vis" H, W = model.get_output_resolution() - white_im = torch.ones((B,1,H,W),device="cuda") + white_im = torch.ones((B, 1, H, W), device="cuda") im_B_transfer_rgb = F.grid_sample( - im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False + im_B.cuda(), + matches[:, :, :W, 2:], + mode="bilinear", + align_corners=False, ) warp_im = im_B_transfer_rgb - c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None] + c_b = certainty[ + :, None + ] # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None] vis_im = c_b * warp_im + (1 - c_b) * white_im for b in range(B): import os - os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True) + + os.makedirs( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True + ) tensor_to_pil(vis_im[b], unnormalize=True).save( - f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg") + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg" + ) tensor_to_pil(im_A[b].cuda(), unnormalize=True).save( - f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg") + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg" + ) tensor_to_pil(im_B[b].cuda(), unnormalize=True).save( - f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg") - + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg" + ) gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( gd_tot + gd.mean(), diff --git a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py index 8007fe8ecad09c33401450ad6b7af1f3dad043d2..5d936a07d550763d0378a23ea83c79cec5d373fe 100644 --- a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py +++ b/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py @@ -7,8 +7,9 @@ import torch.nn.functional as F import roma import kornia.geometry.epipolar as kepi + class MegaDepthPoseEstimationBenchmark: - def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + def __init__(self, data_root="data/megadepth", scene_names=None) -> None: if scene_names is None: self.scene_names = [ "0015_0.1_0.3.npz", @@ -25,14 +26,22 @@ class MegaDepthPoseEstimationBenchmark: ] self.data_root = data_root - def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): - H,W = model.get_output_resolution() + def benchmark( + self, + model, + model_name=None, + resolution=None, + scale_intrinsics=True, + calibrated=True, + ): + H, W = model.get_output_resolution() with torch.no_grad(): data_root = self.data_root tot_e_t, tot_e_R, tot_e_pose = [], [], [] thresholds = [5, 10, 20] for scene_ind in range(len(self.scenes)): import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] scene = self.scenes[scene_ind] pairs = scene["pair_infos"] @@ -49,16 +58,16 @@ class MegaDepthPoseEstimationBenchmark: T2 = poses[idx2].copy() R2, t2 = T2[:3, :3], T2[:3, 3] R, t = compute_relative_pose(R1, t1, R2, t2) - T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + T1_to_2 = np.concatenate((R, t[:, None]), axis=-1) im_A_path = f"{data_root}/{im_paths[idx1]}" im_B_path = f"{data_root}/{im_paths[idx2]}" dense_matches, dense_certainty = model.match( im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy() ) - sparse_matches,_ = model.sample( + sparse_matches, _ = model.sample( dense_matches, dense_certainty, 5000 ) - + im_A = Image.open(im_A_path) w1, h1 = im_A.size im_B = Image.open(im_B_path) @@ -74,24 +83,20 @@ class MegaDepthPoseEstimationBenchmark: K2[:2] = K2[:2] * scale2 kpts1 = sparse_matches[:, :2] - kpts1 = ( - np.stack( - ( - w1 * (kpts1[:, 0] + 1) / 2, - h1 * (kpts1[:, 1] + 1) / 2, - ), - axis=-1, - ) + kpts1 = np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, ) kpts2 = sparse_matches[:, 2:] - kpts2 = ( - np.stack( - ( - w2 * (kpts2[:, 0] + 1) / 2, - h2 * (kpts2[:, 1] + 1) / 2, - ), - axis=-1, - ) + kpts2 = np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, ) for _ in range(5): @@ -99,9 +104,12 @@ class MegaDepthPoseEstimationBenchmark: kpts1 = kpts1[shuffling] kpts2 = kpts2[shuffling] try: - threshold = 0.5 + threshold = 0.5 if calibrated: - norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + norm_threshold = threshold / ( + np.mean(np.abs(K1[:2, :2])) + + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1, kpts2, diff --git a/third_party/Roma/roma/benchmarks/scannet_benchmark.py b/third_party/Roma/roma/benchmarks/scannet_benchmark.py index 853af0d0ebef4dfefe2632eb49e4156ea791ee76..3187c2acf79f5af8f64397f55f6df40af327945b 100644 --- a/third_party/Roma/roma/benchmarks/scannet_benchmark.py +++ b/third_party/Roma/roma/benchmarks/scannet_benchmark.py @@ -10,7 +10,7 @@ class ScanNetBenchmark: def __init__(self, data_root="data/scannet") -> None: self.data_root = data_root - def benchmark(self, model, model_name = None): + def benchmark(self, model, model_name=None): model.train(False) with torch.no_grad(): data_root = self.data_root @@ -24,20 +24,20 @@ class ScanNetBenchmark: scene = pairs[pairind] scene_name = f"scene0{scene[0]}_00" im_A_path = osp.join( - self.data_root, - "scans_test", - scene_name, - "color", - f"{scene[2]}.jpg", - ) + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) im_A = Image.open(im_A_path) im_B_path = osp.join( - self.data_root, - "scans_test", - scene_name, - "color", - f"{scene[3]}.jpg", - ) + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) im_B = Image.open(im_B_path) T_gt = rel_pose[pairind].reshape(3, 4) R, t = T_gt[:3, :3], T_gt[:3, 3] @@ -76,24 +76,20 @@ class ScanNetBenchmark: offset = 0.5 kpts1 = sparse_matches[:, :2] - kpts1 = ( - np.stack( - ( - w1 * (kpts1[:, 0] + 1) / 2 - offset, - h1 * (kpts1[:, 1] + 1) / 2 - offset, - ), - axis=-1, - ) + kpts1 = np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, ) kpts2 = sparse_matches[:, 2:] - kpts2 = ( - np.stack( - ( - w2 * (kpts2[:, 0] + 1) / 2 - offset, - h2 * (kpts2[:, 1] + 1) / 2 - offset, - ), - axis=-1, - ) + kpts2 = np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, ) for _ in range(5): shuffling = np.random.permutation(np.arange(len(kpts1))) @@ -101,7 +97,8 @@ class ScanNetBenchmark: kpts2 = kpts2[shuffling] try: norm_threshold = 0.5 / ( - np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])) + ) R_est, t_est, mask = estimate_pose( kpts1, kpts2, diff --git a/third_party/Roma/roma/checkpointing/checkpoint.py b/third_party/Roma/roma/checkpointing/checkpoint.py index 8995efeb54f4d558127ea63423fa958c64e9088f..6372d89fe86c00c7acedf015886717bfeca7bb1f 100644 --- a/third_party/Roma/roma/checkpointing/checkpoint.py +++ b/third_party/Roma/roma/checkpointing/checkpoint.py @@ -7,6 +7,7 @@ import gc import roma + class CheckPoint: def __init__(self, dir=None, name="tmp"): self.name = name @@ -19,7 +20,7 @@ class CheckPoint: optimizer, lr_scheduler, n, - ): + ): if roma.RANK == 0: assert model is not None if isinstance(model, (DataParallel, DistributedDataParallel)): @@ -32,14 +33,14 @@ class CheckPoint: } torch.save(states, self.dir + self.name + f"_latest.pth") logger.info(f"Saved states {list(states.keys())}, at step {n}") - + def load( self, model, optimizer, lr_scheduler, n, - ): + ): if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0: states = torch.load(self.dir + self.name + f"_latest.pth") if "model" in states: @@ -57,4 +58,4 @@ class CheckPoint: del states gc.collect() torch.cuda.empty_cache() - return model, optimizer, lr_scheduler, n \ No newline at end of file + return model, optimizer, lr_scheduler, n diff --git a/third_party/Roma/roma/datasets/__init__.py b/third_party/Roma/roma/datasets/__init__.py index b60c709926a4a7bd019b73eac10879063a996c90..6a11f122e222f0a9eded4afd3dd0b900826063e8 100644 --- a/third_party/Roma/roma/datasets/__init__.py +++ b/third_party/Roma/roma/datasets/__init__.py @@ -1,2 +1,2 @@ from .megadepth import MegadepthBuilder -from .scannet import ScanNetBuilder \ No newline at end of file +from .scannet import ScanNetBuilder diff --git a/third_party/Roma/roma/datasets/megadepth.py b/third_party/Roma/roma/datasets/megadepth.py index 5deee5ac30c439a9f300c0ad2271f141931020c0..75cb72ded02c80d1ad6bce0d0269626ee49a9275 100644 --- a/third_party/Roma/roma/datasets/megadepth.py +++ b/third_party/Roma/roma/datasets/megadepth.py @@ -10,6 +10,7 @@ import roma from roma.utils import * import math + class MegadepthScene: def __init__( self, @@ -22,18 +23,20 @@ class MegadepthScene: shake_t=0, rot_prob=0.0, normalize=True, - max_num_pairs = 100_000, - scene_name = None, - use_horizontal_flip_aug = False, - use_single_horizontal_flip_aug = False, - colorjiggle_params = None, - random_eraser = None, - use_randaug = False, - randaug_params = None, - randomize_size = False, + max_num_pairs=100_000, + scene_name=None, + use_horizontal_flip_aug=False, + use_single_horizontal_flip_aug=False, + colorjiggle_params=None, + random_eraser=None, + use_randaug=False, + randaug_params=None, + randomize_size=False, ) -> None: self.data_root = data_root - self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.scene_name = ( + os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}" + ) self.image_paths = scene_info["image_paths"] self.depth_paths = scene_info["depth_paths"] self.intrinsics = scene_info["intrinsics"] @@ -51,18 +54,18 @@ class MegadepthScene: self.overlaps = self.overlaps[pairinds] if randomize_size: area = ht * wt - s = int(16 * (math.sqrt(area)//16)) - sizes = ((ht,wt), (s,s), (wt,ht)) + s = int(16 * (math.sqrt(area) // 16)) + sizes = ((ht, wt), (s, s), (wt, ht)) choice = roma.RANK % 3 - ht, wt = sizes[choice] + ht, wt = sizes[choice] # counts, bins = np.histogram(self.overlaps,20) # print(counts) self.im_transform_ops = get_tuple_transform_ops( - resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params, + resize=(ht, wt), + normalize=normalize, + colorjiggle_params=colorjiggle_params, ) - self.depth_transform_ops = get_depth_tuple_transform_ops( - resize=(ht, wt) - ) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt)) self.wt, self.ht = wt, ht self.shake_t = shake_t self.random_eraser = random_eraser @@ -75,17 +78,19 @@ class MegadepthScene: def load_im(self, im_path): im = Image.open(im_path) return im - - def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): im_A = im_A.flip(-1) im_B = im_B.flip(-1) - depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) - flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) - K_A = flip_mat@K_A - K_B = flip_mat@K_B - + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to( + K_A.device + ) + K_A = flip_mat @ K_A + K_B = flip_mat @ K_B + return im_A, im_B, depth_A, depth_B, K_A, K_B - + def load_depth(self, depth_ref, crop=None): depth = np.array(h5py.File(depth_ref, "r")["depth"]) return torch.from_numpy(depth) @@ -140,29 +145,31 @@ class MegadepthScene: depth_A, depth_B = self.depth_transform_ops( (depth_A[None, None], depth_B[None, None]) ) - - [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B) + + [im_A, im_B, depth_A, depth_B], t = self.rand_shake( + im_A, im_B, depth_A, depth_B + ) K1[:2, 2] += t K2[:2, 2] += t - + im_A, im_B = im_A[None], im_B[None] if self.random_eraser is not None: im_A, depth_A = self.random_eraser(im_A, depth_A) im_B, depth_B = self.random_eraser(im_B, depth_B) - + if self.use_horizontal_flip_aug: if np.random.rand() > 0.5: - im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip( + im_A, im_B, depth_A, depth_B, K1, K2 + ) if self.use_single_horizontal_flip_aug: if np.random.rand() > 0.5: im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2) - + if roma.DEBUG_MODE: - tensor_to_pil(im_A[0], unnormalize=True).save( - f"vis/im_A.jpg") - tensor_to_pil(im_B[0], unnormalize=True).save( - f"vis/im_B.jpg") - + tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg") + data_dict = { "im_A": im_A[0], "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], @@ -175,25 +182,53 @@ class MegadepthScene: "T_1to2": T_1to2, "im_A_path": im_A_ref, "im_B_path": im_B_ref, - } return data_dict class MegadepthBuilder: - def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + def __init__( + self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True + ) -> None: self.data_root = data_root self.scene_info_root = os.path.join(data_root, "prep_scene_info") self.all_scenes = os.listdir(self.scene_info_root) self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those - self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) - self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.loftr_ignore_scenes = set( + [ + "0121.npy", + "0133.npy", + "0168.npy", + "0178.npy", + "0229.npy", + "0349.npy", + "0412.npy", + "0430.npy", + "0443.npy", + "1001.npy", + "5014.npy", + "5015.npy", + "5016.npy", + ] + ) + self.imc21_scenes = set( + [ + "0008.npy", + "0019.npy", + "0021.npy", + "0024.npy", + "0025.npy", + "0032.npy", + "0063.npy", + "1589.npy", + ] + ) self.test_scenes_loftr = ["0015.npy", "0022.npy"] self.loftr_ignore = loftr_ignore self.imc21_ignore = imc21_ignore - def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs): if split == "train": scene_names = set(self.all_scenes) - set(self.test_scenes) elif split == "train_loftr": @@ -217,7 +252,11 @@ class MegadepthBuilder: ).item() scenes.append( MegadepthScene( - self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs + self.data_root, + scene_info, + min_overlap=min_overlap, + scene_name=scene_name, + **kwargs, ) ) return scenes diff --git a/third_party/Roma/roma/datasets/scannet.py b/third_party/Roma/roma/datasets/scannet.py index 704ea57259afdfbbca627ad143bee97a0a79d41c..91bea57c9d1ae2773c11a9c8d47f31026a2c227b 100644 --- a/third_party/Roma/roma/datasets/scannet.py +++ b/third_party/Roma/roma/datasets/scannet.py @@ -5,10 +5,7 @@ import cv2 import h5py import numpy as np import torch -from torch.utils.data import ( - Dataset, - DataLoader, - ConcatDataset) +from torch.utils.data import Dataset, DataLoader, ConcatDataset import torchvision.transforms.functional as tvf import kornia.augmentation as K @@ -19,22 +16,36 @@ from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops from roma.utils.transforms import GeometricSequential from tqdm import tqdm + class ScanNetScene: - def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False, -) -> None: - self.scene_root = osp.join(data_root,"scans","scans_train") - self.data_names = scene_info['name'] - self.overlaps = scene_info['score'] + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + shake_t=0, + rot_prob=0.0, + use_horizontal_flip_aug=False, + ) -> None: + self.scene_root = osp.join(data_root, "scans", "scans_train") + self.data_names = scene_info["name"] + self.overlaps = scene_info["score"] # Only sample 10s - valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0 self.overlaps = self.overlaps[valid] self.data_names = self.data_names[valid] if len(self.data_names) > 10000: - pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + pairinds = np.random.choice( + np.arange(0, len(self.data_names)), 10000, replace=False + ) self.data_names = self.data_names[pairinds] self.overlaps = self.overlaps[pairinds] self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) - self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt), normalize=False + ) self.wt, self.ht = wt, ht self.shake_t = shake_t self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) @@ -43,7 +54,7 @@ class ScanNetScene: def load_im(self, im_B, crop=None): im = Image.open(im_B) return im - + def load_depth(self, depth_ref, crop=None): depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) depth = depth / 1000 @@ -52,64 +63,73 @@ class ScanNetScene: def __len__(self): return len(self.data_names) - + def scale_intrinsic(self, K, wi, hi): - sx, sy = self.wt / wi, self.ht / hi - sK = torch.tensor([[sx, 0, 0], - [0, sy, 0], - [0, 0, 1]]) - return sK@K + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K - def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): im_A = im_A.flip(-1) im_B = im_B.flip(-1) - depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) - flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) - K_A = flip_mat@K_A - K_B = flip_mat@K_B - + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to( + K_A.device + ) + K_A = flip_mat @ K_A + K_B = flip_mat @ K_B + return im_A, im_B, depth_A, depth_B, K_A, K_B - def read_scannet_pose(self,path): - """ Read ScanNet's Camera2World pose and transform it to World2Camera. - + + def read_scannet_pose(self, path): + """Read ScanNet's Camera2World pose and transform it to World2Camera. + Returns: pose_w2c (np.ndarray): (4, 4) """ - cam2world = np.loadtxt(path, delimiter=' ') + cam2world = np.loadtxt(path, delimiter=" ") world2cam = np.linalg.inv(cam2world) return world2cam - - def read_scannet_intrinsic(self,path): - """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. - """ - intrinsic = np.loadtxt(path, delimiter=' ') - return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float) + def read_scannet_intrinsic(self, path): + """Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" + intrinsic = np.loadtxt(path, delimiter=" ") + return torch.tensor(intrinsic[:-1, :-1], dtype=torch.float) def __getitem__(self, pair_idx): # read intrinsics of original size data_name = self.data_names[pair_idx] scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name - scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' - + scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" + # read the intrinsic of depthmap - K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, - scene_name, - 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + K1 = K2 = self.read_scannet_intrinsic( + osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt") + ) # the depth K is not the same, but doesnt really matter # read and compute relative poses - T1 = self.read_scannet_pose(osp.join(self.scene_root, - scene_name, - 'pose', f'{stem_name_1}.txt')) - T2 = self.read_scannet_pose(osp.join(self.scene_root, - scene_name, - 'pose', f'{stem_name_2}.txt')) - T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + T1 = self.read_scannet_pose( + osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt") + ) + T2 = self.read_scannet_pose( + osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt") + ) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) # Load positive pair data - im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') - im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') - depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') - depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + im_A_ref = os.path.join( + self.scene_root, scene_name, "color", f"{stem_name_1}.jpg" + ) + im_B_ref = os.path.join( + self.scene_root, scene_name, "color", f"{stem_name_2}.jpg" + ) + depth_A_ref = os.path.join( + self.scene_root, scene_name, "depth", f"{stem_name_1}.png" + ) + depth_B_ref = os.path.join( + self.scene_root, scene_name, "depth", f"{stem_name_2}.png" + ) im_A = self.load_im(im_A_ref) im_B = self.load_im(im_B_ref) @@ -121,40 +141,51 @@ class ScanNetScene: K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) # Process images im_A, im_B = self.im_transform_ops((im_A, im_B)) - depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None])) + depth_A, depth_B = self.depth_transform_ops( + (depth_A[None, None], depth_B[None, None]) + ) if self.use_horizontal_flip_aug: if np.random.rand() > 0.5: - im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) - - data_dict = {'im_A': im_A, - 'im_B': im_B, - 'im_A_depth': depth_A[0,0], - 'im_B_depth': depth_B[0,0], - 'K1': K1, - 'K2': K2, - 'T_1to2':T_1to2, - } + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip( + im_A, im_B, depth_A, depth_B, K1, K2 + ) + + data_dict = { + "im_A": im_A, + "im_B": im_B, + "im_A_depth": depth_A[0, 0], + "im_B_depth": depth_B[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + } return data_dict class ScanNetBuilder: - def __init__(self, data_root = 'data/scannet') -> None: + def __init__(self, data_root="data/scannet") -> None: self.data_root = data_root - self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.scene_info_root = os.path.join(data_root, "scannet_indices") self.all_scenes = os.listdir(self.scene_info_root) - - def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + + def build_scenes(self, split="train", min_overlap=0.0, **kwargs): # Note: split doesn't matter here as we always use same scannet_train scenes scene_names = self.all_scenes scenes = [] - for scene_name in tqdm(scene_names, disable = roma.RANK > 0): - scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) - scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + for scene_name in tqdm(scene_names, disable=roma.RANK > 0): + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ) + scenes.append( + ScanNetScene( + self.data_root, scene_info, min_overlap=min_overlap, **kwargs + ) + ) return scenes - - def weight_scenes(self, concat_dataset, alpha=.5): + + def weight_scenes(self, concat_dataset, alpha=0.5): ns = [] for d in concat_dataset.datasets: ns.append(len(d)) - ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) return ws diff --git a/third_party/Roma/roma/losses/__init__.py b/third_party/Roma/roma/losses/__init__.py index 2e08abacfc0f83d7de0f2ddc0583766a80bf53cf..12cb6d40b90ca3ccf712321f78c033401db865fb 100644 --- a/third_party/Roma/roma/losses/__init__.py +++ b/third_party/Roma/roma/losses/__init__.py @@ -1 +1 @@ -from .robust_loss import RobustLosses \ No newline at end of file +from .robust_loss import RobustLosses diff --git a/third_party/Roma/roma/losses/robust_loss.py b/third_party/Roma/roma/losses/robust_loss.py index b932b2706f619c083485e1be0d86eec44ead83ef..cd9fd5bbc9c2d01bb6dd40823e350b588bd598b3 100644 --- a/third_party/Roma/roma/losses/robust_loss.py +++ b/third_party/Roma/roma/losses/robust_loss.py @@ -7,6 +7,7 @@ import wandb import roma import math + class RobustLosses(nn.Module): def __init__( self, @@ -17,12 +18,12 @@ class RobustLosses(nn.Module): local_loss=True, local_dist=4.0, local_largest_scale=8, - smooth_mask = False, - depth_interpolation_mode = "bilinear", - mask_depth_loss = False, - relative_depth_error_threshold = 0.05, - alpha = 1., - c = 1e-3, + smooth_mask=False, + depth_interpolation_mode="bilinear", + mask_depth_loss=False, + relative_depth_error_threshold=0.05, + alpha=1.0, + c=1e-3, ): super().__init__() self.robust = robust # measured in pixels @@ -45,68 +46,103 @@ class RobustLosses(nn.Module): B, C, H, W = scale_gm_cls.shape device = x2.device cls_res = round(math.sqrt(C)) - G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) - G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) - GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices - cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] + G = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device + ) + for _ in range(2) + ] + ) + G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) + GT = ( + (G[None, :, None, None, :] - x2[:, None]) + .norm(dim=-1) + .min(dim=1) + .indices + ) + cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99] if not torch.any(cls_loss): - cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere - certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) + certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob) losses = { f"gm_certainty_loss_{scale}": certainty_loss.mean(), f"gm_cls_loss_{scale}": cls_loss.mean(), } - wandb.log(losses, step = roma.GLOBAL_STEP) + wandb.log(losses, step=roma.GLOBAL_STEP) return losses - def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): + def delta_cls_loss( + self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale + ): with torch.no_grad(): B, C, H, W = delta_cls.shape device = x2.device cls_res = round(math.sqrt(C)) - G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) - G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale - GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices - cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] + G = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device + ) + for _ in range(2) + ] + ) + G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale + GT = ( + (G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None]) + .norm(dim=-1) + .min(dim=1) + .indices + ) + cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99] if not torch.any(cls_loss): - cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere - certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) + cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere + certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) losses = { f"delta_certainty_loss_{scale}": certainty_loss.mean(), f"delta_cls_loss_{scale}": cls_loss.mean(), } - wandb.log(losses, step = roma.GLOBAL_STEP) + wandb.log(losses, step=roma.GLOBAL_STEP) return losses - def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): - epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) + def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"): + epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1) if scale == 1: - pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() - wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP) + pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean() + wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP) ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) a = self.alpha cs = self.c * scale x = epe[prob > 0.99] - reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) + reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2) if not torch.any(reg_loss): - reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere + reg_loss = ce_loss * 0.0 # Prevent issues where prob is 0 everywhere losses = { f"{mode}_certainty_loss_{scale}": ce_loss.mean(), f"{mode}_regression_loss_{scale}": reg_loss.mean(), } - wandb.log(losses, step = roma.GLOBAL_STEP) + wandb.log(losses, step=roma.GLOBAL_STEP) return losses def forward(self, corresps, batch): scales = list(corresps.keys()) tot_loss = 0.0 # scale_weights due to differences in scale for regression gradients and classification gradients - scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} + scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1} for scale in scales: scale_corresps = corresps[scale] - scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( + ( + scale_certainty, + flow_pre_delta, + delta_cls, + offset_scale, + scale_gm_cls, + scale_gm_certainty, + flow, + scale_gm_flow, + ) = ( scale_corresps["certainty"], scale_corresps["flow_pre_delta"], scale_corresps.get("delta_cls"), @@ -115,43 +151,72 @@ class RobustLosses(nn.Module): scale_corresps.get("gm_certainty"), scale_corresps["flow"], scale_corresps.get("gm_flow"), - ) flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") b, h, w, d = flow_pre_delta.shape - gt_warp, gt_prob = get_gt_warp( - batch["im_A_depth"], - batch["im_B_depth"], - batch["T_1to2"], - batch["K1"], - batch["K2"], - H=h, - W=w, - ) + gt_warp, gt_prob = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=h, + W=w, + ) x2 = gt_warp.float() prob = gt_prob - + if self.local_largest_scale >= scale: prob = prob * ( - F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] - < (2 / 512) * (self.local_dist[scale] * scale)) - + F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[ + :, 0 + ] + < (2 / 512) * (self.local_dist[scale] * scale) + ) + if scale_gm_cls is not None: - gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) - gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] + gm_cls_losses = self.gm_cls_loss( + x2, prob, scale_gm_cls, scale_gm_certainty, scale + ) + gm_loss = ( + self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + + gm_cls_losses[f"gm_cls_loss_{scale}"] + ) tot_loss = tot_loss + scale_weights[scale] * gm_loss elif scale_gm_flow is not None: - gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") - gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] + gm_flow_losses = self.regression_loss( + x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm" + ) + gm_loss = ( + self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + + gm_flow_losses[f"gm_regression_loss_{scale}"] + ) tot_loss = tot_loss + scale_weights[scale] * gm_loss - + if delta_cls is not None: - delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) - delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] + delta_cls_losses = self.delta_cls_loss( + x2, + prob, + flow_pre_delta, + delta_cls, + scale_certainty, + scale, + offset_scale, + ) + delta_cls_loss = ( + self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + + delta_cls_losses[f"delta_cls_loss_{scale}"] + ) tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss else: - delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) - reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] + delta_regression_losses = self.regression_loss( + x2, prob, flow, scale_certainty, scale + ) + reg_loss = ( + self.ce_weight + * delta_regression_losses[f"delta_certainty_loss_{scale}"] + + delta_regression_losses[f"delta_regression_loss_{scale}"] + ) tot_loss = tot_loss + scale_weights[scale] * reg_loss - prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() + prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach() return tot_loss diff --git a/third_party/Roma/roma/models/__init__.py b/third_party/Roma/roma/models/__init__.py index 5f20461e2f3a1722e558cefab94c5164be8842c3..3918d67063b9ab7a8ced80c22a5e74f95ff7fd4a 100644 --- a/third_party/Roma/roma/models/__init__.py +++ b/third_party/Roma/roma/models/__init__.py @@ -1 +1 @@ -from .model_zoo import roma_outdoor, roma_indoor \ No newline at end of file +from .model_zoo import roma_outdoor, roma_indoor diff --git a/third_party/Roma/roma/models/encoders.py b/third_party/Roma/roma/models/encoders.py index 69b488743b91905aca6adc3e4d3439421d492051..923a56d7ca30d73884ac5f313d44614998540dc3 100644 --- a/third_party/Roma/roma/models/encoders.py +++ b/third_party/Roma/roma/models/encoders.py @@ -8,35 +8,52 @@ import gc class ResNet50(nn.Module): - def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None: + def __init__( + self, + pretrained=False, + high_res=False, + weights=None, + dilation=None, + freeze_bn=True, + anti_aliased=False, + early_exit=False, + amp=False, + ) -> None: super().__init__() if dilation is None: - dilation = [False,False,False] + dilation = [False, False, False] if anti_aliased: pass else: if weights is not None: - self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + self.net = tvm.resnet50( + weights=weights, replace_stride_with_dilation=dilation + ) else: - self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) - + self.net = tvm.resnet50( + pretrained=pretrained, replace_stride_with_dilation=dilation + ) + self.high_res = high_res self.freeze_bn = freeze_bn self.early_exit = early_exit self.amp = amp - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 def forward(self, x, **kwargs): - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): net = self.net - feats = {1:x} + feats = {1: x} x = net.conv1(x) x = net.bn1(x) x = net.relu(x) - feats[2] = x + feats[2] = x x = net.maxpool(x) x = net.layer1(x) - feats[4] = x + feats[4] = x x = net.layer2(x) feats[8] = x if self.early_exit: @@ -55,35 +72,45 @@ class ResNet50(nn.Module): m.eval() pass + class VGG19(nn.Module): - def __init__(self, pretrained=False, amp = False) -> None: + def __init__(self, pretrained=False, amp=False) -> None: super().__init__() self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) self.amp = amp - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 def forward(self, x, **kwargs): - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): feats = {} scale = 1 for layer in self.layers: if isinstance(layer, nn.MaxPool2d): feats[scale] = x - scale = scale*2 + scale = scale * 2 x = layer(x) return feats + class CNNandDinov2(nn.Module): - def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None): + def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None): super().__init__() if dinov2_weights is None: - dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu") + dinov2_weights = torch.hub.load_state_dict_from_url( + "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", + map_location="cpu", + ) from .transformer import vit_large - vit_kwargs = dict(img_size= 518, - patch_size= 14, - init_values = 1.0, - ffn_layer = "mlp", - block_chunks = 0, + + vit_kwargs = dict( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp", + block_chunks=0, ) dinov2_vitl14 = vit_large(**vit_kwargs).eval() @@ -94,25 +121,35 @@ class CNNandDinov2(nn.Module): else: self.cnn = VGG19(**cnn_kwargs) self.amp = amp - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 if self.amp: dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) - self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP - - + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + def train(self, mode: bool = True): return self.cnn.train(mode) - - def forward(self, x, upsample = False): - B,C,H,W = x.shape + + def forward(self, x, upsample=False): + B, C, H, W = x.shape feature_pyramid = self.cnn(x) - + if not upsample: with torch.no_grad(): if self.dinov2_vitl14[0].device != x.device: - self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) - dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) - features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + self.dinov2_vitl14[0] = ( + self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + ) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features( + x.to(self.amp_dtype) + ) + features_16 = ( + dinov2_features_16["x_norm_patchtokens"] + .permute(0, 2, 1) + .reshape(B, 1024, H // 14, W // 14) + ) del dinov2_features_16 feature_pyramid[16] = features_16 - return feature_pyramid \ No newline at end of file + return feature_pyramid diff --git a/third_party/Roma/roma/models/matcher.py b/third_party/Roma/roma/models/matcher.py index c06e1ba3aebe8dec7ee9f1800a6f4ba55ac8f0d9..3e1cee16b586ef1ff5f18e74b203d20aa1f16b1c 100644 --- a/third_party/Roma/roma/models/matcher.py +++ b/third_party/Roma/roma/models/matcher.py @@ -14,6 +14,7 @@ from roma.utils.local_correlation import local_correlation from roma.utils.utils import cls_to_flow_refine from roma.utils.kde import kde + class ConvRefiner(nn.Module): def __init__( self, @@ -23,25 +24,29 @@ class ConvRefiner(nn.Module): dw=False, kernel_size=5, hidden_blocks=3, - displacement_emb = None, - displacement_emb_dim = None, - local_corr_radius = None, - corr_in_other = None, - no_im_B_fm = False, - amp = False, - concat_logits = False, - use_bias_block_1 = True, - use_cosine_corr = False, - disable_local_corr_grad = False, - is_classifier = False, - sample_mode = "bilinear", - norm_type = nn.BatchNorm2d, - bn_momentum = 0.1, + displacement_emb=None, + displacement_emb_dim=None, + local_corr_radius=None, + corr_in_other=None, + no_im_B_fm=False, + amp=False, + concat_logits=False, + use_bias_block_1=True, + use_cosine_corr=False, + disable_local_corr_grad=False, + is_classifier=False, + sample_mode="bilinear", + norm_type=nn.BatchNorm2d, + bn_momentum=0.1, ): super().__init__() self.bn_momentum = bn_momentum self.block1 = self.create_block( - in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1, + in_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + bias=use_bias_block_1, ) self.hidden_blocks = nn.Sequential( *[ @@ -59,7 +64,7 @@ class ConvRefiner(nn.Module): self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) if displacement_emb: self.has_displacement_emb = True - self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) else: self.has_displacement_emb = False self.local_corr_radius = local_corr_radius @@ -71,16 +76,19 @@ class ConvRefiner(nn.Module): self.disable_local_corr_grad = disable_local_corr_grad self.is_classifier = is_classifier self.sample_mode = sample_mode - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 + def create_block( self, in_dim, out_dim, dw=False, kernel_size=5, - bias = True, - norm_type = nn.BatchNorm2d, + bias=True, + norm_type=nn.BatchNorm2d, ): num_groups = 1 if not dw else in_dim if dw: @@ -96,38 +104,56 @@ class ConvRefiner(nn.Module): groups=num_groups, bias=bias, ) - norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + norm = ( + norm_type(out_dim, momentum=self.bn_momentum) + if norm_type is nn.BatchNorm2d + else norm_type(num_channels=out_dim) + ) relu = nn.ReLU(inplace=True) conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) return nn.Sequential(conv1, norm, relu, conv2) - - def forward(self, x, y, flow, scale_factor = 1, logits = None): - b,c,hs,ws = x.shape - with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + + def forward(self, x, y, flow, scale_factor=1, logits=None): + b, c, hs, ws = x.shape + with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): with torch.no_grad(): - x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode) + x_hat = F.grid_sample( + y, + flow.permute(0, 2, 3, 1), + align_corners=False, + mode=self.sample_mode, + ) if self.has_displacement_emb: im_A_coords = torch.meshgrid( - ( - torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"), - torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"), - ) + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"), + ) ) im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) - in_displacement = flow-im_A_coords - emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement) + in_displacement = flow - im_A_coords + emb_in_displacement = self.disp_emb( + 40 / 32 * scale_factor * in_displacement + ) if self.local_corr_radius: if self.corr_in_other: # Corr in other means take a kxk grid around the predicted coordinate in other image - local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, - sample_mode = self.sample_mode) + local_corr = local_correlation( + x, + y, + local_radius=self.local_corr_radius, + flow=flow, + sample_mode=self.sample_mode, + ) else: - raise NotImplementedError("Local corr in own frame should not be used.") + raise NotImplementedError( + "Local corr in own frame should not be used." + ) if self.no_im_B_fm: x_hat = torch.zeros_like(x) d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) - else: + else: d = torch.cat((x, x_hat, emb_in_displacement), dim=1) else: if self.no_im_B_fm: @@ -141,6 +167,7 @@ class ConvRefiner(nn.Module): displacement, certainty = d[:, :-1], d[:, -1:] return displacement, certainty + class CosKernel(nn.Module): # similar to softmax kernel def __init__(self, T, learn_temperature=False): super().__init__() @@ -161,6 +188,7 @@ class CosKernel(nn.Module): # similar to softmax kernel K = ((c - 1.0) / T).exp() return K + class GP(nn.Module): def __init__( self, @@ -174,7 +202,7 @@ class GP(nn.Module): only_nearest_neighbour=False, sigma_noise=0.1, no_cov=False, - predict_features = False, + predict_features=False, ): super().__init__() self.K = kernel(T=T, learn_temperature=learn_temperature) @@ -262,7 +290,9 @@ class GP(nn.Module): mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) if not self.no_cov: cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) - cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + cov_x = rearrange( + cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1 + ) local_cov_x = self.get_local_cov(cov_x) local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") gp_feats = torch.cat((mu_x, local_cov_x), dim=1) @@ -270,11 +300,22 @@ class GP(nn.Module): gp_feats = mu_x return gp_feats + class Decoder(nn.Module): def __init__( - self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None, - num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0, - flow_upsample_mode = "bilinear" + self, + embedding_decoder, + gps, + proj, + conv_refiner, + detach=False, + scales="all", + pos_embeddings=None, + num_refinement_steps_per_scale=1, + warp_noise_std=0.0, + displacement_dropout_p=0.0, + gm_warp_dropout_p=0.0, + flow_upsample_mode="bilinear", ): super().__init__() self.embedding_decoder = embedding_decoder @@ -296,8 +337,11 @@ class Decoder(nn.Module): self.displacement_dropout_p = displacement_dropout_p self.gm_warp_dropout_p = gm_warp_dropout_p self.flow_upsample_mode = flow_upsample_mode - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 + def get_placeholder_flow(self, b, h, w, device): coarse_coords = torch.meshgrid( ( @@ -310,8 +354,8 @@ class Decoder(nn.Module): ].expand(b, h, w, 2) coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") return coarse_coords - - def get_positional_embedding(self, b, h ,w, device): + + def get_positional_embedding(self, b, h, w, device): coarse_coords = torch.meshgrid( ( torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), @@ -326,16 +370,29 @@ class Decoder(nn.Module): coarse_embedded_coords = self.pos_embedding(coarse_coords) return coarse_embedded_coords - def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1): + def forward( + self, + f1, + f2, + gt_warp=None, + gt_prob=None, + upsample=False, + flow=None, + certainty=None, + scale_factor=1, + ): coarse_scales = self.embedding_decoder.scales() - all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] sizes = {scale: f1[scale].shape[-2:] for scale in f1} h, w = sizes[1] b = f1[1].shape[0] device = f1[1].device coarsest_scale = int(all_scales[0]) old_stuff = torch.zeros( - b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + b, + self.embedding_decoder.hidden_dim, + *sizes[coarsest_scale], + device=f1[coarsest_scale].device, ) corresps = {} if not upsample: @@ -343,17 +400,17 @@ class Decoder(nn.Module): certainty = 0.0 else: flow = F.interpolate( - flow, - size=sizes[coarsest_scale], - align_corners=False, - mode="bilinear", - ) + flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) certainty = F.interpolate( - certainty, - size=sizes[coarsest_scale], - align_corners=False, - mode="bilinear", - ) + certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) displacement = 0.0 for new_scale in all_scales: ins = int(new_scale) @@ -371,32 +428,59 @@ class Decoder(nn.Module): gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( gp_posterior, f1_s, old_stuff, new_scale ) - + if self.embedding_decoder.is_classifier: flow = cls_to_flow_refine( gm_warp_or_cls, - ).permute(0,3,1,2) - corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + ).permute(0, 3, 1, 2) + corresps[ins].update( + { + "gm_cls": gm_warp_or_cls, + "gm_certainty": certainty, + } + ) if self.training else None else: - corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + corresps[ins].update( + { + "gm_flow": gm_warp_or_cls, + "gm_certainty": certainty, + } + ) if self.training else None flow = gm_warp_or_cls.detach() - + if new_scale in self.conv_refiner: - corresps[ins].update({"flow_pre_delta": flow}) if self.training else None + corresps[ins].update( + {"flow_pre_delta": flow} + ) if self.training else None delta_flow, delta_certainty = self.conv_refiner[new_scale]( - f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty, - ) - corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None - displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w), - delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,) + f1_s, + f2_s, + flow, + scale_factor=scale_factor, + logits=certainty, + ) + corresps[ins].update( + { + "delta_flow": delta_flow, + } + ) if self.training else None + displacement = ins * torch.stack( + ( + delta_flow[:, 0].float() / (self.refine_init * w), + delta_flow[:, 1].float() / (self.refine_init * h), + ), + dim=1, + ) flow = flow + displacement certainty = ( certainty + delta_certainty ) # predict both certainty and displacement - corresps[ins].update({ - "certainty": certainty, - "flow": flow, - }) + corresps[ins].update( + { + "certainty": certainty, + "flow": flow, + } + ) if new_scale != "1": flow = F.interpolate( flow, @@ -411,7 +495,7 @@ class Decoder(nn.Module): if self.detach: flow = flow.detach() certainty = certainty.detach() - #torch.cuda.empty_cache() + # torch.cuda.empty_cache() return corresps @@ -422,11 +506,11 @@ class RegressionMatcher(nn.Module): decoder, h=448, w=448, - sample_mode = "threshold", - upsample_preds = False, - symmetric = False, - name = None, - attenuate_cert = None, + sample_mode="threshold", + upsample_preds=False, + symmetric=False, + name=None, + attenuate_cert=None, ): super().__init__() self.attenuate_cert = attenuate_cert @@ -438,24 +522,26 @@ class RegressionMatcher(nn.Module): self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) self.sample_mode = sample_mode self.upsample_preds = upsample_preds - self.upsample_res = (14*16*6, 14*16*6) + self.upsample_res = (14 * 16 * 6, 14 * 16 * 6) self.symmetric = symmetric self.sample_thresh = 0.05 - + def get_output_resolution(self): if not self.upsample_preds: return self.h_resized, self.w_resized else: return self.upsample_res - - def extract_backbone_features(self, batch, batched = True, upsample = False): + + def extract_backbone_features(self, batch, batched=True, upsample=False): x_q = batch["im_A"] x_s = batch["im_B"] if batched: - X = torch.cat((x_q, x_s), dim = 0) - feature_pyramid = self.encoder(X, upsample = upsample) + X = torch.cat((x_q, x_s), dim=0) + feature_pyramid = self.encoder(X, upsample=upsample) else: - feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample) + feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder( + x_s, upsample=upsample + ) return feature_pyramid def sample( @@ -473,22 +559,28 @@ class RegressionMatcher(nn.Module): certainty.reshape(-1), ) expansion_factor = 4 if "balanced" in self.sample_mode else 1 - good_samples = torch.multinomial(certainty, - num_samples = min(expansion_factor*num, len(certainty)), - replacement=False) + good_samples = torch.multinomial( + certainty, + num_samples=min(expansion_factor * num, len(certainty)), + replacement=False, + ) good_matches, good_certainty = matches[good_samples], certainty[good_samples] if "balanced" not in self.sample_mode: return good_matches, good_certainty density = kde(good_matches, std=0.1) - p = 1 / (density+1) - p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones - balanced_samples = torch.multinomial(p, - num_samples = min(num,len(good_certainty)), - replacement=False) + p = 1 / (density + 1) + p[ + density < 10 + ] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial( + p, num_samples=min(num, len(good_certainty)), replacement=False + ) return good_matches[balanced_samples], good_certainty[balanced_samples] - def forward(self, batch, batched = True, upsample = False, scale_factor = 1): - feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample) + def forward(self, batch, batched=True, upsample=False, scale_factor=1): + feature_pyramid = self.extract_backbone_features( + batch, batched=batched, upsample=upsample + ) if batched: f_q_pyramid = { scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() @@ -498,32 +590,42 @@ class RegressionMatcher(nn.Module): } else: f_q_pyramid, f_s_pyramid = feature_pyramid - corresps = self.decoder(f_q_pyramid, - f_s_pyramid, - upsample = upsample, - **(batch["corresps"] if "corresps" in batch else {}), - scale_factor=scale_factor) - + corresps = self.decoder( + f_q_pyramid, + f_s_pyramid, + upsample=upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor, + ) + return corresps - def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1): - feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample) + def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1): + feature_pyramid = self.extract_backbone_features( + batch, batched=batched, upsample=upsample + ) f_q_pyramid = feature_pyramid f_s_pyramid = { - scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0) + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0) for scale, f_scale in feature_pyramid.items() } - corresps = self.decoder(f_q_pyramid, - f_s_pyramid, - upsample = upsample, - **(batch["corresps"] if "corresps" in batch else {}), - scale_factor=scale_factor) + corresps = self.decoder( + f_q_pyramid, + f_s_pyramid, + upsample=upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor, + ) return corresps - + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): - kpts_A, kpts_B = matches[...,:2], matches[...,2:] - kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) - kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + kpts_A, kpts_B = matches[..., :2], matches[..., 2:] + kpts_A = torch.stack( + (W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1 + ) + kpts_B = torch.stack( + (W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1 + ) return kpts_A, kpts_B def match( @@ -532,11 +634,12 @@ class RegressionMatcher(nn.Module): im_B_path, *args, batched=False, - device = None, + device=None, ): if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from PIL import Image + if isinstance(im_A_path, (str, os.PathLike)): im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) else: @@ -552,9 +655,9 @@ class RegressionMatcher(nn.Module): # Get images in good format ws = self.w_resized hs = self.h_resized - + test_transform = get_tuple_transform_ops( - resize=(hs, ws), normalize=True, clahe = False + resize=(hs, ws), normalize=True, clahe=False ) im_A, im_B = test_transform((im_A, im_B)) batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} @@ -564,25 +667,32 @@ class RegressionMatcher(nn.Module): assert w == w2 and h == h2, "For batched images we assume same size" batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} if h != self.h_resized or self.w_resized != w: - warn("Model resolution and batch resolution differ, may produce unexpected results") + warn( + "Model resolution and batch resolution differ, may produce unexpected results" + ) hs, ws = h, w finest_scale = 1 # Run matcher if symmetric: - corresps = self.forward_symmetric(batch) + corresps = self.forward_symmetric(batch) else: - corresps = self.forward(batch, batched = True) + corresps = self.forward(batch, batched=True) if self.upsample_preds: hs, ws = self.upsample_res - + if self.attenuate_cert: low_res_certainty = F.interpolate( - corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + corresps[16]["certainty"], + size=(hs, ws), + align_corners=False, + mode="bilinear", ) cert_clamp = 0 factor = 0.5 - low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + low_res_certainty = ( + factor * low_res_certainty * (low_res_certainty < cert_clamp) + ) if self.upsample_preds: finest_corresps = corresps[finest_scale] @@ -593,25 +703,33 @@ class RegressionMatcher(nn.Module): im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) im_A, im_B = test_transform((im_A, im_B)) im_A, im_B = im_A[None].to(device), im_B[None].to(device) - scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) + scale_factor = math.sqrt( + self.upsample_res[0] + * self.upsample_res[1] + / (self.w_resized * self.h_resized) + ) batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} if symmetric: - corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor) + corresps = self.forward_symmetric( + batch, upsample=True, batched=True, scale_factor=scale_factor + ) else: - corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor) - - im_A_to_im_B = corresps[finest_scale]["flow"] - certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) + corresps = self.forward( + batch, batched=True, upsample=True, scale_factor=scale_factor + ) + + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] - ( + low_res_certainty if self.attenuate_cert else 0 + ) if finest_scale != 1: im_A_to_im_B = F.interpolate( - im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" ) certainty = F.interpolate( - certainty, size=(hs, ws), align_corners=False, mode="bilinear" - ) - im_A_to_im_B = im_A_to_im_B.permute( - 0, 2, 3, 1 + certainty, size=(hs, ws), align_corners=False, mode="bilinear" ) + im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1) # Create im_A meshgrid im_A_coords = torch.meshgrid( ( @@ -625,25 +743,21 @@ class RegressionMatcher(nn.Module): im_A_coords = im_A_coords.permute(0, 2, 3, 1) if (im_A_to_im_B.abs() > 1).any() and True: wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 - certainty[wrong[:,None]] = 0 + certainty[wrong[:, None]] = 0 im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) if symmetric: A_to_B, B_to_A = im_A_to_im_B.chunk(2) q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) im_B_coords = im_A_coords s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) - warp = torch.cat((q_warp, s_warp),dim=2) + warp = torch.cat((q_warp, s_warp), dim=2) certainty = torch.cat(certainty.chunk(2), dim=3) else: warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) if batched: - return ( - warp, - certainty[:, 0] - ) + return (warp, certainty[:, 0]) else: return ( warp[0], certainty[0, 0], ) - diff --git a/third_party/Roma/roma/models/model_zoo/__init__.py b/third_party/Roma/roma/models/model_zoo/__init__.py index 91edd4e69f2b39f18d62545a95f2774324ff404b..2ef0b6cf03473500d4198521764cd6dc9ccba784 100644 --- a/third_party/Roma/roma/models/model_zoo/__init__.py +++ b/third_party/Roma/roma/models/model_zoo/__init__.py @@ -6,25 +6,41 @@ weight_urls = { "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", }, - "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D + "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", # hopefully this doesnt change :D } + def roma_outdoor(device, weights=None, dinov2_weights=None): if weights is None: - weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"], - map_location=device) + weights = torch.hub.load_state_dict_from_url( + weight_urls["roma"]["outdoor"], map_location=device + ) if dinov2_weights is None: - dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], - map_location=device) - return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True, - weights=weights,dinov2_weights = dinov2_weights,device=device) + dinov2_weights = torch.hub.load_state_dict_from_url( + weight_urls["dinov2"], map_location=device + ) + return roma_model( + resolution=(14 * 8 * 6, 14 * 8 * 6), + upsample_preds=True, + weights=weights, + dinov2_weights=dinov2_weights, + device=device, + ) + def roma_indoor(device, weights=None, dinov2_weights=None): if weights is None: - weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"], - map_location=device) + weights = torch.hub.load_state_dict_from_url( + weight_urls["roma"]["indoor"], map_location=device + ) if dinov2_weights is None: - dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], - map_location=device) - return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False, - weights=weights,dinov2_weights = dinov2_weights,device=device) + dinov2_weights = torch.hub.load_state_dict_from_url( + weight_urls["dinov2"], map_location=device + ) + return roma_model( + resolution=(14 * 8 * 5, 14 * 8 * 5), + upsample_preds=False, + weights=weights, + dinov2_weights=dinov2_weights, + device=device, + ) diff --git a/third_party/Roma/roma/models/model_zoo/roma_models.py b/third_party/Roma/roma/models/model_zoo/roma_models.py index dfb0ff7264880d25f0feb0802e582bf29c84b051..f98ee44f5e2ebd7e43a8e4b17f99b6ed0e85c93a 100644 --- a/third_party/Roma/roma/models/model_zoo/roma_models.py +++ b/third_party/Roma/roma/models/model_zoo/roma_models.py @@ -4,87 +4,95 @@ from roma.models.matcher import * from roma.models.transformer import Block, TransformerDecoder, MemEffAttention from roma.models.encoders import * -def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs): + +def roma_model( + resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs +): # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + warnings.filterwarnings( + "ignore", category=UserWarning, message="TypedStorage is deprecated" + ) gp_dim = 512 feat_dim = 512 decoder_dim = gp_dim + feat_dim cls_to_coord_res = 64 coordinate_decoder = TransformerDecoder( - nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), - decoder_dim, + nn.Sequential( + *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)] + ), + decoder_dim, cls_to_coord_res**2 + 1, is_classifier=True, - amp = True, - pos_enc = False,) + amp=True, + pos_enc=False, + ) dw = True hidden_blocks = 8 kernel_size = 5 displacement_emb = "linear" disable_local_corr_grad = True - + conv_refiner = nn.ModuleDict( { "16": ConvRefiner( - 2 * 512+128+(2*7+1)**2, - 2 * 512+128+(2*7+1)**2, + 2 * 512 + 128 + (2 * 7 + 1) ** 2, + 2 * 512 + 128 + (2 * 7 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=128, - local_corr_radius = 7, - corr_in_other = True, - amp = True, - disable_local_corr_grad = disable_local_corr_grad, - bn_momentum = 0.01, + local_corr_radius=7, + corr_in_other=True, + amp=True, + disable_local_corr_grad=disable_local_corr_grad, + bn_momentum=0.01, ), "8": ConvRefiner( - 2 * 512+64+(2*3+1)**2, - 2 * 512+64+(2*3+1)**2, + 2 * 512 + 64 + (2 * 3 + 1) ** 2, + 2 * 512 + 64 + (2 * 3 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=64, - local_corr_radius = 3, - corr_in_other = True, - amp = True, - disable_local_corr_grad = disable_local_corr_grad, - bn_momentum = 0.01, + local_corr_radius=3, + corr_in_other=True, + amp=True, + disable_local_corr_grad=disable_local_corr_grad, + bn_momentum=0.01, ), "4": ConvRefiner( - 2 * 256+32+(2*2+1)**2, - 2 * 256+32+(2*2+1)**2, + 2 * 256 + 32 + (2 * 2 + 1) ** 2, + 2 * 256 + 32 + (2 * 2 + 1) ** 2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=32, - local_corr_radius = 2, - corr_in_other = True, - amp = True, - disable_local_corr_grad = disable_local_corr_grad, - bn_momentum = 0.01, + local_corr_radius=2, + corr_in_other=True, + amp=True, + disable_local_corr_grad=disable_local_corr_grad, + bn_momentum=0.01, ), "2": ConvRefiner( - 2 * 64+16, - 128+16, + 2 * 64 + 16, + 128 + 16, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=16, - amp = True, - disable_local_corr_grad = disable_local_corr_grad, - bn_momentum = 0.01, + amp=True, + disable_local_corr_grad=disable_local_corr_grad, + bn_momentum=0.01, ), "1": ConvRefiner( 2 * 9 + 6, @@ -92,12 +100,12 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w 2 + 1, kernel_size=kernel_size, dw=dw, - hidden_blocks = hidden_blocks, - displacement_emb = displacement_emb, - displacement_emb_dim = 6, - amp = True, - disable_local_corr_grad = disable_local_corr_grad, - bn_momentum = 0.01, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + amp=True, + disable_local_corr_grad=disable_local_corr_grad, + bn_momentum=0.01, ), } ) @@ -122,36 +130,46 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) - proj = nn.ModuleDict({ - "16": proj16, - "8": proj8, - "4": proj4, - "2": proj2, - "1": proj1, - }) + proj = nn.ModuleDict( + { + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + } + ) displacement_dropout_p = 0.0 gm_warp_dropout_p = 0.0 - decoder = Decoder(coordinate_decoder, - gps, - proj, - conv_refiner, - detach=True, - scales=["16", "8", "4", "2", "1"], - displacement_dropout_p = displacement_dropout_p, - gm_warp_dropout_p = gm_warp_dropout_p) - + decoder = Decoder( + coordinate_decoder, + gps, + proj, + conv_refiner, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p=displacement_dropout_p, + gm_warp_dropout_p=gm_warp_dropout_p, + ) + encoder = CNNandDinov2( - cnn_kwargs = dict( - pretrained=False, - amp = True), - amp = True, - use_vgg = True, - dinov2_weights = dinov2_weights + cnn_kwargs=dict(pretrained=False, amp=True), + amp=True, + use_vgg=True, + dinov2_weights=dinov2_weights, ) - h,w = resolution + h, w = resolution symmetric = True attenuate_cert = True - matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, - symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device) + matcher = RegressionMatcher( + encoder, + decoder, + h=h, + w=w, + upsample_preds=upsample_preds, + symmetric=symmetric, + attenuate_cert=attenuate_cert, + **kwargs + ).to(device) matcher.load_state_dict(weights) return matcher diff --git a/third_party/Roma/roma/models/transformer/__init__.py b/third_party/Roma/roma/models/transformer/__init__.py index 4770ebb19f111df14f1539fa3696553d96d4e48b..a4b45d163d7e693b62edb5322a56387f82b27e04 100644 --- a/third_party/Roma/roma/models/transformer/__init__.py +++ b/third_party/Roma/roma/models/transformer/__init__.py @@ -7,9 +7,21 @@ from .layers.block import Block from .layers.attention import MemEffAttention from .dinov2 import vit_large + class TransformerDecoder(nn.Module): - def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, - amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None: + def __init__( + self, + blocks, + hidden_dim, + out_dim, + is_classifier=False, + *args, + amp=False, + pos_enc=True, + learned_embeddings=False, + embedding_dim=None, + **kwargs + ) -> None: super().__init__(*args, **kwargs) self.blocks = blocks self.to_out = nn.Linear(hidden_dim, out_dim) @@ -18,30 +30,44 @@ class TransformerDecoder(nn.Module): self._scales = [16] self.is_classifier = is_classifier self.amp = amp - self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + self.amp_dtype = torch.bfloat16 + else: + self.amp_dtype = torch.float16 self.pos_enc = pos_enc self.learned_embeddings = learned_embeddings if self.learned_embeddings: - self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim)))) + self.learned_pos_embeddings = nn.Parameter( + nn.init.kaiming_normal_( + torch.empty((1, hidden_dim, embedding_dim, embedding_dim)) + ) + ) def scales(self): return self._scales.copy() def forward(self, gp_posterior, features, old_stuff, new_scale): with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp): - B,C,H,W = gp_posterior.shape - x = torch.cat((gp_posterior, features), dim = 1) - B,C,H,W = x.shape - grid = get_grid(B, H, W, x.device).reshape(B,H*W,2) + B, C, H, W = gp_posterior.shape + x = torch.cat((gp_posterior, features), dim=1) + B, C, H, W = x.shape + grid = get_grid(B, H, W, x.device).reshape(B, H * W, 2) if self.learned_embeddings: - pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C) + pos_enc = ( + F.interpolate( + self.learned_pos_embeddings, + size=(H, W), + mode="bilinear", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .reshape(1, H * W, C) + ) else: pos_enc = 0 - tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc + tokens = x.reshape(B, C, H * W).permute(0, 2, 1) + pos_enc z = self.blocks(tokens) out = self.to_out(z) - out = out.permute(0,2,1).reshape(B, self.out_dim, H, W) + out = out.permute(0, 2, 1).reshape(B, self.out_dim, H, W) warp, certainty = out[:, :-1], out[:, -1:] return warp, certainty, None - - diff --git a/third_party/Roma/roma/models/transformer/dinov2.py b/third_party/Roma/roma/models/transformer/dinov2.py index b556c63096d17239c8603d5fe626c331963099fd..1c27c65b5061cc0113792e40b96eaf7f4266ce18 100644 --- a/third_party/Roma/roma/models/transformer/dinov2.py +++ b/third_party/Roma/roma/models/transformer/dinov2.py @@ -18,16 +18,29 @@ import torch.nn as nn import torch.utils.checkpoint from torch.nn.init import trunc_normal_ -from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: +from .layers import ( + Mlp, + PatchEmbed, + SwiGLUFFNFused, + MemEffAttention, + NestedTensorBlock as Block, +) + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) if depth_first and include_root: fn(module=module, name=name) return module @@ -87,22 +100,33 @@ class DinoVisionTransformer(nn.Module): super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule if ffn_layer == "mlp": ffn_layer = Mlp @@ -139,7 +163,9 @@ class DinoVisionTransformer(nn.Module): chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False @@ -153,7 +179,7 @@ class DinoVisionTransformer(nn.Module): self.init_weights() for param in self.parameters(): param.requires_grad = False - + @property def device(self): return self.cls_token.device @@ -180,20 +206,29 @@ class DinoVisionTransformer(nn.Module): w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode="bicubic", ) - assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + assert ( + int(w0) == patch_pos_embed.shape[-2] + and int(h0) == patch_pos_embed.shape[-1] + ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) @@ -201,7 +236,10 @@ class DinoVisionTransformer(nn.Module): return x def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] for blk in self.blocks: x = blk(x) @@ -240,26 +278,34 @@ class DinoVisionTransformer(nn.Module): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( @@ -281,7 +327,9 @@ class DinoVisionTransformer(nn.Module): if reshape: B, _, w, h = x.shape outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() for out in outputs ] if return_class_token: @@ -356,4 +404,4 @@ def vit_giant2(patch_size=16, **kwargs): block_fn=partial(Block, attn_class=MemEffAttention), **kwargs, ) - return model \ No newline at end of file + return model diff --git a/third_party/Roma/roma/models/transformer/layers/attention.py b/third_party/Roma/roma/models/transformer/layers/attention.py index 1f9b0c94b40967dfdff4f261c127cbd21328c905..12f388719bf5f171d59aee238d902bb7915f864b 100644 --- a/third_party/Roma/roma/models/transformer/layers/attention.py +++ b/third_party/Roma/roma/models/transformer/layers/attention.py @@ -48,7 +48,11 @@ class Attention(nn.Module): def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q @ k.transpose(-2, -1) diff --git a/third_party/Roma/roma/models/transformer/layers/block.py b/third_party/Roma/roma/models/transformer/layers/block.py index 25488f57cc0ad3c692f86b62555f6668e2a66db1..1b5f5158f073788d3d5fe3e09742d4485ef26441 100644 --- a/third_party/Roma/roma/models/transformer/layers/block.py +++ b/third_party/Roma/roma/models/transformer/layers/block.py @@ -62,7 +62,9 @@ class Block(nn.Module): attn_drop=attn_drop, proj_drop=drop, ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) @@ -74,7 +76,9 @@ class Block(nn.Module): drop=drop, bias=ffn_bias, ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path @@ -127,7 +131,9 @@ def drop_add_residual_stochastic_depth( residual_scale_factor = b / sample_subset_size # 3) add the residual - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) return x_plus_residual.view_as(x) @@ -143,10 +149,16 @@ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) else: x_plus_residual = scaled_index_add( - x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, ) return x_plus_residual @@ -158,7 +170,11 @@ def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ - batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] @@ -170,7 +186,9 @@ def get_attn_bias_and_cat(x_list, branges=None): attn_bias_cache[all_shapes] = attn_bias if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) @@ -185,7 +203,9 @@ def drop_add_residual_stochastic_depth_list( scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch - branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] @@ -196,8 +216,14 @@ def drop_add_residual_stochastic_depth_list( residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] - for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): - outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) return outputs @@ -220,13 +246,17 @@ class NestedTensorBlock(Block): x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, ) return x_list else: @@ -246,7 +276,9 @@ class NestedTensorBlock(Block): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): - assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + assert ( + XFORMERS_AVAILABLE + ), "Please install xFormers for nested tensors usage" return self.forward_nested(x_or_x_list) else: raise AssertionError diff --git a/third_party/Roma/roma/models/transformer/layers/dino_head.py b/third_party/Roma/roma/models/transformer/layers/dino_head.py index 7212db92a4fd8d4c7230e284e551a0234e9d8623..1147dd3a3c046aee8d427b42b1055f38a218275b 100644 --- a/third_party/Roma/roma/models/transformer/layers/dino_head.py +++ b/third_party/Roma/roma/models/transformer/layers/dino_head.py @@ -23,7 +23,14 @@ class DINOHead(nn.Module): ): super().__init__() nlayers = max(nlayers, 1) - self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) self.apply(self._init_weights) self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) @@ -42,7 +49,9 @@ class DINOHead(nn.Module): return x -def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): if nlayers == 1: return nn.Linear(in_dim, bottleneck_dim, bias=bias) else: diff --git a/third_party/Roma/roma/models/transformer/layers/drop_path.py b/third_party/Roma/roma/models/transformer/layers/drop_path.py index af05625984dd14682cc96a63bf0c97bab1f123b1..a23ba7325d0fd154d5885573770956042ce2311d 100644 --- a/third_party/Roma/roma/models/transformer/layers/drop_path.py +++ b/third_party/Roma/roma/models/transformer/layers/drop_path.py @@ -16,7 +16,9 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) diff --git a/third_party/Roma/roma/models/transformer/layers/patch_embed.py b/third_party/Roma/roma/models/transformer/layers/patch_embed.py index 574abe41175568d700a389b8b96d1ba554914779..837f952cf9a463444feeb146e0d5b539102ee26c 100644 --- a/third_party/Roma/roma/models/transformer/layers/patch_embed.py +++ b/third_party/Roma/roma/models/transformer/layers/patch_embed.py @@ -63,15 +63,21 @@ class PatchEmbed(nn.Module): self.flatten_embedding = flatten_embedding - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) @@ -83,7 +89,13 @@ class PatchEmbed(nn.Module): def flops(self) -> float: Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops diff --git a/third_party/Roma/roma/train/train.py b/third_party/Roma/roma/train/train.py index 5556f7ebf9b6378e1395c125dde093f5e55e7141..eb3deaf1792a315d1cce77a2ee0fd50ae9e98ac1 100644 --- a/third_party/Roma/roma/train/train.py +++ b/third_party/Roma/roma/train/train.py @@ -4,41 +4,62 @@ import roma import torch import wandb -def log_param_statistics(named_parameters, norm_type = 2): + +def log_param_statistics(named_parameters, norm_type=2): named_parameters = list(named_parameters) grads = [p.grad for n, p in named_parameters if p.grad is not None] - weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None] - names = [n for n,p in named_parameters if p.grad is not None] + weight_norms = [ + p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None + ] + names = [n for n, p in named_parameters if p.grad is not None] param_norm = torch.stack(weight_norms).norm(p=norm_type) device = grads[0].device - grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]) + grad_norms = torch.stack( + [torch.norm(g.detach(), norm_type).to(device) for g in grads] + ) nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] total_grad_norm = torch.norm(grad_norms, norm_type) if torch.any(nans_or_infs): print(f"These params have nan or inf grads: {nan_inf_names}") - wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP) - wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP) + wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP) + wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP) + -def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs): +def train_step( + train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs +): optimizer.zero_grad() out = model(train_batch) l = objective(out, train_batch) grad_scaler.scale(l).backward() grad_scaler.unscale_(optimizer) log_param_statistics(model.named_parameters()) - torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be? + torch.nn.utils.clip_grad_norm_( + model.parameters(), grad_clip_norm + ) # what should max norm be? grad_scaler.step(optimizer) grad_scaler.update() - wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP) - if grad_scaler._scale < 1.: - grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale) - roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step + wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP) + if grad_scaler._scale < 1.0: + grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale) + roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step return {"train_out": out, "train_loss": l.item()} def train_k_steps( - n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, + n_0, + k, + dataloader, + model, + objective, + optimizer, + lr_scheduler, + grad_scaler, + progress_bar=True, + grad_clip_norm=1.0, + warmup=None, + ema_model=None, ): for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): batch = next(dataloader) @@ -52,7 +73,7 @@ def train_k_steps( lr_scheduler=lr_scheduler, grad_scaler=grad_scaler, n=n, - grad_clip_norm = grad_clip_norm, + grad_clip_norm=grad_clip_norm, ) if ema_model is not None: ema_model.update() @@ -61,7 +82,10 @@ def train_k_steps( lr_scheduler.step() else: lr_scheduler.step() - [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())] + [ + wandb.log({f"lr_group_{grp}": lr}) + for grp, lr in enumerate(lr_scheduler.get_last_lr()) + ] def train_epoch( diff --git a/third_party/Roma/roma/utils/kde.py b/third_party/Roma/roma/utils/kde.py index 90a058fb68253cfe23c2a7f21b213bea8e06cfe3..eff7c72dad4a3f90f5ff79d2630427de89838fc5 100644 --- a/third_party/Roma/roma/utils/kde.py +++ b/third_party/Roma/roma/utils/kde.py @@ -1,8 +1,9 @@ import torch -def kde(x, std = 0.1): + +def kde(x, std=0.1): # use a gaussian kernel to estimate density - x = x.half() # Do it in half precision - scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + x = x.half() # Do it in half precision + scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp() density = scores.sum(dim=-1) - return density \ No newline at end of file + return density diff --git a/third_party/Roma/roma/utils/local_correlation.py b/third_party/Roma/roma/utils/local_correlation.py index 586eef5f154a95968b253ad9701933b55b3a4dd6..84a13c63b52db979000916bcb9511e1d3a5ca7fa 100644 --- a/third_party/Roma/roma/utils/local_correlation.py +++ b/third_party/Roma/roma/utils/local_correlation.py @@ -1,47 +1,66 @@ import torch import torch.nn.functional as F + def local_correlation( feature0, feature1, local_radius, padding_mode="zeros", - flow = None, - sample_mode = "bilinear", + flow=None, + sample_mode="bilinear", ): r = local_radius - K = (2*r+1)**2 + K = (2 * r + 1) ** 2 B, c, h, w = feature0.size() feature0 = feature0.half() feature1 = feature1.half() - corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype) + corr = torch.empty((B, K, h, w), device=feature0.device, dtype=feature0.dtype) if flow is None: # If flow is None, assume feature0 and feature1 are aligned coords = torch.meshgrid( - ( - torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"), - torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"), - )) - coords = torch.stack((coords[1], coords[0]), dim=-1)[ - None - ].expand(B, h, w, 2) + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"), + ) + ) + coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2) else: - coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + coords = flow.permute(0, 2, 3, 1) # If using flow, sample around flow target. local_window = torch.meshgrid( - ( - torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"), - torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"), - )) - local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ - None - ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2) + ( + torch.linspace( + -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device="cuda" + ), + torch.linspace( + -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device="cuda" + ), + ) + ) + local_window = ( + torch.stack((local_window[1], local_window[0]), dim=-1)[None] + .expand(1, 2 * r + 1, 2 * r + 1, 2) + .reshape(1, (2 * r + 1) ** 2, 2) + ) for _ in range(B): with torch.no_grad(): - local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).float() + local_window_coords = ( + (coords[_, :, :, None] + local_window[:, None, None]) + .reshape(1, h, w * (2 * r + 1) ** 2, 2) + .float() + ) window_feature = F.grid_sample( - feature1[_:_+1].float(), local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, # + feature1[_ : _ + 1].float(), + local_window_coords, + padding_mode=padding_mode, + align_corners=False, + mode=sample_mode, # ) - window_feature = window_feature.reshape(c,h,w,(2*r+1)**2) - corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1) + window_feature = window_feature.reshape(c, h, w, (2 * r + 1) ** 2) + corr[_] = ( + (feature0[_, ..., None] / (c**0.5) * window_feature) + .sum(dim=0) + .permute(2, 0, 1) + ) torch.cuda.empty_cache() - return corr \ No newline at end of file + return corr diff --git a/third_party/Roma/roma/utils/transforms.py b/third_party/Roma/roma/utils/transforms.py index ea6476bd816a31df36f7d1b5417853637b65474b..b33c3f30f422bca6a81aa201952b7bb2d3d906bf 100644 --- a/third_party/Roma/roma/utils/transforms.py +++ b/third_party/Roma/roma/utils/transforms.py @@ -16,7 +16,9 @@ class GeometricSequential: for t in self.transforms: if np.random.rand() < t.p: M = M.matmul( - t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None) + t.compute_transformation( + x, t.generate_parameters((b, c, h, w)), None + ) ) return ( warp_perspective( @@ -104,15 +106,14 @@ class RandomPerspective(K.RandomPerspective): return dict(start_points=start_points, end_points=end_points) - class RandomErasing: - def __init__(self, p = 0., scale = 0.) -> None: + def __init__(self, p=0.0, scale=0.0) -> None: self.p = p self.scale = scale - self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p) + self.random_eraser = K.RandomErasing(scale=(0.02, scale), p=p) + def __call__(self, image, depth): if self.p > 0: image = self.random_eraser(image) depth = self.random_eraser(depth, params=self.random_eraser._params) return image, depth - \ No newline at end of file diff --git a/third_party/Roma/roma/utils/utils.py b/third_party/Roma/roma/utils/utils.py index d673f679823c833688e2548dd40bf50943796a71..969e1003419f3b7f05874830b79de73363017f01 100644 --- a/third_party/Roma/roma/utils/utils.py +++ b/third_party/Roma/roma/utils/utils.py @@ -9,13 +9,14 @@ import torch.nn.functional as F from PIL import Image import kornia + def recover_pose(E, kpts0, kpts1, K0, K1, mask): best_num_inliers = 0 - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) @@ -25,17 +26,16 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask): return ret - # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py # --- GEOMETRY --- def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) - kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf ) @@ -51,31 +51,40 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): ret = (R, t, mask.ravel() > 0) return ret + def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None method = cv2.USAC_ACCURATE F, mask = cv2.findFundamentalMat( - kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 + kpts0, + kpts1, + ransacReprojThreshold=norm_thresh, + confidence=conf, + method=method, + maxIters=10000, ) - E = K1.T@F@K0 + E = K1.T @ F @ K0 ret = None if E is not None: best_num_inliers = 0 - K0inv = np.linalg.inv(K0[:2,:2]) - K1inv = np.linalg.inv(K1[:2,:2]) + K0inv = np.linalg.inv(K0[:2, :2]) + K1inv = np.linalg.inv(K1[:2, :2]) + + kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T + kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T - kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T - kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T - for _E in np.split(E, len(E) / 3): - n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + n, R, t, _ = cv2.recoverPose( + _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask + ) if n > best_num_inliers: best_num_inliers = n ret = (R, t, mask.ravel() > 0) return ret -def unnormalize_coords(x_n,h,w): + +def unnormalize_coords(x_n, h, w): x = torch.stack( (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] @@ -155,6 +164,7 @@ def get_depth_tuple_transform_ops_nearest_exact(resize=None): ops.append(TupleResizeNearestExact(resize)) return TupleCompose(ops) + def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): ops = [] if resize: @@ -162,7 +172,9 @@ def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): return TupleCompose(ops) -def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): +def get_tuple_transform_ops( + resize=None, normalize=True, unscale=False, clahe=False, colorjiggle_params=None +): ops = [] if resize: ops.append(TupleResize(resize)) @@ -173,6 +185,7 @@ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = ) # Imagenet mean/std return TupleCompose(ops) + class ToTensorScaled(object): """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" @@ -221,11 +234,15 @@ class TupleToTensorUnscaled(object): def __repr__(self): return "TupleToTensorUnscaled()" + class TupleResizeNearestExact: def __init__(self, size): self.size = size + def __call__(self, im_tuple): - return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] + return [ + F.interpolate(im, size=self.size, mode="nearest-exact") for im in im_tuple + ] def __repr__(self): return "TupleResizeNearestExact(size={})".format(self.size) @@ -235,17 +252,19 @@ class TupleResize(object): def __init__(self, size, mode=InterpolationMode.BICUBIC): self.size = size self.resize = transforms.Resize(size, mode) + def __call__(self, im_tuple): return [self.resize(im) for im in im_tuple] def __repr__(self): return "TupleResize(size={})".format(self.size) - + + class Normalize: - def __call__(self,im): - mean = im.mean(dim=(1,2), keepdims=True) - std = im.std(dim=(1,2), keepdims=True) - return (im-mean)/std + def __call__(self, im): + mean = im.mean(dim=(1, 2), keepdims=True) + std = im.std(dim=(1, 2), keepdims=True) + return (im - mean) / std class TupleNormalize(object): @@ -255,7 +274,7 @@ class TupleNormalize(object): self.normalize = transforms.Normalize(mean=mean, std=std) def __call__(self, im_tuple): - c,h,w = im_tuple[0].shape + c, h, w = im_tuple[0].shape if c > 3: warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") return [self.normalize(im[:3]) for im in im_tuple] @@ -281,50 +300,82 @@ class TupleCompose(object): format_string += "\n)" return format_string + @torch.no_grad() -def cls_to_flow(cls, deterministic_sampling = True): - B,C,H,W = cls.shape +def cls_to_flow(cls, deterministic_sampling=True): + B, C, H, W = cls.shape device = cls.device res = round(math.sqrt(C)) - G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) - G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + G = torch.meshgrid( + *[ + torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device) + for _ in range(2) + ] + ) + G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2) if deterministic_sampling: sampled_cls = cls.max(dim=1).indices else: - sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) + sampled_cls = torch.multinomial( + cls.permute(0, 2, 3, 1).reshape(B * H * W, C).softmax(dim=-1), 1 + ).reshape(B, H, W) flow = G[sampled_cls] return flow + @torch.no_grad() def cls_to_flow_refine(cls): - B,C,H,W = cls.shape + B, C, H, W = cls.shape device = cls.device res = round(math.sqrt(C)) - G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) - G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + G = torch.meshgrid( + *[ + torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device) + for _ in range(2) + ] + ) + G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2) cls = cls.softmax(dim=1) mode = cls.max(dim=1).indices - - index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() - neighbours = torch.gather(cls, dim = 1, index = index)[...,None] - flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] - tot_prob = neighbours.sum(dim=1) + + index = ( + torch.stack((mode - 1, mode, mode + 1, mode - res, mode + res), dim=1) + .clamp(0, C - 1) + .long() + ) + neighbours = torch.gather(cls, dim=1, index=index)[..., None] + flow = ( + neighbours[:, 0] * G[index[:, 0]] + + neighbours[:, 1] * G[index[:, 1]] + + neighbours[:, 2] * G[index[:, 2]] + + neighbours[:, 3] * G[index[:, 3]] + + neighbours[:, 4] * G[index[:, 4]] + ) + tot_prob = neighbours.sum(dim=1) flow = flow / tot_prob return flow -def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): - +def get_gt_warp( + depth1, + depth2, + T_1to2, + K1, + K2, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=0.05, + H=None, + W=None, +): + if H is None: - B,H,W = depth1.shape + B, H, W = depth1.shape else: B = depth1.shape[0] with torch.no_grad(): x1_n = torch.meshgrid( *[ - torch.linspace( - -1 + 1 / n, 1 - 1 / n, n, device=depth1.device - ) + torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device) for n in (B, H, W) ] ) @@ -336,15 +387,27 @@ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bili T_1to2.double(), K1.double(), K2.double(), - depth_interpolation_mode = depth_interpolation_mode, - relative_depth_error_threshold = relative_depth_error_threshold, + depth_interpolation_mode=depth_interpolation_mode, + relative_depth_error_threshold=relative_depth_error_threshold, ) prob = mask.float().reshape(B, H, W) x2 = x2.reshape(B, H, W, 2) return x2, prob + @torch.no_grad() -def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): +def warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=False, + return_relative_depth_error=False, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=0.05, +): """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). @@ -369,26 +432,44 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation if smooth_mask: raise NotImplementedError("Combined bilinear and NN warp not implemented") - valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, - smooth_mask = smooth_mask, - return_relative_depth_error = return_relative_depth_error, - depth_interpolation_mode = "bilinear", - relative_depth_error_threshold = relative_depth_error_threshold) - valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, - smooth_mask = smooth_mask, - return_relative_depth_error = return_relative_depth_error, - depth_interpolation_mode = "nearest-exact", - relative_depth_error_threshold = relative_depth_error_threshold) - nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + valid_bilinear, warp_bilinear = warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=smooth_mask, + return_relative_depth_error=return_relative_depth_error, + depth_interpolation_mode="bilinear", + relative_depth_error_threshold=relative_depth_error_threshold, + ) + valid_nearest, warp_nearest = warp_kpts( + kpts0, + depth0, + depth1, + T_0to1, + K0, + K1, + smooth_mask=smooth_mask, + return_relative_depth_error=return_relative_depth_error, + depth_interpolation_mode="nearest-exact", + relative_depth_error_threshold=relative_depth_error_threshold, + ) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) warp = warp_bilinear.clone() - warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + warp[nearest_valid_bilinear_invalid] = warp_nearest[ + nearest_valid_bilinear_invalid + ] valid = valid_bilinear | valid_nearest return valid, warp - - - kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ - :, 0, :, 0 - ] + + kpts0_depth = F.grid_sample( + depth0[:, None], + kpts0[:, :, None], + mode=depth_interpolation_mode, + align_corners=False, + )[:, 0, :, 0] kpts0 = torch.stack( (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] @@ -427,22 +508,26 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return # w_kpts0[~covisible_mask, :] = -5 # xd w_kpts0_depth = F.grid_sample( - depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + depth1[:, None], + w_kpts0[:, :, None], + mode=depth_interpolation_mode, + align_corners=False, )[:, 0, :, 0] - + relative_depth_error = ( (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth ).abs() if not smooth_mask: consistent_mask = relative_depth_error < relative_depth_error_threshold else: - consistent_mask = (-relative_depth_error/smooth_mask).exp() + consistent_mask = (-relative_depth_error / smooth_mask).exp() valid_mask = nonzero_mask * covisible_mask * consistent_mask if return_relative_depth_error: return relative_depth_error, w_kpts0 else: return valid_mask, w_kpts0 + imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) imagenet_std = torch.tensor([0.229, 0.224, 0.225]) @@ -462,7 +547,9 @@ def numpy_to_pil(x: np.ndarray): def tensor_to_pil(x, unnormalize=False): if unnormalize: - x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + x = x * (imagenet_std[:, None, None].to(x.device)) + ( + imagenet_mean[:, None, None].to(x.device) + ) x = x.detach().permute(1, 2, 0).cpu().numpy() x = np.clip(x, 0.0, 1.0) return numpy_to_pil(x) @@ -492,70 +579,63 @@ def compute_relative_pose(R1, t1, R2, t2): trans = -rots @ t1 + t2 return rots, trans + @torch.no_grad() def reset_opt(opt): for group in opt.param_groups: - for p in group['params']: + for p in group["params"]: if p.requires_grad: state = opt.state[p] # State initialization # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) # Exponential moving average of gradient difference - state['exp_avg_diff'] = torch.zeros_like(p) + state["exp_avg_diff"] = torch.zeros_like(p) def flow_to_pixel_coords(flow, h1, w1): - flow = ( - torch.stack( - ( - w1 * (flow[..., 0] + 1) / 2, - h1 * (flow[..., 1] + 1) / 2, - ), - axis=-1, - ) + flow = torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, ) return flow + def flow_to_normalized_coords(flow, h1, w1): - flow = ( - torch.stack( - ( - 2 * (flow[..., 0]) / w1 - 1, - 2 * (flow[..., 1]) / h1 - 1, - ), - axis=-1, - ) + flow = torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, ) return flow def warp_to_pixel_coords(warp, h1, w1, h2, w2): warp1 = warp[..., :2] - warp1 = ( - torch.stack( - ( - w1 * (warp1[..., 0] + 1) / 2, - h1 * (warp1[..., 1] + 1) / 2, - ), - axis=-1, - ) + warp1 = torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, ) warp2 = warp[..., 2:] - warp2 = ( - torch.stack( - ( - w2 * (warp2[..., 0] + 1) / 2, - h2 * (warp2[..., 1] + 1) / 2, - ), - axis=-1, - ) + warp2 = torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, ) - return torch.cat((warp1,warp2), dim=-1) - + return torch.cat((warp1, warp2), dim=-1) def signed_point_line_distance(point, line, eps: float = 1e-9): @@ -576,7 +656,9 @@ def signed_point_line_distance(point, line, eps: float = 1e-9): if not line.shape[-1] == 3: raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") - numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) + numerator = ( + line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2] + ) denominator = line[..., :2].norm(dim=-1) return numerator / (denominator + eps) @@ -600,6 +682,7 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): the computed Symmetrical distance with shape :math:`(*, N)`. """ import kornia + if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") @@ -611,12 +694,10 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): return signed_point_line_distance(pts2, line1_in_2) + def get_grid(b, h, w, device): grid = torch.meshgrid( - *[ - torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) - for n in (b, h, w) - ] + *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (b, h, w)] ) grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) return grid diff --git a/third_party/SGMNet/components/__init__.py b/third_party/SGMNet/components/__init__.py index c10d2027efcf985c68abf7185f28b947012cae45..a3a974825d770263feafa99fb09b7b656602584d 100644 --- a/third_party/SGMNet/components/__init__.py +++ b/third_party/SGMNet/components/__init__.py @@ -1,3 +1,3 @@ -from . import extractors +from . import extractors from . import matchers -from .load_component import load_component \ No newline at end of file +from .load_component import load_component diff --git a/third_party/SGMNet/components/evaluators.py b/third_party/SGMNet/components/evaluators.py index 59bf0bd7ce3dd085dc86072fc41bad24b9805991..a59af1a1614cfa217b6c50be9826e0ee1832191c 100644 --- a/third_party/SGMNet/components/evaluators.py +++ b/third_party/SGMNet/components/evaluators.py @@ -1,127 +1,181 @@ import numpy as np import sys import os + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) -from utils import evaluation_utils,metrics,fm_utils +from utils import evaluation_utils, metrics, fm_utils import cv2 + class auc_eval: - def __init__(self,config): - self.config=config - self.err_r,self.err_t,self.err=[],[],[] - self.ms=[] - self.precision=[] - - def run(self,info): - E,r_gt,t_gt=info['e'],info['r_gt'],info['t_gt'] - K1,K2,img1,img2=info['K1'],info['K2'],info['img1'],info['img2'] - corr1,corr2=info['corr1'],info['corr2'] - corr1,corr2=evaluation_utils.normalize_intrinsic(corr1,K1),evaluation_utils.normalize_intrinsic(corr2,K2) - size1,size2=max(img1.shape),max(img2.shape) - scale1,scale2=self.config['rescale']/size1,self.config['rescale']/size2 - #ransac - ransac_th=4./((K1[0,0]+K1[1,1])*scale1+(K2[0,0]+K2[1,1])*scale2) - R_hat,t_hat,E_hat=self.estimate(corr1,corr2,ransac_th) - #get pose error - err_r, err_t=metrics.evaluate_R_t(r_gt,t_gt,R_hat,t_hat) - err=max(err_r,err_t) - - if len(corr1)>1: - inlier_mask=metrics.compute_epi_inlier(corr1,corr2,E,self.config['inlier_th']) - precision=inlier_mask.mean() - ms=inlier_mask.sum()/len(info['x1']) + def __init__(self, config): + self.config = config + self.err_r, self.err_t, self.err = [], [], [] + self.ms = [] + self.precision = [] + + def run(self, info): + E, r_gt, t_gt = info["e"], info["r_gt"], info["t_gt"] + K1, K2, img1, img2 = info["K1"], info["K2"], info["img1"], info["img2"] + corr1, corr2 = info["corr1"], info["corr2"] + corr1, corr2 = evaluation_utils.normalize_intrinsic( + corr1, K1 + ), evaluation_utils.normalize_intrinsic(corr2, K2) + size1, size2 = max(img1.shape), max(img2.shape) + scale1, scale2 = self.config["rescale"] / size1, self.config["rescale"] / size2 + # ransac + ransac_th = 4.0 / ( + (K1[0, 0] + K1[1, 1]) * scale1 + (K2[0, 0] + K2[1, 1]) * scale2 + ) + R_hat, t_hat, E_hat = self.estimate(corr1, corr2, ransac_th) + # get pose error + err_r, err_t = metrics.evaluate_R_t(r_gt, t_gt, R_hat, t_hat) + err = max(err_r, err_t) + + if len(corr1) > 1: + inlier_mask = metrics.compute_epi_inlier( + corr1, corr2, E, self.config["inlier_th"] + ) + precision = inlier_mask.mean() + ms = inlier_mask.sum() / len(info["x1"]) else: - ms=precision=0 - - return {'err_r':err_r,'err_t':err_t,'err':err,'ms':ms,'precision':precision} - - def res_inqueue(self,res): - self.err_r.append(res['err_r']),self.err_t.append(res['err_t']),self.err.append(res['err']) - self.ms.append(res['ms']),self.precision.append(res['precision']) - - def estimate(self,corr1,corr2,th): + ms = precision = 0 + + return { + "err_r": err_r, + "err_t": err_t, + "err": err, + "ms": ms, + "precision": precision, + } + + def res_inqueue(self, res): + self.err_r.append(res["err_r"]), self.err_t.append( + res["err_t"] + ), self.err.append(res["err"]) + self.ms.append(res["ms"]), self.precision.append(res["precision"]) + + def estimate(self, corr1, corr2, th): num_inlier = -1 if corr1.shape[0] >= 5: - E, mask_new = cv2.findEssentialMat(corr1, corr2,method=cv2.RANSAC, threshold=th,prob=1-1e-5) + E, mask_new = cv2.findEssentialMat( + corr1, corr2, method=cv2.RANSAC, threshold=th, prob=1 - 1e-5 + ) if E is None: - E=[np.eye(3)] + E = [np.eye(3)] for _E in np.split(E, len(E) / 3): - _num_inlier, _R, _t, _ = cv2.recoverPose(_E, corr1, corr2,np.eye(3), 1e9,mask=mask_new) + _num_inlier, _R, _t, _ = cv2.recoverPose( + _E, corr1, corr2, np.eye(3), 1e9, mask=mask_new + ) if _num_inlier > num_inlier: num_inlier = _num_inlier R = _R t = _t E = _E else: - E,R,t=np.eye(3),np.eye(3),np.zeros(3) - return R,t,E + E, R, t = np.eye(3), np.eye(3), np.zeros(3) + return R, t, E def parse(self): ths = np.arange(7) * 5 - approx_auc=metrics.approx_pose_auc(self.err,ths) - exact_auc=metrics.pose_auc(self.err,ths) - mean_pre,mean_ms=np.mean(np.asarray(self.precision)),np.mean(np.asarray(self.ms)) - - print('auc th: ',ths[1:]) - print('approx auc: ',approx_auc) - print('exact auc: ', exact_auc) - print('mean match score: ',mean_ms*100) - print('mean precision: ',mean_pre*100) - - - -class FMbench_eval: + approx_auc = metrics.approx_pose_auc(self.err, ths) + exact_auc = metrics.pose_auc(self.err, ths) + mean_pre, mean_ms = np.mean(np.asarray(self.precision)), np.mean( + np.asarray(self.ms) + ) - def __init__(self,config): - self.config=config - self.pre,self.pre_post,self.sgd=[],[],[] - self.num_corr,self.num_corr_post=[],[] + print("auc th: ", ths[1:]) + print("approx auc: ", approx_auc) + print("exact auc: ", exact_auc) + print("mean match score: ", mean_ms * 100) + print("mean precision: ", mean_pre * 100) - def run(self,info): - corr1,corr2=info['corr1'],info['corr2'] - F=info['f'] - img1,img2=info['img1'],info['img2'] - if len(corr1)>1: - pre_bf=fm_utils.compute_inlier_rate(corr1,corr2,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean() - F_hat,mask_F=cv2.findFundamentalMat(corr1,corr2,method=cv2.FM_RANSAC,ransacReprojThreshold=1,confidence=1-1e-5) +class FMbench_eval: + def __init__(self, config): + self.config = config + self.pre, self.pre_post, self.sgd = [], [], [] + self.num_corr, self.num_corr_post = [], [] + + def run(self, info): + corr1, corr2 = info["corr1"], info["corr2"] + F = info["f"] + img1, img2 = info["img1"], info["img2"] + + if len(corr1) > 1: + pre_bf = fm_utils.compute_inlier_rate( + corr1, + corr2, + np.flip(img1.shape[:2]), + np.flip(img2.shape[:2]), + F, + th=self.config["inlier_th"], + ).mean() + F_hat, mask_F = cv2.findFundamentalMat( + corr1, + corr2, + method=cv2.FM_RANSAC, + ransacReprojThreshold=1, + confidence=1 - 1e-5, + ) if F_hat is None: - F_hat=np.ones([3,3]) - mask_F=np.ones([len(corr1)]).astype(bool) + F_hat = np.ones([3, 3]) + mask_F = np.ones([len(corr1)]).astype(bool) else: - mask_F=mask_F.squeeze().astype(bool) - F_hat=F_hat[:3] - pre_af=fm_utils.compute_inlier_rate(corr1[mask_F],corr2[mask_F],np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean() - num_corr_af=mask_F.sum() - num_corr=len(corr1) - sgd=fm_utils.compute_SGD(F,F_hat,np.flip(img1.shape[:2]),np.flip(img2.shape[:2])) + mask_F = mask_F.squeeze().astype(bool) + F_hat = F_hat[:3] + pre_af = fm_utils.compute_inlier_rate( + corr1[mask_F], + corr2[mask_F], + np.flip(img1.shape[:2]), + np.flip(img2.shape[:2]), + F, + th=self.config["inlier_th"], + ).mean() + num_corr_af = mask_F.sum() + num_corr = len(corr1) + sgd = fm_utils.compute_SGD( + F, F_hat, np.flip(img1.shape[:2]), np.flip(img2.shape[:2]) + ) else: - pre_bf,pre_af,sgd=0,0,1e8 - num_corr,num_corr_af=0,0 - return {'pre':pre_bf,'pre_post':pre_af,'sgd':sgd,'num_corr':num_corr,'num_corr_post':num_corr_af} - - - def res_inqueue(self,res): - self.pre.append(res['pre']),self.pre_post.append(res['pre_post']),self.sgd.append(res['sgd']) - self.num_corr.append(res['num_corr']),self.num_corr_post.append(res['num_corr_post']) + pre_bf, pre_af, sgd = 0, 0, 1e8 + num_corr, num_corr_af = 0, 0 + return { + "pre": pre_bf, + "pre_post": pre_af, + "sgd": sgd, + "num_corr": num_corr, + "num_corr_post": num_corr_af, + } + + def res_inqueue(self, res): + self.pre.append(res["pre"]), self.pre_post.append( + res["pre_post"] + ), self.sgd.append(res["sgd"]) + self.num_corr.append(res["num_corr"]), self.num_corr_post.append( + res["num_corr_post"] + ) def parse(self): - for seq_index in range(len(self.config['seq'])): - seq=self.config['seq'][seq_index] - offset=seq_index*1000 - pre=np.asarray(self.pre)[offset:offset+1000].mean() - pre_post=np.asarray(self.pre_post)[offset:offset+1000].mean() - num_corr=np.asarray(self.num_corr)[offset:offset+1000].mean() - num_corr_post=np.asarray(self.num_corr_post)[offset:offset+1000].mean() - f_recall=(np.asarray(self.sgd)[offset:offset+1000]self.p_th,index[:,0],index2.squeeze(0) - mask_mc=index2[index] == torch.arange(len(p)).cuda() - mask=mask_th&mask_mc - index1,index2=torch.nonzero(mask).squeeze(1),index[mask] - return index1,index2 + res = self.model(feed_data, test_mode=True) + p = res["p"] + index1, index2 = self.match_p(p[0, :-1, :-1]) + corr1, corr2 = ( + test_data["x1"][:, :2][index1.cpu()], + test_data["x2"][:, :2][index2.cpu()], + ) + if len(corr1.shape) == 1: + corr1, corr2 = corr1[np.newaxis], corr2[np.newaxis] + return corr1, corr2 + def match_p(self, p): # p N*M + score, index = torch.topk(p, k=1, dim=-1) + _, index2 = torch.topk(p, k=1, dim=-2) + mask_th, index, index2 = score[:, 0] > self.p_th, index[:, 0], index2.squeeze(0) + mask_mc = index2[index] == torch.arange(len(p)).cuda() + mask = mask_th & mask_mc + index1, index2 = torch.nonzero(mask).squeeze(1), index[mask] + return index1, index2 -class NN_Matcher(object): - def __init__(self,config): - config=namedtuple('config',config.keys())(*config.values()) - self.mutual_check=config.mutual_check - self.ratio_th=config.ratio_th +class NN_Matcher(object): + def __init__(self, config): + config = namedtuple("config", config.keys())(*config.values()) + self.mutual_check = config.mutual_check + self.ratio_th = config.ratio_th - def run(self,test_data): - desc1,desc2,x1,x2=test_data['desc1'],test_data['desc2'],test_data['x1'],test_data['x2'] - desc_mat=np.sqrt(abs((desc1**2).sum(-1)[:,np.newaxis]+(desc2**2).sum(-1)[np.newaxis]-2*desc1@desc2.T)) - nn_index=np.argpartition(desc_mat,kth=(1,2),axis=-1) - dis_value12=np.take_along_axis(desc_mat,nn_index, axis=-1) - ratio_score=dis_value12[:,0]/dis_value12[:,1] - nn_index1=nn_index[:,0] - nn_index2=np.argmin(desc_mat,axis=0) - mask_ratio,mask_mutual=ratio_scoreself.config['angle_th'][0],angle_listself.config['overlap_th'][0],overlap_score self.config["angle_th"][0], + angle_list < self.config["angle_th"][1], + ), + np.logical_and( + overlap_score > self.config["overlap_th"][0], + overlap_score < self.config["overlap_th"][1], + ), + ) + pair_list = pair_list[mask_survive] + if len(pair_list) < 100: + print(seq, len(pair_list)) + # sample pairs + shuffled_pair_list = np.random.permutation(pair_list) + sample_target = min(self.config["pairs_per_seq"], len(shuffled_pair_list)) + sample_number = 0 + + info = { + "dR": [], + "dt": [], + "K1": [], + "K2": [], + "img_path1": [], + "img_path2": [], + "fea_path1": [], + "fea_path2": [], + "size1": [], + "size2": [], + "corr": [], + "incorr1": [], + "incorr2": [], + "pair_num": [], + } for cur_pair in shuffled_pair_list: - pair_index1,pair_index2=cur_pair[0],cur_pair[1] - geo1,geo2=geom_dict[pair_index1],geom_dict[pair_index2] - dR = np.dot(geo2['R'], geo1['R'].T) + pair_index1, pair_index2 = cur_pair[0], cur_pair[1] + geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2] + dR = np.dot(geo2["R"], geo1["R"].T) t1, t2 = geo1["T"].reshape([3, 1]), geo2["T"].reshape([3, 1]) dt = t2 - np.dot(dR, t1) - K1,K2=geo1['K'],geo2['K'] - size1,size2=geo1['size'],geo2['size'] - - basename1,basename2=basename_list[pair_index1],basename_list[pair_index2] - img_path1,img_path2=os.path.join(seq,'undist_images',basename1+'.jpg'),os.path.join(seq,'undist_images',basename2+'.jpg') - fea_path1,fea_path2=os.path.join(seq,basename1+'.jpg'+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\ - os.path.join(seq,basename2+'.jpg'+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') - - with h5py.File(os.path.join(self.config['feature_dump_dir'],fea_path1),'r') as fea1, \ - h5py.File(os.path.join(self.config['feature_dump_dir'],fea_path2),'r') as fea2: - desc1,desc2=fea1['descriptors'][()],fea2['descriptors'][()] - kpt1,kpt2=fea1['keypoints'][()],fea2['keypoints'][()] - depth_path1,depth_path2=os.path.join(self.config['rawdata_dir'],'data',seq,'depths',basename1+'.pfm'),\ - os.path.join(self.config['rawdata_dir'],'data',seq,'depths',basename2+'.pfm') - depth1,depth2=self.load_depth(depth_path1),self.load_depth(depth_path2) - corr_index,incorr_index1,incorr_index2=data_utils.make_corr(kpt1[:,:2],kpt2[:,:2],desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2, - self.config['corr_th'],self.config['incorr_th'],self.config['check_desc']) - - if len(corr_index)>self.config['min_corr'] and len(incorr_index1)>self.config['min_incorr'] and len(incorr_index2)>self.config['min_incorr']: - info['corr'].append(corr_index),info['incorr1'].append(incorr_index1),info['incorr2'].append(incorr_index2) - info['dR'].append(dR),info['dt'].append(dt),info['K1'].append(K1),info['K2'].append(K2),info['img_path1'].append(img_path1),info['img_path2'].append(img_path2) - info['fea_path1'].append(fea_path1),info['fea_path2'].append(fea_path2),info['size1'].append(size1),info['size2'].append(size2) - sample_number+=1 - if sample_number==sample_target: + K1, K2 = geo1["K"], geo2["K"] + size1, size2 = geo1["size"], geo2["size"] + + basename1, basename2 = ( + basename_list[pair_index1], + basename_list[pair_index2], + ) + img_path1, img_path2 = os.path.join( + seq, "undist_images", basename1 + ".jpg" + ), os.path.join(seq, "undist_images", basename2 + ".jpg") + fea_path1, fea_path2 = os.path.join( + seq, + basename1 + + ".jpg" + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ), os.path.join( + seq, + basename2 + + ".jpg" + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ) + + with h5py.File( + os.path.join(self.config["feature_dump_dir"], fea_path1), "r" + ) as fea1, h5py.File( + os.path.join(self.config["feature_dump_dir"], fea_path2), "r" + ) as fea2: + desc1, desc2 = fea1["descriptors"][()], fea2["descriptors"][()] + kpt1, kpt2 = fea1["keypoints"][()], fea2["keypoints"][()] + depth_path1, depth_path2 = os.path.join( + self.config["rawdata_dir"], + "data", + seq, + "depths", + basename1 + ".pfm", + ), os.path.join( + self.config["rawdata_dir"], + "data", + seq, + "depths", + basename2 + ".pfm", + ) + depth1, depth2 = self.load_depth(depth_path1), self.load_depth( + depth_path2 + ) + corr_index, incorr_index1, incorr_index2 = data_utils.make_corr( + kpt1[:, :2], + kpt2[:, :2], + desc1, + desc2, + depth1, + depth2, + K1, + K2, + dR, + dt, + size1, + size2, + self.config["corr_th"], + self.config["incorr_th"], + self.config["check_desc"], + ) + + if ( + len(corr_index) > self.config["min_corr"] + and len(incorr_index1) > self.config["min_incorr"] + and len(incorr_index2) > self.config["min_incorr"] + ): + info["corr"].append(corr_index), info["incorr1"].append( + incorr_index1 + ), info["incorr2"].append(incorr_index2) + info["dR"].append(dR), info["dt"].append(dt), info["K1"].append( + K1 + ), info["K2"].append(K2), info["img_path1"].append(img_path1), info[ + "img_path2" + ].append( + img_path2 + ) + info["fea_path1"].append(fea_path1), info["fea_path2"].append( + fea_path2 + ), info["size1"].append(size1), info["size2"].append(size2) + sample_number += 1 + if sample_number == sample_target: break - info['pair_num']=sample_number - #dump info - self.dump_info(seq,info) + info["pair_num"] = sample_number + # dump info + self.dump_info(seq, info) - def collect_meta(self): - print('collecting meta info...') - dump_path,seq_list=[],[] - if self.config['dump_train']: - dump_path.append(os.path.join(self.config['dataset_dump_dir'],'train')) + print("collecting meta info...") + dump_path, seq_list = [], [] + if self.config["dump_train"]: + dump_path.append(os.path.join(self.config["dataset_dump_dir"], "train")) seq_list.append(self.train_list) - if self.config['dump_valid']: - dump_path.append(os.path.join(self.config['dataset_dump_dir'],'valid')) + if self.config["dump_valid"]: + dump_path.append(os.path.join(self.config["dataset_dump_dir"], "valid")) seq_list.append(self.valid_list) - for pth,seqs in zip(dump_path,seq_list): + for pth, seqs in zip(dump_path, seq_list): if not os.path.exists(pth): os.mkdir(pth) - pair_num_list,total_pair=[],0 - for seq_index in range(len(seqs)): - seq=seqs[seq_index] - pair_num=np.loadtxt(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'),dtype=int) + pair_num_list, total_pair = [], 0 + for seq_index in range(len(seqs)): + seq = seqs[seq_index] + pair_num = np.loadtxt( + os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"), + dtype=int, + ) pair_num_list.append(str(pair_num)) - total_pair+=pair_num - pair_num_list=np.stack([np.asarray(seqs,dtype=str),np.asarray(pair_num_list,dtype=str)],axis=1) - pair_num_list=np.concatenate([np.asarray([['total',str(total_pair)]]),pair_num_list],axis=0) - np.savetxt(os.path.join(pth,'pair_num.txt'),pair_num_list,fmt='%s') - + total_pair += pair_num + pair_num_list = np.stack( + [np.asarray(seqs, dtype=str), np.asarray(pair_num_list, dtype=str)], + axis=1, + ) + pair_num_list = np.concatenate( + [np.asarray([["total", str(total_pair)]]), pair_num_list], axis=0 + ) + np.savetxt(os.path.join(pth, "pair_num.txt"), pair_num_list, fmt="%s") + def format_dump_data(self): - print('Formatting data...') - iteration_num=len(self.seq_list)//self.config['num_process'] - if len(self.seq_list)%self.config['num_process']!=0: - iteration_num+=1 - pool=Pool(self.config['num_process']) + print("Formatting data...") + iteration_num = len(self.seq_list) // self.config["num_process"] + if len(self.seq_list) % self.config["num_process"] != 0: + iteration_num += 1 + pool = Pool(self.config["num_process"]) for index in trange(iteration_num): - indices=range(index*self.config['num_process'],min((index+1)*self.config['num_process'],len(self.seq_list))) - pool.map(self.format_seq,indices) + indices = range( + index * self.config["num_process"], + min((index + 1) * self.config["num_process"], len(self.seq_list)), + ) + pool.map(self.format_seq, indices) pool.close() pool.join() - self.collect_meta() \ No newline at end of file + self.collect_meta() diff --git a/third_party/SGMNet/datadump/dumper/scannet.py b/third_party/SGMNet/datadump/dumper/scannet.py index 2556f727fcc9b4c621e44d9ee5cb4e99cb19b7e8..ac45f41e3530fea49191188146187bcef7bd514d 100644 --- a/third_party/SGMNet/datadump/dumper/scannet.py +++ b/third_party/SGMNet/datadump/dumper/scannet.py @@ -7,66 +7,137 @@ import h5py from .base_dumper import BaseDumper import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) sys.path.insert(0, ROOT_DIR) import utils + class scannet(BaseDumper): def get_seqs(self): - self.pair_list=np.loadtxt('../assets/scannet_eval_list.txt',dtype=str) - self.seq_list=np.unique(np.asarray([path.split('/')[0] for path in self.pair_list[:,0]],dtype=str)) - self.dump_seq,self.img_seq=[],[] + self.pair_list = np.loadtxt("../assets/scannet_eval_list.txt", dtype=str) + self.seq_list = np.unique( + np.asarray([path.split("/")[0] for path in self.pair_list[:, 0]], dtype=str) + ) + self.dump_seq, self.img_seq = [], [] for seq in self.seq_list: - dump_dir=os.path.join(self.config['feature_dump_dir'],seq) - cur_img_seq=glob.glob(os.path.join(os.path.join(self.config['rawdata_dir'],seq,'img','*.jpg'))) - cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\ - +'.hdf5' for path in cur_img_seq] - self.img_seq+=cur_img_seq - self.dump_seq+=cur_dump_seq + dump_dir = os.path.join(self.config["feature_dump_dir"], seq) + cur_img_seq = glob.glob( + os.path.join( + os.path.join(self.config["rawdata_dir"], seq, "img", "*.jpg") + ) + ) + cur_dump_seq = [ + os.path.join(dump_dir, path.split("/")[-1]) + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5" + for path in cur_img_seq + ] + self.img_seq += cur_img_seq + self.dump_seq += cur_dump_seq def format_dump_folder(self): - if not os.path.exists(self.config['feature_dump_dir']): - os.mkdir(self.config['feature_dump_dir']) + if not os.path.exists(self.config["feature_dump_dir"]): + os.mkdir(self.config["feature_dump_dir"]) for seq in self.seq_list: - seq_dir=os.path.join(self.config['feature_dump_dir'],seq) + seq_dir = os.path.join(self.config["feature_dump_dir"], seq) if not os.path.exists(seq_dir): os.mkdir(seq_dir) def format_dump_data(self): - print('Formatting data...') - self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]} + print("Formatting data...") + self.data = { + "K1": [], + "K2": [], + "R": [], + "T": [], + "e": [], + "f": [], + "fea_path1": [], + "fea_path2": [], + "img_path1": [], + "img_path2": [], + } for pair in self.pair_list: - img_path1,img_path2=pair[0],pair[1] - seq=img_path1.split('/')[0] - index1,index2=int(img_path1.split('/')[-1][:-4]),int(img_path2.split('/')[-1][:-4]) - ex1,ex2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index1)+'.txt'),dtype=float),\ - np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index2)+'.txt'),dtype=float) - K1,K2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index1)+'.txt'),dtype=float),\ - np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index2)+'.txt'),dtype=float) - + img_path1, img_path2 = pair[0], pair[1] + seq = img_path1.split("/")[0] + index1, index2 = int(img_path1.split("/")[-1][:-4]), int( + img_path2.split("/")[-1][:-4] + ) + ex1, ex2 = np.loadtxt( + os.path.join( + self.config["rawdata_dir"], seq, "extrinsic", str(index1) + ".txt" + ), + dtype=float, + ), np.loadtxt( + os.path.join( + self.config["rawdata_dir"], seq, "extrinsic", str(index2) + ".txt" + ), + dtype=float, + ) + K1, K2 = np.loadtxt( + os.path.join( + self.config["rawdata_dir"], seq, "intrinsic", str(index1) + ".txt" + ), + dtype=float, + ), np.loadtxt( + os.path.join( + self.config["rawdata_dir"], seq, "intrinsic", str(index2) + ".txt" + ), + dtype=float, + ) - relative_extrinsic=np.matmul(np.linalg.inv(ex2),ex1) - dR,dt=relative_extrinsic[:3,:3],relative_extrinsic[:3,3] + relative_extrinsic = np.matmul(np.linalg.inv(ex2), ex1) + dR, dt = relative_extrinsic[:3, :3], relative_extrinsic[:3, 3] dt /= np.sqrt(np.sum(dt**2)) - - e_gt_unnorm = np.reshape(np.matmul( - np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)), - np.reshape(dR.astype('float64'), (3, 3))), (3, 3)) + + e_gt_unnorm = np.reshape( + np.matmul( + np.reshape( + utils.evaluation_utils.np_skew_symmetric( + dt.astype("float64").reshape(1, 3) + ), + (3, 3), + ), + np.reshape(dR.astype("float64"), (3, 3)), + ), + (3, 3), + ) e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) - f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1) + f_gt_unnorm = np.linalg.inv(K2.T) @ e_gt @ np.linalg.inv(K1) f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm) - self.data['K1'].append(K1),self.data['K2'].append(K2) - self.data['R'].append(dR),self.data['T'].append(dt) - self.data['e'].append(e_gt),self.data['f'].append(f_gt) - - dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq) - fea_path1,fea_path2=os.path.join(dump_seq_dir,img_path1.split('/')[-1]+'_'+self.config['extractor']['name'] - +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\ - os.path.join(dump_seq_dir,img_path2.split('/')[-1]+'_'+self.config['extractor']['name'] - +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') - self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2) - self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2) + self.data["K1"].append(K1), self.data["K2"].append(K2) + self.data["R"].append(dR), self.data["T"].append(dt) + self.data["e"].append(e_gt), self.data["f"].append(f_gt) + + dump_seq_dir = os.path.join(self.config["feature_dump_dir"], seq) + fea_path1, fea_path2 = os.path.join( + dump_seq_dir, + img_path1.split("/")[-1] + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ), os.path.join( + dump_seq_dir, + img_path2.split("/")[-1] + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ) + self.data["img_path1"].append(img_path1), self.data["img_path2"].append( + img_path2 + ) + self.data["fea_path1"].append(fea_path1), self.data["fea_path2"].append( + fea_path2 + ) self.form_standard_dataset() diff --git a/third_party/SGMNet/datadump/dumper/yfcc.py b/third_party/SGMNet/datadump/dumper/yfcc.py index 0c52e4324bba3e5ed424fe58af7a94fd3132b1e5..be1efe71775aef04a6e720751d637a093e28c06a 100644 --- a/third_party/SGMNet/datadump/dumper/yfcc.py +++ b/third_party/SGMNet/datadump/dumper/yfcc.py @@ -6,82 +6,145 @@ import h5py from .base_dumper import BaseDumper import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) sys.path.insert(0, ROOT_DIR) import utils + class yfcc(BaseDumper): - def get_seqs(self): - data_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m') - for seq in self.config['data_seq']: - for split in self.config['data_split']: - split_dir=os.path.join(data_dir,seq,split) - dump_dir=os.path.join(self.config['feature_dump_dir'],seq,split) - cur_img_seq=glob.glob(os.path.join(split_dir,'images','*.jpg')) - cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\ - +'.hdf5' for path in cur_img_seq] - self.img_seq+=cur_img_seq - self.dump_seq+=cur_dump_seq + data_dir = os.path.join(self.config["rawdata_dir"], "yfcc100m") + for seq in self.config["data_seq"]: + for split in self.config["data_split"]: + split_dir = os.path.join(data_dir, seq, split) + dump_dir = os.path.join(self.config["feature_dump_dir"], seq, split) + cur_img_seq = glob.glob(os.path.join(split_dir, "images", "*.jpg")) + cur_dump_seq = [ + os.path.join(dump_dir, path.split("/")[-1]) + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5" + for path in cur_img_seq + ] + self.img_seq += cur_img_seq + self.dump_seq += cur_dump_seq def format_dump_folder(self): - if not os.path.exists(self.config['feature_dump_dir']): - os.mkdir(self.config['feature_dump_dir']) - for seq in self.config['data_seq']: - seq_dir=os.path.join(self.config['feature_dump_dir'],seq) + if not os.path.exists(self.config["feature_dump_dir"]): + os.mkdir(self.config["feature_dump_dir"]) + for seq in self.config["data_seq"]: + seq_dir = os.path.join(self.config["feature_dump_dir"], seq) if not os.path.exists(seq_dir): os.mkdir(seq_dir) - for split in self.config['data_split']: - split_dir=os.path.join(seq_dir,split) + for split in self.config["data_split"]: + split_dir = os.path.join(seq_dir, split) if not os.path.exists(split_dir): os.mkdir(split_dir) def format_dump_data(self): - print('Formatting data...') - pair_path=os.path.join(self.config['rawdata_dir'],'pairs') - self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]} + print("Formatting data...") + pair_path = os.path.join(self.config["rawdata_dir"], "pairs") + self.data = { + "K1": [], + "K2": [], + "R": [], + "T": [], + "e": [], + "f": [], + "fea_path1": [], + "fea_path2": [], + "img_path1": [], + "img_path2": [], + } + + for seq in self.config["data_seq"]: + pair_name = os.path.join(pair_path, seq + "-te-1000-pairs.pkl") + with open(pair_name, "rb") as f: + pairs = pickle.load(f) - for seq in self.config['data_seq']: - pair_name=os.path.join(pair_path,seq+'-te-1000-pairs.pkl') - with open(pair_name, 'rb') as f: - pairs=pickle.load(f) - - #generate id list - seq_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m',seq,'test') - name_list=np.loadtxt(os.path.join(seq_dir,'images.txt'),dtype=str) - cam_name_list=np.loadtxt(os.path.join(seq_dir,'calibration.txt'),dtype=str) + # generate id list + seq_dir = os.path.join(self.config["rawdata_dir"], "yfcc100m", seq, "test") + name_list = np.loadtxt(os.path.join(seq_dir, "images.txt"), dtype=str) + cam_name_list = np.loadtxt( + os.path.join(seq_dir, "calibration.txt"), dtype=str + ) for cur_pair in pairs: - index1,index2=cur_pair[0],cur_pair[1] - cam1,cam2=h5py.File(os.path.join(seq_dir,cam_name_list[index1]),'r'),h5py.File(os.path.join(seq_dir,cam_name_list[index2]),'r') - K1,K2=cam1['K'][()],cam2['K'][()] - [w1,h1],[w2,h2]=cam1['imsize'][()][0],cam2['imsize'][()][0] - cx1,cy1,cx2,cy2 = (w1 - 1.0) * 0.5,(h1 - 1.0) * 0.5, (w2 - 1.0) * 0.5,(h2 - 1.0) * 0.5 - K1[0,2],K1[1,2],K2[0,2],K2[1,2]=cx1,cy1,cx2,cy2 + index1, index2 = cur_pair[0], cur_pair[1] + cam1, cam2 = h5py.File( + os.path.join(seq_dir, cam_name_list[index1]), "r" + ), h5py.File(os.path.join(seq_dir, cam_name_list[index2]), "r") + K1, K2 = cam1["K"][()], cam2["K"][()] + [w1, h1], [w2, h2] = cam1["imsize"][()][0], cam2["imsize"][()][0] + cx1, cy1, cx2, cy2 = ( + (w1 - 1.0) * 0.5, + (h1 - 1.0) * 0.5, + (w2 - 1.0) * 0.5, + (h2 - 1.0) * 0.5, + ) + K1[0, 2], K1[1, 2], K2[0, 2], K2[1, 2] = cx1, cy1, cx2, cy2 - R1,R2,t1,t2=cam1['R'][()],cam2['R'][()],cam1['T'][()].reshape([3,1]),cam2['T'][()].reshape([3,1]) + R1, R2, t1, t2 = ( + cam1["R"][()], + cam2["R"][()], + cam1["T"][()].reshape([3, 1]), + cam2["T"][()].reshape([3, 1]), + ) dR = np.dot(R2, R1.T) dt = t2 - np.dot(dR, t1) dt /= np.sqrt(np.sum(dt**2)) - - e_gt_unnorm = np.reshape(np.matmul( - np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)), - np.reshape(dR.astype('float64'), (3, 3))), (3, 3)) + + e_gt_unnorm = np.reshape( + np.matmul( + np.reshape( + utils.evaluation_utils.np_skew_symmetric( + dt.astype("float64").reshape(1, 3) + ), + (3, 3), + ), + np.reshape(dR.astype("float64"), (3, 3)), + ), + (3, 3), + ) e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) - f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1) + f_gt_unnorm = np.linalg.inv(K2.T) @ e_gt @ np.linalg.inv(K1) f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm) - self.data['K1'].append(K1),self.data['K2'].append(K2) - self.data['R'].append(dR),self.data['T'].append(dt) - self.data['e'].append(e_gt),self.data['f'].append(f_gt) - - img_path1,img_path2=os.path.join('yfcc100m',seq,'test',name_list[index1]),os.path.join('yfcc100m',seq,'test',name_list[index2]) - dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq,'test') - fea_path1,fea_path2=os.path.join(dump_seq_dir,name_list[index1].split('/')[-1]+'_'+self.config['extractor']['name'] - +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\ - os.path.join(dump_seq_dir,name_list[index2].split('/')[-1]+'_'+self.config['extractor']['name'] - +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') - self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2) - self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2) + self.data["K1"].append(K1), self.data["K2"].append(K2) + self.data["R"].append(dR), self.data["T"].append(dt) + self.data["e"].append(e_gt), self.data["f"].append(f_gt) + + img_path1, img_path2 = os.path.join( + "yfcc100m", seq, "test", name_list[index1] + ), os.path.join("yfcc100m", seq, "test", name_list[index2]) + dump_seq_dir = os.path.join( + self.config["feature_dump_dir"], seq, "test" + ) + fea_path1, fea_path2 = os.path.join( + dump_seq_dir, + name_list[index1].split("/")[-1] + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ), os.path.join( + dump_seq_dir, + name_list[index2].split("/")[-1] + + "_" + + self.config["extractor"]["name"] + + "_" + + str(self.config["extractor"]["num_kpt"]) + + ".hdf5", + ) + self.data["img_path1"].append(img_path1), self.data["img_path2"].append( + img_path2 + ) + self.data["fea_path1"].append(fea_path1), self.data["fea_path2"].append( + fea_path2 + ) self.form_standard_dataset() diff --git a/third_party/SGMNet/demo/demo.py b/third_party/SGMNet/demo/demo.py index cbe277e26d09121f5517854a7ea014b0797a2bde..835b20485698fbccb055a8f08024014142666377 100644 --- a/third_party/SGMNet/demo/demo.py +++ b/third_party/SGMNet/demo/demo.py @@ -10,36 +10,56 @@ from components import load_component from utils import evaluation_utils import argparse + parser = argparse.ArgumentParser() -parser.add_argument('--config_path', type=str, default='configs/sgm_config.yaml', - help='number of processes.') -parser.add_argument('--img1_path', type=str, default='demo_1.jpg', - help='number of processes.') -parser.add_argument('--img2_path', type=str, default='demo_2.jpg', - help='number of processes.') +parser.add_argument( + "--config_path", + type=str, + default="configs/sgm_config.yaml", + help="number of processes.", +) +parser.add_argument( + "--img1_path", type=str, default="demo_1.jpg", help="number of processes." +) +parser.add_argument( + "--img2_path", type=str, default="demo_2.jpg", help="number of processes." +) args = parser.parse_args() -if __name__=='__main__': - with open(args.config_path, 'r') as f: - demo_config = yaml.load(f) +if __name__ == "__main__": + with open(args.config_path, "r") as f: + demo_config = yaml.load(f) + + extractor = load_component( + "extractor", demo_config["extractor"]["name"], demo_config["extractor"] + ) - extractor=load_component('extractor',demo_config['extractor']['name'],demo_config['extractor']) + img1, img2 = cv2.imread(args.img1_path), cv2.imread(args.img2_path) + size1, size2 = np.flip(np.asarray(img1.shape[:2])), np.flip( + np.asarray(img2.shape[:2]) + ) + kpt1, desc1 = extractor.run(args.img1_path) + kpt2, desc2 = extractor.run(args.img2_path) - img1,img2=cv2.imread(args.img1_path),cv2.imread(args.img2_path) - size1,size2=np.flip(np.asarray(img1.shape[:2])),np.flip(np.asarray(img2.shape[:2])) - kpt1,desc1=extractor.run(args.img1_path) - kpt2,desc2=extractor.run(args.img2_path) - - matcher=load_component('matcher',demo_config['matcher']['name'],demo_config['matcher']) - test_data={'x1':kpt1,'x2':kpt2,'desc1':desc1,'desc2':desc2,'size1':size1,'size2':size2} - corr1,corr2= matcher.run(test_data) + matcher = load_component( + "matcher", demo_config["matcher"]["name"], demo_config["matcher"] + ) + test_data = { + "x1": kpt1, + "x2": kpt2, + "desc1": desc1, + "desc2": desc2, + "size1": size1, + "size2": size2, + } + corr1, corr2 = matcher.run(test_data) - #draw points + # draw points dis_points_1 = evaluation_utils.draw_points(img1, kpt1) - dis_points_2 = evaluation_utils.draw_points(img2, kpt2) + dis_points_2 = evaluation_utils.draw_points(img2, kpt2) - #visualize match - display=evaluation_utils.draw_match(dis_points_1,dis_points_2,corr1,corr2) - cv2.imwrite('match.png',display) + # visualize match + display = evaluation_utils.draw_match(dis_points_1, dis_points_2, corr1, corr2) + cv2.imwrite("match.png", display) diff --git a/third_party/SGMNet/evaluation/eval_cost.py b/third_party/SGMNet/evaluation/eval_cost.py index dd3f88abc93290c96ed3d7fa8624c3534e006911..972b4c226c84c3f24dfb2b76e0a31b12719166b0 100644 --- a/third_party/SGMNet/evaluation/eval_cost.py +++ b/third_party/SGMNet/evaluation/eval_cost.py @@ -1,9 +1,10 @@ import torch import yaml import time -from collections import OrderedDict,namedtuple +from collections import OrderedDict, namedtuple import os import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) @@ -12,49 +13,59 @@ from superglue import matcher as SG_Model import argparse + parser = argparse.ArgumentParser() -parser.add_argument('--matcher_name', type=str, default='SGM', - help='number of processes.') -parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml', - help='number of processes.') -parser.add_argument('--num_kpt', type=int, default=4000, - help='keypoint number, default:100') -parser.add_argument('--iter_num', type=int, default=100, - help='keypoint number, default:100') +parser.add_argument( + "--matcher_name", type=str, default="SGM", help="number of processes." +) +parser.add_argument( + "--config_path", + type=str, + default="configs/cost/sgm_cost.yaml", + help="number of processes.", +) +parser.add_argument( + "--num_kpt", type=int, default=4000, help="keypoint number, default:100" +) +parser.add_argument( + "--iter_num", type=int, default=100, help="keypoint number, default:100" +) -def test_cost(test_data,model): +def test_cost(test_data, model): with torch.no_grad(): - #warm up call - _=model(test_data) + # warm up call + _ = model(test_data) torch.cuda.synchronize() - a=time.time() + a = time.time() for _ in range(int(args.iter_num)): - _=model(test_data) + _ = model(test_data) torch.cuda.synchronize() - b=time.time() - print('Average time per run(ms): ',(b-a)/args.iter_num*1e3) - print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6) + b = time.time() + print("Average time per run(ms): ", (b - a) / args.iter_num * 1e3) + print("Peak memory(MB): ", torch.cuda.max_memory_allocated() / 1e6) -if __name__=='__main__': - torch.backends.cudnn.benchmark=False +if __name__ == "__main__": + torch.backends.cudnn.benchmark = False args = parser.parse_args() - with open(args.config_path, 'r') as f: - model_config = yaml.load(f) - model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) - - if args.matcher_name=='SGM': - model = SGM_Model(model_config) - elif args.matcher_name=='SG': - model = SG_Model(model_config) - model.cuda(),model.eval() - + with open(args.config_path, "r") as f: + model_config = yaml.load(f) + model_config = namedtuple("model_config", model_config.keys())( + *model_config.values() + ) + + if args.matcher_name == "SGM": + model = SGM_Model(model_config) + elif args.matcher_name == "SG": + model = SG_Model(model_config) + model.cuda(), model.eval() + test_data = { - 'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5, - 'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5, - 'desc1': torch.rand(1,args.num_kpt,128).cuda(), - 'desc2': torch.rand(1,args.num_kpt,128).cuda() - } + "x1": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, + "x2": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, + "desc1": torch.rand(1, args.num_kpt, 128).cuda(), + "desc2": torch.rand(1, args.num_kpt, 128).cuda(), + } - test_cost(test_data,model) + test_cost(test_data, model) diff --git a/third_party/SGMNet/evaluation/evaluate.py b/third_party/SGMNet/evaluation/evaluate.py index dd5229375caa03b2763bf37a266fb76e80f8e25e..ec6c3ed2aa907838ed3d1cc0ed15710bcd5a6e5f 100644 --- a/third_party/SGMNet/evaluation/evaluate.py +++ b/third_party/SGMNet/evaluation/evaluate.py @@ -1,5 +1,5 @@ import os -from torch.multiprocessing import Process,Manager,set_start_method,Pool +from torch.multiprocessing import Process, Manager, set_start_method, Pool import functools import argparse import yaml @@ -7,111 +7,144 @@ import numpy as np import sys import cv2 from tqdm import trange -set_start_method('spawn',force=True) + +set_start_method("spawn", force=True) ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from components import load_component -from utils import evaluation_utils,metrics - -parser = argparse.ArgumentParser(description='dump eval data.') -parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml') -parser.add_argument('--num_process_match', type=int, default=4) -parser.add_argument('--num_process_eval', type=int, default=4) -parser.add_argument('--vis_folder',type=str,default=None) -args=parser.parse_args() - -def feed_match(info,matcher): - x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2] - test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) } - corr1,corr2=matcher.run(test_data) - return [corr1,corr2] - - -def reader_handler(config,read_que): - reader=load_component('reader',config['name'],config) - for index in range(len(reader)): - index+=0 - info=reader.run(index) - read_que.put(info) - read_que.put('over') - - -def match_handler(config,read_que,match_que): - matcher=load_component('matcher',config['name'],config) - match_func=functools.partial(feed_match,matcher=matcher) - pool = Pool(args.num_process_match) - cache=[] - while True: - item=read_que.get() - #clear cache - if item=='over': - if len(cache)!=0: - results=pool.map(match_func,cache) - for cur_item,cur_result in zip(cache,results): - cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] - match_que.put(cur_item) - match_que.put('over') - break - cache.append(item) - #print(len(cache)) - if len(cache)==args.num_process_match: - #matching in parallel - results=pool.map(match_func,cache) - for cur_item,cur_result in zip(cache,results): - cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] - match_que.put(cur_item) - cache=[] - pool.close() - pool.join() - - -def evaluate_handler(config,match_que): - evaluator=load_component('evaluator',config['name'],config) - pool = Pool(args.num_process_eval) - cache=[] - for _ in trange(config['num_pair']): - item=match_que.get() - if item=='over': - if len(cache)!=0: - results=pool.map(evaluator.run,cache) - for cur_res in results: - evaluator.res_inqueue(cur_res) - break - cache.append(item) - if len(cache)==args.num_process_eval: - results=pool.map(evaluator.run,cache) - for cur_res in results: - evaluator.res_inqueue(cur_res) - cache=[] - if args.vis_folder is not None: - #dump visualization - corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\ - evaluation_utils.normalize_intrinsic(item['corr2'],item['K2']) - inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th']) - display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask) - cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display) - evaluator.parse() - - -if __name__=='__main__': - with open(args.config_path, 'r') as f: - config = yaml.load(f) - if args.vis_folder is not None and not os.path.exists(args.vis_folder): - os.mkdir(args.vis_folder) - - read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100) - - read_process=Process(target=reader_handler,args=(config['reader'],read_que)) - match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que)) - evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que)) - - read_process.start() - match_process.start() - evaluate_process.start() - - read_process.join() - match_process.join() - evaluate_process.join() \ No newline at end of file +from utils import evaluation_utils, metrics + +parser = argparse.ArgumentParser(description="dump eval data.") +parser.add_argument( + "--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml" +) +parser.add_argument("--num_process_match", type=int, default=4) +parser.add_argument("--num_process_eval", type=int, default=4) +parser.add_argument("--vis_folder", type=str, default=None) +args = parser.parse_args() + + +def feed_match(info, matcher): + x1, x2, desc1, desc2, size1, size2 = ( + info["x1"], + info["x2"], + info["desc1"], + info["desc2"], + info["img1"].shape[:2], + info["img2"].shape[:2], + ) + test_data = { + "x1": x1, + "x2": x2, + "desc1": desc1, + "desc2": desc2, + "size1": np.flip(np.asarray(size1)), + "size2": np.flip(np.asarray(size2)), + } + corr1, corr2 = matcher.run(test_data) + return [corr1, corr2] + + +def reader_handler(config, read_que): + reader = load_component("reader", config["name"], config) + for index in range(len(reader)): + index += 0 + info = reader.run(index) + read_que.put(info) + read_que.put("over") + + +def match_handler(config, read_que, match_que): + matcher = load_component("matcher", config["name"], config) + match_func = functools.partial(feed_match, matcher=matcher) + pool = Pool(args.num_process_match) + cache = [] + while True: + item = read_que.get() + # clear cache + if item == "over": + if len(cache) != 0: + results = pool.map(match_func, cache) + for cur_item, cur_result in zip(cache, results): + cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] + match_que.put(cur_item) + match_que.put("over") + break + cache.append(item) + # print(len(cache)) + if len(cache) == args.num_process_match: + # matching in parallel + results = pool.map(match_func, cache) + for cur_item, cur_result in zip(cache, results): + cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] + match_que.put(cur_item) + cache = [] + pool.close() + pool.join() + + +def evaluate_handler(config, match_que): + evaluator = load_component("evaluator", config["name"], config) + pool = Pool(args.num_process_eval) + cache = [] + for _ in trange(config["num_pair"]): + item = match_que.get() + if item == "over": + if len(cache) != 0: + results = pool.map(evaluator.run, cache) + for cur_res in results: + evaluator.res_inqueue(cur_res) + break + cache.append(item) + if len(cache) == args.num_process_eval: + results = pool.map(evaluator.run, cache) + for cur_res in results: + evaluator.res_inqueue(cur_res) + cache = [] + if args.vis_folder is not None: + # dump visualization + corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic( + item["corr1"], item["K1"] + ), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"]) + inlier_mask = metrics.compute_epi_inlier( + corr1_norm, corr2_norm, item["e"], config["inlier_th"] + ) + display = evaluation_utils.draw_match( + item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask + ) + cv2.imwrite( + os.path.join(args.vis_folder, str(item["index"]) + ".png"), display + ) + evaluator.parse() + + +if __name__ == "__main__": + with open(args.config_path, "r") as f: + config = yaml.load(f) + if args.vis_folder is not None and not os.path.exists(args.vis_folder): + os.mkdir(args.vis_folder) + + read_que, match_que, estimate_que = ( + Manager().Queue(maxsize=100), + Manager().Queue(maxsize=100), + Manager().Queue(maxsize=100), + ) + + read_process = Process(target=reader_handler, args=(config["reader"], read_que)) + match_process = Process( + target=match_handler, args=(config["matcher"], read_que, match_que) + ) + evaluate_process = Process( + target=evaluate_handler, args=(config["evaluator"], match_que) + ) + + read_process.start() + match_process.start() + evaluate_process.start() + + read_process.join() + match_process.join() + evaluate_process.join() diff --git a/third_party/SGMNet/sgmnet/__init__.py b/third_party/SGMNet/sgmnet/__init__.py index 828543beceebb10d05fd9d5fdfcc4b1c91e5af6b..fabeccd0fe21eb5be637602f2b2eb3cfd944d11b 100644 --- a/third_party/SGMNet/sgmnet/__init__.py +++ b/third_party/SGMNet/sgmnet/__init__.py @@ -1 +1 @@ -from .match_model import matcher \ No newline at end of file +from .match_model import matcher diff --git a/third_party/SGMNet/sgmnet/match_model.py b/third_party/SGMNet/sgmnet/match_model.py index 1e55fa5d042b010f8d9a99e006002563a3961ae7..c758cf5d6537fb3c47a2de00cc279857755943ef 100644 --- a/third_party/SGMNet/sgmnet/match_model.py +++ b/third_party/SGMNet/sgmnet/match_model.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn -eps=1e-8 +eps = 1e-8 -def sinkhorn(M,r,c,iteration): + +def sinkhorn(M, r, c, iteration): p = torch.softmax(M, dim=-1) u = torch.ones_like(r) v = torch.ones_like(c) @@ -13,46 +14,79 @@ def sinkhorn(M,r,c,iteration): p = p * u.unsqueeze(-1) * v.unsqueeze(-2) return p -def sink_algorithm(M,dustbin,iteration): + +def sink_algorithm(M, dustbin, iteration): M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) - r = torch.ones([M.shape[0], M.shape[1] - 1],device='cuda') - r = torch.cat([r, torch.ones([M.shape[0], 1],device='cuda') * M.shape[1]], dim=-1) - c = torch.ones([M.shape[0], M.shape[2] - 1],device='cuda') - c = torch.cat([c, torch.ones([M.shape[0], 1],device='cuda') * M.shape[2]], dim=-1) - p=sinkhorn(M,r,c,iteration) + r = torch.ones([M.shape[0], M.shape[1] - 1], device="cuda") + r = torch.cat([r, torch.ones([M.shape[0], 1], device="cuda") * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device="cuda") + c = torch.cat([c, torch.ones([M.shape[0], 1], device="cuda") * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) return p - -def seeding(nn_index1,nn_index2,x1,x2,topk,match_score,confbar,nms_radius,use_mc=True,test=False): - - #apply mutual check before nms + +def seeding( + nn_index1, + nn_index2, + x1, + x2, + topk, + match_score, + confbar, + nms_radius, + use_mc=True, + test=False, +): + + # apply mutual check before nms if use_mc: - mask_not_mutual=nn_index2.gather(dim=-1,index=nn_index1)!=torch.arange(nn_index1.shape[1],device='cuda') - match_score[mask_not_mutual]=-1 - #NMS - pos_dismat1=((x1.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x1.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x1@x1.transpose(1,2))).abs_().sqrt_() - x2=x2.gather(index=nn_index1.unsqueeze(-1).expand(-1,-1,2),dim=1) - pos_dismat2=((x2.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x2.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x2@x2.transpose(1,2))).abs_().sqrt_() - radius1, radius2 = nms_radius * pos_dismat1.mean(dim=(1,2),keepdim=True), nms_radius * pos_dismat2.mean(dim=(1,2),keepdim=True) + mask_not_mutual = nn_index2.gather(dim=-1, index=nn_index1) != torch.arange( + nn_index1.shape[1], device="cuda" + ) + match_score[mask_not_mutual] = -1 + # NMS + pos_dismat1 = ( + ( + (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-1) + + (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-2) + - 2 * (x1 @ x1.transpose(1, 2)) + ) + .abs_() + .sqrt_() + ) + x2 = x2.gather(index=nn_index1.unsqueeze(-1).expand(-1, -1, 2), dim=1) + pos_dismat2 = ( + ( + (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-1) + + (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-2) + - 2 * (x2 @ x2.transpose(1, 2)) + ) + .abs_() + .sqrt_() + ) + radius1, radius2 = nms_radius * pos_dismat1.mean( + dim=(1, 2), keepdim=True + ), nms_radius * pos_dismat2.mean(dim=(1, 2), keepdim=True) nms_mask = (pos_dismat1 >= radius1) & (pos_dismat2 >= radius2) - mask_not_local_max=(match_score.unsqueeze(-1)>=match_score.unsqueeze(-2))|nms_mask - mask_not_local_max=~(mask_not_local_max.min(dim=-1).values) + mask_not_local_max = ( + match_score.unsqueeze(-1) >= match_score.unsqueeze(-2) + ) | nms_mask + mask_not_local_max = ~(mask_not_local_max.min(dim=-1).values) match_score[mask_not_local_max] = -1 - - #confidence bar - match_score[match_score0 - if test: - topk=min(mask_survive.sum(dim=1)[0]+2,topk) - _,topindex = torch.topk(match_score,topk,dim=-1)#b*k - seed_index1,seed_index2=topindex,nn_index1.gather(index=topindex,dim=-1) - return seed_index1,seed_index2 + # confidence bar + match_score[match_score < confbar] = -1 + mask_survive = match_score > 0 + if test: + topk = min(mask_survive.sum(dim=1)[0] + 2, topk) + _, topindex = torch.topk(match_score, topk, dim=-1) # b*k + seed_index1, seed_index2 = topindex, nn_index1.gather(index=topindex, dim=-1) + return seed_index1, seed_index2 class PointCN(nn.Module): - def __init__(self, channels,out_channels): + def __init__(self, channels, out_channels): nn.Module.__init__(self) self.shot_cut = nn.Conv1d(channels, out_channels, kernel_size=1) self.conv = nn.Sequential( @@ -63,7 +97,7 @@ class PointCN(nn.Module): nn.InstanceNorm1d(channels, eps=1e-3), nn.SyncBatchNorm(channels), nn.ReLU(), - nn.Conv1d(channels, out_channels, kernel_size=1) + nn.Conv1d(channels, out_channels, kernel_size=1), ) def forward(self, x): @@ -71,152 +105,254 @@ class PointCN(nn.Module): class attention_propagantion(nn.Module): - - def __init__(self,channel,head): + def __init__(self, channel, head): nn.Module.__init__(self) - self.head=head - self.head_dim=channel//head - self.query_filter,self.key_filter,self.value_filter=nn.Conv1d(channel,channel,kernel_size=1),nn.Conv1d(channel,channel,kernel_size=1),\ - nn.Conv1d(channel,channel,kernel_size=1) - self.mh_filter=nn.Conv1d(channel,channel,kernel_size=1) - self.cat_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(), - nn.Conv1d(2*channel, channel, kernel_size=1)) - - def forward(self,desc1,desc2,weight_v=None): - #desc1(q) attend to desc2(k,v) - batch_size=desc1.shape[0] - query,key,value=self.query_filter(desc1).view(batch_size,self.head,self.head_dim,-1),self.key_filter(desc2).view(batch_size,self.head,self.head_dim,-1),\ - self.value_filter(desc2).view(batch_size,self.head,self.head_dim,-1) + self.head = head + self.head_dim = channel // head + self.query_filter, self.key_filter, self.value_filter = ( + nn.Conv1d(channel, channel, kernel_size=1), + nn.Conv1d(channel, channel, kernel_size=1), + nn.Conv1d(channel, channel, kernel_size=1), + ) + self.mh_filter = nn.Conv1d(channel, channel, kernel_size=1) + self.cat_filter = nn.Sequential( + nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), + nn.SyncBatchNorm(2 * channel), + nn.ReLU(), + nn.Conv1d(2 * channel, channel, kernel_size=1), + ) + + def forward(self, desc1, desc2, weight_v=None): + # desc1(q) attend to desc2(k,v) + batch_size = desc1.shape[0] + query, key, value = ( + self.query_filter(desc1).view(batch_size, self.head, self.head_dim, -1), + self.key_filter(desc2).view(batch_size, self.head, self.head_dim, -1), + self.value_filter(desc2).view(batch_size, self.head, self.head_dim, -1), + ) if weight_v is not None: - value=value*weight_v.view(batch_size,1,1,-1) - score=torch.softmax(torch.einsum('bhdn,bhdm->bhnm',query,key)/ self.head_dim ** 0.5,dim=-1) - add_value=torch.einsum('bhnm,bhdm->bhdn',score,value).reshape(batch_size,self.head_dim*self.head,-1) - add_value=self.mh_filter(add_value) - desc1_new=desc1+self.cat_filter(torch.cat([desc1,add_value],dim=1)) + value = value * weight_v.view(batch_size, 1, 1, -1) + score = torch.softmax( + torch.einsum("bhdn,bhdm->bhnm", query, key) / self.head_dim**0.5, dim=-1 + ) + add_value = torch.einsum("bhnm,bhdm->bhdn", score, value).reshape( + batch_size, self.head_dim * self.head, -1 + ) + add_value = self.mh_filter(add_value) + desc1_new = desc1 + self.cat_filter(torch.cat([desc1, add_value], dim=1)) return desc1_new class hybrid_block(nn.Module): - def __init__(self,channel,head): + def __init__(self, channel, head): nn.Module.__init__(self) - self.head=head - self.channel=channel + self.head = head + self.channel = channel self.attention_block_down = attention_propagantion(channel, head) - self.cluster_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(), - nn.Conv1d(2*channel, 2*channel, kernel_size=1)) - self.cross_filter=attention_propagantion(channel,head) - self.confidence_filter=PointCN(2*channel,1) - self.attention_block_self=attention_propagantion(channel,head) - self.attention_block_up=attention_propagantion(channel,head) - - def forward(self,desc1,desc2,seed_index1,seed_index2): - cluster1, cluster2 = desc1.gather(dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1)), \ - desc2.gather(dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1)) - - #pooling - cluster1, cluster2 = self.attention_block_down(cluster1, desc1), self.attention_block_down(cluster2, desc2) - concate_cluster=self.cluster_filter(torch.cat([cluster1,cluster2],dim=1)) - #filtering - cluster1,cluster2=self.cross_filter(concate_cluster[:,:self.channel],concate_cluster[:,self.channel:]),\ - self.cross_filter(concate_cluster[:,self.channel:],concate_cluster[:,:self.channel]) - cluster1,cluster2=self.attention_block_self(cluster1,cluster1),self.attention_block_self(cluster2,cluster2) - #unpooling - seed_weight=self.confidence_filter(torch.cat([cluster1,cluster2],dim=1)) - seed_weight=torch.sigmoid(seed_weight).squeeze(1) - desc1_new,desc2_new=self.attention_block_up(desc1,cluster1,seed_weight),self.attention_block_up(desc2,cluster2,seed_weight) - return desc1_new,desc2_new,seed_weight + self.cluster_filter = nn.Sequential( + nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), + nn.SyncBatchNorm(2 * channel), + nn.ReLU(), + nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), + ) + self.cross_filter = attention_propagantion(channel, head) + self.confidence_filter = PointCN(2 * channel, 1) + self.attention_block_self = attention_propagantion(channel, head) + self.attention_block_up = attention_propagantion(channel, head) + def forward(self, desc1, desc2, seed_index1, seed_index2): + cluster1, cluster2 = desc1.gather( + dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1) + ), desc2.gather( + dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1) + ) + + # pooling + cluster1, cluster2 = self.attention_block_down( + cluster1, desc1 + ), self.attention_block_down(cluster2, desc2) + concate_cluster = self.cluster_filter(torch.cat([cluster1, cluster2], dim=1)) + # filtering + cluster1, cluster2 = self.cross_filter( + concate_cluster[:, : self.channel], concate_cluster[:, self.channel :] + ), self.cross_filter( + concate_cluster[:, self.channel :], concate_cluster[:, : self.channel] + ) + cluster1, cluster2 = self.attention_block_self( + cluster1, cluster1 + ), self.attention_block_self(cluster2, cluster2) + # unpooling + seed_weight = self.confidence_filter(torch.cat([cluster1, cluster2], dim=1)) + seed_weight = torch.sigmoid(seed_weight).squeeze(1) + desc1_new, desc2_new = self.attention_block_up( + desc1, cluster1, seed_weight + ), self.attention_block_up(desc2, cluster2, seed_weight) + return desc1_new, desc2_new, seed_weight class matcher(nn.Module): - def __init__(self,config): + def __init__(self, config): nn.Module.__init__(self) - self.seed_top_k=config.seed_top_k - self.conf_bar=config.conf_bar - self.seed_radius_coe=config.seed_radius_coe - self.use_score_encoding=config.use_score_encoding - self.detach_iter=config.detach_iter - self.seedlayer=config.seedlayer - self.layer_num=config.layer_num - self.sink_iter=config.sink_iter - - self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), - nn.SyncBatchNorm(32),nn.ReLU(), - nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(), - nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128),nn.ReLU(), - nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256),nn.ReLU(), - nn.Conv1d(256, config.net_channels, kernel_size=1)) - - - self.hybrid_block=nn.Sequential(*[hybrid_block(config.net_channels, config.head) for _ in range(config.layer_num)]) - self.final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) - self.dustbin=nn.Parameter(torch.tensor(1.5,dtype=torch.float32)) - - #if reseeding - if len(config.seedlayer)!=1: - self.mid_dustbin=nn.ParameterDict({str(i):nn.Parameter(torch.tensor(2,dtype=torch.float32)) for i in config.seedlayer[1:]}) - self.mid_final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) - - def forward(self,data,test_mode=True): - x1, x2, desc1, desc2 = data['x1'][:,:,:2], data['x2'][:,:,:2], data['desc1'], data['desc2'] - desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1) + self.seed_top_k = config.seed_top_k + self.conf_bar = config.conf_bar + self.seed_radius_coe = config.seed_radius_coe + self.use_score_encoding = config.use_score_encoding + self.detach_iter = config.detach_iter + self.seedlayer = config.seedlayer + self.layer_num = config.layer_num + self.sink_iter = config.sink_iter + + self.position_encoder = nn.Sequential( + nn.Conv1d(3, 32, kernel_size=1) + if config.use_score_encoding + else nn.Conv1d(2, 32, kernel_size=1), + nn.SyncBatchNorm(32), + nn.ReLU(), + nn.Conv1d(32, 64, kernel_size=1), + nn.SyncBatchNorm(64), + nn.ReLU(), + nn.Conv1d(64, 128, kernel_size=1), + nn.SyncBatchNorm(128), + nn.ReLU(), + nn.Conv1d(128, 256, kernel_size=1), + nn.SyncBatchNorm(256), + nn.ReLU(), + nn.Conv1d(256, config.net_channels, kernel_size=1), + ) + + self.hybrid_block = nn.Sequential( + *[ + hybrid_block(config.net_channels, config.head) + for _ in range(config.layer_num) + ] + ) + self.final_project = nn.Conv1d( + config.net_channels, config.net_channels, kernel_size=1 + ) + self.dustbin = nn.Parameter(torch.tensor(1.5, dtype=torch.float32)) + + # if reseeding + if len(config.seedlayer) != 1: + self.mid_dustbin = nn.ParameterDict( + { + str(i): nn.Parameter(torch.tensor(2, dtype=torch.float32)) + for i in config.seedlayer[1:] + } + ) + self.mid_final_project = nn.Conv1d( + config.net_channels, config.net_channels, kernel_size=1 + ) + + def forward(self, data, test_mode=True): + x1, x2, desc1, desc2 = ( + data["x1"][:, :, :2], + data["x2"][:, :, :2], + data["desc1"], + data["desc2"], + ) + desc1, desc2 = torch.nn.functional.normalize( + desc1, dim=-1 + ), torch.nn.functional.normalize(desc2, dim=-1) if test_mode: - encode_x1,encode_x2=data['x1'],data['x2'] + encode_x1, encode_x2 = data["x1"], data["x2"] else: - encode_x1,encode_x2=data['aug_x1'], data['aug_x2'] - - #preparation - desc_dismat=(2-2*torch.matmul(desc1,desc2.transpose(1,2))).sqrt_() - values,nn_index=torch.topk(desc_dismat,k=2,largest=False,dim=-1,sorted=True) - nn_index2=torch.min(desc_dismat,dim=1).indices.squeeze(1) - inverse_ratio_score,nn_index1=values[:,:,1]/values[:,:,0],nn_index[:,:,0]#get inverse score - - #initial seeding - seed_index1,seed_index2=seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[0],inverse_ratio_score,self.conf_bar[0],\ - self.seed_radius_coe,test=test_mode) - - #position encoding - desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2) + encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"] + + # preparation + desc_dismat = (2 - 2 * torch.matmul(desc1, desc2.transpose(1, 2))).sqrt_() + values, nn_index = torch.topk( + desc_dismat, k=2, largest=False, dim=-1, sorted=True + ) + nn_index2 = torch.min(desc_dismat, dim=1).indices.squeeze(1) + inverse_ratio_score, nn_index1 = ( + values[:, :, 1] / values[:, :, 0], + nn_index[:, :, 0], + ) # get inverse score + + # initial seeding + seed_index1, seed_index2 = seeding( + nn_index1, + nn_index2, + x1, + x2, + self.seed_top_k[0], + inverse_ratio_score, + self.conf_bar[0], + self.seed_radius_coe, + test=test_mode, + ) + + # position encoding + desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2) if not self.use_score_encoding: - encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2] - encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2) - x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2) + encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2] + encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2) + x1_pos_embedding, x2_pos_embedding = self.position_encoder( + encode_x1 + ), self.position_encoder(encode_x2) aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2 - - seed_weight_tower,mid_p_tower,seed_index_tower,nn_index_tower=[],[],[],[] - seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)) + + seed_weight_tower, mid_p_tower, seed_index_tower, nn_index_tower = ( + [], + [], + [], + [], + ) + seed_index_tower.append(torch.stack([seed_index1, seed_index2], dim=-1)) nn_index_tower.append(nn_index1) - seed_para_index=0 + seed_para_index = 0 for i in range(self.layer_num): - #mid seeding - if i in self.seedlayer and i!= 0: - seed_para_index+=1 - aug_desc1,aug_desc2=self.mid_final_project(aug_desc1),self.mid_final_project(aug_desc2) - M=torch.matmul(aug_desc1.transpose(1,2),aug_desc2) - p=sink_algorithm(M,self.mid_dustbin[str(i)],self.sink_iter[seed_para_index-1]) + # mid seeding + if i in self.seedlayer and i != 0: + seed_para_index += 1 + aug_desc1, aug_desc2 = self.mid_final_project( + aug_desc1 + ), self.mid_final_project(aug_desc2) + M = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) + p = sink_algorithm( + M, self.mid_dustbin[str(i)], self.sink_iter[seed_para_index - 1] + ) mid_p_tower.append(p) - #rematching with p - values,nn_index=torch.topk(p[:,:-1,:-1],k=1,dim=-1) - nn_index2=torch.max(p[:,:-1,:-1],dim=1).indices.squeeze(1) - p_match_score,nn_index1=values[:,:,0],nn_index[:,:,0] - #reseeding - seed_index1, seed_index2 = seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[seed_para_index],p_match_score,\ - self.conf_bar[seed_para_index],self.seed_radius_coe,test=test_mode) - seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)), nn_index_tower.append(nn_index1) - if not test_mode and data['step']bhnm',query1,key1)/self.head_dim**0.5,dim=-1),\ - torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query2,key2)/self.head_dim**0.5,dim=-1) - add_value1, add_value2 = torch.einsum('bhnm,bdhm->bdhn', score1, value1), torch.einsum('bhnm,bdhm->bdhn',score2, value2) + self.head = head + self.type = type + self.head_dim = channels // head + self.query_filter = nn.Conv1d(channels, channels, kernel_size=1) + self.key_filter = nn.Conv1d(channels, channels, kernel_size=1) + self.value_filter = nn.Conv1d(channels, channels, kernel_size=1) + self.attention_filter = nn.Sequential( + nn.Conv1d(2 * channels, 2 * channels, kernel_size=1), + nn.SyncBatchNorm(2 * channels), + nn.ReLU(), + nn.Conv1d(2 * channels, channels, kernel_size=1), + ) + self.mh_filter = nn.Conv1d(channels, channels, kernel_size=1) + + def forward(self, fea1, fea2): + batch_size, n, m = fea1.shape[0], fea1.shape[2], fea2.shape[2] + query1, key1, value1 = ( + self.query_filter(fea1).view(batch_size, self.head_dim, self.head, -1), + self.key_filter(fea1).view(batch_size, self.head_dim, self.head, -1), + self.value_filter(fea1).view(batch_size, self.head_dim, self.head, -1), + ) + query2, key2, value2 = ( + self.query_filter(fea2).view(batch_size, self.head_dim, self.head, -1), + self.key_filter(fea2).view(batch_size, self.head_dim, self.head, -1), + self.value_filter(fea2).view(batch_size, self.head_dim, self.head, -1), + ) + if self.type == "self": + score1, score2 = torch.softmax( + torch.einsum("bdhn,bdhm->bhnm", query1, key1) / self.head_dim**0.5, + dim=-1, + ), torch.softmax( + torch.einsum("bdhn,bdhm->bhnm", query2, key2) / self.head_dim**0.5, + dim=-1, + ) + add_value1, add_value2 = torch.einsum( + "bhnm,bdhm->bdhn", score1, value1 + ), torch.einsum("bhnm,bdhm->bdhn", score2, value2) else: - score1,score2 = torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query1, key2) / self.head_dim ** 0.5,dim=-1), \ - torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query2, key1) / self.head_dim ** 0.5, dim=-1) - add_value1, add_value2 =torch.einsum('bhnm,bdhm->bdhn',score1,value2),torch.einsum('bhnm,bdhm->bdhn',score2,value1) - add_value1,add_value2=self.mh_filter(add_value1.contiguous().view(batch_size,self.head*self.head_dim,n)),self.mh_filter(add_value2.contiguous().view(batch_size,self.head*self.head_dim,m)) - fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat([fea2, add_value2], dim=1) - fea1, fea2 = fea1+self.attention_filter(fea11), fea2+self.attention_filter(fea22) - - return fea1,fea2 + score1, score2 = torch.softmax( + torch.einsum("bdhn,bdhm->bhnm", query1, key2) / self.head_dim**0.5, + dim=-1, + ), torch.softmax( + torch.einsum("bdhn,bdhm->bhnm", query2, key1) / self.head_dim**0.5, + dim=-1, + ) + add_value1, add_value2 = torch.einsum( + "bhnm,bdhm->bdhn", score1, value2 + ), torch.einsum("bhnm,bdhm->bdhn", score2, value1) + add_value1, add_value2 = self.mh_filter( + add_value1.contiguous().view(batch_size, self.head * self.head_dim, n) + ), self.mh_filter( + add_value2.contiguous().view(batch_size, self.head * self.head_dim, m) + ) + fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat( + [fea2, add_value2], dim=1 + ) + fea1, fea2 = fea1 + self.attention_filter(fea11), fea2 + self.attention_filter( + fea22 + ) + + return fea1, fea2 class matcher(nn.Module): def __init__(self, config): nn.Module.__init__(self) - self.use_score_encoding=config.use_score_encoding - self.layer_num=config.layer_num - self.sink_iter=config.sink_iter - self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), - nn.SyncBatchNorm(32), nn.ReLU(), - nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(), - nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128), nn.ReLU(), - nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(), - nn.Conv1d(256, config.net_channels, kernel_size=1)) - - self.dustbin=nn.Parameter(torch.tensor(1,dtype=torch.float32,device='cuda')) - self.self_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'self') for _ in range(config.layer_num)]) - self.cross_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'cross') for _ in range(config.layer_num)]) - self.final_project=nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) - - def forward(self,data,test_mode=True): - desc1, desc2 = data['desc1'], data['desc2'] - desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1) - desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2) + self.use_score_encoding = config.use_score_encoding + self.layer_num = config.layer_num + self.sink_iter = config.sink_iter + self.position_encoder = nn.Sequential( + nn.Conv1d(3, 32, kernel_size=1) + if config.use_score_encoding + else nn.Conv1d(2, 32, kernel_size=1), + nn.SyncBatchNorm(32), + nn.ReLU(), + nn.Conv1d(32, 64, kernel_size=1), + nn.SyncBatchNorm(64), + nn.ReLU(), + nn.Conv1d(64, 128, kernel_size=1), + nn.SyncBatchNorm(128), + nn.ReLU(), + nn.Conv1d(128, 256, kernel_size=1), + nn.SyncBatchNorm(256), + nn.ReLU(), + nn.Conv1d(256, config.net_channels, kernel_size=1), + ) + + self.dustbin = nn.Parameter(torch.tensor(1, dtype=torch.float32, device="cuda")) + self.self_attention_block = nn.Sequential( + *[ + attention_block(config.net_channels, config.head, "self") + for _ in range(config.layer_num) + ] + ) + self.cross_attention_block = nn.Sequential( + *[ + attention_block(config.net_channels, config.head, "cross") + for _ in range(config.layer_num) + ] + ) + self.final_project = nn.Conv1d( + config.net_channels, config.net_channels, kernel_size=1 + ) + + def forward(self, data, test_mode=True): + desc1, desc2 = data["desc1"], data["desc2"] + desc1, desc2 = torch.nn.functional.normalize( + desc1, dim=-1 + ), torch.nn.functional.normalize(desc2, dim=-1) + desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2) if test_mode: - encode_x1,encode_x2=data['x1'],data['x2'] + encode_x1, encode_x2 = data["x1"], data["x2"] else: - encode_x1,encode_x2=data['aug_x1'], data['aug_x2'] + encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"] if not self.use_score_encoding: - encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2] + encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2] - encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2) + encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2) - x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2) - aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding+desc2 + x1_pos_embedding, x2_pos_embedding = self.position_encoder( + encode_x1 + ), self.position_encoder(encode_x2) + aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2 for i in range(self.layer_num): - aug_desc1,aug_desc2=self.self_attention_block[i](aug_desc1,aug_desc2) - aug_desc1,aug_desc2=self.cross_attention_block[i](aug_desc1,aug_desc2) + aug_desc1, aug_desc2 = self.self_attention_block[i](aug_desc1, aug_desc2) + aug_desc1, aug_desc2 = self.cross_attention_block[i](aug_desc1, aug_desc2) - aug_desc1,aug_desc2=self.final_project(aug_desc1),self.final_project(aug_desc2) + aug_desc1, aug_desc2 = self.final_project(aug_desc1), self.final_project( + aug_desc2 + ) desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) - p = sink_algorithm(desc_mat, self.dustbin,self.sink_iter[0]) - return {'p':p} - - + p = sink_algorithm(desc_mat, self.dustbin, self.sink_iter[0]) + return {"p": p} diff --git a/third_party/SGMNet/superpoint/__init__.py b/third_party/SGMNet/superpoint/__init__.py index 111c8882a7bc7512c6191ca86a0e71c3b1404233..f1127dfc54047e2d0d877da1d3eb5c2ed569b85e 100644 --- a/third_party/SGMNet/superpoint/__init__.py +++ b/third_party/SGMNet/superpoint/__init__.py @@ -1 +1 @@ -from .superpoint import SuperPoint \ No newline at end of file +from .superpoint import SuperPoint diff --git a/third_party/SGMNet/superpoint/superpoint.py b/third_party/SGMNet/superpoint/superpoint.py index d4e3ce481409264a3188270ad01aa62b1614377f..38b839cbc731460e487c9359c6e0edcaec7be7c9 100644 --- a/third_party/SGMNet/superpoint/superpoint.py +++ b/third_party/SGMNet/superpoint/superpoint.py @@ -3,11 +3,12 @@ from torch import nn def simple_nms(scores, nms_radius): - assert(nms_radius >= 0) + assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool2d( - x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) @@ -36,19 +37,21 @@ def top_k_keypoints(keypoints, scores, k): def sample_descriptors(keypoints, descriptors, s): b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 - keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], - ).to(keypoints)[None] - keypoints = keypoints*2 - 1 # normalize to (-1, 1) - args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if int(torch.__version__[2]) > 2 else {} descriptors = torch.nn.functional.grid_sample( - descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) descriptors = torch.nn.functional.normalize( - descriptors.reshape(b, c, -1), p=2, dim=1) + descriptors.reshape(b, c, -1), p=2, dim=1 + ) return descriptors class SuperPoint(nn.Module): - def __init__(self, config): super().__init__() self.config = {**config} @@ -71,16 +74,16 @@ class SuperPoint(nn.Module): self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, self.config['descriptor_dim'], - kernel_size=1, stride=1, padding=0) + c5, self.config["descriptor_dim"], kernel_size=1, stride=1, padding=0 + ) - self.load_state_dict(torch.load(config['model_path'])) + self.load_state_dict(torch.load(config["model_path"])) - mk = self.config['max_keypoints'] + mk = self.config["max_keypoints"] if mk == 0 or mk < -1: - raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + raise ValueError('"max_keypoints" must be positive or "-1"') - print('Loaded SuperPoint model') + print("Loaded SuperPoint model") def forward(self, data): # Shared Encoder @@ -101,25 +104,35 @@ class SuperPoint(nn.Module): scores = torch.nn.functional.softmax(scores, 1)[:, :-1] b, c, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) - scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) - scores = simple_nms(scores, self.config['nms_radius']) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.config["nms_radius"]) # Extract keypoints keypoints = [ - torch.nonzero(s > self.config['detection_threshold']) - for s in scores] + torch.nonzero(s > self.config["detection_threshold"]) for s in scores + ] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders - keypoints, scores = list(zip(*[ - remove_borders(k, s, self.config['remove_borders'], h*8, w*8) - for k, s in zip(keypoints, scores)])) + keypoints, scores = list( + zip( + *[ + remove_borders(k, s, self.config["remove_borders"], h * 8, w * 8) + for k, s in zip(keypoints, scores) + ] + ) + ) # Keep the k keypoints with highest score - if self.config['max_keypoints'] >= 0: - keypoints, scores = list(zip(*[ - top_k_keypoints(k, s, self.config['max_keypoints']) - for k, s in zip(keypoints, scores)])) + if self.config["max_keypoints"] >= 0: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, self.config["max_keypoints"]) + for k, s in zip(keypoints, scores) + ] + ) + ) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] @@ -130,11 +143,13 @@ class SuperPoint(nn.Module): descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) # Extract descriptors - descriptors = [sample_descriptors(k[None], d[None], 8)[0] - for k, d in zip(keypoints, descriptors)] + descriptors = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors) + ] return { - 'keypoints': keypoints, - 'scores': scores, - 'descriptors': descriptors, + "keypoints": keypoints, + "scores": scores, + "descriptors": descriptors, } diff --git a/third_party/SGMNet/train/config.py b/third_party/SGMNet/train/config.py index 31c4c1c6deef3d6dd568897f4202d96456586376..3610e40ff0628b1c5c4a2bc2a73d38a6d2cd65b1 100644 --- a/third_party/SGMNet/train/config.py +++ b/third_party/SGMNet/train/config.py @@ -1,5 +1,6 @@ import argparse + def str2bool(v): return v.lower() in ("true", "1") @@ -18,102 +19,111 @@ def add_argument_group(name): # Network net_arg = add_argument_group("Network") net_arg.add_argument( - "--model_name", type=str,default='SGM', help="" - "model for training") + "--model_name", type=str, default="SGM", help="" "model for training" +) net_arg.add_argument( - "--config_path", type=str,default='configs/sgm.yaml', help="" - "config path for model") + "--config_path", + type=str, + default="configs/sgm.yaml", + help="" "config path for model", +) # ----------------------------------------------------------------------------- # Data data_arg = add_argument_group("Data") data_arg.add_argument( - "--rawdata_path", type=str, default='rawdata', help="" - "path for rawdata") + "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata" +) data_arg.add_argument( - "--dataset_path", type=str, default='dataset', help="" - "path for dataset") + "--dataset_path", type=str, default="dataset", help="" "path for dataset" +) data_arg.add_argument( - "--desc_path", type=str, default='desc', help="" - "path for descriptor(kpt) dir") + "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir" +) data_arg.add_argument( - "--num_kpt", type=int, default=1000, help="" - "number of kpt for training") + "--num_kpt", type=int, default=1000, help="" "number of kpt for training" +) data_arg.add_argument( - "--input_normalize", type=str, default='img', help="" - "normalize type for input kpt, img or intrinsic") + "--input_normalize", + type=str, + default="img", + help="" "normalize type for input kpt, img or intrinsic", +) data_arg.add_argument( - "--data_aug", type=str2bool, default=True, help="" - "apply kpt coordinate homography augmentation") + "--data_aug", + type=str2bool, + default=True, + help="" "apply kpt coordinate homography augmentation", +) data_arg.add_argument( - "--desc_suffix", type=str, default='suffix', help="" - "desc file suffix") + "--desc_suffix", type=str, default="suffix", help="" "desc file suffix" +) # ----------------------------------------------------------------------------- # Loss loss_arg = add_argument_group("loss") +loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum") loss_arg.add_argument( - "--momentum", type=float, default=0.9, help="" - "momentum") -loss_arg.add_argument( - "--seed_loss_weight", type=float, default=250, help="" - "confidence loss weight for sgm") + "--seed_loss_weight", + type=float, + default=250, + help="" "confidence loss weight for sgm", +) loss_arg.add_argument( - "--mid_loss_weight", type=float, default=1, help="" - "midseeding loss weight for sgm") + "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm" +) loss_arg.add_argument( - "--inlier_th", type=float, default=5e-3, help="" - "inlier threshold for epipolar distance (for sgm and visualization)") + "--inlier_th", + type=float, + default=5e-3, + help="" "inlier threshold for epipolar distance (for sgm and visualization)", +) # ----------------------------------------------------------------------------- # Training train_arg = add_argument_group("Train") +train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate") +train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size") train_arg.add_argument( - "--train_lr", type=float, default=1e-4, help="" - "learning rate") -train_arg.add_argument( - "--train_batch_size", type=int, default=16, help="" - "batch size") -train_arg.add_argument( - "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES') -train_arg.add_argument( - "--train_iter", type=int, default=1000000, help="" - "training iterations to perform") -train_arg.add_argument( - "--log_base", type=str, default="./log/", help="" - "log path") + "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES" +) train_arg.add_argument( - "--val_intv", type=int, default=20000, help="" - "validation interval") + "--train_iter", type=int, default=1000000, help="" "training iterations to perform" +) +train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path") train_arg.add_argument( - "--save_intv", type=int, default=1000, help="" - "summary interval") + "--val_intv", type=int, default=20000, help="" "validation interval" +) train_arg.add_argument( - "--log_intv", type=int, default=100, help="" - "log interval") + "--save_intv", type=int, default=1000, help="" "summary interval" +) +train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval") train_arg.add_argument( - "--decay_rate", type=float, default=0.999996, help="" - "lr decay rate") + "--decay_rate", type=float, default=0.999996, help="" "lr decay rate" +) train_arg.add_argument( - "--decay_iter", type=float, default=300000, help="" - "lr decay iter") + "--decay_iter", type=float, default=300000, help="" "lr decay iter" +) train_arg.add_argument( - "--local_rank", type=int, default=0, help="" - "local rank for ddp") + "--local_rank", type=int, default=0, help="" "local rank for ddp" +) train_arg.add_argument( - "--train_vis_folder", type=str, default='.', help="" - "visualization folder during training") + "--train_vis_folder", + type=str, + default=".", + help="" "visualization folder during training", +) # ----------------------------------------------------------------------------- # Visualization -vis_arg = add_argument_group('Visualization') +vis_arg = add_argument_group("Visualization") vis_arg.add_argument( - "--tqdm_width", type=int, default=79, help="" - "width of the tqdm bar" + "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar" ) + def get_config(): config, unparsed = parser.parse_known_args() return config, unparsed @@ -122,5 +132,6 @@ def get_config(): def print_usage(): parser.print_usage() + # -# config.py ends here \ No newline at end of file +# config.py ends here diff --git a/third_party/SGMNet/train/dataset.py b/third_party/SGMNet/train/dataset.py index d07a84e9588b755a86119363f08860187d1668c0..37a97fd6204240e636d4b234f6c855f948c76b99 100644 --- a/third_party/SGMNet/train/dataset.py +++ b/third_party/SGMNet/train/dataset.py @@ -7,137 +7,278 @@ import h5py import random import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) sys.path.insert(0, ROOT_DIR) -from utils import train_utils,evaluation_utils +from utils import train_utils, evaluation_utils -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") class Offline_Dataset(data.Dataset): - def __init__(self,config,mode): - assert mode=='train' or mode=='valid' + def __init__(self, config, mode): + assert mode == "train" or mode == "valid" self.config = config self.mode = mode - metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train') - - pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str) - self.total_pairs=int(pair_num_list[0,1]) - self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list) + metadir = ( + os.path.join(config.dataset_path, "valid") + if mode == "valid" + else os.path.join(config.dataset_path, "train") + ) + pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str) + self.total_pairs = int(pair_num_list[0, 1]) + self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq( + pair_num_list + ) def collate_fn(self, batch): - batch_size, num_pts = len(batch), batch[0]['x1'].shape[0] - + batch_size, num_pts = len(batch), batch[0]["x1"].shape[0] + data = {} - dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2'] + dtype = [ + "x1", + "x2", + "kpt1", + "kpt2", + "desc1", + "desc2", + "num_corr", + "num_incorr1", + "num_incorr2", + "e_gt", + "pscore1", + "pscore2", + "img_path1", + "img_path2", + ] for key in dtype: - data[key]=[] + data[key] = [] for sample in batch: for key in dtype: data[key].append(sample[key]) - - for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']: + + for key in [ + "x1", + "x2", + "kpt1", + "kpt2", + "desc1", + "desc2", + "e_gt", + "pscore1", + "pscore2", + ]: data[key] = torch.from_numpy(np.stack(data[key])).float() - for key in ['num_corr', 'num_incorr1', 'num_incorr2']: + for key in ["num_corr", "num_incorr1", "num_incorr2"]: data[key] = torch.from_numpy(np.stack(data[key])).int() # kpt augmentation with random homography - if (self.mode == 'train' and self.config.data_aug): - homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1) - aug_seed=random.random() - if aug_seed<0.5: - x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) + if self.mode == "train" and self.config.data_aug: + homo_mat = torch.from_numpy( + train_utils.get_rnd_homography(batch_size) + ).unsqueeze(1) + aug_seed = random.random() + if aug_seed < 0.5: + x1_homo = torch.cat( + [data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1 + ).unsqueeze(-1) x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) - data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) - data['aug_x2']=data['x2'] + data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) + data["aug_x2"] = data["x2"] else: - x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) + x2_homo = torch.cat( + [data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1 + ).unsqueeze(-1) x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) - data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) - data['aug_x1']=data['x1'] + data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) + data["aug_x1"] = data["x1"] else: - data['aug_x1'],data['aug_x2']=data['x1'],data['x2'] + data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"] return data - def __getitem__(self, index): - seq=self.pair_seq_list[index] - index_within_seq=index-self.accu_pair_num[seq] + seq = self.pair_seq_list[index] + index_within_seq = index - self.accu_pair_num[seq] - with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data: - R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()] - egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3)) + with h5py.File( + os.path.join(self.config.dataset_path, seq, "info.h5py"), "r" + ) as data: + R, t = ( + data["dR"][str(index_within_seq)][()], + data["dt"][str(index_within_seq)][()], + ) + egt = np.reshape( + np.matmul( + np.reshape( + evaluation_utils.np_skew_symmetric( + t.astype("float64").reshape(1, 3) + ), + (3, 3), + ), + np.reshape(R.astype("float64"), (3, 3)), + ), + (3, 3), + ) egt = egt / np.linalg.norm(egt) - K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()] - size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()] - - img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode() - img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1] - img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2) - fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\ - os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix) - with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2: - desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2] - desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2] - kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt] + K1, K2 = ( + data["K1"][str(index_within_seq)][()], + data["K2"][str(index_within_seq)][()], + ) + size1, size2 = ( + data["size1"][str(index_within_seq)][()], + data["size2"][str(index_within_seq)][()], + ) + + img_path1, img_path2 = ( + data["img_path1"][str(index_within_seq)][()][0].decode(), + data["img_path2"][str(index_within_seq)][()][0].decode(), + ) + img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1] + img_path1, img_path2 = os.path.join( + self.config.rawdata_path, img_path1 + ), os.path.join(self.config.rawdata_path, img_path2) + fea_path1, fea_path2 = os.path.join( + self.config.desc_path, seq, img_name1 + self.config.desc_suffix + ), os.path.join( + self.config.desc_path, seq, img_name2 + self.config.desc_suffix + ) + with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2: + desc1, kpt1, pscore1 = ( + fea1["descriptors"][()], + fea1["keypoints"][()][:, :2], + fea1["keypoints"][()][:, 2], + ) + desc2, kpt2, pscore2 = ( + fea2["descriptors"][()], + fea2["keypoints"][()][:, :2], + fea2["keypoints"][()][:, 2], + ) + kpt1, kpt2, desc1, desc2 = ( + kpt1[: self.config.num_kpt], + kpt2[: self.config.num_kpt], + desc1[: self.config.num_kpt], + desc2[: self.config.num_kpt], + ) # normalize kpt - if self.config.input_normalize=='intrinsic': - x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate( - [kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) - x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2] - elif self.config.input_normalize=='img' : - x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2 - S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\ - np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]]) - M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv) - egt=np.matmul(np.matmul(M2.transpose(),egt),M1) + if self.config.input_normalize == "intrinsic": + x1, x2 = np.concatenate( + [kpt1, np.ones([kpt1.shape[0], 1])], axis=-1 + ), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) + x1, x2 = ( + np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], + np.matmul(np.linalg.inv(K2), x2.T).T[:, :2], + ) + elif self.config.input_normalize == "img": + x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2 + S1_inv, S2_inv = np.asarray( + [ + [size1[0], 0, 0.5 * size1[0]], + [0, size1[1], 0.5 * size1[1]], + [0, 0, 1], + ] + ), np.asarray( + [ + [size2[0], 0, 0.5 * size2[0]], + [0, size2[1], 0.5 * size2[1]], + [0, 0, 1], + ] + ) + M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul( + np.linalg.inv(K2), S2_inv + ) + egt = np.matmul(np.matmul(M2.transpose(), egt), M1) egt = egt / np.linalg.norm(egt) else: raise NotImplementedError - corr=data['corr'][str(index_within_seq)][()] - incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()] - - #permute kpt - valid_corr=corr[corr.max(axis=-1)= cur_kpt1): - sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False) - if (invalid_index2.shape[0] < cur_kpt2): - sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))]) - if (invalid_index2.shape[0] >= cur_kpt2): - sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False) - - per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\ - np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]]) - - pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis] - x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2] - desc1,desc2=desc1[per_idx1],desc2[per_idx2] - kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2] - - return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\ - 'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2} + if invalid_index1.shape[0] < cur_kpt1: + sub_idx1 = np.concatenate( + [ + np.arange(len(invalid_index1)), + np.random.randint( + len(invalid_index1), size=cur_kpt1 - len(invalid_index1) + ), + ] + ) + if invalid_index1.shape[0] >= cur_kpt1: + sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False) + if invalid_index2.shape[0] < cur_kpt2: + sub_idx2 = np.concatenate( + [ + np.arange(len(invalid_index2)), + np.random.randint( + len(invalid_index2), size=cur_kpt2 - len(invalid_index2) + ), + ] + ) + if invalid_index2.shape[0] >= cur_kpt2: + sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False) - def __len__(self): - return self.total_pairs + per_idx1, per_idx2 = np.concatenate( + [valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]] + ), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]]) + + pscore1, pscore2 = ( + pscore1[per_idx1][:, np.newaxis], + pscore2[per_idx2][:, np.newaxis], + ) + x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2] + desc1, desc2 = desc1[per_idx1], desc2[per_idx2] + kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2] + return { + "x1": x1, + "x2": x2, + "kpt1": kpt1, + "kpt2": kpt2, + "desc1": desc1, + "desc2": desc2, + "num_corr": num_corr, + "num_incorr1": num_incorr1, + "num_incorr2": num_incorr2, + "e_gt": egt, + "pscore1": pscore1, + "pscore2": pscore2, + "img_path1": img_path1, + "img_path2": img_path2, + } + def __len__(self): + return self.total_pairs diff --git a/third_party/SGMNet/train/loss.py b/third_party/SGMNet/train/loss.py index fad4234fc5827321c31e72c08ad4a3466bad1c30..227f7c5d237be292e552a25ea899940ec54fc923 100644 --- a/third_party/SGMNet/train/loss.py +++ b/third_party/SGMNet/train/loss.py @@ -4,122 +4,195 @@ import numpy as np def batch_episym(x1, x2, F): batch_size, num_pts = x1.shape[0], x1.shape[1] - x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) - x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) - F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1) - x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts) - Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3) - Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3) - ys = (x2Fx1**2 * ( - 1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) + - 1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15))).sqrt() + x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts, 1)], dim=-1).reshape( + batch_size, num_pts, 3, 1 + ) + x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts, 1)], dim=-1).reshape( + batch_size, num_pts, 3, 1 + ) + F = F.reshape(-1, 1, 3, 3).repeat(1, num_pts, 1, 1) + x2Fx1 = torch.matmul(x2.transpose(2, 3), torch.matmul(F, x1)).reshape( + batch_size, num_pts + ) + Fx1 = torch.matmul(F, x1).reshape(batch_size, num_pts, 3) + Ftx2 = torch.matmul(F.transpose(2, 3), x2).reshape(batch_size, num_pts, 3) + ys = ( + x2Fx1**2 + * ( + 1.0 / (Fx1[:, :, 0] ** 2 + Fx1[:, :, 1] ** 2 + 1e-15) + + 1.0 / (Ftx2[:, :, 0] ** 2 + Ftx2[:, :, 1] ** 2 + 1e-15) + ) + ).sqrt() return ys - - -def CELoss(seed_x1,seed_x2,e,confidence,inlier_th,batch_mask=1): - #seed_x: b*k*2 - ys=batch_episym(seed_x1,seed_x2,e) - mask_pos,mask_neg=(ys<=inlier_th).float(),(ys>inlier_th).float() - num_pos,num_neg=torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0 - loss_pos,loss_neg=-torch.log(abs(confidence) + 1e-8)*mask_pos,-torch.log(abs(1-confidence)+1e-8)*mask_neg - classif_loss = torch.mean(loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),dim=-1) - classif_loss =classif_loss*batch_mask - classif_loss=classif_loss.mean() + + +def CELoss(seed_x1, seed_x2, e, confidence, inlier_th, batch_mask=1): + # seed_x: b*k*2 + ys = batch_episym(seed_x1, seed_x2, e) + mask_pos, mask_neg = (ys <= inlier_th).float(), (ys > inlier_th).float() + num_pos, num_neg = ( + torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0, + torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0, + ) + loss_pos, loss_neg = ( + -torch.log(abs(confidence) + 1e-8) * mask_pos, + -torch.log(abs(1 - confidence) + 1e-8) * mask_neg, + ) + classif_loss = torch.mean( + loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1), + dim=-1, + ) + classif_loss = classif_loss * batch_mask + classif_loss = classif_loss.mean() precision = torch.mean( - torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) / - (torch.sum((confidence > 0.5).type(confidence.type()), dim=1)+1e-8) + torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) + / (torch.sum((confidence > 0.5).type(confidence.type()), dim=1) + 1e-8) ) recall = torch.mean( - torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) / - num_pos + torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) + / num_pos ) - return classif_loss,precision,recall + return classif_loss, precision, recall -def CorrLoss(desc_mat,batch_num_corr,batch_num_incorr1,batch_num_incorr2): - total_loss_corr,total_loss_incorr=0,0 - total_acc_corr,total_acc_incorr=0,0 +def CorrLoss(desc_mat, batch_num_corr, batch_num_incorr1, batch_num_incorr2): + total_loss_corr, total_loss_incorr = 0, 0 + total_acc_corr, total_acc_incorr = 0, 0 batch_size = desc_mat.shape[0] - log_p=torch.log(abs(desc_mat)+1e-8) + log_p = torch.log(abs(desc_mat) + 1e-8) for i in range(batch_size): - cur_log_p=log_p[i] - num_corr=batch_num_corr[i] - num_incorr1,num_incorr2=batch_num_incorr1[i],batch_num_incorr2[i] - - #loss and acc + cur_log_p = log_p[i] + num_corr = batch_num_corr[i] + num_incorr1, num_incorr2 = batch_num_incorr1[i], batch_num_incorr2[i] + + # loss and acc loss_corr = -torch.diag(cur_log_p)[:num_corr].mean() - loss_incorr=(-cur_log_p[num_corr:num_corr+num_incorr1,-1].mean()-cur_log_p[-1,num_corr:num_corr+num_incorr2].mean())/2 + loss_incorr = ( + -cur_log_p[num_corr : num_corr + num_incorr1, -1].mean() + - cur_log_p[-1, num_corr : num_corr + num_incorr2].mean() + ) / 2 - value_row, row_index = torch.max(desc_mat[i,:-1,:-1], dim=-1) - value_col, col_index = torch.max(desc_mat[i,:-1,:-1], dim=-2) - acc_incorr=((value_row[num_corr:num_corr+num_incorr1]<0.2).float().mean()+ - (value_col[num_corr:num_corr+num_incorr2]<0.2).float().mean())/2 + value_row, row_index = torch.max(desc_mat[i, :-1, :-1], dim=-1) + value_col, col_index = torch.max(desc_mat[i, :-1, :-1], dim=-2) + acc_incorr = ( + (value_row[num_corr : num_corr + num_incorr1] < 0.2).float().mean() + + (value_col[num_corr : num_corr + num_incorr2] < 0.2).float().mean() + ) / 2 acc_row_mask = row_index[:num_corr] == torch.arange(num_corr).cuda() acc_col_mask = col_index[:num_corr] == torch.arange(num_corr).cuda() acc = (acc_col_mask & acc_row_mask).float().mean() - - total_loss_corr+=loss_corr - total_loss_incorr+=loss_incorr + + total_loss_corr += loss_corr + total_loss_incorr += loss_incorr total_acc_corr += acc - total_acc_incorr+=acc_incorr + total_acc_incorr += acc_incorr - total_acc_corr/=batch_size - total_acc_incorr/=batch_size - total_loss_corr/=batch_size - total_loss_incorr/=batch_size - return total_loss_corr,total_loss_incorr,total_acc_corr,total_acc_incorr + total_acc_corr /= batch_size + total_acc_incorr /= batch_size + total_loss_corr /= batch_size + total_loss_incorr /= batch_size + return total_loss_corr, total_loss_incorr, total_acc_corr, total_acc_incorr class SGMLoss: - def __init__(self,config,model_config): - self.config=config - self.model_config=model_config - - def run(self,data,result): - loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2']) - loss_mid_corr_tower,loss_mid_incorr_tower,acc_mid_tower=[],[],[] - - #mid loss - for i in range(len(result['mid_p'])): - mid_p=result['mid_p'][i] - loss_mid_corr,loss_mid_incorr,mid_acc_corr,mid_acc_incorr=CorrLoss(mid_p,data['num_corr'],data['num_incorr1'],data['num_incorr2']) - loss_mid_corr_tower.append(loss_mid_corr),loss_mid_incorr_tower.append(loss_mid_incorr),acc_mid_tower.append(mid_acc_corr) - if len(result['mid_p']) != 0: - loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower = torch.stack(loss_mid_corr_tower), torch.stack(loss_mid_incorr_tower), torch.stack(acc_mid_tower) + def __init__(self, config, model_config): + self.config = config + self.model_config = model_config + + def run(self, data, result): + loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss( + result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"] + ) + loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = [], [], [] + + # mid loss + for i in range(len(result["mid_p"])): + mid_p = result["mid_p"][i] + loss_mid_corr, loss_mid_incorr, mid_acc_corr, mid_acc_incorr = CorrLoss( + mid_p, data["num_corr"], data["num_incorr1"], data["num_incorr2"] + ) + loss_mid_corr_tower.append(loss_mid_corr), loss_mid_incorr_tower.append( + loss_mid_incorr + ), acc_mid_tower.append(mid_acc_corr) + if len(result["mid_p"]) != 0: + loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = ( + torch.stack(loss_mid_corr_tower), + torch.stack(loss_mid_incorr_tower), + torch.stack(acc_mid_tower), + ) else: - loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower= torch.zeros(1).cuda(), torch.zeros(1).cuda(),torch.zeros(1).cuda() - - #seed confidence loss - classif_loss_tower,classif_precision_tower,classif_recall_tower=[],[],[] - for layer in range(len(result['seed_conf'])): - confidence=result['seed_conf'][layer] - seed_index=result['seed_index'][(np.asarray(self.model_config.seedlayer)<=layer).nonzero()[0][-1]] - seed_x1,seed_x2=data['x1'].gather(dim=1, index=seed_index[:,:,0,None].expand(-1, -1,2)),\ - data['x2'].gather(dim=1, index=seed_index[:,:,1,None].expand(-1, -1,2)) - classif_loss,classif_precision,classif_recall=CELoss(seed_x1,seed_x2,data['e_gt'],confidence,self.config.inlier_th) - classif_loss_tower.append(classif_loss), classif_precision_tower.append(classif_precision), classif_recall_tower.append(classif_recall) - classif_loss, classif_precision_tower, classif_recall_tower=torch.stack(classif_loss_tower).mean(),torch.stack(classif_precision_tower), \ - torch.stack(classif_recall_tower) - - - classif_loss*=self.config.seed_loss_weight - loss_mid_corr_tower*=self.config.mid_loss_weight - loss_mid_incorr_tower*=self.config.mid_loss_weight - total_loss=loss_corr+loss_incorr+classif_loss+loss_mid_corr_tower.sum()+loss_mid_incorr_tower.sum() - - return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'loss_seed_conf':classif_loss, - 'pre_seed_conf':classif_precision_tower,'recall_seed_conf':classif_recall_tower,'loss_corr_mid':loss_mid_corr_tower, - 'loss_incorr_mid':loss_mid_incorr_tower,'mid_acc_corr':acc_mid_tower,'total_loss':total_loss} - + loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = ( + torch.zeros(1).cuda(), + torch.zeros(1).cuda(), + torch.zeros(1).cuda(), + ) + + # seed confidence loss + classif_loss_tower, classif_precision_tower, classif_recall_tower = [], [], [] + for layer in range(len(result["seed_conf"])): + confidence = result["seed_conf"][layer] + seed_index = result["seed_index"][ + (np.asarray(self.model_config.seedlayer) <= layer).nonzero()[0][-1] + ] + seed_x1, seed_x2 = data["x1"].gather( + dim=1, index=seed_index[:, :, 0, None].expand(-1, -1, 2) + ), data["x2"].gather( + dim=1, index=seed_index[:, :, 1, None].expand(-1, -1, 2) + ) + classif_loss, classif_precision, classif_recall = CELoss( + seed_x1, seed_x2, data["e_gt"], confidence, self.config.inlier_th + ) + classif_loss_tower.append(classif_loss), classif_precision_tower.append( + classif_precision + ), classif_recall_tower.append(classif_recall) + classif_loss, classif_precision_tower, classif_recall_tower = ( + torch.stack(classif_loss_tower).mean(), + torch.stack(classif_precision_tower), + torch.stack(classif_recall_tower), + ) + + classif_loss *= self.config.seed_loss_weight + loss_mid_corr_tower *= self.config.mid_loss_weight + loss_mid_incorr_tower *= self.config.mid_loss_weight + total_loss = ( + loss_corr + + loss_incorr + + classif_loss + + loss_mid_corr_tower.sum() + + loss_mid_incorr_tower.sum() + ) + + return { + "loss_corr": loss_corr, + "loss_incorr": loss_incorr, + "acc_corr": acc_corr, + "acc_incorr": acc_incorr, + "loss_seed_conf": classif_loss, + "pre_seed_conf": classif_precision_tower, + "recall_seed_conf": classif_recall_tower, + "loss_corr_mid": loss_mid_corr_tower, + "loss_incorr_mid": loss_mid_incorr_tower, + "mid_acc_corr": acc_mid_tower, + "total_loss": total_loss, + } + + class SGLoss: - def __init__(self,config,model_config): - self.config=config - self.model_config=model_config - - def run(self,data,result): - loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2']) - total_loss=loss_corr+loss_incorr - return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'total_loss':total_loss} - \ No newline at end of file + def __init__(self, config, model_config): + self.config = config + self.model_config = model_config + + def run(self, data, result): + loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss( + result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"] + ) + total_loss = loss_corr + loss_incorr + return { + "loss_corr": loss_corr, + "loss_incorr": loss_incorr, + "acc_corr": acc_corr, + "acc_incorr": acc_incorr, + "total_loss": total_loss, + } diff --git a/third_party/SGMNet/train/main.py b/third_party/SGMNet/train/main.py index 9d4c8fff432a3b2d58c82b9e5f2897a4e702b2dd..00e1bf699a92057c445d4b5f83eb46794d6fb7f7 100644 --- a/third_party/SGMNet/train/main.py +++ b/third_party/SGMNet/train/main.py @@ -11,51 +11,72 @@ from train import train from config import get_config, print_usage -def main(config,model_config): +def main(config, model_config): """The main function.""" # Initialize network - if config.model_name=='SGM': + if config.model_name == "SGM": model = SGM_Model(model_config) - elif config.model_name=='SG': - model= SG_Model(model_config) + elif config.model_name == "SG": + model = SG_Model(model_config) else: raise NotImplementedError - #initialize ddp + # initialize ddp torch.cuda.set_device(config.local_rank) - device = torch.device(f'cuda:{config.local_rank}') + device = torch.device(f"cuda:{config.local_rank}") model.to(device) - dist.init_process_group(backend='nccl',init_method='env://') - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank]) - - if config.local_rank==0: - os.system('nvidia-smi') - - #initialize dataset - train_dataset = Offline_Dataset(config,'train') - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) - train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(), - num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn) - - valid_dataset = Offline_Dataset(config,'valid') - valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False) - valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size, - num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler) - - if config.local_rank==0: - print('start training .....') - train(model,train_loader, valid_loader, config,model_config) + dist.init_process_group(backend="nccl", init_method="env://") + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[config.local_rank] + ) + + if config.local_rank == 0: + os.system("nvidia-smi") + + # initialize dataset + train_dataset = Offline_Dataset(config, "train") + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, shuffle=True + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config.train_batch_size // torch.distributed.get_world_size(), + num_workers=8 // dist.get_world_size(), + pin_memory=False, + sampler=train_sampler, + collate_fn=train_dataset.collate_fn, + ) + + valid_dataset = Offline_Dataset(config, "valid") + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_dataset, shuffle=False + ) + valid_loader = torch.utils.data.DataLoader( + valid_dataset, + batch_size=config.train_batch_size, + num_workers=8 // dist.get_world_size(), + pin_memory=False, + collate_fn=valid_dataset.collate_fn, + sampler=valid_sampler, + ) + + if config.local_rank == 0: + print("start training .....") + train(model, train_loader, valid_loader, config, model_config) + if __name__ == "__main__": # ---------------------------------------- # Parse configuration config, unparsed = get_config() - with open(config.config_path, 'r') as f: + with open(config.config_path, "r") as f: model_config = yaml.load(f) - model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) + model_config = namedtuple("model_config", model_config.keys())( + *model_config.values() + ) # If we have unparsed arguments, print usage and exit if len(unparsed) > 0: print_usage() exit(1) - main(config,model_config) + main(config, model_config) diff --git a/third_party/SGMNet/train/train.py b/third_party/SGMNet/train/train.py index 31e848e1d2e5f028d4ff3abaf0cc446be7d89c65..b012b7bf231de77972f443ab6979038151d2cfce 100644 --- a/third_party/SGMNet/train/train.py +++ b/third_party/SGMNet/train/train.py @@ -5,156 +5,226 @@ import os from tensorboardX import SummaryWriter import numpy as np import cv2 -from loss import SGMLoss,SGLoss -from valid import valid,dump_train_vis +from loss import SGMLoss, SGLoss +from valid import valid, dump_train_vis import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from utils import train_utils -def train_step(optimizer, model, match_loss, data,step,pre_avg_loss): - data['step']=step - result=model(data,test_mode=False) - loss_res=match_loss.run(data,result) - + +def train_step(optimizer, model, match_loss, data, step, pre_avg_loss): + data["step"] = step + result = model(data, test_mode=False) + loss_res = match_loss.run(data, result) + optimizer.zero_grad() - loss_res['total_loss'].backward() - #apply reduce on all record tensor + loss_res["total_loss"].backward() + # apply reduce on all record tensor for key in loss_res.keys(): - loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean') - - if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0: + loss_res[key] = train_utils.reduce_tensor(loss_res[key], "mean") + + if loss_res["total_loss"] < 7 * pre_avg_loss or step < 200 or pre_avg_loss == 0: optimizer.step() - unusual_loss=False + unusual_loss = False else: optimizer.zero_grad() - unusual_loss=True - return loss_res,unusual_loss + unusual_loss = True + return loss_res, unusual_loss -def train(model, train_loader, valid_loader, config,model_config): +def train(model, train_loader, valid_loader, config, model_config): model.train() optimizer = optim.Adam(model.parameters(), lr=config.train_lr) - - if config.model_name=='SGM': - match_loss = SGMLoss(config,model_config) - elif config.model_name=='SG': - match_loss= SGLoss(config,model_config) + + if config.model_name == "SGM": + match_loss = SGMLoss(config, model_config) + elif config.model_name == "SG": + match_loss = SGLoss(config, model_config) else: raise NotImplementedError - - checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth') + + checkpoint_path = os.path.join(config.log_base, "checkpoint.pth") config.resume = os.path.isfile(checkpoint_path) if config.resume: - if config.local_rank==0: - print('==> Resuming from checkpoint..') - checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank)) - model.load_state_dict(checkpoint['state_dict']) - best_acc = checkpoint['best_acc'] - start_step = checkpoint['step'] - optimizer.load_state_dict(checkpoint['optimizer']) + if config.local_rank == 0: + print("==> Resuming from checkpoint..") + checkpoint = torch.load( + checkpoint_path, map_location="cuda:{}".format(config.local_rank) + ) + model.load_state_dict(checkpoint["state_dict"]) + best_acc = checkpoint["best_acc"] + start_step = checkpoint["step"] + optimizer.load_state_dict(checkpoint["optimizer"]) else: best_acc = -1 start_step = 0 train_loader_iter = iter(train_loader) - - if config.local_rank==0: - writer=SummaryWriter(os.path.join(config.log_base,'log_file')) - - train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset)) - pre_avg_loss=0 - - progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter) + + if config.local_rank == 0: + writer = SummaryWriter(os.path.join(config.log_base, "log_file")) + + train_loader.sampler.set_epoch( + start_step * config.train_batch_size // len(train_loader.dataset) + ) + pre_avg_loss = 0 + + progress_bar = ( + trange(start_step, config.train_iter, ncols=config.tqdm_width) + if config.local_rank == 0 + else range(start_step, config.train_iter) + ) for step in progress_bar: try: train_data = next(train_loader_iter) except StopIteration: - if config.local_rank==0: - print('epoch: ',step*config.train_batch_size//len(train_loader.dataset)) - train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset)) + if config.local_rank == 0: + print( + "epoch: ", + step * config.train_batch_size // len(train_loader.dataset), + ) + train_loader.sampler.set_epoch( + step * config.train_batch_size // len(train_loader.dataset) + ) train_loader_iter = iter(train_loader) train_data = next(train_loader_iter) - + train_data = train_utils.tocuda(train_data) - lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr) + lr = min( + config.train_lr * config.decay_rate ** (step - config.decay_iter), + config.train_lr, + ) for param_group in optimizer.param_groups: - param_group['lr'] = lr + param_group["lr"] = lr # run training - loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss) - if (step-start_step)<=200: - pre_avg_loss=loss_res['total_loss'].data - if (step-start_step)>200 and not unusual_loss: - pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1 - if unusual_loss and config.local_rank==0: - print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data) - #log - if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss: - writer.add_scalar('TotalLoss',loss_res['total_loss'],step) - writer.add_scalar('CorrLoss',loss_res['loss_corr'],step) - writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step) - writer.add_scalar('dustbin', model.module.dustbin, step) - - if config.model_name=='SGM': - writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step) - writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step) - writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step) - + loss_res, unusual_loss = train_step( + optimizer, model, match_loss, train_data, step - start_step, pre_avg_loss + ) + if (step - start_step) <= 200: + pre_avg_loss = loss_res["total_loss"].data + if (step - start_step) > 200 and not unusual_loss: + pre_avg_loss = pre_avg_loss.data * 0.9 + loss_res["total_loss"].data * 0.1 + if unusual_loss and config.local_rank == 0: + print( + "unusual loss! pre_avg_loss: ", + pre_avg_loss, + "cur_loss: ", + loss_res["total_loss"].data, + ) + # log + if config.local_rank == 0 and step % config.log_intv == 0 and not unusual_loss: + writer.add_scalar("TotalLoss", loss_res["total_loss"], step) + writer.add_scalar("CorrLoss", loss_res["loss_corr"], step) + writer.add_scalar("InCorrLoss", loss_res["loss_incorr"], step) + writer.add_scalar("dustbin", model.module.dustbin, step) + + if config.model_name == "SGM": + writer.add_scalar("SeedConfLoss", loss_res["loss_seed_conf"], step) + writer.add_scalar("MidCorrLoss", loss_res["loss_corr_mid"].sum(), step) + writer.add_scalar( + "MidInCorrLoss", loss_res["loss_incorr_mid"].sum(), step + ) # valid ans save b_save = ((step + 1) % config.save_intv) == 0 b_validate = ((step + 1) % config.val_intv) == 0 if b_validate: - total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config) - if config.local_rank==0: - writer.add_scalar('ValidAcc', acc_corr, step) - writer.add_scalar('ValidLoss', total_loss, step) - - if config.model_name=='SGM': + ( + total_loss, + acc_corr, + acc_incorr, + seed_precision_tower, + seed_recall_tower, + acc_mid, + ) = valid(valid_loader, model, match_loss, config, model_config) + if config.local_rank == 0: + writer.add_scalar("ValidAcc", acc_corr, step) + writer.add_scalar("ValidLoss", total_loss, step) + + if config.model_name == "SGM": for i in range(len(seed_recall_tower)): - writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step) - writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step) + writer.add_scalar( + "seed_conf_pre_%d" % i, seed_precision_tower[i], step + ) + writer.add_scalar( + "seed_conf_recall_%d" % i, seed_precision_tower[i], step + ) for i in range(len(acc_mid)): - writer.add_scalar('acc_mid%d'%i,acc_mid[i],step) - print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data, - 'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data) + writer.add_scalar("acc_mid%d" % i, acc_mid[i], step) + print( + "acc_corr: ", + acc_corr.data, + "acc_incorr: ", + acc_incorr.data, + "seed_conf_pre: ", + seed_precision_tower.mean().data, + "seed_conf_recall: ", + seed_recall_tower.mean().data, + "acc_mid: ", + acc_mid.mean().data, + ) else: - print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data) - - #saving best + print("acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data) + + # saving best if acc_corr > best_acc: print("Saving best model with va_res = {}".format(acc_corr)) best_acc = acc_corr - save_dict={'step': step + 1, - 'state_dict': model.state_dict(), - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict()} + save_dict = { + "step": step + 1, + "state_dict": model.state_dict(), + "best_acc": best_acc, + "optimizer": optimizer.state_dict(), + } save_dict.update(save_dict) - torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth')) + torch.save( + save_dict, os.path.join(config.log_base, "model_best.pth") + ) if b_save: - if config.local_rank==0: - save_dict={'step': step + 1, - 'state_dict': model.state_dict(), - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict()} + if config.local_rank == 0: + save_dict = { + "step": step + 1, + "state_dict": model.state_dict(), + "best_acc": best_acc, + "optimizer": optimizer.state_dict(), + } torch.save(save_dict, checkpoint_path) - - #draw match results + + # draw match results model.eval() with torch.no_grad(): - if config.local_rank==0: - if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')): - os.mkdir(os.path.join(config.train_vis_folder,'train_vis')) - if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)): - os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base)) - os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step))) - res=model(train_data) - dump_train_vis(res,train_data,step,config) + if config.local_rank == 0: + if not os.path.exists( + os.path.join(config.train_vis_folder, "train_vis") + ): + os.mkdir(os.path.join(config.train_vis_folder, "train_vis")) + if not os.path.exists( + os.path.join( + config.train_vis_folder, "train_vis", config.log_base + ) + ): + os.mkdir( + os.path.join( + config.train_vis_folder, "train_vis", config.log_base + ) + ) + os.mkdir( + os.path.join( + config.train_vis_folder, + "train_vis", + config.log_base, + str(step), + ) + ) + res = model(train_data) + dump_train_vis(res, train_data, step, config) model.train() - - if config.local_rank==0: + + if config.local_rank == 0: writer.close() diff --git a/third_party/SGMNet/train/valid.py b/third_party/SGMNet/train/valid.py index 443694d85104730cd50aeb342326ce593dc5684d..b9873f9b34ff77462d87aaad8c128e3b497fa39a 100644 --- a/third_party/SGMNet/train/valid.py +++ b/third_party/SGMNet/train/valid.py @@ -6,72 +6,119 @@ from loss import batch_episym from tqdm import tqdm import sys + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) -from utils import evaluation_utils,train_utils +from utils import evaluation_utils, train_utils -def valid(valid_loader, model,match_loss, config,model_config): +def valid(valid_loader, model, match_loss, config, model_config): model.eval() loader_iter = iter(valid_loader) num_pair = 0 - total_loss,total_acc_corr,total_acc_incorr=0,0,0 - total_precision,total_recall=torch.zeros(model_config.layer_num ,device='cuda'),\ - torch.zeros(model_config.layer_num ,device='cuda') - total_acc_mid=torch.zeros(len(model_config.seedlayer)-1,device='cuda') + total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0 + total_precision, total_recall = torch.zeros( + model_config.layer_num, device="cuda" + ), torch.zeros(model_config.layer_num, device="cuda") + total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda") with torch.no_grad(): - if config.local_rank==0: - loader_iter=tqdm(loader_iter) - print('validating...') + if config.local_rank == 0: + loader_iter = tqdm(loader_iter) + print("validating...") for test_data in loader_iter: - num_pair+= 1 + num_pair += 1 test_data = train_utils.tocuda(test_data) - res= model(test_data) - loss_res=match_loss.run(test_data,res) - - total_acc_corr+=loss_res['acc_corr'] - total_acc_incorr+=loss_res['acc_incorr'] - total_loss+=loss_res['total_loss'] + res = model(test_data) + loss_res = match_loss.run(test_data, res) + + total_acc_corr += loss_res["acc_corr"] + total_acc_incorr += loss_res["acc_incorr"] + total_loss += loss_res["total_loss"] - if config.model_name=='SGM': - total_acc_mid+=loss_res['mid_acc_corr'] - total_precision,total_recall=total_precision+loss_res['pre_seed_conf'],total_recall+loss_res['recall_seed_conf'] - - total_acc_corr/=num_pair + if config.model_name == "SGM": + total_acc_mid += loss_res["mid_acc_corr"] + total_precision, total_recall = ( + total_precision + loss_res["pre_seed_conf"], + total_recall + loss_res["recall_seed_conf"], + ) + + total_acc_corr /= num_pair total_acc_incorr /= num_pair - total_precision/=num_pair - total_recall/=num_pair - total_acc_mid/=num_pair + total_precision /= num_pair + total_recall /= num_pair + total_acc_mid /= num_pair - #apply tensor reduction - total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid=train_utils.reduce_tensor(total_loss,'sum'),\ - train_utils.reduce_tensor(total_acc_corr,'mean'),train_utils.reduce_tensor(total_acc_incorr,'mean'),\ - train_utils.reduce_tensor(total_precision,'mean'),train_utils.reduce_tensor(total_recall,'mean'),train_utils.reduce_tensor(total_acc_mid,'mean') + # apply tensor reduction + ( + total_loss, + total_acc_corr, + total_acc_incorr, + total_precision, + total_recall, + total_acc_mid, + ) = ( + train_utils.reduce_tensor(total_loss, "sum"), + train_utils.reduce_tensor(total_acc_corr, "mean"), + train_utils.reduce_tensor(total_acc_incorr, "mean"), + train_utils.reduce_tensor(total_precision, "mean"), + train_utils.reduce_tensor(total_recall, "mean"), + train_utils.reduce_tensor(total_acc_mid, "mean"), + ) model.train() - return total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid - + return ( + total_loss, + total_acc_corr, + total_acc_incorr, + total_precision, + total_recall, + total_acc_mid, + ) -def dump_train_vis(res,data,step,config): - #batch matching - p=res['p'][:,:-1,:-1] - score,index1=torch.max(p,dim=-1) - _,index2=torch.max(p,dim=-2) - mask_th=score>0.2 - mask_mc=index2.gather(index=index1,dim=1) == torch.arange(len(p[0])).cuda()[None] - mask_p=mask_th&mask_mc#B*N +def dump_train_vis(res, data, step, config): + # batch matching + p = res["p"][:, :-1, :-1] + score, index1 = torch.max(p, dim=-1) + _, index2 = torch.max(p, dim=-2) + mask_th = score > 0.2 + mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None] + mask_p = mask_th & mask_mc # B*N - corr1,corr2=data['x1'],data['x2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1) - corr1_kpt,corr2_kpt=data['kpt1'],data['kpt2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1) - epi_dis=batch_episym(corr1,corr2,data['e_gt']) - mask_inlier=epi_dis0,i0,j 0, - depth_top_right > 0 - ), - np.logical_and( - depth_down_left > 0, - depth_down_left > 0 - ) - ) - ids=ids[valid_depth] - depth_top_left,depth_top_right,depth_down_left,depth_down_right=depth_top_left[valid_depth],depth_top_right[valid_depth],\ - depth_down_left[valid_depth],depth_down_right[valid_depth] - - i,j,i_top_left,j_top_left=i[valid_depth],j[valid_depth],i_top_left[valid_depth],j_top_left[valid_depth] - - # Interpolation - dist_i_top_left = i - i_top_left.astype(np.float32) - dist_j_top_left = j - j_top_left.astype(np.float32) - w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) - w_top_right = (1 - dist_i_top_left) * dist_j_top_left - w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) - w_bottom_right = dist_i_top_left * dist_j_top_left - - interpolated_depth = ( - w_top_left * depth_top_left + - w_top_right * depth_top_right+ - w_bottom_left * depth_down_left + - w_bottom_right * depth_down_right - ) - return [interpolated_depth, ids] - -def reprojection(depth_map,kpt,dR,dt,K1_img2depth,K1,K2): - #warp kpt from img1 to img2 +def interpolate_depth(pos, depth): + # pos:[y,x] + ids = np.array(range(0, pos.shape[0])) + + h, w = depth.shape + + i = pos[:, 0] + j = pos[:, 1] + valid_corner = np.logical_and( + np.logical_and(i > 0, i < h - 1), np.logical_and(j > 0, j < w - 1) + ) + i, j = i[valid_corner], j[valid_corner] + ids = ids[valid_corner] + + i_top_left = np.floor(i).astype(np.int32) + j_top_left = np.floor(j).astype(np.int32) + + i_top_right = np.floor(i).astype(np.int32) + j_top_right = np.ceil(j).astype(np.int32) + + i_bottom_left = np.ceil(i).astype(np.int32) + j_bottom_left = np.floor(j).astype(np.int32) + + i_bottom_right = np.ceil(i).astype(np.int32) + j_bottom_right = np.ceil(j).astype(np.int32) + + # Valid depth + depth_top_left, depth_top_right, depth_down_left, depth_down_right = ( + depth[i_top_left, j_top_left], + depth[i_top_right, j_top_right], + depth[i_bottom_left, j_bottom_left], + depth[i_bottom_right, j_bottom_right], + ) + + valid_depth = np.logical_and( + np.logical_and(depth_top_left > 0, depth_top_right > 0), + np.logical_and(depth_down_left > 0, depth_down_left > 0), + ) + ids = ids[valid_depth] + depth_top_left, depth_top_right, depth_down_left, depth_down_right = ( + depth_top_left[valid_depth], + depth_top_right[valid_depth], + depth_down_left[valid_depth], + depth_down_right[valid_depth], + ) + + i, j, i_top_left, j_top_left = ( + i[valid_depth], + j[valid_depth], + i_top_left[valid_depth], + j_top_left[valid_depth], + ) + + # Interpolation + dist_i_top_left = i - i_top_left.astype(np.float32) + dist_j_top_left = j - j_top_left.astype(np.float32) + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + interpolated_depth = ( + w_top_left * depth_top_left + + w_top_right * depth_top_right + + w_bottom_left * depth_down_left + + w_bottom_right * depth_down_right + ) + return [interpolated_depth, ids] + + +def reprojection(depth_map, kpt, dR, dt, K1_img2depth, K1, K2): + # warp kpt from img1 to img2 def swap_axis(data): return np.stack([data[:, 1], data[:, 0]], axis=-1) - kp_depth = unnorm_kp(K1_img2depth,kpt) + kp_depth = unnorm_kp(K1_img2depth, kpt) uv_depth = swap_axis(kp_depth) - z,valid_idx = interpolate_depth(uv_depth, depth_map) + z, valid_idx = interpolate_depth(uv_depth, depth_map) - norm_kp=norm_kpt(K1,kpt) - norm_kp_valid = np.concatenate([norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1) + norm_kp = norm_kpt(K1, kpt) + norm_kp_valid = np.concatenate( + [norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1 + ) xyz_valid = norm_kp_valid * z.reshape(-1, 1) xyz2 = np.matmul(xyz_valid, dR.T) + dt.reshape(1, 3) xy2 = xyz2[:, :2] / xyz2[:, 2:] kp2, valid = np.ones(kpt.shape) * 1e5, np.zeros(kpt.shape[0]) - kp2[valid_idx] = unnorm_kp(K2,xy2) + kp2[valid_idx] = unnorm_kp(K2, xy2) valid[valid_idx] = 1 return kp2, valid.astype(bool) -def reprojection_2s(kp1, kp2,depth1, depth2, K1, K2, dR, dt, size1,size2): - #size:H*W - depth_size1,depth_size2 = [depth1.shape[0], depth1.shape[1]], [depth2.shape[0], depth2.shape[1]] - scale_1= [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1] - scale_2= [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1] - K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag(np.asarray(scale_2)) - kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth,K1,K2) - kp2_1_proj, valid2_1 = reprojection(depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth,K2,K1) - return [kp1_2_proj,kp2_1_proj],[valid1_2,valid2_1] - -def make_corr(kp1,kp2,desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2,corr_th,incorr_th,check_desc=False): - #make reprojection - [kp1_2,kp2_1],[valid1_2,valid2_1]=reprojection_2s(kp1,kp2,depth1,depth2,K1,K2,dR,dt,size1,size2) + +def reprojection_2s(kp1, kp2, depth1, depth2, K1, K2, dR, dt, size1, size2): + # size:H*W + depth_size1, depth_size2 = [depth1.shape[0], depth1.shape[1]], [ + depth2.shape[0], + depth2.shape[1], + ] + scale_1 = [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1] + scale_2 = [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1] + K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag( + np.asarray(scale_2) + ) + kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth, K1, K2) + kp2_1_proj, valid2_1 = reprojection( + depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth, K2, K1 + ) + return [kp1_2_proj, kp2_1_proj], [valid1_2, valid2_1] + + +def make_corr( + kp1, + kp2, + desc1, + desc2, + depth1, + depth2, + K1, + K2, + dR, + dt, + size1, + size2, + corr_th, + incorr_th, + check_desc=False, +): + # make reprojection + [kp1_2, kp2_1], [valid1_2, valid2_1] = reprojection_2s( + kp1, kp2, depth1, depth2, K1, K2, dR, dt, size1, size2 + ) num_pts1, num_pts2 = kp1.shape[0], kp2.shape[0] - #reprojection error - dis_mat1=np.sqrt(abs((kp1 ** 2).sum(1,keepdims=True) + (kp2_1 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp1, kp2_1.T))) - dis_mat2 =np.sqrt(abs((kp2 ** 2).sum(1,keepdims=True) + (kp1_2 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp2,kp1_2.T))) - repro_error = np.maximum(dis_mat1,dis_mat2.T) #n1*n2 - + # reprojection error + dis_mat1 = np.sqrt( + abs( + (kp1**2).sum(1, keepdims=True) + + (kp2_1**2).sum(1, keepdims=False)[np.newaxis] + - 2 * np.matmul(kp1, kp2_1.T) + ) + ) + dis_mat2 = np.sqrt( + abs( + (kp2**2).sum(1, keepdims=True) + + (kp1_2**2).sum(1, keepdims=False)[np.newaxis] + - 2 * np.matmul(kp2, kp1_2.T) + ) + ) + repro_error = np.maximum(dis_mat1, dis_mat2.T) # n1*n2 + # find corr index nn_sort1 = np.argmin(repro_error, axis=1) nn_sort2 = np.argmin(repro_error, axis=0) mask_mutual = nn_sort2[nn_sort1] == np.arange(kp1.shape[0]) - mask_inlier=np.take_along_axis(repro_error,indices=nn_sort1[:,np.newaxis],axis=-1).squeeze(1)1,mask_samepos2.sum(-1)>1) - duplicated_index=np.nonzero(duplicated_mask)[0] + mask_samepos1 = np.logical_and( + x1_valid[:, 0, np.newaxis] == kp1[np.newaxis, :, 0], + x1_valid[:, 1, np.newaxis] == kp1[np.newaxis, :, 1], + ) + mask_samepos2 = np.logical_and( + x2_valid[:, 0, np.newaxis] == kp2[np.newaxis, :, 0], + x2_valid[:, 1, np.newaxis] == kp2[np.newaxis, :, 1], + ) + duplicated_mask = np.logical_or( + mask_samepos1.sum(-1) > 1, mask_samepos2.sum(-1) > 1 + ) + duplicated_index = np.nonzero(duplicated_mask)[0] - unique_corr_index=corr_index[~duplicated_mask] - clean_duplicated_corr=[] + unique_corr_index = corr_index[~duplicated_mask] + clean_duplicated_corr = [] for index in duplicated_index: - cur_desc1, cur_desc2 = desc1[mask_samepos1[index]], desc2[mask_samepos2[index]] + cur_desc1, cur_desc2 = ( + desc1[mask_samepos1[index]], + desc2[mask_samepos2[index]], + ) cur_desc_mat = np.matmul(cur_desc1, cur_desc2.T) - cur_max_index =[np.argmax(cur_desc_mat)//cur_desc_mat.shape[1],np.argmax(cur_desc_mat)%cur_desc_mat.shape[1]] - clean_duplicated_corr.append(np.stack([np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]], - np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]]])) - - clean_corr_index=unique_corr_index - if len(clean_duplicated_corr)!=0: - clean_duplicated_corr=np.stack(clean_duplicated_corr,axis=0) - clean_corr_index=np.concatenate([clean_corr_index,clean_duplicated_corr],axis=0) + cur_max_index = [ + np.argmax(cur_desc_mat) // cur_desc_mat.shape[1], + np.argmax(cur_desc_mat) % cur_desc_mat.shape[1], + ] + clean_duplicated_corr.append( + np.stack( + [ + np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]], + np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]], + ] + ) + ) + + clean_corr_index = unique_corr_index + if len(clean_duplicated_corr) != 0: + clean_duplicated_corr = np.stack(clean_duplicated_corr, axis=0) + clean_corr_index = np.concatenate( + [clean_corr_index, clean_duplicated_corr], axis=0 + ) else: - clean_corr_index=corr_index + clean_corr_index = corr_index # find incorr mask_incorr1 = np.min(dis_mat2.T[valid1_2], axis=-1) > incorr_th mask_incorr2 = np.min(dis_mat1.T[valid2_1], axis=-1) > incorr_th - incorr_index1, incorr_index2 = np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()], \ - np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()] - - return clean_corr_index,incorr_index1,incorr_index2 + incorr_index1, incorr_index2 = ( + np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()], + np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()], + ) + return clean_corr_index, incorr_index1, incorr_index2 diff --git a/third_party/SGMNet/utils/evaluation_utils.py b/third_party/SGMNet/utils/evaluation_utils.py index 82c4715a192d3c361c849896b035cd91ee56dc42..a65a3075791857f586cc4f537dcb67eecc3ef681 100644 --- a/third_party/SGMNet/utils/evaluation_utils.py +++ b/third_party/SGMNet/utils/evaluation_utils.py @@ -2,57 +2,110 @@ import numpy as np import h5py import cv2 -def normalize_intrinsic(x,K): - #print(x,K) - return (x-K[:2,2])/np.diag(K)[:2] -def normalize_size(x,size,scale=1): - size=size.reshape([1,2]) - norm_fac=size.max() - return (x-size/2+0.5)/(norm_fac*scale) +def normalize_intrinsic(x, K): + # print(x,K) + return (x - K[:2, 2]) / np.diag(K)[:2] + + +def normalize_size(x, size, scale=1): + size = size.reshape([1, 2]) + norm_fac = size.max() + return (x - size / 2 + 0.5) / (norm_fac * scale) + def np_skew_symmetric(v): zero = np.zeros_like(v[:, 0]) - M = np.stack([ - zero, -v[:, 2], v[:, 1], - v[:, 2], zero, -v[:, 0], - -v[:, 1], v[:, 0], zero, - ], axis=1) + M = np.stack( + [ + zero, + -v[:, 2], + v[:, 1], + v[:, 2], + zero, + -v[:, 0], + -v[:, 1], + v[:, 0], + zero, + ], + axis=1, + ) return M -def draw_points(img,points,color=(0,255,0),radius=3): + +def draw_points(img, points, color=(0, 255, 0), radius=3): dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] for i in range(points.shape[0]): - cv2.circle(img, dp[i],radius=radius,color=color) + cv2.circle(img, dp[i], radius=radius, color=color) return img - -def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + +def draw_match( + img1, + img2, + corr1, + corr2, + inlier=[True], + color=None, + radius1=1, + radius2=1, + resize=None, +): if resize is not None: - scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] - img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) - corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] - corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] - corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [ + img2.shape[1] / resize[0], + img2.shape[0] / resize[1], + ] + img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize( + img2, resize, interpolation=cv2.INTER_AREA + ) + corr1, corr2 = ( + corr1 / np.asarray(scale1)[np.newaxis], + corr2 / np.asarray(scale2)[np.newaxis], + ) + corr1_key = [ + cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0]) + ] + corr2_key = [ + cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0]) + ] assert len(corr1) == len(corr2) draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] if color is None: - color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] - if len(color)==1: - display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, - matchColor=color[0], - singlePointColor=color[0], - flags=4 - ) + color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier] + if len(color) == 1: + display = cv2.drawMatches( + img1, + corr1_key, + img2, + corr2_key, + draw_matches, + None, + matchColor=color[0], + singlePointColor=color[0], + flags=4, + ) else: - height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] - display=np.zeros([height,width,3],np.uint8) - display[:img1.shape[0],:img1.shape[1]]=img1 - display[:img2.shape[0],img1.shape[1]:]=img2 + height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1] + display = np.zeros([height, width, 3], np.uint8) + display[: img1.shape[0], : img1.shape[1]] = img1 + display[: img2.shape[0], img1.shape[1] :] = img2 for i in range(len(corr1)): - left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) - cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) - cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) - return display \ No newline at end of file + left_x, left_y, right_x, right_y = ( + int(corr1[i][0]), + int(corr1[i][1]), + int(corr2[i][0] + img1.shape[1]), + int(corr2[i][1]), + ) + cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2])) + cv2.line( + display, + (left_x, left_y), + (right_x, right_y), + cur_color, + 1, + lineType=cv2.LINE_AA, + ) + return display diff --git a/third_party/SGMNet/utils/fm_utils.py b/third_party/SGMNet/utils/fm_utils.py index f9cbbeefe5d6b59c1ae1fa26cdaa42146ad22a74..900b73c42723cd9c5bcbef5c758deadcd0b309df 100644 --- a/third_party/SGMNet/utils/fm_utils.py +++ b/third_party/SGMNet/utils/fm_utils.py @@ -1,95 +1,100 @@ import numpy as np -def line_to_border(line,size): - #line:(a,b,c), ax+by+c=0 - #size:(W,H) - H,W=size[1],size[0] - a,b,c=line[0],line[1],line[2] - epsa=1e-8 if a>=0 else -1e-8 - epsb=1e-8 if b>=0 else -1e-8 - intersection_list=[] - - y_left=-c/(b+epsb) - y_right=(-c-a*(W-1))/(b+epsb) - x_top=-c/(a+epsa) - x_down=(-c-b*(H-1))/(a+epsa) - - if y_left>=0 and y_left<=H-1: - intersection_list.append([0,y_left]) - if y_right>=0 and y_right<=H-1: - intersection_list.append([W-1,y_right]) - if x_top>=0 and x_top<=W-1: - intersection_list.append([x_top,0]) - if x_down>=0 and x_down<=W-1: - intersection_list.append([x_down,H-1]) - if len(intersection_list)!=2: +def line_to_border(line, size): + # line:(a,b,c), ax+by+c=0 + # size:(W,H) + H, W = size[1], size[0] + a, b, c = line[0], line[1], line[2] + epsa = 1e-8 if a >= 0 else -1e-8 + epsb = 1e-8 if b >= 0 else -1e-8 + intersection_list = [] + + y_left = -c / (b + epsb) + y_right = (-c - a * (W - 1)) / (b + epsb) + x_top = -c / (a + epsa) + x_down = (-c - b * (H - 1)) / (a + epsa) + + if y_left >= 0 and y_left <= H - 1: + intersection_list.append([0, y_left]) + if y_right >= 0 and y_right <= H - 1: + intersection_list.append([W - 1, y_right]) + if x_top >= 0 and x_top <= W - 1: + intersection_list.append([x_top, 0]) + if x_down >= 0 and x_down <= W - 1: + intersection_list.append([x_down, H - 1]) + if len(intersection_list) != 2: return None - intersection_list=np.asarray(intersection_list) + intersection_list = np.asarray(intersection_list) return intersection_list + def find_point_in_line(end_point): - x_span,y_span=end_point[1,0]-end_point[0,0],end_point[1,1]-end_point[0,1] - mv=np.random.uniform() - point=np.asarray([end_point[0,0]+x_span*mv,end_point[0,1]+y_span*mv]) + x_span, y_span = ( + end_point[1, 0] - end_point[0, 0], + end_point[1, 1] - end_point[0, 1], + ) + mv = np.random.uniform() + point = np.asarray([end_point[0, 0] + x_span * mv, end_point[0, 1] + y_span * mv]) return point -def epi_line(point,F): - homo=np.concatenate([point,np.ones([len(point),1])],axis=-1) - epi=np.matmul(homo,F.T) + +def epi_line(point, F): + homo = np.concatenate([point, np.ones([len(point), 1])], axis=-1) + epi = np.matmul(homo, F.T) return epi -def dis_point_to_line(line,point): - homo=np.concatenate([point,np.ones([len(point),1])],axis=-1) - dis=line*homo - dis=dis.sum(axis=-1)/(np.linalg.norm(line[:,:2],axis=-1)+1e-8) + +def dis_point_to_line(line, point): + homo = np.concatenate([point, np.ones([len(point), 1])], axis=-1) + dis = line * homo + dis = dis.sum(axis=-1) / (np.linalg.norm(line[:, :2], axis=-1) + 1e-8) return abs(dis) -def SGD_oneiter(F1,F2,size1,size2): - H1,W1=size1[1],size1[0] + +def SGD_oneiter(F1, F2, size1, size2): + H1, W1 = size1[1], size1[0] factor1 = 1 / np.linalg.norm(size1) factor2 = 1 / np.linalg.norm(size2) - p0=np.asarray([(W1-1)*np.random.uniform(),(H1-1)*np.random.uniform()]) - epi1=epi_line(p0[np.newaxis],F1)[0] - border_point1=line_to_border(epi1,size2) + p0 = np.asarray([(W1 - 1) * np.random.uniform(), (H1 - 1) * np.random.uniform()]) + epi1 = epi_line(p0[np.newaxis], F1)[0] + border_point1 = line_to_border(epi1, size2) if border_point1 is None: return -1 - - p1=find_point_in_line(border_point1) - epi2=epi_line(p0[np.newaxis],F2) - d1=dis_point_to_line(epi2,p1[np.newaxis])[0]*factor2 - epi3=epi_line(p1[np.newaxis],F2.T) - d2=dis_point_to_line(epi3,p0[np.newaxis])[0]*factor1 - return (d1+d2)/2 - -def compute_SGD(F1,F2,size1,size2): + + p1 = find_point_in_line(border_point1) + epi2 = epi_line(p0[np.newaxis], F2) + d1 = dis_point_to_line(epi2, p1[np.newaxis])[0] * factor2 + epi3 = epi_line(p1[np.newaxis], F2.T) + d2 = dis_point_to_line(epi3, p0[np.newaxis])[0] * factor1 + return (d1 + d2) / 2 + + +def compute_SGD(F1, F2, size1, size2): np.random.seed(1234) - N=1000 - max_iter=N*10 - count,sgd=0,0 + N = 1000 + max_iter = N * 10 + count, sgd = 0, 0 for i in range(max_iter): - d1=SGD_oneiter(F1,F2,size1,size2) - if d1<0: + d1 = SGD_oneiter(F1, F2, size1, size2) + if d1 < 0: continue - d2=SGD_oneiter(F2,F1,size1,size2) - if d2<0: + d2 = SGD_oneiter(F2, F1, size1, size2) + if d2 < 0: continue - count+=1 - sgd+=(d1+d2)/2 - if count==N: + count += 1 + sgd += (d1 + d2) / 2 + if count == N: break - if count==0: + if count == 0: return 1 else: - return sgd/count - -def compute_inlier_rate(x1,x2,size1,size2,F_gt,th=0.003): - t1,t2=np.linalg.norm(size1)*th,np.linalg.norm(size2)*th - epi1,epi2=epi_line(x1,F_gt),epi_line(x2,F_gt.T) - dis1,dis2=dis_point_to_line(epi1,x2),dis_point_to_line(epi2,x1) - mask_inlier=np.logical_and(dis1 1e-8: - sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2] + sina = (R[1, 0] + (cosa - 1.0) * direction[0] * direction[1]) / direction[2] elif abs(direction[1]) > 1e-8: - sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1] + sina = (R[0, 2] + (cosa - 1.0) * direction[0] * direction[2]) / direction[1] else: - sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0] + sina = (R[2, 1] + (cosa - 1.0) * direction[1] * direction[2]) / direction[0] angle = math.atan2(sina, cosa) return angle, direction, point @@ -458,8 +462,7 @@ def scale_from_matrix(matrix): return factor, origin, direction -def projection_matrix(point, normal, direction=None, - perspective=None, pseudo=False): +def projection_matrix(point, normal, direction=None, perspective=None, pseudo=False): """Return matrix to project onto plane defined by point and normal. Using either perspective point, projection direction, or none of both. @@ -495,14 +498,13 @@ def projection_matrix(point, normal, direction=None, normal = unit_vector(normal[:3]) if perspective is not None: # perspective projection - perspective = numpy.array(perspective[:3], dtype=numpy.float64, - copy=False) - M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal) + perspective = numpy.array(perspective[:3], dtype=numpy.float64, copy=False) + M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective - point, normal) M[:3, :3] -= numpy.outer(perspective, normal) if pseudo: # preserve relative depth M[:3, :3] -= numpy.outer(normal, normal) - M[:3, 3] = numpy.dot(point, normal) * (perspective+normal) + M[:3, 3] = numpy.dot(point, normal) * (perspective + normal) else: M[:3, 3] = numpy.dot(point, normal) * perspective M[3, :3] = -normal @@ -582,11 +584,10 @@ def projection_from_matrix(matrix, pseudo=False): # perspective projection i = numpy.where(abs(numpy.real(w)) > 1e-8)[0] if not len(i): - raise ValueError( - "no eigenvector not corresponding to eigenvalue 0") + raise ValueError("no eigenvector not corresponding to eigenvalue 0") point = numpy.real(V[:, i[-1]]).squeeze() point /= point[3] - normal = - M[3, :3] + normal = -M[3, :3] perspective = M[:3, 3] / numpy.dot(point[:3], normal) if pseudo: perspective -= normal @@ -633,15 +634,19 @@ def clip_matrix(left, right, bottom, top, near, far, perspective=False): if near <= _EPS: raise ValueError("invalid frustum: near <= 0") t = 2.0 * near - M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0], - [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0], - [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)], - [0.0, 0.0, -1.0, 0.0]] + M = [ + [t / (left - right), 0.0, (right + left) / (right - left), 0.0], + [0.0, t / (bottom - top), (top + bottom) / (top - bottom), 0.0], + [0.0, 0.0, (far + near) / (near - far), t * far / (far - near)], + [0.0, 0.0, -1.0, 0.0], + ] else: - M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)], - [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)], - [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)], - [0.0, 0.0, 0.0, 1.0]] + M = [ + [2.0 / (right - left), 0.0, 0.0, (right + left) / (left - right)], + [0.0, 2.0 / (top - bottom), 0.0, (top + bottom) / (bottom - top)], + [0.0, 0.0, 2.0 / (far - near), (far + near) / (near - far)], + [0.0, 0.0, 0.0, 1.0], + ] return numpy.array(M) @@ -761,7 +766,7 @@ def decompose_matrix(matrix): if not numpy.linalg.det(P): raise ValueError("matrix is singular") - scale = numpy.zeros((3, )) + scale = numpy.zeros((3,)) shear = [0.0, 0.0, 0.0] angles = [0.0, 0.0, 0.0] @@ -799,15 +804,16 @@ def decompose_matrix(matrix): angles[0] = math.atan2(row[1, 2], row[2, 2]) angles[2] = math.atan2(row[0, 1], row[0, 0]) else: - #angles[0] = math.atan2(row[1, 0], row[1, 1]) + # angles[0] = math.atan2(row[1, 0], row[1, 1]) angles[0] = math.atan2(-row[2, 1], row[1, 1]) angles[2] = 0.0 return scale, shear, angles, translate, perspective -def compose_matrix(scale=None, shear=None, angles=None, translate=None, - perspective=None): +def compose_matrix( + scale=None, shear=None, angles=None, translate=None, perspective=None +): """Return transformation matrix from sequence of transformations. This is the inverse of the decompose_matrix function. @@ -841,7 +847,7 @@ def compose_matrix(scale=None, shear=None, angles=None, translate=None, T[:3, 3] = translate[:3] M = numpy.dot(M, T) if angles is not None: - R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz') + R = euler_matrix(angles[0], angles[1], angles[2], "sxyz") M = numpy.dot(M, R) if shear is not None: Z = numpy.identity(4) @@ -879,11 +885,14 @@ def orthogonalization_matrix(lengths, angles): sina, sinb, _ = numpy.sin(angles) cosa, cosb, cosg = numpy.cos(angles) co = (cosa * cosb - cosg) / (sina * sinb) - return numpy.array([ - [ a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0], - [-a*sinb*co, b*sina, 0.0, 0.0], - [ a*cosb, b*cosa, c, 0.0], - [ 0.0, 0.0, 0.0, 1.0]]) + return numpy.array( + [ + [a * sinb * math.sqrt(1.0 - co * co), 0.0, 0.0, 0.0], + [-a * sinb * co, b * sina, 0.0, 0.0], + [a * cosb, b * cosa, c, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): @@ -936,11 +945,11 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): # move centroids to origin t0 = -numpy.mean(v0, axis=1) - M0 = numpy.identity(ndims+1) + M0 = numpy.identity(ndims + 1) M0[:ndims, ndims] = t0 v0 += t0.reshape(ndims, 1) t1 = -numpy.mean(v1, axis=1) - M1 = numpy.identity(ndims+1) + M1 = numpy.identity(ndims + 1) M1[:ndims, ndims] = t1 v1 += t1.reshape(ndims, 1) @@ -950,10 +959,10 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): u, s, vh = numpy.linalg.svd(A.T) vh = vh[:ndims].T B = vh[:ndims] - C = vh[ndims:2*ndims] + C = vh[ndims : 2 * ndims] t = numpy.dot(C, numpy.linalg.pinv(B)) t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1) - M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,))) + M = numpy.vstack((t, ((0.0,) * ndims) + (1.0,))) elif usesvd or ndims != 3: # Rigid transformation via SVD of covariance matrix u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T)) @@ -961,10 +970,10 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): R = numpy.dot(u, vh) if numpy.linalg.det(R) < 0.0: # R does not constitute right handed system - R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0) + R -= numpy.outer(u[:, ndims - 1], vh[ndims - 1, :] * 2.0) s[-1] *= -1.0 # homogeneous transformation matrix - M = numpy.identity(ndims+1) + M = numpy.identity(ndims + 1) M[:ndims, :ndims] = R else: # Rigid transformation matrix via quaternion @@ -972,10 +981,12 @@ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): xx, yy, zz = numpy.sum(v0 * v1, axis=1) xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1) xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1) - N = [[xx+yy+zz, 0.0, 0.0, 0.0], - [yz-zy, xx-yy-zz, 0.0, 0.0], - [zx-xz, xy+yx, yy-xx-zz, 0.0], - [xy-yx, zx+xz, yz+zy, zz-xx-yy]] + N = [ + [xx + yy + zz, 0.0, 0.0, 0.0], + [yz - zy, xx - yy - zz, 0.0, 0.0], + [zx - xz, xy + yx, yy - xx - zz, 0.0], + [xy - yx, zx + xz, yz + zy, zz - xx - yy], + ] # quaternion: eigenvector corresponding to most positive eigenvalue w, V = numpy.linalg.eigh(N) q = V[:, numpy.argmax(w)] @@ -1042,11 +1053,10 @@ def superimposition_matrix(v0, v1, scale=False, usesvd=True): """ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3] v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3] - return affine_matrix_from_points(v0, v1, shear=False, - scale=scale, usesvd=usesvd) + return affine_matrix_from_points(v0, v1, shear=False, scale=scale, usesvd=usesvd) -def euler_matrix(ai, aj, ak, axes='sxyz'): +def euler_matrix(ai, aj, ak, axes="sxyz"): """Return homogeneous rotation matrix from Euler angles and axis sequence. ai, aj, ak : Euler's roll, pitch and yaw angles @@ -1072,8 +1082,8 @@ def euler_matrix(ai, aj, ak, axes='sxyz'): firstaxis, parity, repetition, frame = axes i = firstaxis - j = _NEXT_AXIS[i+parity] - k = _NEXT_AXIS[i-parity+1] + j = _NEXT_AXIS[i + parity] + k = _NEXT_AXIS[i - parity + 1] if frame: ai, ak = ak, ai @@ -1082,34 +1092,34 @@ def euler_matrix(ai, aj, ak, axes='sxyz'): si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) - cc, cs = ci*ck, ci*sk - sc, ss = si*ck, si*sk + cc, cs = ci * ck, ci * sk + sc, ss = si * ck, si * sk M = numpy.identity(4) if repetition: M[i, i] = cj - M[i, j] = sj*si - M[i, k] = sj*ci - M[j, i] = sj*sk - M[j, j] = -cj*ss+cc - M[j, k] = -cj*cs-sc - M[k, i] = -sj*ck - M[k, j] = cj*sc+cs - M[k, k] = cj*cc-ss + M[i, j] = sj * si + M[i, k] = sj * ci + M[j, i] = sj * sk + M[j, j] = -cj * ss + cc + M[j, k] = -cj * cs - sc + M[k, i] = -sj * ck + M[k, j] = cj * sc + cs + M[k, k] = cj * cc - ss else: - M[i, i] = cj*ck - M[i, j] = sj*sc-cs - M[i, k] = sj*cc+ss - M[j, i] = cj*sk - M[j, j] = sj*ss+cc - M[j, k] = sj*cs-sc + M[i, i] = cj * ck + M[i, j] = sj * sc - cs + M[i, k] = sj * cc + ss + M[j, i] = cj * sk + M[j, j] = sj * ss + cc + M[j, k] = sj * cs - sc M[k, i] = -sj - M[k, j] = cj*si - M[k, k] = cj*ci + M[k, j] = cj * si + M[k, k] = cj * ci return M -def euler_from_matrix(matrix, axes='sxyz'): +def euler_from_matrix(matrix, axes="sxyz"): """Return Euler angles from rotation matrix for specified axis sequence. axes : One of 24 axis sequences as string or encoded tuple @@ -1135,29 +1145,29 @@ def euler_from_matrix(matrix, axes='sxyz'): firstaxis, parity, repetition, frame = axes i = firstaxis - j = _NEXT_AXIS[i+parity] - k = _NEXT_AXIS[i-parity+1] + j = _NEXT_AXIS[i + parity] + k = _NEXT_AXIS[i - parity + 1] M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] if repetition: - sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) + sy = math.sqrt(M[i, j] * M[i, j] + M[i, k] * M[i, k]) if sy > _EPS: - ax = math.atan2( M[i, j], M[i, k]) - ay = math.atan2( sy, M[i, i]) - az = math.atan2( M[j, i], -M[k, i]) + ax = math.atan2(M[i, j], M[i, k]) + ay = math.atan2(sy, M[i, i]) + az = math.atan2(M[j, i], -M[k, i]) else: - ax = math.atan2(-M[j, k], M[j, j]) - ay = math.atan2( sy, M[i, i]) + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(sy, M[i, i]) az = 0.0 else: - cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) + cy = math.sqrt(M[i, i] * M[i, i] + M[j, i] * M[j, i]) if cy > _EPS: - ax = math.atan2( M[k, j], M[k, k]) - ay = math.atan2(-M[k, i], cy) - az = math.atan2( M[j, i], M[i, i]) + ax = math.atan2(M[k, j], M[k, k]) + ay = math.atan2(-M[k, i], cy) + az = math.atan2(M[j, i], M[i, i]) else: - ax = math.atan2(-M[j, k], M[j, j]) - ay = math.atan2(-M[k, i], cy) + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(-M[k, i], cy) az = 0.0 if parity: @@ -1167,7 +1177,7 @@ def euler_from_matrix(matrix, axes='sxyz'): return ax, ay, az -def euler_from_quaternion(quaternion, axes='sxyz'): +def euler_from_quaternion(quaternion, axes="sxyz"): """Return Euler angles from quaternion for specified axis sequence. >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0]) @@ -1178,7 +1188,7 @@ def euler_from_quaternion(quaternion, axes='sxyz'): return euler_from_matrix(quaternion_matrix(quaternion), axes) -def quaternion_from_euler(ai, aj, ak, axes='sxyz'): +def quaternion_from_euler(ai, aj, ak, axes="sxyz"): """Return quaternion from Euler angles and axis sequence. ai, aj, ak : Euler's roll, pitch and yaw angles @@ -1196,8 +1206,8 @@ def quaternion_from_euler(ai, aj, ak, axes='sxyz'): firstaxis, parity, repetition, frame = axes i = firstaxis + 1 - j = _NEXT_AXIS[i+parity-1] + 1 - k = _NEXT_AXIS[i-parity] + 1 + j = _NEXT_AXIS[i + parity - 1] + 1 + k = _NEXT_AXIS[i - parity] + 1 if frame: ai, ak = ak, ai @@ -1213,22 +1223,22 @@ def quaternion_from_euler(ai, aj, ak, axes='sxyz'): sj = math.sin(aj) ck = math.cos(ak) sk = math.sin(ak) - cc = ci*ck - cs = ci*sk - sc = si*ck - ss = si*sk + cc = ci * ck + cs = ci * sk + sc = si * ck + ss = si * sk - q = numpy.empty((4, )) + q = numpy.empty((4,)) if repetition: - q[0] = cj*(cc - ss) - q[i] = cj*(cs + sc) - q[j] = sj*(cc + ss) - q[k] = sj*(cs - sc) + q[0] = cj * (cc - ss) + q[i] = cj * (cs + sc) + q[j] = sj * (cc + ss) + q[k] = sj * (cs - sc) else: - q[0] = cj*cc + sj*ss - q[i] = cj*sc - sj*cs - q[j] = cj*ss + sj*cc - q[k] = cj*cs - sj*sc + q[0] = cj * cc + sj * ss + q[i] = cj * sc - sj * cs + q[j] = cj * ss + sj * cc + q[k] = cj * cs - sj * sc if parity: q[j] *= -1.0 @@ -1246,8 +1256,8 @@ def quaternion_about_axis(angle, axis): q = numpy.array([0.0, axis[0], axis[1], axis[2]]) qlen = vector_norm(q) if qlen > _EPS: - q *= math.sin(angle/2.0) / qlen - q[0] = math.cos(angle/2.0) + q *= math.sin(angle / 2.0) / qlen + q[0] = math.cos(angle / 2.0) return q @@ -1271,11 +1281,14 @@ def quaternion_matrix(quaternion): return numpy.identity(4) q *= math.sqrt(2.0 / n) q = numpy.outer(q, q) - return numpy.array([ - [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], - [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], - [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], - [ 0.0, 0.0, 0.0, 1.0]]) + return numpy.array( + [ + [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0], + [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0], + [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) def quaternion_from_matrix(matrix, isprecise=False): @@ -1316,7 +1329,7 @@ def quaternion_from_matrix(matrix, isprecise=False): """ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] if isprecise: - q = numpy.empty((4, )) + q = numpy.empty((4,)) t = numpy.trace(M) if t > M[3, 3]: q[0] = t @@ -1346,10 +1359,14 @@ def quaternion_from_matrix(matrix, isprecise=False): m21 = M[2, 1] m22 = M[2, 2] # symmetric matrix K - K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], - [m01+m10, m11-m00-m22, 0.0, 0.0], - [m02+m20, m12+m21, m22-m00-m11, 0.0], - [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) + K = numpy.array( + [ + [m00 - m11 - m22, 0.0, 0.0, 0.0], + [m01 + m10, m11 - m00 - m22, 0.0, 0.0], + [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0], + [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22], + ] + ) K /= 3.0 # quaternion is eigenvector of K that corresponds to largest eigenvalue w, V = numpy.linalg.eigh(K) @@ -1369,10 +1386,15 @@ def quaternion_multiply(quaternion1, quaternion0): """ w0, x0, y0, z0 = quaternion0 w1, x1, y1, z1 = quaternion1 - return numpy.array([-x1*x0 - y1*y0 - z1*z0 + w1*w0, - x1*w0 + y1*z0 - z1*y0 + w1*x0, - -x1*z0 + y1*w0 + z1*x0 + w1*y0, - x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64) + return numpy.array( + [ + -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0, + x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0, + -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0, + x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0, + ], + dtype=numpy.float64, + ) def quaternion_conjugate(quaternion): @@ -1488,8 +1510,9 @@ def random_quaternion(rand=None): pi2 = math.pi * 2.0 t1 = pi2 * rand[1] t2 = pi2 * rand[2] - return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1, - numpy.cos(t1)*r1, numpy.sin(t2)*r2]) + return numpy.array( + [numpy.cos(t2) * r2, numpy.sin(t1) * r1, numpy.cos(t1) * r1, numpy.sin(t2) * r2] + ) def random_rotation_matrix(rand=None): @@ -1530,6 +1553,7 @@ class Arcball(object): >>> ball.next() """ + def __init__(self, initial=None): """Initialize virtual trackball control. @@ -1548,7 +1572,7 @@ class Arcball(object): initial = numpy.array(initial, dtype=numpy.float64) if initial.shape == (4, 4): self._qdown = quaternion_from_matrix(initial) - elif initial.shape == (4, ): + elif initial.shape == (4,): initial /= vector_norm(initial) self._qdown = initial else: @@ -1610,7 +1634,7 @@ class Arcball(object): def next(self, acceleration=0.0): """Continue rotation in direction of last drag.""" - q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False) + q = quaternion_slerp(self._qpre, self._qnow, 2.0 + acceleration, False) self._qpre, self._qnow = self._qnow, q def matrix(self): @@ -1622,11 +1646,11 @@ def arcball_map_to_sphere(point, center, radius): """Return unit sphere coordinates from window coordinates.""" v0 = (point[0] - center[0]) / radius v1 = (center[1] - point[1]) / radius - n = v0*v0 + v1*v1 + n = v0 * v0 + v1 * v1 if n > 1.0: # position outside of sphere n = math.sqrt(n) - return numpy.array([v0/n, v1/n, 0.0]) + return numpy.array([v0 / n, v1 / n, 0.0]) else: return numpy.array([v0, v1, math.sqrt(1.0 - n)]) @@ -1668,14 +1692,31 @@ _NEXT_AXIS = [1, 2, 0, 1] # map axes strings to/from tuples of inner axis, parity, repetition, frame _AXES2TUPLE = { - 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), - 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), - 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), - 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), - 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), - 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), - 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), - 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} + "sxyz": (0, 0, 0, 0), + "sxyx": (0, 0, 1, 0), + "sxzy": (0, 1, 0, 0), + "sxzx": (0, 1, 1, 0), + "syzx": (1, 0, 0, 0), + "syzy": (1, 0, 1, 0), + "syxz": (1, 1, 0, 0), + "syxy": (1, 1, 1, 0), + "szxy": (2, 0, 0, 0), + "szxz": (2, 0, 1, 0), + "szyx": (2, 1, 0, 0), + "szyz": (2, 1, 1, 0), + "rzyx": (0, 0, 0, 1), + "rxyx": (0, 0, 1, 1), + "ryzx": (0, 1, 0, 1), + "rxzx": (0, 1, 1, 1), + "rxzy": (1, 0, 0, 1), + "ryzy": (1, 0, 1, 1), + "rzxy": (1, 1, 0, 1), + "ryxy": (1, 1, 1, 1), + "ryxz": (2, 0, 0, 1), + "rzxz": (2, 0, 1, 1), + "rxyz": (2, 1, 0, 1), + "rzyz": (2, 1, 1, 1), +} _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) @@ -1754,7 +1795,7 @@ def unit_vector(data, axis=None, out=None): if out is not data: out[:] = numpy.array(data, copy=False) data = out - length = numpy.atleast_1d(numpy.sum(data*data, axis)) + length = numpy.atleast_1d(numpy.sum(data * data, axis)) numpy.sqrt(length, length) if axis is not None: length = numpy.expand_dims(length, axis) @@ -1878,7 +1919,7 @@ def is_same_transform(matrix0, matrix1): return numpy.allclose(matrix0, matrix1) -def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'): +def _import_module(name, package=None, warn=True, prefix="_py_", ignore="_"): """Try import all public attributes from module into global namespace. Existing attributes with name clashes are renamed with prefix. @@ -1889,14 +1930,15 @@ def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'): """ import warnings from importlib import import_module + try: if not package: module = import_module(name) else: - module = import_module('.' + name, package=package) + module = import_module("." + name, package=package) except ImportError: if warn: - #warnings.warn("failed to import module %s" % name) + # warnings.warn("failed to import module %s" % name) pass else: for attr in dir(module): @@ -1911,11 +1953,11 @@ def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'): return True -_import_module('_transformations') +_import_module("_transformations") if __name__ == "__main__": import doctest import random # used in doctests + numpy.set_printoptions(suppress=True, precision=5) doctest.testmod() - diff --git a/third_party/SOLD2/setup.py b/third_party/SOLD2/setup.py index 69f72fecdc54cf9b43a7fc55144470e83c5a862d..e6c9cdcb47bdd73758cbd2d5b125dcb91306705f 100644 --- a/third_party/SOLD2/setup.py +++ b/third_party/SOLD2/setup.py @@ -1,4 +1,4 @@ from setuptools import setup -setup(name='sold2', version="0.0", packages=['sold2']) +setup(name="sold2", version="0.0", packages=["sold2"]) diff --git a/third_party/SOLD2/sold2/config/project_config.py b/third_party/SOLD2/sold2/config/project_config.py index 42ed00d1c1900e71568d1b06ff4f9d19a295232d..6846b4451e038b1c517043ea6db08f3029b79852 100644 --- a/third_party/SOLD2/sold2/config/project_config.py +++ b/third_party/SOLD2/sold2/config/project_config.py @@ -5,26 +5,29 @@ import os class Config(object): - """ Datasets and experiments folders for the whole project. """ + """Datasets and experiments folders for the whole project.""" + ##################### ## Dataset setting ## ##################### - DATASET_ROOT = os.getenv("DATASET_ROOT", "./datasets/") # TODO: path to your datasets folder + DATASET_ROOT = os.getenv( + "DATASET_ROOT", "./datasets/" + ) # TODO: path to your datasets folder if not os.path.exists(DATASET_ROOT): os.makedirs(DATASET_ROOT) - + # Synthetic shape dataset synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes") synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes") if not os.path.exists(synthetic_dataroot): os.makedirs(synthetic_dataroot) - + # Exported predictions dataset export_dataroot = os.path.join(DATASET_ROOT, "export_datasets") export_cache_path = os.path.join(DATASET_ROOT, "export_datasets") if not os.path.exists(export_dataroot): os.makedirs(export_dataroot) - + # Wireframe dataset wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe") wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe") @@ -32,10 +35,12 @@ class Config(object): # Holicity dataset holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity") holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity") - + ######################## ## Experiment Setting ## ######################## - EXP_PATH = os.getenv("EXP_PATH", "./experiments/") # TODO: path to your experiments folder + EXP_PATH = os.getenv( + "EXP_PATH", "./experiments/" + ) # TODO: path to your experiments folder if not os.path.exists(EXP_PATH): os.makedirs(EXP_PATH) diff --git a/third_party/SOLD2/sold2/dataset/dataset_util.py b/third_party/SOLD2/sold2/dataset/dataset_util.py index 50439ef3e2958d82719da0f6d10f4a7d98322f9a..67271bc915e6975cad005e9001d2bb430a8baa14 100644 --- a/third_party/SOLD2/sold2/dataset/dataset_util.py +++ b/third_party/SOLD2/sold2/dataset/dataset_util.py @@ -8,53 +8,50 @@ from .merge_dataset import MergeDataset def get_dataset(mode="train", dataset_cfg=None): - """ Initialize different dataset based on a configuration. """ + """Initialize different dataset based on a configuration.""" # Check dataset config is given if dataset_cfg is None: raise ValueError("[Error] The dataset config is required!") # Synthetic dataset if dataset_cfg["dataset_name"] == "synthetic_shape": - dataset = SyntheticShapes( - mode, dataset_cfg - ) + dataset = SyntheticShapes(mode, dataset_cfg) # Get the collate_fn from .synthetic_dataset import synthetic_collate_fn + collate_fn = synthetic_collate_fn # Wireframe dataset elif dataset_cfg["dataset_name"] == "wireframe": - dataset = WireframeDataset( - mode, dataset_cfg - ) + dataset = WireframeDataset(mode, dataset_cfg) # Get the collate_fn from .wireframe_dataset import wireframe_collate_fn + collate_fn = wireframe_collate_fn - + # Holicity dataset elif dataset_cfg["dataset_name"] == "holicity": - dataset = HolicityDataset( - mode, dataset_cfg - ) + dataset = HolicityDataset(mode, dataset_cfg) # Get the collate_fn from .holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn - + # Dataset merging several datasets in one elif dataset_cfg["dataset_name"] == "merge": - dataset = MergeDataset( - mode, dataset_cfg - ) + dataset = MergeDataset(mode, dataset_cfg) # Get the collate_fn from .holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn else: raise ValueError( - "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) + "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"] + ) return dataset, collate_fn diff --git a/third_party/SOLD2/sold2/dataset/holicity_dataset.py b/third_party/SOLD2/sold2/dataset/holicity_dataset.py index e4437f37bda366983052de902a41467ca01412bd..af182c5ef46d68d595da4c3dda76c1f631d56fcc 100644 --- a/third_party/SOLD2/sold2/dataset/holicity_dataset.py +++ b/third_party/SOLD2/sold2/dataset/holicity_dataset.py @@ -26,12 +26,19 @@ from ..misc.train_utils import parse_h5_data def holicity_collate_fn(batch): - """ Customized collate_fn. """ - batch_keys = ["image", "junction_map", "valid_mask", "heatmap", - "heatmap_pos", "heatmap_neg", "homography", - "line_points", "line_indices"] - list_keys = ["junctions", "line_map", "line_map_pos", - "line_map_neg", "file_key"] + """Customized collate_fn.""" + batch_keys = [ + "image", + "junction_map", + "valid_mask", + "heatmap", + "heatmap_pos", + "heatmap_neg", + "homography", + "line_points", + "line_indices", + ] + list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key"] outputs = {} for data_key in batch[0].keys(): @@ -40,14 +47,16 @@ def holicity_collate_fn(batch): # print(batch_match, list_match) if batch_match > 0 and list_match == 0: outputs[data_key] = torch_loader.default_collate( - [b[data_key] for b in batch]) + [b[data_key] for b in batch] + ) elif batch_match == 0 and list_match > 0: outputs[data_key] = [b[data_key] for b in batch] elif batch_match == 0 and list_match == 0: continue else: raise ValueError( - "[Error] A key matches batch keys and list keys simultaneously.") + "[Error] A key matches batch keys and list keys simultaneously." + ) return outputs @@ -57,7 +66,8 @@ class HolicityDataset(Dataset): super(HolicityDataset, self).__init__() if not mode in ["train", "test"]: raise ValueError( - "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'.") + "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'." + ) self.mode = mode if config is None: @@ -71,17 +81,18 @@ class HolicityDataset(Dataset): self.dataset_name = self.get_dataset_name() self.cache_name = self.get_cache_name() self.cache_path = cfg.holicity_cache_path - + # Get the ground truth source if it exists self.gt_source = None - if "gt_source_%s"%(self.mode) in self.config: - self.gt_source = self.config.get("gt_source_%s"%(self.mode)) + if "gt_source_%s" % (self.mode) in self.config: + self.gt_source = self.config.get("gt_source_%s" % (self.mode)) self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source) # Check the full path exists if not os.path.exists(self.gt_source): raise ValueError( - "[Error] The specified ground truth source does not exist.") - + "[Error] The specified ground truth source does not exist." + ) + # Get the filename dataset print("[Info] Initializing Holicity dataset...") self.filename_dataset, self.datapoints = self.construct_dataset() @@ -92,22 +103,22 @@ class HolicityDataset(Dataset): # Print some info print("[Info] Successfully initialized dataset") print("\t Name: Holicity") - print("\t Mode: %s" %(self.mode)) - print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), - "None"))) - print("\t Counts: %d" %(self.dataset_length)) + print("\t Mode: %s" % (self.mode)) + print("\t Gt: %s" % (self.config.get("gt_source_%s" % (self.mode), "None"))) + print("\t Counts: %d" % (self.dataset_length)) print("----------------------------------------") ####################################### ## Dataset construction related APIs ## ####################################### def construct_dataset(self): - """ Construct the dataset (from scratch or from cache). """ + """Construct the dataset (from scratch or from cache).""" # Check if the filename cache exists # If cache exists, load from cache if self.check_dataset_cache(): - print("\t Found filename cache %s at %s"%(self.cache_name, - self.cache_path)) + print( + "\t Found filename cache %s at %s" % (self.cache_name, self.cache_path) + ) print("\t Load filename cache...") filename_dataset, datapoints = self.get_filename_dataset_from_cache() # If not, initialize dataset from scratch @@ -117,56 +128,56 @@ class HolicityDataset(Dataset): filename_dataset, datapoints = self.get_filename_dataset() print("\t Create filename dataset cache...") self.create_filename_dataset_cache(filename_dataset, datapoints) - + return filename_dataset, datapoints - + def create_filename_dataset_cache(self, filename_dataset, datapoints): - """ Create filename dataset cache for faster initialization. """ + """Create filename dataset cache for faster initialization.""" # Check cache path exists if not os.path.exists(self.cache_path): os.makedirs(self.cache_path) cache_file_path = os.path.join(self.cache_path, self.cache_name) - data = { - "filename_dataset": filename_dataset, - "datapoints": datapoints - } + data = {"filename_dataset": filename_dataset, "datapoints": datapoints} with open(cache_file_path, "wb") as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) - + def get_filename_dataset_from_cache(self): - """ Get filename dataset from cache. """ + """Get filename dataset from cache.""" # Load from pkl cache cache_file_path = os.path.join(self.cache_path, self.cache_name) with open(cache_file_path, "rb") as f: data = pickle.load(f) - + return data["filename_dataset"], data["datapoints"] def get_filename_dataset(self): - """ Get the path to the dataset. """ + """Get the path to the dataset.""" if self.mode == "train": # Contains 5720 or 11872 images - dataset_path = [os.path.join(cfg.holicity_dataroot, p) - for p in self.config["train_splits"]] + dataset_path = [ + os.path.join(cfg.holicity_dataroot, p) + for p in self.config["train_splits"] + ] else: # Test mode - Contains 520 images dataset_path = [os.path.join(cfg.holicity_dataroot, "2018-03")] - + # Get paths to all image files image_paths = [] for folder in dataset_path: - image_paths += [os.path.join(folder, img) - for img in os.listdir(folder) - if os.path.splitext(img)[-1] == ".jpg"] + image_paths += [ + os.path.join(folder, img) + for img in os.listdir(folder) + if os.path.splitext(img)[-1] == ".jpg" + ] image_paths = sorted(image_paths) # Verify all the images exist for idx in range(len(image_paths)): image_path = image_paths[idx] if not (os.path.exists(image_path)): - raise ValueError( - "[Error] The image does not exist. %s"%(image_path)) + raise ValueError("[Error] The image does not exist. %s" % (image_path)) # Construct the filename dataset num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) @@ -176,82 +187,77 @@ class HolicityDataset(Dataset): key = self.get_padded_filename(num_pad, idx) filename_dataset[key] = {"image": image_paths[idx]} - + # Get the datapoints datapoints = list(sorted(filename_dataset.keys())) return filename_dataset, datapoints - + def get_dataset_name(self): - """ Get dataset name from dataset config / default config. """ - dataset_name = self.config.get("dataset_name", - self.default_config["dataset_name"]) + """Get dataset name from dataset config / default config.""" + dataset_name = self.config.get( + "dataset_name", self.default_config["dataset_name"] + ) dataset_name = dataset_name + "_%s" % self.mode return dataset_name - + def get_cache_name(self): - """ Get cache name from dataset config / default config. """ - dataset_name = self.config.get("dataset_name", - self.default_config["dataset_name"]) + """Get cache name from dataset config / default config.""" + dataset_name = self.config.get( + "dataset_name", self.default_config["dataset_name"] + ) dataset_name = dataset_name + "_%s" % self.mode # Compose cache name cache_name = dataset_name + "_cache.pkl" return cache_name def check_dataset_cache(self): - """ Check if dataset cache exists. """ + """Check if dataset cache exists.""" cache_file_path = os.path.join(self.cache_path, self.cache_name) if os.path.exists(cache_file_path): return True else: return False - + @staticmethod def get_padded_filename(num_pad, idx): - """ Get the padded filename using adaptive padding. """ + """Get the padded filename using adaptive padding.""" file_len = len("%d" % (idx)) filename = "0" * (num_pad - file_len) + "%d" % (idx) return filename def get_default_config(self): - """ Get the default configuration. """ + """Get the default configuration.""" return { "dataset_name": "holicity", "train_split": "2018-01", "add_augmentation_to_all_splits": False, - "preprocessing": { - "resize": [512, 512], - "blur_size": 11 - }, - "augmentation":{ - "photometric":{ - "enable": False - }, - "homographic":{ - "enable": False - }, + "preprocessing": {"resize": [512, 512], "blur_size": 11}, + "augmentation": { + "photometric": {"enable": False}, + "homographic": {"enable": False}, }, } - + ############################################ ## Pytorch and preprocessing related APIs ## ############################################ @staticmethod def get_data_from_path(data_path): - """ Get data from the information from filename dataset. """ + """Get data from the information from filename dataset.""" output = {} # Get image data image_path = data_path["image"] image = imread(image_path) output["image"] = image - + return output - + @staticmethod def convert_line_map(lcnn_line_map, num_junctions): - """ Convert the line_pos or line_neg - (represented by two junction indexes) to our line map. """ + """Convert the line_pos or line_neg + (represented by two junction indexes) to our line map.""" # Initialize empty line map line_map = np.zeros([num_junctions, num_junctions]) @@ -262,59 +268,60 @@ class HolicityDataset(Dataset): line_map[index1, index2] = 1 line_map[index2, index1] = 1 - + return line_map @staticmethod def junc_to_junc_map(junctions, image_size): - """ Convert junction points to junction maps. """ + """Convert junction points to junction maps.""" junctions = np.round(junctions).astype(np.int) # Clip the boundary by image size - junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) - junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1) + junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1) # Create junction map junc_map = np.zeros([image_size[0], image_size[1]]) junc_map[junctions[:, 0], junctions[:, 1]] = 1 return junc_map[..., None].astype(np.int) - + def parse_transforms(self, names, all_transforms): - """ Parse the transform. """ - trans = all_transforms if (names == 'all') \ + """Parse the transform.""" + trans = ( + all_transforms + if (names == "all") else (names if isinstance(names, list) else [names]) + ) assert set(trans) <= set(all_transforms) return trans def get_photo_transform(self): - """ Get list of photometric transforms (according to the config). """ + """Get list of photometric transforms (according to the config).""" # Get the photometric transform config photo_config = self.config["augmentation"]["photometric"] if not photo_config["enable"]: - raise ValueError( - "[Error] Photometric augmentation is not enabled.") - + raise ValueError("[Error] Photometric augmentation is not enabled.") + # Parse photometric transforms - trans_lst = self.parse_transforms(photo_config["primitives"], - photoaug.available_augmentations) - trans_config_lst = [photo_config["params"].get(p, {}) - for p in trans_lst] + trans_lst = self.parse_transforms( + photo_config["primitives"], photoaug.available_augmentations + ) + trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst] # List of photometric augmentation photometric_trans_lst = [ - getattr(photoaug, trans)(**conf) \ + getattr(photoaug, trans)(**conf) for (trans, conf) in zip(trans_lst, trans_config_lst) ] return photometric_trans_lst def get_homo_transform(self): - """ Get homographic transforms (according to the config). """ + """Get homographic transforms (according to the config).""" # Get homographic transforms for image homo_config = self.config["augmentation"]["homographic"]["params"] if not self.config["augmentation"]["homographic"]["enable"]: - raise ValueError( - "[Error] Homographic augmentation is not enabled") + raise ValueError("[Error] Homographic augmentation is not enabled") # Parse the homographic transforms image_shape = self.config["preprocessing"]["resize"] @@ -324,30 +331,33 @@ class HolicityDataset(Dataset): min_label_tmp = self.config["generation"]["min_label_len"] except: min_label_tmp = None - + # float label len => fraction - if isinstance(min_label_tmp, float): # Skip if not provided + if isinstance(min_label_tmp, float): # Skip if not provided min_label_len = min_label_tmp * min(image_shape) # int label len => length in pixel elif isinstance(min_label_tmp, int): - scale_ratio = (self.config["preprocessing"]["resize"] - / self.config["generation"]["image_size"][0]) - min_label_len = (self.config["generation"]["min_label_len"] - * scale_ratio) + scale_ratio = ( + self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0] + ) + min_label_len = self.config["generation"]["min_label_len"] * scale_ratio # if none => no restriction else: min_label_len = 0 - + # Initialize the transform homographic_trans = homoaug.homography_transform( - image_shape, homo_config, 0, min_label_len) + image_shape, homo_config, 0, min_label_len + ) return homographic_trans - def get_line_points(self, junctions, line_map, H1=None, H2=None, - img_size=None, warp=False): - """ Sample evenly points along each line segments - and keep track of line idx. """ + def get_line_points( + self, junctions, line_map, H1=None, H2=None, img_size=None, warp=False + ): + """Sample evenly points along each line segments + and keep track of line idx.""" if np.sum(line_map) == 0: # No segment detected in the image line_indices = np.zeros(self.config["max_pts"], dtype=int) @@ -356,35 +366,38 @@ class HolicityDataset(Dataset): # Extract all pairs of connected junctions junc_indices = np.array( - [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]) - line_segments = np.stack([junctions[junc_indices[:, 0]], - junctions[junc_indices[:, 1]]], axis=1) + [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i] + ) + line_segments = np.stack( + [junctions[junc_indices[:, 0]], junctions[junc_indices[:, 1]]], axis=1 + ) # line_segments is (num_lines, 2, 2) - line_lengths = np.linalg.norm( - line_segments[:, 0] - line_segments[:, 1], axis=1) + line_lengths = np.linalg.norm(line_segments[:, 0] - line_segments[:, 1], axis=1) # Sample the points separated by at least min_dist_pts along each line # The number of samples depends on the length of the line - num_samples = np.minimum(line_lengths // self.config["min_dist_pts"], - self.config["max_num_samples"]) + num_samples = np.minimum( + line_lengths // self.config["min_dist_pts"], self.config["max_num_samples"] + ) line_points = [] line_indices = [] cur_line_idx = 1 for n in np.arange(2, self.config["max_num_samples"] + 1): # Consider all lines where we can fit up to n points cur_line_seg = line_segments[num_samples == n] - line_points_x = np.linspace(cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 0], - n, axis=-1).flatten() - line_points_y = np.linspace(cur_line_seg[:, 0, 1], - cur_line_seg[:, 1, 1], - n, axis=-1).flatten() + line_points_x = np.linspace( + cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1 + ).flatten() + line_points_y = np.linspace( + cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1 + ).flatten() jitter = self.config.get("jittering", 0) if jitter: # Add a small random jittering of all points along the line angles = np.arctan2( cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n) + cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1], + ).repeat(n) jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter line_points_x += jitter_hyp * np.sin(angles) line_points_y += jitter_hyp * np.cos(angles) @@ -394,10 +407,8 @@ class HolicityDataset(Dataset): line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines) line_indices.append(line_idx.repeat(n)) cur_line_idx += num_cur_lines - line_points = np.concatenate(line_points, - axis=0)[:self.config["max_pts"]] - line_indices = np.concatenate(line_indices, - axis=0)[:self.config["max_pts"]] + line_points = np.concatenate(line_points, axis=0)[: self.config["max_pts"]] + line_indices = np.concatenate(line_indices, axis=0)[: self.config["max_pts"]] # Warp the points if need be, and filter unvalid ones # If the other view is also warped @@ -419,37 +430,43 @@ class HolicityDataset(Dataset): mask = mask_points(warped_points, img_size) line_points = line_points[mask] line_indices = line_indices[mask] - + # Pad the line points to a fixed length # Index of 0 means padded line - line_indices = np.concatenate([line_indices, np.zeros( - self.config["max_pts"] - len(line_indices))], axis=0) + line_indices = np.concatenate( + [line_indices, np.zeros(self.config["max_pts"] - len(line_indices))], axis=0 + ) line_points = np.concatenate( - [line_points, - np.zeros((self.config["max_pts"] - len(line_points), 2), - dtype=float)], axis=0) - + [ + line_points, + np.zeros((self.config["max_pts"] - len(line_points), 2), dtype=float), + ], + axis=0, + ) + return line_points, line_indices def export_preprocessing(self, data, numpy=False): - """ Preprocess the exported data. """ + """Preprocess the exported data.""" # Fetch the corresponding entries image = data["image"] image_size = image.shape[:2] # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) - + # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) image = photoaug.normalize_image()(image) @@ -459,11 +476,21 @@ class HolicityDataset(Dataset): return {"image": to_tensor(image)} else: return {"image": image} - + def train_preprocessing_exported( - self, data, numpy=False, disable_homoaug=False, desc_training=False, - H1=None, H1_scale=None, H2=None, scale=1., h_crop=None, w_crop=None): - """ Train preprocessing for the exported labels. """ + self, + data, + numpy=False, + disable_homoaug=False, + desc_training=False, + H1=None, + H1_scale=None, + H2=None, + scale=1.0, + h_crop=None, + w_crop=None, + ): + """Train preprocessing for the exported labels.""" data = copy.deepcopy(data) # Fetch the corresponding entries image = data["image"] @@ -483,13 +510,15 @@ class HolicityDataset(Dataset): w_crop = np.random.randint(W_scale - W) # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # # In HW format @@ -504,7 +533,7 @@ class HolicityDataset(Dataset): # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Check if we need to apply augmentations # In training mode => yes. @@ -514,16 +543,17 @@ class HolicityDataset(Dataset): ### Image transform ### np.random.shuffle(photo_trans_lst) image_transform = transforms.Compose( - photo_trans_lst + [photoaug.normalize_image()]) + photo_trans_lst + [photoaug.normalize_image()] + ) else: image_transform = photoaug.normalize_image() image = image_transform(image) # Perform the random scaling - if scale != 1.: + if scale != 1.0: image, junctions, line_map, valid_mask = random_scaling( - image, junctions, line_map, scale, - h_crop=h_crop, w_crop=w_crop) + image, junctions, line_map, scale, h_crop=h_crop, w_crop=w_crop + ) else: # Declare default valid mask (all ones) valid_mask = np.ones(image_size) @@ -534,20 +564,28 @@ class HolicityDataset(Dataset): to_tensor = transforms.ToTensor() # Check homographic augmentation - warp = (self.config["augmentation"]["homographic"]["enable"] - and disable_homoaug == False) + warp = ( + self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False + ) if warp: homo_trans = self.get_homo_transform() # Perform homographic transform if H1 is None: - homo_outputs = homo_trans(image, junctions, line_map, - valid_mask=valid_mask) + homo_outputs = homo_trans( + image, junctions, line_map, valid_mask=valid_mask + ) else: homo_outputs = homo_trans( - image, junctions, line_map, homo=H1, scale=H1_scale, - valid_mask=valid_mask) + image, + junctions, + line_map, + homo=H1, + scale=H1_scale, + valid_mask=valid_mask, + ) homography_mat = homo_outputs["homo"] - + # Give the warp of the other view if H1 is None: H1 = homo_outputs["homo"] @@ -555,8 +593,8 @@ class HolicityDataset(Dataset): # Sample points along each line segments for the descriptor if desc_training: line_points, line_indices = self.get_line_points( - junctions, line_map, H1=H1, H2=H2, - img_size=image_size, warp=warp) + junctions, line_map, H1=H1, H2=H2, img_size=image_size, warp=warp + ) # Record the warped results if warp: @@ -565,52 +603,59 @@ class HolicityDataset(Dataset): line_map = homo_outputs["line_map"] valid_mask = homo_outputs["valid_mask"] # Same for pos and neg heatmap = homo_outputs["warped_heatmap"] - + # Optionally put warping information first. if not numpy: - outputs["homography_mat"] = to_tensor( - homography_mat).to(torch.float32)[0, ...] + outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[ + 0, ... + ] else: outputs["homography_mat"] = homography_mat.astype(np.float32) junction_map = self.junc_to_junc_map(junctions, image_size) - + if not numpy: - outputs.update({ - "image": to_tensor(image), - "junctions": to_tensor(junctions).to(torch.float32)[0, ...], - "junction_map": to_tensor(junction_map).to(torch.int), - "line_map": to_tensor(line_map).to(torch.int32)[0, ...], - "heatmap": to_tensor(heatmap).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) - }) + outputs.update( + { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32), + } + ) if desc_training: - outputs.update({ - "line_points": to_tensor( - line_points).to(torch.float32)[0], - "line_indices": torch.tensor(line_indices, - dtype=torch.int) - }) + outputs.update( + { + "line_points": to_tensor(line_points).to(torch.float32)[0], + "line_indices": torch.tensor(line_indices, dtype=torch.int), + } + ) else: - outputs.update({ - "image": image, - "junctions": junctions.astype(np.float32), - "junction_map": junction_map.astype(np.int32), - "line_map": line_map.astype(np.int32), - "heatmap": heatmap.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) - }) + outputs.update( + { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32), + } + ) if desc_training: - outputs.update({ - "line_points": line_points.astype(np.float32), - "line_indices": line_indices.astype(int) - }) - + outputs.update( + { + "line_points": line_points.astype(np.float32), + "line_indices": line_indices.astype(int), + } + ) + return outputs - - def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.): - """ Train preprocessing for paired data for the exported labels - for descriptor training. """ + + def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.0): + """Train preprocessing for paired data for the exported labels + for descriptor training.""" outputs = {} # Define the random crop for scaling if necessary @@ -622,51 +667,66 @@ class HolicityDataset(Dataset): h_crop = np.random.randint(H_scale - H) if W_scale > W: w_crop = np.random.randint(W_scale - W) - + # Sample ref homography first homo_config = self.config["augmentation"]["homographic"]["params"] image_shape = self.config["preprocessing"]["resize"] - ref_H, ref_scale = homoaug.sample_homography(image_shape, - **homo_config) + ref_H, ref_scale = homoaug.sample_homography(image_shape, **homo_config) # Data for target view (All augmentation) target_data = self.train_preprocessing_exported( - data, numpy=numpy, desc_training=True, H1=None, H2=ref_H, - scale=scale, h_crop=h_crop, w_crop=w_crop) + data, + numpy=numpy, + desc_training=True, + H1=None, + H2=ref_H, + scale=scale, + h_crop=h_crop, + w_crop=w_crop, + ) # Data for reference view (No homographical augmentation) ref_data = self.train_preprocessing_exported( - data, numpy=numpy, desc_training=True, H1=ref_H, - H1_scale=ref_scale, H2=target_data['homography_mat'].numpy(), - scale=scale, h_crop=h_crop, w_crop=w_crop) + data, + numpy=numpy, + desc_training=True, + H1=ref_H, + H1_scale=ref_scale, + H2=target_data["homography_mat"].numpy(), + scale=scale, + h_crop=h_crop, + w_crop=w_crop, + ) # Spread ref data for key, val in ref_data.items(): outputs["ref_" + key] = val - + # Spread target data for key, val in target_data.items(): outputs["target_" + key] = val - + return outputs def test_preprocessing_exported(self, data, numpy=False): - """ Test preprocessing for the exported labels. """ + """Test preprocessing for the exported labels.""" data = copy.deepcopy(data) # Fetch the corresponding entries image = data["image"] junctions = data["junctions"] - line_map = data["line_map"] + line_map = data["line_map"] image_size = image.shape[:2] # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # # In HW format @@ -676,7 +736,7 @@ class HolicityDataset(Dataset): # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Still need to normalize image image_transform = photoaug.normalize_image() @@ -686,7 +746,7 @@ class HolicityDataset(Dataset): junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) image_size = image.shape[:2] heatmap = get_line_heatmap(junctions_xy, line_map, image_size) - + # Declare default valid mask (all ones) valid_mask = np.ones(image_size) @@ -701,7 +761,7 @@ class HolicityDataset(Dataset): "junction_map": to_tensor(junction_map).to(torch.int), "line_map": to_tensor(line_map).to(torch.int32)[0, ...], "heatmap": to_tensor(heatmap).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) + "valid_mask": to_tensor(valid_mask).to(torch.int32), } else: outputs = { @@ -710,38 +770,36 @@ class HolicityDataset(Dataset): "junction_map": junction_map.astype(np.int32), "line_map": line_map.astype(np.int32), "heatmap": heatmap.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) + "valid_mask": valid_mask.astype(np.int32), } - + return outputs def __len__(self): return self.dataset_length - + def get_data_from_key(self, file_key): - """ Get data from file_key. """ + """Get data from file_key.""" # Check key exists if not file_key in self.filename_dataset.keys(): - raise ValueError( - "[Error] the specified key is not in the dataset.") - + raise ValueError("[Error] the specified key is not in the dataset.") + # Get the data paths data_path = self.filename_dataset[file_key] # Read in the image and npz labels data = self.get_data_from_path(data_path) # Perform transform and augmentation - if (self.mode == "train" - or self.config["add_augmentation_to_all_splits"]): + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: data = self.train_preprocessing(data, numpy=True) else: data = self.test_preprocessing(data, numpy=True) - + # Add file key to the output data["file_key"] = file_key - + return data - + def __getitem__(self, idx): """Return data file_key: str, keys used to retrieve data from the filename dataset. @@ -761,27 +819,25 @@ class HolicityDataset(Dataset): if self.gt_source: with h5py.File(self.gt_source, "r") as f: exported_label = parse_h5_data(f[file_key]) - + data["junctions"] = exported_label["junctions"] data["line_map"] = exported_label["line_map"] - + # Perform transform and augmentation return_type = self.config.get("return_type", "single") if self.gt_source is None: # For export only data = self.export_preprocessing(data) - elif (self.mode == "train" - or self.config["add_augmentation_to_all_splits"]): + elif self.mode == "train" or self.config["add_augmentation_to_all_splits"]: # Perform random scaling first if self.config["augmentation"]["random_scaling"]["enable"]: scale_range = self.config["augmentation"]["random_scaling"]["range"] # Decide the scaling scale = np.random.uniform(min(scale_range), max(scale_range)) else: - scale = 1. + scale = 1.0 if self.mode == "train" and return_type == "paired_desc": - data = self.preprocessing_exported_paired_desc(data, - scale=scale) + data = self.preprocessing_exported_paired_desc(data, scale=scale) else: data = self.train_preprocessing_exported(data, scale=scale) else: @@ -789,9 +845,8 @@ class HolicityDataset(Dataset): data = self.preprocessing_exported_paired_desc(data) else: data = self.test_preprocessing_exported(data) - + # Add file key to the output data["file_key"] = file_key - - return data + return data diff --git a/third_party/SOLD2/sold2/dataset/merge_dataset.py b/third_party/SOLD2/sold2/dataset/merge_dataset.py index 178d3822d56639a49a99f68e392330e388fa8fc3..1f6395873dcfdea0c35898eefbf4c74a8cfac7a1 100644 --- a/third_party/SOLD2/sold2/dataset/merge_dataset.py +++ b/third_party/SOLD2/sold2/dataset/merge_dataset.py @@ -14,23 +14,24 @@ class MergeDataset(Dataset): # Initialize the datasets self._datasets = [] spec_config = deepcopy(config) - for i, d in enumerate(config['datasets']): - spec_config['dataset_name'] = d - spec_config['gt_source_train'] = config['gt_source_train'][i] - spec_config['gt_source_test'] = config['gt_source_test'][i] + for i, d in enumerate(config["datasets"]): + spec_config["dataset_name"] = d + spec_config["gt_source_train"] = config["gt_source_train"][i] + spec_config["gt_source_test"] = config["gt_source_test"][i] if d == "wireframe": self._datasets.append(WireframeDataset(mode, spec_config)) elif d == "holicity": - spec_config['train_split'] = config['train_splits'][i] + spec_config["train_split"] = config["train_splits"][i] self._datasets.append(HolicityDataset(mode, spec_config)) else: - raise ValueError("Unknown dataset: " + d) + raise ValueError("Unknown dataset: " + d) + + self._weights = config["weights"] - self._weights = config['weights'] - def __getitem__(self, item): - dataset = self._datasets[np.random.choice( - range(len(self._datasets)), p=self._weights)] + dataset = self._datasets[ + np.random.choice(range(len(self._datasets)), p=self._weights) + ] return dataset[np.random.randint(len(dataset))] def __len__(self): diff --git a/third_party/SOLD2/sold2/dataset/synthetic_dataset.py b/third_party/SOLD2/sold2/dataset/synthetic_dataset.py index cf5f11e5407e65887f4995291156f7cc361843d1..4a1dab47bd81ec831554ba42a635a350ef7a73dc 100644 --- a/third_party/SOLD2/sold2/dataset/synthetic_dataset.py +++ b/third_party/SOLD2/sold2/dataset/synthetic_dataset.py @@ -25,9 +25,8 @@ from ..misc.train_utils import parse_h5_data def synthetic_collate_fn(batch): - """ Customized collate_fn. """ - batch_keys = ["image", "junction_map", "heatmap", - "valid_mask", "homography"] + """Customized collate_fn.""" + batch_keys = ["image", "junction_map", "heatmap", "valid_mask", "homography"] list_keys = ["junctions", "line_map", "file_key"] outputs = {} @@ -36,27 +35,31 @@ def synthetic_collate_fn(batch): list_match = sum([_ in data_key for _ in list_keys]) # print(batch_match, list_match) if batch_match > 0 and list_match == 0: - outputs[data_key] = torch_loader.default_collate([b[data_key] - for b in batch]) + outputs[data_key] = torch_loader.default_collate( + [b[data_key] for b in batch] + ) elif batch_match == 0 and list_match > 0: outputs[data_key] = [b[data_key] for b in batch] elif batch_match == 0 and list_match == 0: continue else: raise ValueError( - "[Error] A key matches batch keys and list keys simultaneously.") + "[Error] A key matches batch keys and list keys simultaneously." + ) return outputs class SyntheticShapes(Dataset): - """ Dataset of synthetic shapes. """ + """Dataset of synthetic shapes.""" + # Initialize the dataset def __init__(self, mode="train", config=None): super(SyntheticShapes, self).__init__() if not mode in ["train", "val", "test"]: raise ValueError( - "[Error] Supported dataset modes are 'train', 'val', and 'test'.") + "[Error] Supported dataset modes are 'train', 'val', and 'test'." + ) self.mode = mode # Get configuration @@ -67,14 +70,14 @@ class SyntheticShapes(Dataset): # Set all available primitives self.available_primitives = [ - 'draw_lines', - 'draw_polygon', - 'draw_multiple_polygons', - 'draw_star', - 'draw_checkerboard_multiseg', - 'draw_stripes_multiseg', - 'draw_cube', - 'gaussian_noise' + "draw_lines", + "draw_polygon", + "draw_multiple_polygons", + "draw_star", + "draw_checkerboard_multiseg", + "draw_stripes_multiseg", + "draw_cube", + "gaussian_noise", ] # Some cache setting @@ -88,11 +91,14 @@ class SyntheticShapes(Dataset): self.print_dataset_info() # Initialize h5 file handle - self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") - + self.dataset_path = os.path.join( + cfg.synthetic_dataroot, self.dataset_name + ".h5" + ) + # Fix the random seed for torch and numpy in testing mode - if ((self.mode == "val" or self.mode == "test") - and self.config["add_augmentation_to_all_splits"]): + if (self.mode == "val" or self.mode == "test") and self.config[ + "add_augmentation_to_all_splits" + ]: seed = self.config.get("test_augmentation_seed", 200) np.random.seed(seed) torch.manual_seed(seed) @@ -104,7 +110,7 @@ class SyntheticShapes(Dataset): ## Dataset construction related methods ## ########################################## def construct_dataset(self): - """ Dataset constructor. """ + """Dataset constructor.""" # Check if the filename cache exists # If cache exists, load from cache if self._check_dataset_cache(): @@ -117,13 +123,14 @@ class SyntheticShapes(Dataset): print("\t All files exist!") # If not, need to re-export the synthetic dataset else: - print("\t Some files are missing. Re-export the synthetic shape dataset.") + print( + "\t Some files are missing. Re-export the synthetic shape dataset." + ) self.export_synthetic_shapes() print("\t Initialize filename dataset") filename_dataset, datapoints = self.get_filename_dataset() print("\t Create filename dataset cache...") - self.create_filename_dataset_cache(filename_dataset, - datapoints) + self.create_filename_dataset_cache(filename_dataset, datapoints) # If not, initialize dataset from scratch else: @@ -135,7 +142,9 @@ class SyntheticShapes(Dataset): # If export dataset does not exist, export from scratch else: - print("\t Synthetic dataset does not exist. Export the synthetic dataset.") + print( + "\t Synthetic dataset does not exist. Export the synthetic dataset." + ) self.export_synthetic_shapes() print("\t Initialize filename dataset") @@ -146,7 +155,7 @@ class SyntheticShapes(Dataset): return filename_dataset, datapoints def get_cache_name(self): - """ Get cache name from dataset config / default config. """ + """Get cache name from dataset config / default config.""" if self.config["dataset_name"] is None: dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode else: @@ -157,7 +166,7 @@ class SyntheticShapes(Dataset): return cache_name def get_dataset_name(self): - """Get dataset name from dataset config / default config. """ + """Get dataset name from dataset config / default config.""" if self.config["dataset_name"] is None: dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode else: @@ -166,7 +175,7 @@ class SyntheticShapes(Dataset): return dataset_name def get_filename_dataset_from_cache(self): - """ Get filename dataset from cache. """ + """Get filename dataset from cache.""" # Load from the pkl cache cache_file_path = os.path.join(self.cache_path, self.cache_name) with open(cache_file_path, "rb") as f: @@ -175,10 +184,9 @@ class SyntheticShapes(Dataset): return data["filename_dataset"], data["datapoints"] def get_filename_dataset(self): - """ Get filename dataset from scratch. """ + """Get filename dataset from scratch.""" # Path to the exported dataset - dataset_path = os.path.join(cfg.synthetic_dataroot, - self.dataset_name + ".h5") + dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") filename_dataset = {} datapoints = [] @@ -187,8 +195,7 @@ class SyntheticShapes(Dataset): # Iterate through all the primitives for prim_name in f.keys(): filenames = sorted(f[prim_name].keys()) - filenames_full = [os.path.join(prim_name, _) - for _ in filenames] + filenames_full = [os.path.join(prim_name, _) for _ in filenames] filename_dataset[prim_name] = filenames_full datapoints += filenames_full @@ -196,34 +203,30 @@ class SyntheticShapes(Dataset): return filename_dataset, datapoints def create_filename_dataset_cache(self, filename_dataset, datapoints): - """ Create filename dataset cache for faster initialization. """ + """Create filename dataset cache for faster initialization.""" # Check cache path exists if not os.path.exists(self.cache_path): os.makedirs(self.cache_path) cache_file_path = os.path.join(self.cache_path, self.cache_name) - data = { - "filename_dataset": filename_dataset, - "datapoints": datapoints - } + data = {"filename_dataset": filename_dataset, "datapoints": datapoints} with open(cache_file_path, "wb") as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) def export_synthetic_shapes(self): - """ Export synthetic shapes to disk. """ + """Export synthetic shapes to disk.""" # Set the global random state for data generation - synthetic_util.set_random_state(np.random.RandomState( - self.config["generation"]["random_seed"])) + synthetic_util.set_random_state( + np.random.RandomState(self.config["generation"]["random_seed"]) + ) # Define the export path - dataset_path = os.path.join(cfg.synthetic_dataroot, - self.dataset_name + ".h5") + dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") # Open h5py file with h5py.File(dataset_path, "w", libver="latest") as f: # Iterate through all types of shape - primitives = self.parse_drawing_primitives( - self.config["primitives"]) + primitives = self.parse_drawing_primitives(self.config["primitives"]) split_size = self.config["generation"]["split_sizes"][self.mode] for prim in primitives: # Create h5 group @@ -234,22 +237,23 @@ class SyntheticShapes(Dataset): f.swmr_mode = True def export_single_primitive(self, primitive, split_size, group): - """ Export single primitive. """ + """Export single primitive.""" # Check if the primitive is valid or not if primitive not in self.available_primitives: - raise ValueError( - "[Error]: %s is not a supported primitive" % primitive) + raise ValueError("[Error]: %s is not a supported primitive" % primitive) # Set the random seed - synthetic_util.set_random_state(np.random.RandomState( - self.config["generation"]["random_seed"])) + synthetic_util.set_random_state( + np.random.RandomState(self.config["generation"]["random_seed"]) + ) # Generate shapes print("\t Generating %s ..." % primitive) for idx in tqdm(range(split_size), ascii=True): # Generate background image image = synthetic_util.generate_background( - self.config['generation']['image_size'], - **self.config['generation']['params']['generate_background']) + self.config["generation"]["image_size"], + **self.config["generation"]["params"]["generate_background"] + ) # Generate points drawing_func = getattr(synthetic_util, primitive) @@ -260,14 +264,21 @@ class SyntheticShapes(Dataset): min_label_len = self.config["generation"]["min_label_len"] # Some only take min_label_len, and gaussian noises take nothing - if primitive in ["draw_lines", "draw_polygon", - "draw_multiple_polygons", "draw_star"]: - data = drawing_func(image, min_len=min_len, - min_label_len=min_label_len, **kwarg) - elif primitive in ["draw_checkerboard_multiseg", - "draw_stripes_multiseg", "draw_cube"]: - data = drawing_func(image, min_label_len=min_label_len, - **kwarg) + if primitive in [ + "draw_lines", + "draw_polygon", + "draw_multiple_polygons", + "draw_star", + ]: + data = drawing_func( + image, min_len=min_len, min_label_len=min_label_len, **kwarg + ) + elif primitive in [ + "draw_checkerboard_multiseg", + "draw_stripes_multiseg", + "draw_cube", + ]: + data = drawing_func(image, min_label_len=min_label_len, **kwarg) else: data = drawing_func(image, **kwarg) @@ -284,21 +295,24 @@ class SyntheticShapes(Dataset): image = cv2.GaussianBlur(image, (blur_size, blur_size), 0) # Resize the image and the point location. - points = (points - * np.array(self.config['preprocessing']['resize'], - np.float) - / np.array(self.config['generation']['image_size'], - np.float)) + points = ( + points + * np.array(self.config["preprocessing"]["resize"], np.float) + / np.array(self.config["generation"]["image_size"], np.float) + ) image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # Generate the line heatmap after post-processing junctions = np.flip(np.round(points).astype(np.int32), axis=1) - heatmap = (synthetic_util.get_line_heatmap( - junctions, line_map, - size=image.shape) * 255.).astype(np.uint8) + heatmap = ( + synthetic_util.get_line_heatmap(junctions, line_map, size=image.shape) + * 255.0 + ).astype(np.uint8) # Record the data in group num_pad = math.ceil(math.log10(split_size)) + 1 @@ -306,17 +320,13 @@ class SyntheticShapes(Dataset): file_group = group.create_group(file_key_name) # Store data - file_group.create_dataset("points", data=points, - compression="gzip") - file_group.create_dataset("image", data=image, - compression="gzip") - file_group.create_dataset("line_map", data=line_map, - compression="gzip") - file_group.create_dataset("heatmap", data=heatmap, - compression="gzip") + file_group.create_dataset("points", data=points, compression="gzip") + file_group.create_dataset("image", data=image, compression="gzip") + file_group.create_dataset("line_map", data=line_map, compression="gzip") + file_group.create_dataset("heatmap", data=heatmap, compression="gzip") def get_default_config(self): - """ Get default configuration of the dataset. """ + """Get default configuration of the dataset.""" # Initialize the default configuration self.default_config = { "dataset_name": "synthetic_shape", @@ -324,43 +334,43 @@ class SyntheticShapes(Dataset): "add_augmentation_to_all_splits": False, # Shape generation configuration "generation": { - "split_sizes": {'train': 10000, 'val': 400, 'test': 500}, + "split_sizes": {"train": 10000, "val": 400, "test": 500}, "random_seed": 10, "image_size": [960, 1280], "min_len": 0.09, "min_label_len": 0.1, - 'params': { - 'generate_background': { - 'min_kernel_size': 150, 'max_kernel_size': 500, - 'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031}, - 'draw_stripes': {'transform_params': (0.1, 0.1)}, - 'draw_multiple_polygons': {'kernel_boundaries': (50, 100)} + "params": { + "generate_background": { + "min_kernel_size": 150, + "max_kernel_size": 500, + "min_rad_ratio": 0.02, + "max_rad_ratio": 0.031, + }, + "draw_stripes": {"transform_params": (0.1, 0.1)}, + "draw_multiple_polygons": {"kernel_boundaries": (50, 100)}, }, }, # Date preprocessing configuration. - "preprocessing": { - "resize": [240, 320], - "blur_size": 11 - }, - 'augmentation': { - 'photometric': { - 'enable': False, - 'primitives': 'all', - 'params': {}, - 'random_order': True, + "preprocessing": {"resize": [240, 320], "blur_size": 11}, + "augmentation": { + "photometric": { + "enable": False, + "primitives": "all", + "params": {}, + "random_order": True, }, - 'homographic': { - 'enable': False, - 'params': {}, - 'valid_border_margin': 0, + "homographic": { + "enable": False, + "params": {}, + "valid_border_margin": 0, }, - } + }, } return self.default_config def parse_drawing_primitives(self, names): - """ Parse the primitives in config to list of primitive names. """ + """Parse the primitives in config to list of primitive names.""" if names == "all": p = self.available_primitives else: @@ -375,42 +385,42 @@ class SyntheticShapes(Dataset): @staticmethod def get_padded_filename(num_pad, idx): - """ Get the padded filename using adaptive padding. """ + """Get the padded filename using adaptive padding.""" file_len = len("%d" % (idx)) filename = "0" * (num_pad - file_len) + "%d" % (idx) return filename def print_dataset_info(self): - """ Print dataset info. """ + """Print dataset info.""" print("\t ---------Summary------------------") print("\t Dataset mode: \t\t %s" % self.mode) print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys())) print("\t Number of data: \t %d" % len(self.datapoints)) print("\t ----------------------------------") - + ######################### ## Pytorch related API ## ######################### def get_data_from_datapoint(self, datapoint, reader=None): - """ Get data given the datapoint - (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """ + """Get data given the datapoint + (keyname of the h5 dataset e.g. "draw_lines/0000.h5").""" # Check if the datapoint is valid if not datapoint in self.datapoints: raise ValueError( - "[Error] The specified datapoint is not in available datapoints.") + "[Error] The specified datapoint is not in available datapoints." + ) # Get data from h5 dataset if reader is None: - raise ValueError( - "[Error] The reader must be provided in __getitem__.") + raise ValueError("[Error] The reader must be provided in __getitem__.") else: data = reader[datapoint] return parse_h5_data(data) def get_data_from_signature(self, primitive_name, index): - """ Get data given the primitive name and index ("draw_lines", 10) """ + """Get data given the primitive name and index ("draw_lines", 10)""" # Check the primitive name and index self._check_primitive_and_index(primitive_name, index) @@ -420,40 +430,41 @@ class SyntheticShapes(Dataset): return self.get_data_from_datapoint(datapoint) def parse_transforms(self, names, all_transforms): - trans = all_transforms if (names == 'all') \ + trans = ( + all_transforms + if (names == "all") else (names if isinstance(names, list) else [names]) + ) assert set(trans) <= set(all_transforms) return trans def get_photo_transform(self): - """ Get list of photometric transforms (according to the config). """ + """Get list of photometric transforms (according to the config).""" # Get the photometric transform config photo_config = self.config["augmentation"]["photometric"] if not photo_config["enable"]: - raise ValueError( - "[Error] Photometric augmentation is not enabled.") - + raise ValueError("[Error] Photometric augmentation is not enabled.") + # Parse photometric transforms - trans_lst = self.parse_transforms(photo_config["primitives"], - photoaug.available_augmentations) - trans_config_lst = [photo_config["params"].get(p, {}) - for p in trans_lst] + trans_lst = self.parse_transforms( + photo_config["primitives"], photoaug.available_augmentations + ) + trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst] # List of photometric augmentation photometric_trans_lst = [ - getattr(photoaug, trans)(**conf) \ + getattr(photoaug, trans)(**conf) for (trans, conf) in zip(trans_lst, trans_config_lst) ] return photometric_trans_lst - + def get_homo_transform(self): - """ Get homographic transforms (according to the config). """ + """Get homographic transforms (according to the config).""" # Get homographic transforms for image homo_config = self.config["augmentation"]["homographic"]["params"] if not self.config["augmentation"]["homographic"]["enable"]: - raise ValueError( - "[Error] Homographic augmentation is not enabled") + raise ValueError("[Error] Homographic augmentation is not enabled") # Parse the homographic transforms # ToDo: use the shape from the config @@ -464,33 +475,35 @@ class SyntheticShapes(Dataset): min_label_tmp = self.config["generation"]["min_label_len"] except: min_label_tmp = None - + # float label len => fraction - if isinstance(min_label_tmp, float): # Skip if not provided + if isinstance(min_label_tmp, float): # Skip if not provided min_label_len = min_label_tmp * min(image_shape) # int label len => length in pixel elif isinstance(min_label_tmp, int): - scale_ratio = (self.config["preprocessing"]["resize"] - / self.config["generation"]["image_size"][0]) - min_label_len = (self.config["generation"]["min_label_len"] - * scale_ratio) + scale_ratio = ( + self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0] + ) + min_label_len = self.config["generation"]["min_label_len"] * scale_ratio # if none => no restriction else: min_label_len = 0 - + # Initialize the transform homographic_trans = homoaug.homography_transform( - image_shape, homo_config, 0, min_label_len) + image_shape, homo_config, 0, min_label_len + ) return homographic_trans @staticmethod def junc_to_junc_map(junctions, image_size): - """ Convert junction points to junction maps. """ + """Convert junction points to junction maps.""" junctions = np.round(junctions).astype(np.int) # Clip the boundary by image size - junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) - junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1) + junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1) # Create junction map junc_map = np.zeros([image_size[0], image_size[1]]) @@ -499,7 +512,7 @@ class SyntheticShapes(Dataset): return junc_map[..., None].astype(np.int) def train_preprocessing(self, data, disable_homoaug=False): - """ Training preprocessing. """ + """Training preprocessing.""" # Fetch corresponding entries image = data["image"] junctions = data["points"] @@ -509,29 +522,32 @@ class SyntheticShapes(Dataset): # Resize the image before the photometric and homographic transforms # Check if we need to do the resizing - if not(list(image.shape) == self.config["preprocessing"]["resize"]): + if not (list(image.shape) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. size_old = list(image.shape) image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) junctions = ( junctions - * np.array(self.config['preprocessing']['resize'], np.float) - / np.array(size_old, np.float)) + * np.array(self.config["preprocessing"]["resize"], np.float) + / np.array(size_old, np.float) + ) # Generate the line heatmap after post-processing - junctions_xy = np.flip(np.round(junctions).astype(np.int32), - axis=1) - heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map, - size=image.shape) - heatmap = (heatmap * 255.).astype(np.uint8) + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + heatmap = synthetic_util.get_line_heatmap( + junctions_xy, line_map, size=image.shape + ) + heatmap = (heatmap * 255.0).astype(np.uint8) # Update image size image_size = image.shape[:2] - + # Declare default valid mask (all ones) valid_mask = np.ones(image_size) @@ -544,7 +560,8 @@ class SyntheticShapes(Dataset): ### Image transform ### np.random.shuffle(photo_trans_lst) image_transform = transforms.Compose( - photo_trans_lst + [photoaug.normalize_image()]) + photo_trans_lst + [photoaug.normalize_image()] + ) else: image_transform = photoaug.normalize_image() image = image_transform(image) @@ -554,40 +571,46 @@ class SyntheticShapes(Dataset): # Convert to tensor and return the results to_tensor = transforms.ToTensor() # Check homographic augmentation - if (self.config["augmentation"]["homographic"]["enable"] - and disable_homoaug == False): + if ( + self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False + ): homo_trans = self.get_homo_transform() # Perform homographic transform homo_outputs = homo_trans(image, junctions, line_map) # Record the warped results - junctions = homo_outputs["junctions"] # Should be HW format + junctions = homo_outputs["junctions"] # Should be HW format image = homo_outputs["warped_image"] line_map = homo_outputs["line_map"] heatmap = homo_outputs["warped_heatmap"] valid_mask = homo_outputs["valid_mask"] # Same for pos and neg homography_mat = homo_outputs["homo"] - + # Optionally put warpping information first. - outputs["homography_mat"] = to_tensor( - homography_mat).to(torch.float32)[0, ...] + outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[ + 0, ... + ] junction_map = self.junc_to_junc_map(junctions, image_size) - outputs.update({ - "image": to_tensor(image), - "junctions": to_tensor(np.ascontiguousarray( - junctions).copy()).to(torch.float32)[0, ...], - "junction_map": to_tensor(junction_map).to(torch.int), - "line_map": to_tensor(line_map).to(torch.int32)[0, ...], - "heatmap": to_tensor(heatmap).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32), - }) + outputs.update( + { + "image": to_tensor(image), + "junctions": to_tensor(np.ascontiguousarray(junctions).copy()).to( + torch.float32 + )[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32), + } + ) return outputs def test_preprocessing(self, data): - """ Test preprocessing. """ + """Test preprocessing.""" # Fetch corresponding entries image = data["image"] points = data["points"] @@ -600,20 +623,24 @@ class SyntheticShapes(Dataset): # Resize the image and the point location. size_old = list(image.shape) image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) - points = (points - * np.array(self.config['preprocessing']['resize'], - np.float) - / np.array(size_old, np.float)) + points = ( + points + * np.array(self.config["preprocessing"]["resize"], np.float) + / np.array(size_old, np.float) + ) # Generate the line heatmap after post-processing junctions = np.flip(np.round(points).astype(np.int32), axis=1) - heatmap = synthetic_util.get_line_heatmap(junctions, line_map, - size=image.shape) - heatmap = (heatmap * 255.).astype(np.uint8) + heatmap = synthetic_util.get_line_heatmap( + junctions, line_map, size=image.shape + ) + heatmap = (heatmap * 255.0).astype(np.uint8) # Update image size image_size = image.shape[:2] @@ -638,7 +665,7 @@ class SyntheticShapes(Dataset): "junction_map": junction_map, "line_map": line_map, "heatmap": heatmap, - "valid_mask": valid_mask + "valid_mask": valid_mask, } def __getitem__(self, index): @@ -649,8 +676,7 @@ class SyntheticShapes(Dataset): data = self.get_data_from_datapoint(datapoint, reader) # Apply different transforms in different mod. - if (self.mode == "train" - or self.config["add_augmentation_to_all_splits"]): + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: return_type = self.config.get("return_type", "single") data = self.train_preprocessing(data) else: @@ -665,7 +691,7 @@ class SyntheticShapes(Dataset): ## Some other methods ## ######################## def _check_dataset_cache(self): - """ Check if dataset cache exists. """ + """Check if dataset cache exists.""" cache_file_path = os.path.join(self.cache_path, self.cache_name) if os.path.exists(cache_file_path): return True @@ -673,7 +699,7 @@ class SyntheticShapes(Dataset): return False def _check_export_dataset(self): - """ Check if exported dataset exists. """ + """Check if exported dataset exists.""" dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name) if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0: return True @@ -681,32 +707,30 @@ class SyntheticShapes(Dataset): return False def _check_file_existence(self, filename_dataset): - """ Check if all exported file exists. """ + """Check if all exported file exists.""" # Path to the exported dataset - dataset_path = os.path.join(cfg.synthetic_dataroot, - self.dataset_name + ".h5") + dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") flag = True # Open the h5 dataset with h5py.File(dataset_path, "r") as f: # Iterate through all the primitives for prim_name in f.keys(): - if (len(filename_dataset[prim_name]) - != len(f[prim_name].keys())): + if len(filename_dataset[prim_name]) != len(f[prim_name].keys()): flag = False return flag def _check_primitive_and_index(self, primitive, index): - """ Check if the primitve and index are valid. """ + """Check if the primitve and index are valid.""" # Check primitives if not primitive in self.available_primitives: - raise ValueError( - "[Error] The primitive is not in available primitives.") + raise ValueError("[Error] The primitive is not in available primitives.") prim_len = len(self.filename_dataset[primitive]) # Check the index if not index < prim_len: raise ValueError( "[Error] The index exceeds the total file counts %d for %s" - % (prim_len, primitive)) + % (prim_len, primitive) + ) diff --git a/third_party/SOLD2/sold2/dataset/synthetic_util.py b/third_party/SOLD2/sold2/dataset/synthetic_util.py index af009e0ce7e91391e31d7069064ae6121aa84cc0..63e41c5bbcadd4a1a633a2b33392dc6d4fd088ff 100644 --- a/third_party/SOLD2/sold2/dataset/synthetic_util.py +++ b/third_party/SOLD2/sold2/dataset/synthetic_util.py @@ -17,8 +17,8 @@ def set_random_state(state): def get_random_color(background_color): - """ Output a random scalar in grayscale with a least a small contrast - with the background color. """ + """Output a random scalar in grayscale with a least a small contrast + with the background color.""" color = random_state.randint(256) if abs(color - background_color) < 30: # not enough contrast color = (color + 128) % 256 @@ -26,7 +26,7 @@ def get_random_color(background_color): def get_different_color(previous_colors, min_dist=50, max_count=20): - """ Output a color that contrasts with the previous colors. + """Output a color that contrasts with the previous colors. Parameters: previous_colors: np.array of the previous colors min_dist: the difference between the new color and @@ -42,7 +42,7 @@ def get_different_color(previous_colors, min_dist=50, max_count=20): def add_salt_and_pepper(img): - """ Add salt and pepper noise to an image. """ + """Add salt and pepper noise to an image.""" noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) cv.randu(noise, 0, 255) black = noise < 30 @@ -53,10 +53,15 @@ def add_salt_and_pepper(img): return np.empty((0, 2), dtype=np.int) -def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01, - max_rad_ratio=0.05, min_kernel_size=50, - max_kernel_size=300): - """ Generate a customized background image. +def generate_background( + size=(960, 1280), + nb_blobs=100, + min_rad_ratio=0.01, + max_rad_ratio=0.05, + min_kernel_size=50, + max_kernel_size=300, +): + """Generate a customized background image. Parameters: size: size of the image nb_blobs: number of circles to draw @@ -71,22 +76,30 @@ def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01, cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img) background_color = int(np.mean(img)) blobs = np.concatenate( - [random_state.randint(0, size[1], size=(nb_blobs, 1)), - random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + [ + random_state.randint(0, size[1], size=(nb_blobs, 1)), + random_state.randint(0, size[0], size=(nb_blobs, 1)), + ], + axis=1, + ) for i in range(nb_blobs): col = get_random_color(background_color) - cv.circle(img, (blobs[i][0], blobs[i][1]), - np.random.randint(int(dim * min_rad_ratio), - int(dim * max_rad_ratio)), - col, -1) + cv.circle( + img, + (blobs[i][0], blobs[i][1]), + np.random.randint(int(dim * min_rad_ratio), int(dim * max_rad_ratio)), + col, + -1, + ) kernel_size = random_state.randint(min_kernel_size, max_kernel_size) cv.blur(img, (kernel_size, kernel_size), img) return img -def generate_custom_background(size, background_color, nb_blobs=3000, - kernel_boundaries=(50, 100)): - """ Generate a customized background to fill the shapes. +def generate_custom_background( + size, background_color, nb_blobs=3000, kernel_boundaries=(50, 100) +): + """Generate a customized background to fill the shapes. Parameters: background_color: average color of the background image nb_blobs: number of circles to draw @@ -95,20 +108,22 @@ def generate_custom_background(size, background_color, nb_blobs=3000, img = np.zeros(size, dtype=np.uint8) img = img + get_random_color(background_color) blobs = np.concatenate( - [np.random.randint(0, size[1], size=(nb_blobs, 1)), - np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + [ + np.random.randint(0, size[1], size=(nb_blobs, 1)), + np.random.randint(0, size[0], size=(nb_blobs, 1)), + ], + axis=1, + ) for i in range(nb_blobs): col = get_random_color(background_color) - cv.circle(img, (blobs[i][0], blobs[i][1]), - np.random.randint(20), col, -1) - kernel_size = np.random.randint(kernel_boundaries[0], - kernel_boundaries[1]) + cv.circle(img, (blobs[i][0], blobs[i][1]), np.random.randint(20), col, -1) + kernel_size = np.random.randint(kernel_boundaries[0], kernel_boundaries[1]) cv.blur(img, (kernel_size, kernel_size), img) return img def final_blur(img, kernel_size=(5, 5)): - """ Gaussian blur applied to an image. + """Gaussian blur applied to an image. Parameters: kernel_size: size of the kernel """ @@ -116,33 +131,39 @@ def final_blur(img, kernel_size=(5, 5)): def ccw(A, B, C, dim): - """ Check if the points are listed in counter-clockwise order. """ + """Check if the points are listed in counter-clockwise order.""" if dim == 2: # only 2 dimensions - return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0]) - > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0])) + return (C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0]) > (B[:, 1] - A[:, 1]) * ( + C[:, 0] - A[:, 0] + ) else: # dim should be equal to 3 - return((C[:, 1, :] - A[:, 1, :]) - * (B[:, 0, :] - A[:, 0, :]) - > (B[:, 1, :] - A[:, 1, :]) - * (C[:, 0, :] - A[:, 0, :])) + return (C[:, 1, :] - A[:, 1, :]) * (B[:, 0, :] - A[:, 0, :]) > ( + B[:, 1, :] - A[:, 1, :] + ) * (C[:, 0, :] - A[:, 0, :]) def intersect(A, B, C, D, dim): - """ Return true if line segments AB and CD intersect """ - return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) & - (ccw(A, B, C, dim) != ccw(A, B, D, dim))) + """Return true if line segments AB and CD intersect""" + return np.any( + (ccw(A, C, D, dim) != ccw(B, C, D, dim)) + & (ccw(A, B, C, dim) != ccw(A, B, D, dim)) + ) def keep_points_inside(points, size): - """ Keep only the points whose coordinates are inside the dimensions of - the image of size 'size' """ - mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\ - (points[:, 1] >= 0) & (points[:, 1] < size[0]) + """Keep only the points whose coordinates are inside the dimensions of + the image of size 'size'""" + mask = ( + (points[:, 0] >= 0) + & (points[:, 0] < size[1]) + & (points[:, 1] >= 0) + & (points[:, 1] < size[0]) + ) return points[mask, :] def get_unique_junctions(segments, min_label_len): - """ Get unique junction points from line segments. """ + """Get unique junction points from line segments.""" # Get all junctions from segments junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0) if junctions_all.shape[0] == 0: @@ -159,7 +180,7 @@ def get_unique_junctions(segments, min_label_len): def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray: - """ Get line map given the points and segment sets. """ + """Get line map given the points and segment sets.""" # create empty line map num_point = points.shape[0] line_map = np.zeros([num_point, num_point]) @@ -183,7 +204,7 @@ def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray: def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1): - """ Get line heat map from junctions and line map. """ + """Get line heat map from junctions and line map.""" # Make sure that the thickness is 1 if not isinstance(thickness, int): thickness = int(thickness) @@ -195,7 +216,7 @@ def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1): # Initialize empty map heat_map = np.zeros(size) - if junctions.shape[0] > 0: # If empty, just return zero map + if junctions.shape[0] > 0: # If empty, just return zero map # Iterate through all the junctions for idx in range(junctions.shape[0]): # if no connectivity, just skip it @@ -209,13 +230,13 @@ def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1): point2 = junctions[idx2, :] # Draw line - cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness) + cv.line(heat_map, tuple(point1), tuple(point2), 1.0, thickness) return heat_map def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): - """ Draw random lines and output the positions of the pair of junctions + """Draw random lines and output the positions of the pair of junctions and line associativities. Parameters: nb_lines: maximal number of lines @@ -228,9 +249,9 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): min_dim = min(img.shape) # Convert length constrain to pixel if given float number - if isinstance(min_len, float) and min_len <= 1.: + if isinstance(min_len, float) and min_len <= 1.0: min_len = int(min_dim * min_len) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) # Generate lines one by one @@ -258,10 +279,8 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): # Only record the segments longer than min_label_len seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) if seg_len >= min_label_len: - segments = np.concatenate([segments, - np.array([[x1, y1, x2, y2]])], axis=0) - points = np.concatenate([points, - np.array([[x1, y1], [x2, y2]])], axis=0) + segments = np.concatenate([segments, np.array([[x1, y1, x2, y2]])], axis=0) + points = np.concatenate([points, np.array([[x1, y1], [x2, y2]])], axis=0) # If no line is drawn, recursively call the function if points.shape[0] == 0: @@ -270,19 +289,16 @@ def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): # Get the line associativity map line_map = get_line_map(points, segments) - return { - "points": points, - "line_map": line_map - } + return {"points": points, "line_map": line_map} def check_segment_len(segments, min_len=32): - """ Check if one of the segments is too short (True means too short). """ + """Check if one of the segments is too short (True means too short).""" point1_vec = segments[:, :2] point2_vec = segments[:, 2:] diff = point1_vec - point2_vec - dist = np.sqrt(np.sum(diff ** 2, axis=1)) + dist = np.sqrt(np.sum(diff**2, axis=1)) if np.any(dist < min_len): return True else: @@ -290,7 +306,7 @@ def check_segment_len(segments, min_len=32): def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): - """ Draw a polygon with a random number of corners and return the position + """Draw a polygon with a random number of corners and return the position of the junctions + line map. Parameters: max_sides: maximal number of sides + 1 @@ -303,31 +319,42 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): y = random_state.randint(rad, img.shape[0] - rad) # Convert length constrain to pixel if given float number - if isinstance(min_len, float) and min_len <= 1.: + if isinstance(min_len, float) and min_len <= 1.0: min_len = int(min_dim * min_len) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) # Sample num_corners points inside the circle slices = np.linspace(0, 2 * math.pi, num_corners + 1) - angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) - for i in range(num_corners)] + angles = [ + slices[i] + random_state.rand() * (slices[i + 1] - slices[i]) + for i in range(num_corners) + ] points = np.array( - [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)), - int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))] - for a in angles]) + [ + [ + int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.4) * rad * math.sin(a)), + ] + for a in angles + ] + ) # Filter the points that are too close or that have an angle too flat - norms = [np.linalg.norm(points[(i-1) % num_corners, :] - - points[i, :]) for i in range(num_corners)] + norms = [ + np.linalg.norm(points[(i - 1) % num_corners, :] - points[i, :]) + for i in range(num_corners) + ] mask = np.array(norms) > 0.01 points = points[mask, :] num_corners = points.shape[0] - corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] - - points[i, :], - points[(i+1) % num_corners, :] - - points[i, :]) - for i in range(num_corners)] + corner_angles = [ + angle_between_vectors( + points[(i - 1) % num_corners, :] - points[i, :], + points[(i + 1) % num_corners, :] - points[i, :], + ) + for i in range(num_corners) + ] mask = np.array(corner_angles) < (2 * math.pi / 3) points = points[mask, :] num_corners = points.shape[0] @@ -349,8 +376,7 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): seg_len = np.sqrt(np.sum((p1 - p2) ** 2)) if seg_len >= min_label_len: segments = np.concatenate((segments, segment[None, ...]), axis=0) - segments_raw = np.concatenate((segments_raw, segment[None, ...]), - axis=0) + segments_raw = np.concatenate((segments_raw, segment[None, ...]), axis=0) # If not enough corner, just regenerate one if (num_corners < 3) or check_segment_len(segments_raw, min_len): @@ -372,15 +398,12 @@ def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): col = get_random_color(int(np.mean(img))) cv.fillPoly(img, [corners], col) - return { - "points": junc_points, - "line_map": line_map - } + return {"points": junc_points, "line_map": line_map} def overlap(center, rad, centers, rads): - """ Check that the circle with (center, rad) - doesn't overlap with the other circles. """ + """Check that the circle with (center, rad) + doesn't overlap with the other circles.""" flag = False for i in range(len(rads)): if np.linalg.norm(center - centers[i]) < rad + rads[i]: @@ -390,15 +413,22 @@ def overlap(center, rad, centers, rads): def angle_between_vectors(v1, v2): - """ Compute the angle (in rad) between the two vectors v1 and v2. """ + """Compute the angle (in rad) between the two vectors v1 and v2.""" v1_u = v1 / np.linalg.norm(v1) v2_u = v2 / np.linalg.norm(v2) return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) -def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, - min_label_len=64, safe_margin=5, **extra): - """ Draw multiple polygons with a random number of corners +def draw_multiple_polygons( + img, + max_sides=8, + nb_polygons=30, + min_len=32, + min_label_len=64, + safe_margin=5, + **extra +): + """Draw multiple polygons with a random number of corners and return the junction points + line map. Parameters: max_sides: maximal number of sides + 1 @@ -413,11 +443,11 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, min_dim = min(img.shape[0], img.shape[1]) # Convert length constrain to pixel if given float number - if isinstance(min_len, float) and min_len <= 1.: + if isinstance(min_len, float) and min_len <= 1.0: min_len = int(min_dim * min_len) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) - if isinstance(safe_margin, float) and safe_margin <= 1.: + if isinstance(safe_margin, float) and safe_margin <= 1.0: safe_margin = int(min_dim * safe_margin) # Sequentially generate polygons @@ -435,8 +465,10 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, # Sample num_corners points inside the circle slices = np.linspace(0, 2 * math.pi, num_corners + 1) - angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) - for i in range(num_corners)] + angles = [ + slices[i] + random_state.rand() * (slices[i + 1] - slices[i]) + for i in range(num_corners) + ] # Sample outer points and inner points new_points = [] @@ -444,29 +476,38 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, for a in angles: x_offset = max(random_state.rand(), 0.4) y_offset = max(random_state.rand(), 0.4) - new_points.append([int(x + x_offset * rad * math.cos(a)), - int(y + y_offset * rad * math.sin(a))]) + new_points.append( + [ + int(x + x_offset * rad * math.cos(a)), + int(y + y_offset * rad * math.sin(a)), + ] + ) new_points_real.append( - [int(x + x_offset * rad_real * math.cos(a)), - int(y + y_offset * rad_real * math.sin(a))]) + [ + int(x + x_offset * rad_real * math.cos(a)), + int(y + y_offset * rad_real * math.sin(a)), + ] + ) new_points = np.array(new_points) new_points_real = np.array(new_points_real) # Filter the points that are too close or that have an angle too flat - norms = [np.linalg.norm(new_points[(i-1) % num_corners, :] - - new_points[i, :]) - for i in range(num_corners)] + norms = [ + np.linalg.norm(new_points[(i - 1) % num_corners, :] - new_points[i, :]) + for i in range(num_corners) + ] mask = np.array(norms) > 0.01 new_points = new_points[mask, :] new_points_real = new_points_real[mask, :] num_corners = new_points.shape[0] corner_angles = [ - angle_between_vectors(new_points[(i-1) % num_corners, :] - - new_points[i, :], - new_points[(i+1) % num_corners, :] - - new_points[i, :]) - for i in range(num_corners)] + angle_between_vectors( + new_points[(i - 1) % num_corners, :] - new_points[i, :], + new_points[(i + 1) % num_corners, :] - new_points[i, :], + ) + for i in range(num_corners) + ] mask = np.array(corner_angles) < (2 * math.pi / 3) new_points = new_points[mask, :] new_points_real = new_points_real[mask, :] @@ -480,28 +521,32 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, new_segments = np.zeros((1, 4, num_corners)) new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)] new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)] - new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0] - for i in range(num_corners)] - new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1] - for i in range(num_corners)] + new_segments[:, 2, :] = [ + new_points[(i + 1) % num_corners][0] for i in range(num_corners) + ] + new_segments[:, 3, :] = [ + new_points[(i + 1) % num_corners][1] for i in range(num_corners) + ] # Segments to record (inner circle) new_segments_real = np.zeros((1, 4, num_corners)) - new_segments_real[:, 0, :] = [new_points_real[i][0] - for i in range(num_corners)] - new_segments_real[:, 1, :] = [new_points_real[i][1] - for i in range(num_corners)] + new_segments_real[:, 0, :] = [new_points_real[i][0] for i in range(num_corners)] + new_segments_real[:, 1, :] = [new_points_real[i][1] for i in range(num_corners)] new_segments_real[:, 2, :] = [ - new_points_real[(i + 1) % num_corners][0] - for i in range(num_corners)] + new_points_real[(i + 1) % num_corners][0] for i in range(num_corners) + ] new_segments_real[:, 3, :] = [ - new_points_real[(i + 1) % num_corners][1] - for i in range(num_corners)] + new_points_real[(i + 1) % num_corners][1] for i in range(num_corners) + ] # Check that the polygon will not overlap with pre-existing shapes - if intersect(segments[:, 0:2, None], segments[:, 2:4, None], - new_segments[:, 0:2, :], new_segments[:, 2:4, :], - 3) or overlap(np.array([x, y]), rad, centers, rads): + if intersect( + segments[:, 0:2, None], + segments[:, 2:4, None], + new_segments[:, 0:2, :], + new_segments[:, 2:4, :], + 3, + ) or overlap(np.array([x, y]), rad, centers, rads): continue # Check that the the edges of the polygon is not too short @@ -515,20 +560,19 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, segments = np.concatenate([segments, new_segments], axis=0) # Only record the segments longer than min_label_len - new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2), - (-1, 4)) + new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2), (-1, 4)) points1 = new_segments_real[:, :2] points2 = new_segments_real[:, 2:] seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) new_label_segment = new_segments_real[seg_len >= min_label_len, :] - label_segments = np.concatenate([label_segments, new_label_segment], - axis=0) + label_segments = np.concatenate([label_segments, new_label_segment], axis=0) # Color the polygon with a custom background corners = new_points_real.reshape((-1, 1, 2)) mask = np.zeros(img.shape, np.uint8) custom_background = generate_custom_background( - img.shape, background_color, **extra) + img.shape, background_color, **extra + ) cv.fillPoly(mask, [corners], 255) locs = np.where(mask != 0) @@ -537,7 +581,8 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, # Get all junctions from label segments junctions_all = np.concatenate( - (label_segments[:, :2], label_segments[:, 2:]), axis=0) + (label_segments[:, :2], label_segments[:, 2:]), axis=0 + ) if junctions_all.shape[0] == 0: junc_points = None line_map = None @@ -548,14 +593,11 @@ def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, # Generate line map from points and segments line_map = get_line_map(junc_points, label_segments) - return { - "points": junc_points, - "line_map": line_map - } + return {"points": junc_points, "line_map": line_map} def draw_ellipses(img, nb_ellipses=20): - """ Draw several ellipses. + """Draw several ellipses. Parameters: nb_ellipses: maximal number of ellipses """ @@ -585,16 +627,16 @@ def draw_ellipses(img, nb_ellipses=20): def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): - """ Draw a star and return the junction points + line map. + """Draw a star and return the junction points + line map. Parameters: nb_branches: number of branches of the star """ num_branches = random_state.randint(3, nb_branches) min_dim = min(img.shape[0], img.shape[1]) # Convert length constrain to pixel if given float number - if isinstance(min_len, float) and min_len <= 1.: + if isinstance(min_len, float) and min_len <= 1.0: min_len = int(min_dim * min_len) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025) @@ -603,12 +645,19 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): y = random_state.randint(rad, img.shape[0] - rad) # Sample num_branches points inside the circle slices = np.linspace(0, 2 * math.pi, num_branches + 1) - angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) - for i in range(num_branches)] + angles = [ + slices[i] + random_state.rand() * (slices[i + 1] - slices[i]) + for i in range(num_branches) + ] points = np.array( - [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)), - int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))] - for a in angles]) + [ + [ + int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.3) * rad * math.sin(a)), + ] + for a in angles + ] + ) points = np.concatenate(([[x, y]], points), axis=0) # Generate segments and check the length @@ -624,7 +673,8 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): # Get all junctions from label segments junctions_all = np.concatenate( - (label_segments[:, :2], label_segments[:, 2:]), axis=0) + (label_segments[:, :2], label_segments[:, 2:]), axis=0 + ) if junctions_all.shape[0] == 0: junc_points = None line_map = None @@ -638,19 +688,25 @@ def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): background_color = int(np.mean(img)) for i in range(1, num_branches + 1): col = get_random_color(background_color) - cv.line(img, (points[0][0], points[0][1]), - (points[i][0], points[i][1]), - col, thickness) - return { - "points": junc_points, - "line_map": line_map - } - - -def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, - transform_params=(0.05, 0.15), - min_label_len=64, seed=None): - """ Draw a checkerboard and output the junctions + line segments + cv.line( + img, + (points[0][0], points[0][1]), + (points[i][0], points[i][1]), + col, + thickness, + ) + return {"points": junc_points, "line_map": line_map} + + +def draw_checkerboard_multiseg( + img, + max_rows=7, + max_cols=7, + transform_params=(0.05, 0.15), + min_label_len=64, + seed=None, +): + """Draw a checkerboard and output the junctions + line segments Parameters: max_rows: maximal number of rows + 1 max_cols: maximal number of cols + 1 @@ -664,57 +720,63 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, background_color = int(np.mean(img)) min_dim = min(img.shape) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) # Create the grid rows = random_state.randint(3, max_rows) # number of rows cols = random_state.randint(3, max_cols) # number of cols s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows) - x_coord = np.tile(range(cols + 1), - rows + 1).reshape(((rows + 1) * (cols + 1), 1)) - y_coord = np.repeat(range(rows + 1), - cols + 1).reshape(((rows + 1) * (cols + 1), 1)) + x_coord = np.tile(range(cols + 1), rows + 1).reshape(((rows + 1) * (cols + 1), 1)) + y_coord = np.repeat(range(rows + 1), cols + 1).reshape(((rows + 1) * (cols + 1), 1)) # points are the grid coordinates points = s * np.concatenate([x_coord, y_coord], axis=1) # Warp the grid using an affine transformation and an homography alpha_affine = np.max(img.shape) * ( - transform_params[0] + random_state.rand() * transform_params[1]) + transform_params[0] + random_state.rand() * transform_params[1] + ) center_square = np.float32(img.shape) // 2 min_dim = min(img.shape) square_size = min_dim // 3 - pts1 = np.float32([center_square + square_size, - [center_square[0] + square_size, - center_square[1] - square_size], - center_square - square_size, - [center_square[0] - square_size, - center_square[1] + square_size]]) - pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, - size=pts1.shape).astype(np.float32) + pts1 = np.float32( + [ + center_square + square_size, + [center_square[0] + square_size, center_square[1] - square_size], + center_square - square_size, + [center_square[0] - square_size, center_square[1] + square_size], + ] + ) + pts2 = pts1 + random_state.uniform( + -alpha_affine, alpha_affine, size=pts1.shape + ).astype(np.float32) affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) - pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, - size=pts1.shape).astype(np.float32) + pts2 = pts1 + random_state.uniform( + -alpha_affine / 2, alpha_affine / 2, size=pts1.shape + ).astype(np.float32) perspective_transform = cv.getPerspectiveTransform(pts1, pts2) # Apply the affine transformation - points = np.transpose(np.concatenate( - (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1)) + points = np.transpose( + np.concatenate((points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1) + ) warped_points = np.transpose(np.dot(affine_transform, points)) # Apply the homography - warped_col0 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[0, :2]), axis=1), - perspective_transform[0, 2]) - warped_col1 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[1, :2]), axis=1), - perspective_transform[1, 2]) - warped_col2 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[2, :2]), axis=1), - perspective_transform[2, 2]) + warped_col0 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2], + ) + warped_col1 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2], + ) + warped_col2 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2], + ) warped_col0 = np.divide(warped_col0, warped_col2) warped_col1 = np.divide(warped_col1, warped_col2) - warped_points = np.concatenate( - [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points = np.concatenate([warped_col0[:, None], warped_col1[:, None]], axis=1) warped_points_float = warped_points.copy() warped_points = warped_points.astype(int) @@ -735,15 +797,30 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, colors[i * cols + j] = col # Fill the cell - cv.fillConvexPoly(img, np.array( - [(warped_points[i * (cols + 1) + j, 0], - warped_points[i * (cols + 1) + j, 1]), - (warped_points[i * (cols + 1) + j + 1, 0], - warped_points[i * (cols + 1) + j + 1, 1]), - (warped_points[(i + 1) * (cols + 1) + j + 1, 0], - warped_points[(i + 1) * (cols + 1) + j + 1, 1]), - (warped_points[(i + 1) * (cols + 1) + j, 0], - warped_points[(i + 1) * (cols + 1) + j, 1])]), col) + cv.fillConvexPoly( + img, + np.array( + [ + ( + warped_points[i * (cols + 1) + j, 0], + warped_points[i * (cols + 1) + j, 1], + ), + ( + warped_points[i * (cols + 1) + j + 1, 0], + warped_points[i * (cols + 1) + j + 1, 1], + ), + ( + warped_points[(i + 1) * (cols + 1) + j + 1, 0], + warped_points[(i + 1) * (cols + 1) + j + 1, 1], + ), + ( + warped_points[(i + 1) * (cols + 1) + j, 0], + warped_points[(i + 1) * (cols + 1) + j, 1], + ), + ] + ), + col, + ) label_segments = np.empty([0, 4], dtype=np.int) # Iterate through rows @@ -751,12 +828,18 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, # Include all the combination of the junctions # Iterate through all the combination of junction index in that row multi_seg_lst = [ - np.array([warped_points_float[id1, 0], - warped_points_float[id1, 1], - warped_points_float[id2, 0], - warped_points_float[id2, 1]])[None, ...] - for (id1, id2) in combinations(range( - row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)] + np.array( + [ + warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1], + ] + )[None, ...] + for (id1, id2) in combinations( + range(row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2 + ) + ] multi_seg = np.concatenate(multi_seg_lst, axis=0) label_segments = np.concatenate((label_segments, multi_seg), axis=0) @@ -765,20 +848,31 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, # Include all the combination of the junctions # Iterate throuhg all the combination of junction index in that column multi_seg_lst = [ - np.array([warped_points_float[id1, 0], - warped_points_float[id1, 1], - warped_points_float[id2, 0], - warped_points_float[id2, 1]])[None, ...] - for (id1, id2) in combinations(range( - col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)] + np.array( + [ + warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1], + ] + )[None, ...] + for (id1, id2) in combinations( + range(col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2 + ) + ] multi_seg = np.concatenate(multi_seg_lst, axis=0) label_segments = np.concatenate((label_segments, multi_seg), axis=0) label_segments_filtered = np.zeros([0, 4]) # Define image boundary polygon (in x y manner) image_poly = shapely.geometry.Polygon( - [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], - [0, img.shape[0] - 1]]) + [ + [0, 0], + [img.shape[1] - 1, 0], + [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1], + ] + ) for idx in range(label_segments.shape[0]): # Get the line segment seg_raw = label_segments[idx, :] @@ -787,20 +881,21 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, # The line segment is just inside the image. if seg.intersection(image_poly) == seg: label_segments_filtered = np.concatenate( - (label_segments_filtered, seg_raw[None, ...]), axis=0) + (label_segments_filtered, seg_raw[None, ...]), axis=0 + ) # Intersect with the image. elif seg.intersects(image_poly): # Check intersection try: - p = np.array(seg.intersection( - image_poly).coords).reshape([-1, 4]) + p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4]) # If intersect with eact one point except: continue segment = p label_segments_filtered = np.concatenate( - (label_segments_filtered, segment), axis=0) + (label_segments_filtered, segment), axis=0 + ) else: continue @@ -814,8 +909,7 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, label_segments = label_segments[seg_len >= min_label_len, :] # Get all junctions from label segments - junc_points, line_map = get_unique_junctions(label_segments, - min_label_len) + junc_points, line_map = get_unique_junctions(label_segments, min_label_len) # Draw lines on the boundaries of the board at random nb_rows = random_state.randint(2, rows + 2) @@ -826,33 +920,52 @@ def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, col_idx1 = random_state.randint(cols + 1) col_idx2 = random_state.randint(cols + 1) col = get_random_color(background_color) - cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0], - warped_points[row_idx * (cols + 1) + col_idx1, 1]), - (warped_points[row_idx * (cols + 1) + col_idx2, 0], - warped_points[row_idx * (cols + 1) + col_idx2, 1]), - col, thickness) + cv.line( + img, + ( + warped_points[row_idx * (cols + 1) + col_idx1, 0], + warped_points[row_idx * (cols + 1) + col_idx1, 1], + ), + ( + warped_points[row_idx * (cols + 1) + col_idx2, 0], + warped_points[row_idx * (cols + 1) + col_idx2, 1], + ), + col, + thickness, + ) for _ in range(nb_cols): col_idx = random_state.randint(cols + 1) row_idx1 = random_state.randint(rows + 1) row_idx2 = random_state.randint(rows + 1) col = get_random_color(background_color) - cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0], - warped_points[row_idx1 * (cols + 1) + col_idx, 1]), - (warped_points[row_idx2 * (cols + 1) + col_idx, 0], - warped_points[row_idx2 * (cols + 1) + col_idx, 1]), - col, thickness) + cv.line( + img, + ( + warped_points[row_idx1 * (cols + 1) + col_idx, 0], + warped_points[row_idx1 * (cols + 1) + col_idx, 1], + ), + ( + warped_points[row_idx2 * (cols + 1) + col_idx, 0], + warped_points[row_idx2 * (cols + 1) + col_idx, 1], + ), + col, + thickness, + ) # Keep only the points inside the image points = keep_points_inside(warped_points, img.shape[:2]) - return { - "points": junc_points, - "line_map": line_map - } - - -def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, - transform_params=(0.05, 0.15), seed=None): - """ Draw stripes in a distorted rectangle + return {"points": junc_points, "line_map": line_map} + + +def draw_stripes_multiseg( + img, + max_nb_cols=13, + min_len=0.04, + min_label_len=64, + transform_params=(0.05, 0.15), + seed=None, +): + """Draw stripes in a distorted rectangle and output the junctions points + line map. Parameters: max_nb_cols: maximal number of stripes to be drawn @@ -868,73 +981,84 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, background_color = int(np.mean(img)) # Create the grid - board_size = (int(img.shape[0] * (1 + random_state.rand())), - int(img.shape[1] * (1 + random_state.rand()))) + board_size = ( + int(img.shape[0] * (1 + random_state.rand())), + int(img.shape[1] * (1 + random_state.rand())), + ) # Number of cols col = random_state.randint(5, max_nb_cols) - cols = np.concatenate([board_size[1] * random_state.rand(col - 1), - np.array([0, board_size[1] - 1])], axis=0) + cols = np.concatenate( + [board_size[1] * random_state.rand(col - 1), np.array([0, board_size[1] - 1])], + axis=0, + ) cols = np.unique(cols.astype(int)) # Remove the indices that are too close min_dim = min(img.shape) # Convert length constrain to pixel if given float number - if isinstance(min_len, float) and min_len <= 1.: + if isinstance(min_len, float) and min_len <= 1.0: min_len = int(min_dim * min_len) - if isinstance(min_label_len, float) and min_label_len <= 1.: + if isinstance(min_label_len, float) and min_label_len <= 1.0: min_label_len = int(min_dim * min_label_len) - cols = cols[(np.concatenate([cols[1:], - np.array([board_size[1] + min_len])], - axis=0) - cols) >= min_len] + cols = cols[ + (np.concatenate([cols[1:], np.array([board_size[1] + min_len])], axis=0) - cols) + >= min_len + ] # Update the number of cols col = cols.shape[0] - 1 cols = np.reshape(cols, (col + 1, 1)) cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1) cols2 = np.concatenate( - [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1) + [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1 + ) points = np.concatenate([cols1, cols2], axis=0) # Warp the grid using an affine transformation and a homography alpha_affine = np.max(img.shape) * ( - transform_params[0] + random_state.rand() * transform_params[1]) + transform_params[0] + random_state.rand() * transform_params[1] + ) center_square = np.float32(img.shape) // 2 square_size = min(img.shape) // 3 - pts1 = np.float32([center_square + square_size, - [center_square[0]+square_size, - center_square[1]-square_size], - center_square - square_size, - [center_square[0]-square_size, - center_square[1]+square_size]]) - pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, - size=pts1.shape).astype(np.float32) + pts1 = np.float32( + [ + center_square + square_size, + [center_square[0] + square_size, center_square[1] - square_size], + center_square - square_size, + [center_square[0] - square_size, center_square[1] + square_size], + ] + ) + pts2 = pts1 + random_state.uniform( + -alpha_affine, alpha_affine, size=pts1.shape + ).astype(np.float32) affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) - pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, - size=pts1.shape).astype(np.float32) + pts2 = pts1 + random_state.uniform( + -alpha_affine / 2, alpha_affine / 2, size=pts1.shape + ).astype(np.float32) perspective_transform = cv.getPerspectiveTransform(pts1, pts2) # Apply the affine transformation - points = np.transpose(np.concatenate((points, - np.ones((2 * (col + 1), 1))), - axis=1)) + points = np.transpose(np.concatenate((points, np.ones((2 * (col + 1), 1))), axis=1)) warped_points = np.transpose(np.dot(affine_transform, points)) # Apply the homography - warped_col0 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[0, :2]), axis=1), - perspective_transform[0, 2]) - warped_col1 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[1, :2]), axis=1), - perspective_transform[1, 2]) - warped_col2 = np.add(np.sum(np.multiply( - warped_points, perspective_transform[2, :2]), axis=1), - perspective_transform[2, 2]) + warped_col0 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2], + ) + warped_col1 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2], + ) + warped_col2 = np.add( + np.sum(np.multiply(warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2], + ) warped_col0 = np.divide(warped_col0, warped_col2) warped_col1 = np.divide(warped_col1, warped_col2) - warped_points = np.concatenate( - [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points = np.concatenate([warped_col0[:, None], warped_col1[:, None]], axis=1) warped_points_float = warped_points.copy() warped_points = warped_points.astype(int) @@ -944,15 +1068,18 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, for i in range(col): # Fill the color color = (color + 128 + random_state.randint(-30, 30)) % 256 - cv.fillConvexPoly(img, np.array([(warped_points[i, 0], - warped_points[i, 1]), - (warped_points[i+1, 0], - warped_points[i+1, 1]), - (warped_points[i+col+2, 0], - warped_points[i+col+2, 1]), - (warped_points[i+col+1, 0], - warped_points[i+col+1, 1])]), - color) + cv.fillConvexPoly( + img, + np.array( + [ + (warped_points[i, 0], warped_points[i, 1]), + (warped_points[i + 1, 0], warped_points[i + 1, 1]), + (warped_points[i + col + 2, 0], warped_points[i + col + 2, 1]), + (warped_points[i + col + 1, 0], warped_points[i + col + 1, 1]), + ] + ), + color, + ) segments = np.zeros([0, 4]) row = 1 # in stripes case @@ -960,27 +1087,39 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, for row_idx in range(row + 1): # Include all the combination of the junctions # Iterate through all the combination of junction index in that row - multi_seg_lst = [np.array( - [warped_points_float[id1, 0], - warped_points_float[id1, 1], - warped_points_float[id2, 0], - warped_points_float[id2, 1]])[None, ...] - for (id1, id2) in combinations(range( - row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)] + multi_seg_lst = [ + np.array( + [ + warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1], + ] + )[None, ...] + for (id1, id2) in combinations( + range(row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2 + ) + ] multi_seg = np.concatenate(multi_seg_lst, axis=0) segments = np.concatenate((segments, multi_seg), axis=0) # Iterate through columns - for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges. + for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges. # Include all the combination of the junctions # Iterate throuhg all the combination of junction index in that column - multi_seg_lst = [np.array( - [warped_points_float[id1, 0], - warped_points_float[id1, 1], - warped_points_float[id2, 0], - warped_points_float[id2, 1]])[None, ...] - for (id1, id2) in combinations(range( - col_idx, col_idx + (row * col) + 2, col + 1), 2)] + multi_seg_lst = [ + np.array( + [ + warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1], + ] + )[None, ...] + for (id1, id2) in combinations( + range(col_idx, col_idx + (row * col) + 2, col + 1), 2 + ) + ] multi_seg = np.concatenate(multi_seg_lst, axis=0) segments = np.concatenate((segments, multi_seg), axis=0) @@ -988,8 +1127,13 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, segments_new = np.zeros([0, 4]) # Define image boundary polygon (in x y manner) image_poly = shapely.geometry.Polygon( - [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1], - [0, img.shape[0]-1]]) + [ + [0, 0], + [img.shape[1] - 1, 0], + [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1], + ] + ) for idx in range(segments.shape[0]): # Get the line segment seg_raw = segments[idx, :] @@ -997,15 +1141,13 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, # The line segment is just inside the image. if seg.intersection(image_poly) == seg: - segments_new = np.concatenate( - (segments_new, seg_raw[None, ...]), axis=0) + segments_new = np.concatenate((segments_new, seg_raw[None, ...]), axis=0) # Intersect with the image. elif seg.intersects(image_poly): # Check intersection try: - p = np.array( - seg.intersection(image_poly).coords).reshape([-1, 4]) + p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4]) # If intersect at exact one point, just continue. except: continue @@ -1025,7 +1167,8 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, # Get all junctions from label segments junctions_all = np.concatenate( - (label_segments[:, :2], label_segments[:, 2:]), axis=0) + (label_segments[:, :2], label_segments[:, 2:]), axis=0 + ) if junctions_all.shape[0] == 0: junc_points = None line_map = None @@ -1045,32 +1188,44 @@ def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, col_idx1 = random_state.randint(col + 1) col_idx2 = random_state.randint(col + 1) color = get_random_color(background_color) - cv.line(img, (warped_points[row_idx + col_idx1, 0], - warped_points[row_idx + col_idx1, 1]), - (warped_points[row_idx + col_idx2, 0], - warped_points[row_idx + col_idx2, 1]), - color, thickness) + cv.line( + img, + ( + warped_points[row_idx + col_idx1, 0], + warped_points[row_idx + col_idx1, 1], + ), + ( + warped_points[row_idx + col_idx2, 0], + warped_points[row_idx + col_idx2, 1], + ), + color, + thickness, + ) for _ in range(nb_cols): col_idx = random_state.randint(col + 1) color = get_random_color(background_color) - cv.line(img, (warped_points[col_idx, 0], - warped_points[col_idx, 1]), - (warped_points[col_idx + col + 1, 0], - warped_points[col_idx + col + 1, 1]), - color, thickness) + cv.line( + img, + (warped_points[col_idx, 0], warped_points[col_idx, 1]), + (warped_points[col_idx + col + 1, 0], warped_points[col_idx + col + 1, 1]), + color, + thickness, + ) # Keep only the points inside the image # points = keep_points_inside(warped_points, img.shape[:2]) - return { - "points": junc_points, - "line_map": line_map - } + return {"points": junc_points, "line_map": line_map} -def draw_cube(img, min_size_ratio=0.2, min_label_len=64, - scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)): - """ Draw a 2D projection of a cube and output the visible juntions. +def draw_cube( + img, + min_size_ratio=0.2, + min_label_len=64, + scale_interval=(0.4, 0.6), + trans_interval=(0.5, 0.2), +): + """Draw a 2D projection of a cube and output the visible juntions. Parameters: min_size_ratio: min(img.shape) * min_size_ratio is the smallest achievable cube side size @@ -1088,46 +1243,68 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64, lx = min_side + random_state.rand() * 2 * min_dim / 3 # dims of the cube ly = min_side + random_state.rand() * 2 * min_dim / 3 lz = min_side + random_state.rand() * 2 * min_dim / 3 - cube = np.array([[0, 0, 0], - [lx, 0, 0], - [0, ly, 0], - [lx, ly, 0], - [0, 0, lz], - [lx, 0, lz], - [0, ly, lz], - [lx, ly, lz]]) - rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10. - rotation_1 = np.array([[math.cos(rot_angles[0]), - -math.sin(rot_angles[0]), 0], - [math.sin(rot_angles[0]), - math.cos(rot_angles[0]), 0], - [0, 0, 1]]) - rotation_2 = np.array([[1, 0, 0], - [0, math.cos(rot_angles[1]), - -math.sin(rot_angles[1])], - [0, math.sin(rot_angles[1]), - math.cos(rot_angles[1])]]) - rotation_3 = np.array([[math.cos(rot_angles[2]), 0, - -math.sin(rot_angles[2])], - [0, 1, 0], - [math.sin(rot_angles[2]), 0, - math.cos(rot_angles[2])]]) - scaling = np.array([[scale_interval[0] + - random_state.rand() * scale_interval[1], 0, 0], - [0, scale_interval[0] + - random_state.rand() * scale_interval[1], 0], - [0, 0, scale_interval[0] + - random_state.rand() * scale_interval[1]]]) - trans = np.array([img.shape[1] * trans_interval[0] + - random_state.randint(-img.shape[1] * trans_interval[1], - img.shape[1] * trans_interval[1]), - img.shape[0] * trans_interval[0] + - random_state.randint(-img.shape[0] * trans_interval[1], - img.shape[0] * trans_interval[1]), - 0]) + cube = np.array( + [ + [0, 0, 0], + [lx, 0, 0], + [0, ly, 0], + [lx, ly, 0], + [0, 0, lz], + [lx, 0, lz], + [0, ly, lz], + [lx, ly, lz], + ] + ) + rot_angles = random_state.rand(3) * 3 * math.pi / 10.0 + math.pi / 10.0 + rotation_1 = np.array( + [ + [math.cos(rot_angles[0]), -math.sin(rot_angles[0]), 0], + [math.sin(rot_angles[0]), math.cos(rot_angles[0]), 0], + [0, 0, 1], + ] + ) + rotation_2 = np.array( + [ + [1, 0, 0], + [0, math.cos(rot_angles[1]), -math.sin(rot_angles[1])], + [0, math.sin(rot_angles[1]), math.cos(rot_angles[1])], + ] + ) + rotation_3 = np.array( + [ + [math.cos(rot_angles[2]), 0, -math.sin(rot_angles[2])], + [0, 1, 0], + [math.sin(rot_angles[2]), 0, math.cos(rot_angles[2])], + ] + ) + scaling = np.array( + [ + [scale_interval[0] + random_state.rand() * scale_interval[1], 0, 0], + [0, scale_interval[0] + random_state.rand() * scale_interval[1], 0], + [0, 0, scale_interval[0] + random_state.rand() * scale_interval[1]], + ] + ) + trans = np.array( + [ + img.shape[1] * trans_interval[0] + + random_state.randint( + -img.shape[1] * trans_interval[1], img.shape[1] * trans_interval[1] + ), + img.shape[0] * trans_interval[0] + + random_state.randint( + -img.shape[0] * trans_interval[1], img.shape[0] * trans_interval[1] + ), + 0, + ] + ) cube = trans + np.transpose( - np.dot(scaling, np.dot(rotation_1, - np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube)))))) + np.dot( + scaling, + np.dot( + rotation_1, np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube))) + ), + ) + ) # The hidden corner is 0 by construction # The front one is 7 @@ -1145,18 +1322,26 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64, face = faces[face_idx, :] # Brute-forcely expand all the segments segment = np.array( - [np.concatenate((cube[face[0]], cube[face[1]]), axis=0), - np.concatenate((cube[face[1]], cube[face[2]]), axis=0), - np.concatenate((cube[face[2]], cube[face[3]]), axis=0), - np.concatenate((cube[face[3]], cube[face[0]]), axis=0)]) + [ + np.concatenate((cube[face[0]], cube[face[1]]), axis=0), + np.concatenate((cube[face[1]], cube[face[2]]), axis=0), + np.concatenate((cube[face[2]], cube[face[3]]), axis=0), + np.concatenate((cube[face[3]], cube[face[0]]), axis=0), + ] + ) segments = np.concatenate((segments, segment), axis=0) # Select and refine the segments segments_new = np.zeros([0, 4]) # Define image boundary polygon (in x y manner) image_poly = shapely.geometry.Polygon( - [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], - [0, img.shape[0] - 1]]) + [ + [0, 0], + [img.shape[1] - 1, 0], + [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1], + ] + ) for idx in range(segments.shape[0]): # Get the line segment seg_raw = segments[idx, :] @@ -1164,14 +1349,12 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64, # The line segment is just inside the image. if seg.intersection(image_poly) == seg: - segments_new = np.concatenate( - (segments_new, seg_raw[None, ...]), axis=0) + segments_new = np.concatenate((segments_new, seg_raw[None, ...]), axis=0) # Intersect with the image. elif seg.intersects(image_poly): try: - p = np.array( - seg.intersection(image_poly).coords).reshape([-1, 4]) + p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4]) except: continue segment = p @@ -1190,7 +1373,8 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64, # Get all junctions from label segments junctions_all = np.concatenate( - (label_segments[:, :2], label_segments[:, 2:]), axis=0) + (label_segments[:, :2], label_segments[:, 2:]), axis=0 + ) if junctions_all.shape[0] == 0: junc_points = None line_map = None @@ -1204,29 +1388,25 @@ def draw_cube(img, min_size_ratio=0.2, min_label_len=64, # Fill the faces and draw the contours col_face = get_random_color(background_color) for i in [0, 1, 2]: - cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))], - col_face) + cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))], col_face) thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015) for i in [0, 1, 2]: for j in [0, 1, 2, 3]: - col_edge = (col_face + 128 - + random_state.randint(-64, 64))\ - % 256 # color that constrats with the face color - cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]), - (cube[faces[i][(j + 1) % 4], 0], - cube[faces[i][(j + 1) % 4], 1]), - col_edge, thickness) + col_edge = ( + col_face + 128 + random_state.randint(-64, 64) + ) % 256 # color that constrats with the face color + cv.line( + img, + (cube[faces[i][j], 0], cube[faces[i][j], 1]), + (cube[faces[i][(j + 1) % 4], 0], cube[faces[i][(j + 1) % 4], 1]), + col_edge, + thickness, + ) - return { - "points": junc_points, - "line_map": line_map - } + return {"points": junc_points, "line_map": line_map} def gaussian_noise(img): - """ Apply random noise to the image. """ + """Apply random noise to the image.""" cv.randu(img, 0, 255) - return { - "points": None, - "line_map": None - } + return {"points": None, "line_map": None} diff --git a/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py b/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py index d9338abb169f7a86f3c6e702a031e1c0de86c339..b9c63613b57f9064333bf80bd59fa6553f3ccb8e 100644 --- a/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py +++ b/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py @@ -12,11 +12,21 @@ import shapely.geometry def sample_homography( - shape, perspective=True, scaling=True, rotation=True, - translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1, - perspective_amplitude_x=0.1, perspective_amplitude_y=0.1, - patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False, - translation_overflow=0.): + shape, + perspective=True, + scaling=True, + rotation=True, + translation=True, + n_scales=5, + n_angles=25, + scaling_amplitude=0.1, + perspective_amplitude_x=0.1, + perspective_amplitude_y=0.1, + patch_ratio=0.5, + max_angle=pi / 2, + allow_artifacts=False, + translation_overflow=0.0, +): """ Computes the homography transformation between a random patch in the original image and a warped projection with the same image size. @@ -51,11 +61,12 @@ def sample_homography( shape = np.array(shape) # Corners of the output image - pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]]) + pts1 = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]) # Corners of the input patch margin = (1 - patch_ratio) / 2 - pts2 = margin + np.array([[0, 0], [0, patch_ratio], - [patch_ratio, patch_ratio], [patch_ratio, 0]]) + pts2 = margin + np.array( + [[0, 0], [0, patch_ratio], [patch_ratio, patch_ratio], [patch_ratio, 0]] + ) # Random perspective and affine perturbations if perspective: @@ -65,25 +76,25 @@ def sample_homography( # normal distribution with mean=0, std=perspective_amplitude_y/2 perspective_displacement = np.random.normal( - 0., perspective_amplitude_y/2, [1]) - h_displacement_left = np.random.normal( - 0., perspective_amplitude_x/2, [1]) - h_displacement_right = np.random.normal( - 0., perspective_amplitude_x/2, [1]) - pts2 += np.stack([np.concatenate([h_displacement_left, - perspective_displacement], 0), - np.concatenate([h_displacement_left, - -perspective_displacement], 0), - np.concatenate([h_displacement_right, - perspective_displacement], 0), - np.concatenate([h_displacement_right, - -perspective_displacement], 0)]) + 0.0, perspective_amplitude_y / 2, [1] + ) + h_displacement_left = np.random.normal(0.0, perspective_amplitude_x / 2, [1]) + h_displacement_right = np.random.normal(0.0, perspective_amplitude_x / 2, [1]) + pts2 += np.stack( + [ + np.concatenate([h_displacement_left, perspective_displacement], 0), + np.concatenate([h_displacement_left, -perspective_displacement], 0), + np.concatenate([h_displacement_right, perspective_displacement], 0), + np.concatenate([h_displacement_right, -perspective_displacement], 0), + ] + ) # Random scaling: sample several scales, check collision with borders, # randomly pick a valid one if scaling: scales = np.concatenate( - [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0) + [[1.0], np.random.normal(1, scaling_amplitude / 2, [n_scales])], 0 + ) center = np.mean(pts2, axis=0, keepdims=True) scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center # all scales are valid except scale=1 @@ -91,17 +102,27 @@ def sample_homography( valid = np.array(range(n_scales)) # Chech the valid scale else: - valid = np.where(np.all((scaled >= 0.) - & (scaled < 1.), (1, 2)))[0] + valid = np.where(np.all((scaled >= 0.0) & (scaled < 1.0), (1, 2)))[0] # No valid scale found => recursively call if valid.shape[0] == 0: return sample_homography( - shape, perspective, scaling, rotation, translation, - n_scales, n_angles, scaling_amplitude, - perspective_amplitude_x, perspective_amplitude_y, - patch_ratio, max_angle, allow_artifacts, translation_overflow) - - idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)] + shape, + perspective, + scaling, + rotation, + translation, + n_scales, + n_angles, + scaling_amplitude, + perspective_amplitude_x, + perspective_amplitude_y, + patch_ratio, + max_angle, + allow_artifacts, + translation_overflow, + ) + + idx = valid[np.random.uniform(0.0, valid.shape[0], ()).astype(np.int32)] pts2 = scaled[idx] # Additionally save and return the selected scale. @@ -113,39 +134,60 @@ def sample_homography( if allow_artifacts: t_min += translation_overflow t_max += translation_overflow - pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()), - np.random.uniform(-t_min[1], - t_max[1], ())]))[None, ...] + pts2 += ( + np.stack( + [ + np.random.uniform(-t_min[0], t_max[0], ()), + np.random.uniform(-t_min[1], t_max[1], ()), + ] + ) + )[None, ...] # Random rotation: sample several rotations, check collision with borders, # randomly pick a valid one if rotation: angles = np.linspace(-max_angle, max_angle, n_angles) # in case no rotation is valid - angles = np.concatenate([[0.], angles], axis=0) + angles = np.concatenate([[0.0], angles], axis=0) center = np.mean(pts2, axis=0, keepdims=True) - rot_mat = np.reshape(np.stack( - [np.cos(angles), -np.sin(angles), - np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2]) - rotated = np.matmul( - np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]), - rot_mat) + center + rot_mat = np.reshape( + np.stack( + [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)], + axis=1, + ), + [-1, 2, 2], + ) + rotated = ( + np.matmul( + np.tile((pts2 - center)[None, ...], [n_angles + 1, 1, 1]), rot_mat + ) + + center + ) if allow_artifacts: # All angles are valid, except angle=0 valid = np.array(range(n_angles)) else: - valid = np.where(np.all((rotated >= 0.) - & (rotated < 1.), axis=(1, 2)))[0] - + valid = np.where(np.all((rotated >= 0.0) & (rotated < 1.0), axis=(1, 2)))[0] + if valid.shape[0] == 0: return sample_homography( - shape, perspective, scaling, rotation, translation, - n_scales, n_angles, scaling_amplitude, - perspective_amplitude_x, perspective_amplitude_y, - patch_ratio, max_angle, allow_artifacts, translation_overflow) - - idx = valid[np.random.uniform(0., valid.shape[0], - ()).astype(np.int32)] + shape, + perspective, + scaling, + rotation, + translation, + n_scales, + n_angles, + scaling_amplitude, + perspective_amplitude_x, + perspective_amplitude_y, + patch_ratio, + max_angle, + allow_artifacts, + translation_overflow, + ) + + idx = valid[np.random.uniform(0.0, valid.shape[0], ()).astype(np.int32)] pts2 = rotated[idx] # Rescale to actual size @@ -153,27 +195,33 @@ def sample_homography( pts1 *= shape[None, ...] pts2 *= shape[None, ...] - def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + def ax(p, q): + return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] - def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + def ay(p, q): + return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] - a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) - for f in (ax, ay)], axis=0) - p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4) - for j in range(2)]], axis=0)) + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) + p_mat = np.transpose( + np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) + ) homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None) # Compose the homography vector back to matrix - homo_mat = np.concatenate([ - homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...], - np.concatenate((homo_vec[6], homo_vec[7], [1]), - axis=0)[None, ...]], axis=0) + homo_mat = np.concatenate( + [ + homo_vec[0:3, 0][None, ...], + homo_vec[3:6, 0][None, ...], + np.concatenate((homo_vec[6], homo_vec[7], [1]), axis=0)[None, ...], + ], + axis=0, + ) return homo_mat, selected_scale def convert_to_line_segments(junctions, line_map): - """ Convert junctions and line map to line segments. """ + """Convert junctions and line map to line segments.""" # Copy the line map line_map_tmp = copy.copy(line_map) @@ -188,9 +236,9 @@ def convert_to_line_segments(junctions, line_map): p1 = junctions[idx, :] p2 = junctions[idx2, :] line_segments = np.concatenate( - (line_segments, - np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), - axis=0) + (line_segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), + axis=0, + ) # Update line_map line_map_tmp[idx, idx2] = 0 line_map_tmp[idx2, idx] = 0 @@ -198,46 +246,50 @@ def convert_to_line_segments(junctions, line_map): return line_segments -def compute_valid_mask(image_size, homography, - border_margin, valid_mask=None): +def compute_valid_mask(image_size, homography, border_margin, valid_mask=None): # Warp the mask if valid_mask is None: initial_mask = np.ones(image_size) else: initial_mask = valid_mask mask = cv2.warpPerspective( - initial_mask, homography, (image_size[1], image_size[0]), - flags=cv2.INTER_NEAREST) + initial_mask, + homography, + (image_size[1], image_size[0]), + flags=cv2.INTER_NEAREST, + ) # Optionally perform erosion if border_margin > 0: - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (border_margin*2, )*2) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (border_margin * 2,) * 2) mask = cv2.erode(mask, kernel) - + # Perform dilation if border_margin is negative if border_margin < 0: - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (abs(int(border_margin))*2, )*2) + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (abs(int(border_margin)) * 2,) * 2 + ) mask = cv2.dilate(mask, kernel) return mask def warp_line_segment(line_segments, homography, image_size): - """ Warp the line segments using a homography. """ + """Warp the line segments using a homography.""" # Separate the line segements into 2N points to apply matrix operation num_segments = line_segments.shape[0] junctions = np.concatenate( - (line_segments[:, :2], # The first junction of each segment. - line_segments[:, 2:]), # The second junction of each segment. - axis=0) + ( + line_segments[:, :2], # The first junction of each segment. + line_segments[:, 2:], + ), # The second junction of each segment. + axis=0, + ) # Convert to homogeneous coordinates # Flip the junctions before converting to homogeneous (xy format) junctions = np.flip(junctions, axis=1) - junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])), - axis=1) + junctions = np.concatenate((junctions, np.ones([2 * num_segments, 1])), axis=1) warped_junctions = np.matmul(homography, junctions.T).T # Convert back to segments @@ -245,41 +297,43 @@ def warp_line_segment(line_segments, homography, image_size): # (Convert back to hw format) warped_junctions = np.flip(warped_junctions, axis=1) warped_segments = np.concatenate( - (warped_junctions[:num_segments, :], - warped_junctions[num_segments:, :]), - axis=1 + (warped_junctions[:num_segments, :], warped_junctions[num_segments:, :]), axis=1 ) # Check the intersections with the boundary warped_segments_new = np.zeros([0, 4]) image_poly = shapely.geometry.Polygon( - [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1], - [0, image_size[0]-1]]) + [ + [0, 0], + [image_size[1] - 1, 0], + [image_size[1] - 1, image_size[0] - 1], + [0, image_size[0] - 1], + ] + ) for idx in range(warped_segments.shape[0]): # Get the line segment - seg_raw = warped_segments[idx, :] # in HW format. + seg_raw = warped_segments[idx, :] # in HW format. # Convert to shapely line (flip to xy format) - seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), - np.flip(seg_raw[2:])]) + seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), np.flip(seg_raw[2:])]) # The line segment is just inside the image. if seg.intersection(image_poly) == seg: - warped_segments_new = np.concatenate((warped_segments_new, - seg_raw[None, ...]), axis=0) - + warped_segments_new = np.concatenate( + (warped_segments_new, seg_raw[None, ...]), axis=0 + ) + # Intersect with the image. elif seg.intersects(image_poly): # Check intersection try: - p = np.array( - seg.intersection(image_poly).coords).reshape([-1, 4]) + p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4]) # If intersect at exact one point, just continue. except: continue - segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], - axis=0)])[None, ...] - warped_segments_new = np.concatenate( - (warped_segments_new, segment), axis=0) + segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], axis=0)])[ + None, ... + ] + warped_segments_new = np.concatenate((warped_segments_new, segment), axis=0) else: continue @@ -289,9 +343,9 @@ def warp_line_segment(line_segments, homography, image_size): class homography_transform(object): - """ # Homography transformations. """ - def __init__(self, image_size, homograpy_config, - border_margin=0, min_label_len=20): + """# Homography transformations.""" + + def __init__(self, image_size, homograpy_config, border_margin=0, min_label_len=20): self.homo_config = homograpy_config self.image_size = image_size self.target_size = (self.image_size[1], self.image_size[0]) @@ -300,31 +354,33 @@ class homography_transform(object): raise ValueError("[Error] min_label_len should be in pixels.") self.min_label_len = min_label_len - def __call__(self, input_image, junctions, line_map, - valid_mask=None, homo=None, scale=None): + def __call__( + self, input_image, junctions, line_map, valid_mask=None, homo=None, scale=None + ): # Sample one random homography or use the given one if homo is None or scale is None: - homo, scale = sample_homography(self.image_size, - **self.homo_config) + homo, scale = sample_homography(self.image_size, **self.homo_config) # Warp the image warped_image = cv2.warpPerspective( - input_image, homo, self.target_size, flags=cv2.INTER_LINEAR) - - valid_mask = compute_valid_mask(self.image_size, homo, - self.border_margin, valid_mask) + input_image, homo, self.target_size, flags=cv2.INTER_LINEAR + ) + + valid_mask = compute_valid_mask( + self.image_size, homo, self.border_margin, valid_mask + ) # Convert junctions and line_map back to line segments line_segments = convert_to_line_segments(junctions, line_map) # Warp the segments and check the length. # Adjust the min_label_length - warped_segments = warp_line_segment(line_segments, homo, - self.image_size) + warped_segments = warp_line_segment(line_segments, homo, self.image_size) # Convert back to junctions and line_map - junctions_new = np.concatenate((warped_segments[:, :2], - warped_segments[:, 2:]), axis=0) + junctions_new = np.concatenate( + (warped_segments[:, :2], warped_segments[:, 2:]), axis=0 + ) if junctions_new.shape[0] == 0: junctions_new = np.zeros([0, 2]) line_map = np.zeros([0, 0]) @@ -333,11 +389,11 @@ class homography_transform(object): junctions_new = np.unique(junctions_new, axis=0) # Generate line map from points and segments - line_map = get_line_map(junctions_new, - warped_segments).astype(np.int) + line_map = get_line_map(junctions_new, warped_segments).astype(np.int) # Compute the heatmap - warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1), - line_map, self.image_size) + warped_heatmap = get_line_heatmap( + np.flip(junctions_new, axis=1), line_map, self.image_size + ) return { "junctions": junctions_new, @@ -346,5 +402,5 @@ class homography_transform(object): "line_map": line_map, "warped_heatmap": warped_heatmap, "homo": homo, - "scale": scale + "scale": scale, } diff --git a/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py b/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py index 8fa44bf0efa93a47e5f8012988058f1cbd49324f..5f41192cd2cba7b47939f031027e8dce6e1a406f 100644 --- a/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py +++ b/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py @@ -9,17 +9,18 @@ import cv2 # List all the available augmentations available_augmentations = [ - 'additive_gaussian_noise', - 'additive_speckle_noise', - 'random_brightness', - 'random_contrast', - 'additive_shade', - 'motion_blur' + "additive_gaussian_noise", + "additive_speckle_noise", + "random_brightness", + "random_contrast", + "additive_shade", + "motion_blur", ] class additive_gaussian_noise(object): - """ Additive gaussian noise. """ + """Additive gaussian noise.""" + def __init__(self, stddev_range=None): # If std is not given, use the default setting if stddev_range is None: @@ -30,14 +31,15 @@ class additive_gaussian_noise(object): def __call__(self, input_image): # Get the noise stddev stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1]) - noise = np.random.normal(0., stddev, size=input_image.shape) - noisy_image = (input_image + noise).clip(0., 255.) + noise = np.random.normal(0.0, stddev, size=input_image.shape) + noisy_image = (input_image + noise).clip(0.0, 255.0) return noisy_image class additive_speckle_noise(object): - """ Additive speckle noise. """ + """Additive speckle noise.""" + def __init__(self, prob_range=None): # If prob range is not given, use the default setting if prob_range is None: @@ -48,7 +50,7 @@ class additive_speckle_noise(object): def __call__(self, input_image): # Sample prob = np.random.uniform(self.prob_range[0], self.prob_range[1]) - sample = np.random.uniform(0., 1., size=input_image.shape) + sample = np.random.uniform(0.0, 1.0, size=input_image.shape) # Get the mask mask0 = sample <= prob @@ -56,14 +58,15 @@ class additive_speckle_noise(object): # Mask the image (here we assume the image ranges from 0~255 noisy = input_image.copy() - noisy[mask0] = 0. - noisy[mask1] = 255. + noisy[mask0] = 0.0 + noisy[mask1] = 255.0 return noisy class random_brightness(object): - """ Brightness change. """ + """Brightness change.""" + def __init__(self, brightness=None): # If the brightness is not given, use the default setting if brightness is None: @@ -83,7 +86,8 @@ class random_brightness(object): class random_contrast(object): - """ Additive contrast. """ + """Additive contrast.""" + def __init__(self, contrast=None): # If the brightness is not given, use the default setting if contrast is None: @@ -103,9 +107,9 @@ class random_contrast(object): class additive_shade(object): - """ Additive shade. """ - def __init__(self, nb_ellipses=20, transparency_range=None, - kernel_size_range=None): + """Additive shade.""" + + def __init__(self, nb_ellipses=20, transparency_range=None, kernel_size_range=None): self.nb_ellipses = nb_ellipses if transparency_range is None: self.transparency_range = [-0.5, 0.8] @@ -136,39 +140,40 @@ class additive_shade(object): # kernel_size has to be odd if (kernel_size % 2) == 0: kernel_size += 1 - mask = cv2.GaussianBlur(mask.astype(np.float32), - (kernel_size, kernel_size), 0) - shaded = (input_image[..., None] - * (1 - transparency * mask[..., np.newaxis]/255.)) + mask = cv2.GaussianBlur(mask.astype(np.float32), (kernel_size, kernel_size), 0) + shaded = input_image[..., None] * ( + 1 - transparency * mask[..., np.newaxis] / 255.0 + ) shaded = np.clip(shaded, 0, 255) return np.reshape(shaded, input_image.shape) class motion_blur(object): - """ Motion blur. """ + """Motion blur.""" + def __init__(self, max_kernel_size=10): self.max_kernel_size = max_kernel_size def __call__(self, input_image): # Either vertical, horizontal or diagonal blur - mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) - ksize = np.random.randint( - 0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1 + mode = np.random.choice(["h", "v", "diag_down", "diag_up"]) + ksize = np.random.randint(0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1 center = int((ksize - 1) / 2) kernel = np.zeros((ksize, ksize)) - if mode == 'h': - kernel[center, :] = 1. - elif mode == 'v': - kernel[:, center] = 1. - elif mode == 'diag_down': + if mode == "h": + kernel[center, :] = 1.0 + elif mode == "v": + kernel[:, center] = 1.0 + elif mode == "diag_down": kernel = np.eye(ksize) - elif mode == 'diag_up': + elif mode == "diag_up": kernel = np.flip(np.eye(ksize), 0) - var = ksize * ksize / 16. + var = ksize * ksize / 16.0 grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) - gaussian = np.exp(-(np.square(grid - center) - + np.square(grid.T - center)) / (2. * var)) + gaussian = np.exp( + -(np.square(grid - center) + np.square(grid.T - center)) / (2.0 * var) + ) kernel *= gaussian kernel /= np.sum(kernel) blurred = cv2.filter2D(input_image, -1, kernel) @@ -177,7 +182,8 @@ class motion_blur(object): class normalize_image(object): - """ Image normalization to the range [0, 1]. """ + """Image normalization to the range [0, 1].""" + def __init__(self): self.normalize_value = 255 diff --git a/third_party/SOLD2/sold2/dataset/transforms/utils.py b/third_party/SOLD2/sold2/dataset/transforms/utils.py index 5f1ed09e5b32e2ae2f3577e0e8e5491495e7b05b..4e2d9b4234400b16c59773ebcf15ecc557df6cac 100644 --- a/third_party/SOLD2/sold2/dataset/transforms/utils.py +++ b/third_party/SOLD2/sold2/dataset/transforms/utils.py @@ -9,7 +9,7 @@ from ..synthetic_util import get_line_map from . import homographic_transforms as homoaug -def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0): +def random_scaling(image, junctions, line_map, scale=1.0, h_crop=0, w_crop=0): H, W = image.shape[:2] H_scale, W_scale = round(H * scale), round(W * scale) @@ -18,42 +18,46 @@ def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0): return (image, junctions, line_map, np.ones([H, W], dtype=np.int)) # Zoom-in => resize and random crop - if scale >= 1.: - image_big = cv2.resize(image, (W_scale, H_scale), - interpolation=cv2.INTER_LINEAR) + if scale >= 1.0: + image_big = cv2.resize( + image, (W_scale, H_scale), interpolation=cv2.INTER_LINEAR + ) # Crop the image - image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...] + image = image_big[h_crop : h_crop + H, w_crop : w_crop + W, ...] valid_mask = np.ones([H, W], dtype=np.int) # Process junctions junctions, line_map = process_junctions_and_line_map( - h_crop, w_crop, H, W, H_scale, W_scale, - junctions, line_map, "zoom-in") + h_crop, w_crop, H, W, H_scale, W_scale, junctions, line_map, "zoom-in" + ) # Zoom-out => resize and pad else: image_shape_raw = image.shape - image_small = cv2.resize(image, (W_scale, H_scale), - interpolation=cv2.INTER_AREA) + image_small = cv2.resize( + image, (W_scale, H_scale), interpolation=cv2.INTER_AREA + ) # Decide the pasting location h_start = round((H - H_scale) / 2) w_start = round((W - W_scale) / 2) # Paste the image to the middle image = np.zeros(image_shape_raw, dtype=np.float) - image[h_start:h_start+H_scale, - w_start:w_start+W_scale, ...] = image_small + image[ + h_start : h_start + H_scale, w_start : w_start + W_scale, ... + ] = image_small valid_mask = np.zeros([H, W], dtype=np.int) - valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1 + valid_mask[h_start : h_start + H_scale, w_start : w_start + W_scale] = 1 # Process the junctions junctions, line_map = process_junctions_and_line_map( - h_start, w_start, H, W, H_scale, W_scale, - junctions, line_map, "zoom-out") + h_start, w_start, H, W, H_scale, W_scale, junctions, line_map, "zoom-out" + ) return image, junctions, line_map, valid_mask -def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale, - junctions, line_map, mode="zoom-in"): +def process_junctions_and_line_map( + h_start, w_start, H, W, H_scale, W_scale, junctions, line_map, mode="zoom-in" +): if mode == "zoom-in": junctions[:, 0] = junctions[:, 0] * H_scale / H junctions[:, 1] = junctions[:, 1] * W_scale / W @@ -61,53 +65,55 @@ def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale, # Crop segments to the new boundaries line_segments_new = np.zeros([0, 4]) image_poly = sg.Polygon( - [[w_start, h_start], - [w_start+W, h_start], - [w_start+W, h_start+H], - [w_start, h_start+H] - ]) + [ + [w_start, h_start], + [w_start + W, h_start], + [w_start + W, h_start + H], + [w_start, h_start + H], + ] + ) for idx in range(line_segments.shape[0]): # Get the line segment - seg_raw = line_segments[idx, :] # in HW format. + seg_raw = line_segments[idx, :] # in HW format. # Convert to shapely line (flip to xy format) - seg = sg.LineString([np.flip(seg_raw[:2]), - np.flip(seg_raw[2:])]) + seg = sg.LineString([np.flip(seg_raw[:2]), np.flip(seg_raw[2:])]) # The line segment is just inside the image. if seg.intersection(image_poly) == seg: line_segments_new = np.concatenate( - (line_segments_new, seg_raw[None, ...]), axis=0) + (line_segments_new, seg_raw[None, ...]), axis=0 + ) # Intersect with the image. elif seg.intersects(image_poly): # Check intersection try: - p = np.array( - seg.intersection(image_poly).coords).reshape([-1, 4]) + p = np.array(seg.intersection(image_poly).coords).reshape([-1, 4]) # If intersect at exact one point, just continue. except: continue - segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], - axis=0)])[None, ...] - line_segments_new = np.concatenate( - (line_segments_new, segment), axis=0) + segment = np.concatenate( + [np.flip(p[0, :2]), np.flip(p[0, 2:], axis=0)] + )[None, ...] + line_segments_new = np.concatenate((line_segments_new, segment), axis=0) else: continue line_segments_new = (np.round(line_segments_new)).astype(np.int) # Filter segments with 0 length segment_lens = np.linalg.norm( - line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1) + line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1 + ) seg_mask = segment_lens != 0 line_segments_new = line_segments_new[seg_mask, :] # Convert back to junctions and line_map junctions_new = np.concatenate( - (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0) + (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0 + ) if junctions_new.shape[0] == 0: junctions_new = np.zeros([0, 2]) line_map = np.zeros([0, 0]) else: junctions_new = np.unique(junctions_new, axis=0) # Generate line map from points and segments - line_map = get_line_map(junctions_new, - line_segments_new).astype(np.int) + line_map = get_line_map(junctions_new, line_segments_new).astype(np.int) junctions_new[:, 0] -= h_start junctions_new[:, 1] -= w_start junctions = junctions_new diff --git a/third_party/SOLD2/sold2/dataset/wireframe_dataset.py b/third_party/SOLD2/sold2/dataset/wireframe_dataset.py index ed5bb910bed1b89934ddaaec3bcddf111ea0faef..44341d7394303188db3ba69123bb4b4212700466 100644 --- a/third_party/SOLD2/sold2/dataset/wireframe_dataset.py +++ b/third_party/SOLD2/sold2/dataset/wireframe_dataset.py @@ -27,12 +27,19 @@ from ..misc.geometry_utils import warp_points, mask_points def wireframe_collate_fn(batch): - """ Customized collate_fn for wireframe dataset. """ - batch_keys = ["image", "junction_map", "valid_mask", "heatmap", - "heatmap_pos", "heatmap_neg", "homography", - "line_points", "line_indices"] - list_keys = ["junctions", "line_map", "line_map_pos", - "line_map_neg", "file_key"] + """Customized collate_fn for wireframe dataset.""" + batch_keys = [ + "image", + "junction_map", + "valid_mask", + "heatmap", + "heatmap_pos", + "heatmap_neg", + "homography", + "line_points", + "line_indices", + ] + list_keys = ["junctions", "line_map", "line_map_pos", "line_map_neg", "file_key"] outputs = {} for data_key in batch[0].keys(): @@ -41,14 +48,16 @@ def wireframe_collate_fn(batch): # print(batch_match, list_match) if batch_match > 0 and list_match == 0: outputs[data_key] = torch_loader.default_collate( - [b[data_key] for b in batch]) + [b[data_key] for b in batch] + ) elif batch_match == 0 and list_match > 0: outputs[data_key] = [b[data_key] for b in batch] elif batch_match == 0 and list_match == 0: continue else: raise ValueError( - "[Error] A key matches batch keys and list keys simultaneously.") + "[Error] A key matches batch keys and list keys simultaneously." + ) return outputs @@ -58,7 +67,8 @@ class WireframeDataset(Dataset): super(WireframeDataset, self).__init__() if not mode in ["train", "test"]: raise ValueError( - "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.") + "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'." + ) self.mode = mode if config is None: @@ -72,18 +82,17 @@ class WireframeDataset(Dataset): self.dataset_name = self.get_dataset_name() self.cache_name = self.get_cache_name() self.cache_path = cfg.wireframe_cache_path - + # Get the ground truth source - self.gt_source = self.config.get("gt_source_%s"%(self.mode), - "official") + self.gt_source = self.config.get("gt_source_%s" % (self.mode), "official") if not self.gt_source == "official": # Convert gt_source to full path self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source) # Check the full path exists if not os.path.exists(self.gt_source): raise ValueError( - "[Error] The specified ground truth source does not exist.") - + "[Error] The specified ground truth source does not exist." + ) # Get the filename dataset print("[Info] Initializing wireframe dataset...") @@ -95,22 +104,22 @@ class WireframeDataset(Dataset): # Print some info print("[Info] Successfully initialized dataset") print("\t Name: wireframe") - print("\t Mode: %s" %(self.mode)) - print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), - "official"))) - print("\t Counts: %d" %(self.dataset_length)) + print("\t Mode: %s" % (self.mode)) + print("\t Gt: %s" % (self.config.get("gt_source_%s" % (self.mode), "official"))) + print("\t Counts: %d" % (self.dataset_length)) print("----------------------------------------") ####################################### ## Dataset construction related APIs ## ####################################### def construct_dataset(self): - """ Construct the dataset (from scratch or from cache). """ + """Construct the dataset (from scratch or from cache).""" # Check if the filename cache exists # If cache exists, load from cache if self._check_dataset_cache(): - print("\t Found filename cache %s at %s"%(self.cache_name, - self.cache_path)) + print( + "\t Found filename cache %s at %s" % (self.cache_name, self.cache_path) + ) print("\t Load filename cache...") filename_dataset, datapoints = self.get_filename_dataset_from_cache() # If not, initialize dataset from scratch @@ -120,30 +129,27 @@ class WireframeDataset(Dataset): filename_dataset, datapoints = self.get_filename_dataset() print("\t Create filename dataset cache...") self.create_filename_dataset_cache(filename_dataset, datapoints) - + return filename_dataset, datapoints - + def create_filename_dataset_cache(self, filename_dataset, datapoints): - """ Create filename dataset cache for faster initialization. """ + """Create filename dataset cache for faster initialization.""" # Check cache path exists if not os.path.exists(self.cache_path): os.makedirs(self.cache_path) cache_file_path = os.path.join(self.cache_path, self.cache_name) - data = { - "filename_dataset": filename_dataset, - "datapoints": datapoints - } + data = {"filename_dataset": filename_dataset, "datapoints": datapoints} with open(cache_file_path, "wb") as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) - + def get_filename_dataset_from_cache(self): - """ Get filename dataset from cache. """ + """Get filename dataset from cache.""" # Load from pkl cache cache_file_path = os.path.join(self.cache_path, self.cache_name) with open(cache_file_path, "rb") as f: data = pickle.load(f) - + return data["filename_dataset"], data["datapoints"] def get_filename_dataset(self): @@ -152,14 +158,18 @@ class WireframeDataset(Dataset): dataset_path = os.path.join(cfg.wireframe_dataroot, "train") elif self.mode == "test": dataset_path = os.path.join(cfg.wireframe_dataroot, "valid") - + # Get paths to all image files - image_paths = sorted([os.path.join(dataset_path, _) - for _ in os.listdir(dataset_path)\ - if os.path.splitext(_)[-1] == ".png"]) + image_paths = sorted( + [ + os.path.join(dataset_path, _) + for _ in os.listdir(dataset_path) + if os.path.splitext(_)[-1] == ".png" + ] + ) # Get the shared prefix prefix_paths = [_.split(".png")[0] for _ in image_paths] - + # Get the label paths (different procedure for different split) if self.mode == "train": label_paths = [_ + "_label.npz" for _ in prefix_paths] @@ -171,17 +181,18 @@ class WireframeDataset(Dataset): for idx in range(len(image_paths)): image_path = image_paths[idx] label_path = label_paths[idx] - if (not (os.path.exists(image_path) - and os.path.exists(label_path))): + if not (os.path.exists(image_path) and os.path.exists(label_path)): raise ValueError( - "[Error] The image and label do not exist. %s"%(image_path)) + "[Error] The image and label do not exist. %s" % (image_path) + ) # Further verify mat paths for test split if self.mode == "test": mat_path = mat_paths[idx] if not os.path.exists(mat_path): raise ValueError( - "[Error] The mat file does not exist. %s"%(mat_path)) - + "[Error] The mat file does not exist. %s" % (mat_path) + ) + # Construct the filename dataset num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) filename_dataset = {} @@ -191,25 +202,25 @@ class WireframeDataset(Dataset): filename_dataset[key] = { "image": image_paths[idx], - "label": label_paths[idx] + "label": label_paths[idx], } # Get the datapoints datapoints = list(sorted(filename_dataset.keys())) return filename_dataset, datapoints - + def get_dataset_name(self): - """ Get dataset name from dataset config / default config. """ + """Get dataset name from dataset config / default config.""" if self.config["dataset_name"] is None: dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode else: dataset_name = self.config["dataset_name"] + "_%s" % self.mode return dataset_name - + def get_cache_name(self): - """ Get cache name from dataset config / default config. """ + """Get cache name from dataset config / default config.""" if self.config["dataset_name"] is None: dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode else: @@ -218,35 +229,27 @@ class WireframeDataset(Dataset): cache_name = dataset_name + "_cache.pkl" return cache_name - + @staticmethod def get_padded_filename(num_pad, idx): - """ Get the padded filename using adaptive padding. """ + """Get the padded filename using adaptive padding.""" file_len = len("%d" % (idx)) filename = "0" * (num_pad - file_len) + "%d" % (idx) return filename def get_default_config(self): - """ Get the default configuration. """ + """Get the default configuration.""" return { "dataset_name": "wireframe", "add_augmentation_to_all_splits": False, - "preprocessing": { - "resize": [240, 320], - "blur_size": 11 - }, - "augmentation":{ - "photometric":{ - "enable": False - }, - "homographic":{ - "enable": False - }, + "preprocessing": {"resize": [240, 320], "blur_size": 11}, + "augmentation": { + "photometric": {"enable": False}, + "homographic": {"enable": False}, }, } - ############################################ ## Pytorch and preprocessing related APIs ## ############################################ @@ -280,13 +283,13 @@ class WireframeDataset(Dataset): # TODO: How to process mat data if data_path.get("line_mat") is not None: raise NotImplementedError - + return output - + @staticmethod def convert_line_map(lcnn_line_map, num_junctions): - """ Convert the line_pos or line_neg - (represented by two junction indexes) to our line map. """ + """Convert the line_pos or line_neg + (represented by two junction indexes) to our line map.""" # Initialize empty line map line_map = np.zeros([num_junctions, num_junctions]) @@ -297,59 +300,60 @@ class WireframeDataset(Dataset): line_map[index1, index2] = 1 line_map[index2, index1] = 1 - + return line_map - + @staticmethod def junc_to_junc_map(junctions, image_size): - """ Convert junction points to junction maps. """ + """Convert junction points to junction maps.""" junctions = np.round(junctions).astype(np.int) # Clip the boundary by image size - junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) - junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + junctions[:, 0] = np.clip(junctions[:, 0], 0.0, image_size[0] - 1) + junctions[:, 1] = np.clip(junctions[:, 1], 0.0, image_size[1] - 1) # Create junction map junc_map = np.zeros([image_size[0], image_size[1]]) junc_map[junctions[:, 0], junctions[:, 1]] = 1 return junc_map[..., None].astype(np.int) - + def parse_transforms(self, names, all_transforms): - """ Parse the transform. """ - trans = all_transforms if (names == 'all') \ + """Parse the transform.""" + trans = ( + all_transforms + if (names == "all") else (names if isinstance(names, list) else [names]) + ) assert set(trans) <= set(all_transforms) return trans def get_photo_transform(self): - """ Get list of photometric transforms (according to the config). """ + """Get list of photometric transforms (according to the config).""" # Get the photometric transform config photo_config = self.config["augmentation"]["photometric"] if not photo_config["enable"]: - raise ValueError( - "[Error] Photometric augmentation is not enabled.") - + raise ValueError("[Error] Photometric augmentation is not enabled.") + # Parse photometric transforms - trans_lst = self.parse_transforms(photo_config["primitives"], - photoaug.available_augmentations) - trans_config_lst = [photo_config["params"].get(p, {}) - for p in trans_lst] + trans_lst = self.parse_transforms( + photo_config["primitives"], photoaug.available_augmentations + ) + trans_config_lst = [photo_config["params"].get(p, {}) for p in trans_lst] # List of photometric augmentation photometric_trans_lst = [ - getattr(photoaug, trans)(**conf) \ + getattr(photoaug, trans)(**conf) for (trans, conf) in zip(trans_lst, trans_config_lst) ] return photometric_trans_lst def get_homo_transform(self): - """ Get homographic transforms (according to the config). """ + """Get homographic transforms (according to the config).""" # Get homographic transforms for image homo_config = self.config["augmentation"]["homographic"]["params"] if not self.config["augmentation"]["homographic"]["enable"]: - raise ValueError( - "[Error] Homographic augmentation is not enabled.") + raise ValueError("[Error] Homographic augmentation is not enabled.") # Parse the homographic transforms image_shape = self.config["preprocessing"]["resize"] @@ -359,67 +363,73 @@ class WireframeDataset(Dataset): min_label_tmp = self.config["generation"]["min_label_len"] except: min_label_tmp = None - + # float label len => fraction - if isinstance(min_label_tmp, float): # Skip if not provided + if isinstance(min_label_tmp, float): # Skip if not provided min_label_len = min_label_tmp * min(image_shape) # int label len => length in pixel elif isinstance(min_label_tmp, int): - scale_ratio = (self.config["preprocessing"]["resize"] - / self.config["generation"]["image_size"][0]) - min_label_len = (self.config["generation"]["min_label_len"] - * scale_ratio) + scale_ratio = ( + self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0] + ) + min_label_len = self.config["generation"]["min_label_len"] * scale_ratio # if none => no restriction else: min_label_len = 0 - + # Initialize the transform homographic_trans = homoaug.homography_transform( - image_shape, homo_config, 0, min_label_len) + image_shape, homo_config, 0, min_label_len + ) return homographic_trans - def get_line_points(self, junctions, line_map, H1=None, H2=None, - img_size=None, warp=False): - """ Sample evenly points along each line segments - and keep track of line idx. """ + def get_line_points( + self, junctions, line_map, H1=None, H2=None, img_size=None, warp=False + ): + """Sample evenly points along each line segments + and keep track of line idx.""" if np.sum(line_map) == 0: # No segment detected in the image line_indices = np.zeros(self.config["max_pts"], dtype=int) line_points = np.zeros((self.config["max_pts"], 2), dtype=float) return line_points, line_indices - + # Extract all pairs of connected junctions junc_indices = np.array( - [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]) - line_segments = np.stack([junctions[junc_indices[:, 0]], - junctions[junc_indices[:, 1]]], axis=1) + [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i] + ) + line_segments = np.stack( + [junctions[junc_indices[:, 0]], junctions[junc_indices[:, 1]]], axis=1 + ) # line_segments is (num_lines, 2, 2) - line_lengths = np.linalg.norm( - line_segments[:, 0] - line_segments[:, 1], axis=1) + line_lengths = np.linalg.norm(line_segments[:, 0] - line_segments[:, 1], axis=1) # Sample the points separated by at least min_dist_pts along each line # The number of samples depends on the length of the line - num_samples = np.minimum(line_lengths // self.config["min_dist_pts"], - self.config["max_num_samples"]) + num_samples = np.minimum( + line_lengths // self.config["min_dist_pts"], self.config["max_num_samples"] + ) line_points = [] line_indices = [] cur_line_idx = 1 for n in np.arange(2, self.config["max_num_samples"] + 1): # Consider all lines where we can fit up to n points cur_line_seg = line_segments[num_samples == n] - line_points_x = np.linspace(cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 0], - n, axis=-1).flatten() - line_points_y = np.linspace(cur_line_seg[:, 0, 1], - cur_line_seg[:, 1, 1], - n, axis=-1).flatten() + line_points_x = np.linspace( + cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1 + ).flatten() + line_points_y = np.linspace( + cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1 + ).flatten() jitter = self.config.get("jittering", 0) if jitter: # Add a small random jittering of all points along the line angles = np.arctan2( cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n) + cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1], + ).repeat(n) jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter line_points_x += jitter_hyp * np.sin(angles) line_points_y += jitter_hyp * np.cos(angles) @@ -429,10 +439,8 @@ class WireframeDataset(Dataset): line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines) line_indices.append(line_idx.repeat(n)) cur_line_idx += num_cur_lines - line_points = np.concatenate(line_points, - axis=0)[:self.config["max_pts"]] - line_indices = np.concatenate(line_indices, - axis=0)[:self.config["max_pts"]] + line_points = np.concatenate(line_points, axis=0)[: self.config["max_pts"]] + line_indices = np.concatenate(line_indices, axis=0)[: self.config["max_pts"]] # Warp the points if need be, and filter unvalid ones # If the other view is also warped @@ -454,20 +462,24 @@ class WireframeDataset(Dataset): mask = mask_points(warped_points, img_size) line_points = line_points[mask] line_indices = line_indices[mask] - + # Pad the line points to a fixed length # Index of 0 means padded line - line_indices = np.concatenate([line_indices, np.zeros( - self.config["max_pts"] - len(line_indices))], axis=0) + line_indices = np.concatenate( + [line_indices, np.zeros(self.config["max_pts"] - len(line_indices))], axis=0 + ) line_points = np.concatenate( - [line_points, - np.zeros((self.config["max_pts"] - len(line_points), 2), - dtype=float)], axis=0) - + [ + line_points, + np.zeros((self.config["max_pts"] - len(line_points), 2), dtype=float), + ], + axis=0, + ) + return line_points, line_indices def train_preprocessing(self, data, numpy=False): - """ Train preprocessing for GT data. """ + """Train preprocessing for GT data.""" # Fetch the corresponding entries image = data["image"] junctions = data["junc"][:, :2] @@ -476,23 +488,27 @@ class WireframeDataset(Dataset): image_size = image.shape[:2] # Convert junctions to pixel coordinates (from 128x128) junctions[:, 0] *= image_size[0] / 128 - junctions[:, 1] *= image_size[1] / 128 + junctions[:, 1] *= image_size[1] / 128 # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # In HW format - junctions = (junctions * np.array( - self.config['preprocessing']['resize'], np.float) - / np.array(size_old, np.float)) - + junctions = ( + junctions + * np.array(self.config["preprocessing"]["resize"], np.float) + / np.array(size_old, np.float) + ) + # Convert to positive line map and negative line map (our format) num_junctions = junctions.shape[0] line_map_pos = self.convert_line_map(line_pos, num_junctions) @@ -509,7 +525,7 @@ class WireframeDataset(Dataset): # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Check if we need to apply augmentations # In training mode => yes. @@ -519,7 +535,8 @@ class WireframeDataset(Dataset): ### Image transform ### np.random.shuffle(photo_trans_lst) image_transform = transforms.Compose( - photo_trans_lst + [photoaug.normalize_image()]) + photo_trans_lst + [photoaug.normalize_image()] + ) else: image_transform = photoaug.normalize_image() image = image_transform(image) @@ -549,13 +566,11 @@ class WireframeDataset(Dataset): "image": to_tensor(image), "junctions": to_tensor(junctions).to(torch.float32)[0, ...], "junction_map": to_tensor(junction_map).to(torch.int), - "line_map_pos": to_tensor( - line_map_pos).to(torch.int32)[0, ...], - "line_map_neg": to_tensor( - line_map_neg).to(torch.int32)[0, ...], + "line_map_pos": to_tensor(line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor(line_map_neg).to(torch.int32)[0, ...], "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) + "valid_mask": to_tensor(valid_mask).to(torch.int32), } else: return { @@ -566,14 +581,23 @@ class WireframeDataset(Dataset): "line_map_neg": line_map_neg.astype(np.int32), "heatmap_pos": heatmap_pos.astype(np.int32), "heatmap_neg": heatmap_neg.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) + "valid_mask": valid_mask.astype(np.int32), } - + def train_preprocessing_exported( - self, data, numpy=False, disable_homoaug=False, - desc_training=False, H1=None, H1_scale=None, H2=None, scale=1., - h_crop=None, w_crop=None): - """ Train preprocessing for the exported labels. """ + self, + data, + numpy=False, + disable_homoaug=False, + desc_training=False, + H1=None, + H1_scale=None, + H2=None, + scale=1.0, + h_crop=None, + w_crop=None, + ): + """Train preprocessing for the exported labels.""" data = copy.deepcopy(data) # Fetch the corresponding entries image = data["image"] @@ -593,13 +617,15 @@ class WireframeDataset(Dataset): w_crop = np.random.randint(W_scale - W) # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # # In HW format @@ -614,7 +640,7 @@ class WireframeDataset(Dataset): # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Check if we need to apply augmentations # In training mode => yes. @@ -624,40 +650,49 @@ class WireframeDataset(Dataset): ### Image transform ### np.random.shuffle(photo_trans_lst) image_transform = transforms.Compose( - photo_trans_lst + [photoaug.normalize_image()]) + photo_trans_lst + [photoaug.normalize_image()] + ) else: image_transform = photoaug.normalize_image() image = image_transform(image) - + # Perform the random scaling - if scale != 1.: + if scale != 1.0: image, junctions, line_map, valid_mask = random_scaling( - image, junctions, line_map, scale, - h_crop=h_crop, w_crop=w_crop) + image, junctions, line_map, scale, h_crop=h_crop, w_crop=w_crop + ) else: # Declare default valid mask (all ones) valid_mask = np.ones(image_size) - + # Initialize the empty output dict outputs = {} # Convert to tensor and return the results to_tensor = transforms.ToTensor() # Check homographic augmentation - warp = (self.config["augmentation"]["homographic"]["enable"] - and disable_homoaug == False) + warp = ( + self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False + ) if warp: homo_trans = self.get_homo_transform() # Perform homographic transform if H1 is None: homo_outputs = homo_trans( - image, junctions, line_map, valid_mask=valid_mask) + image, junctions, line_map, valid_mask=valid_mask + ) else: homo_outputs = homo_trans( - image, junctions, line_map, homo=H1, scale=H1_scale, - valid_mask=valid_mask) + image, + junctions, + line_map, + homo=H1, + scale=H1_scale, + valid_mask=valid_mask, + ) homography_mat = homo_outputs["homo"] - + # Give the warp of the other view if H1 is None: H1 = homo_outputs["homo"] @@ -665,8 +700,8 @@ class WireframeDataset(Dataset): # Sample points along each line segments for the descriptor if desc_training: line_points, line_indices = self.get_line_points( - junctions, line_map, H1=H1, H2=H2, - img_size=image_size, warp=warp) + junctions, line_map, H1=H1, H2=H2, img_size=image_size, warp=warp + ) # Record the warped results if warp: @@ -675,52 +710,59 @@ class WireframeDataset(Dataset): line_map = homo_outputs["line_map"] valid_mask = homo_outputs["valid_mask"] # Same for pos and neg heatmap = homo_outputs["warped_heatmap"] - + # Optionally put warping information first. if not numpy: - outputs["homography_mat"] = to_tensor( - homography_mat).to(torch.float32)[0, ...] + outputs["homography_mat"] = to_tensor(homography_mat).to(torch.float32)[ + 0, ... + ] else: outputs["homography_mat"] = homography_mat.astype(np.float32) junction_map = self.junc_to_junc_map(junctions, image_size) - + if not numpy: - outputs.update({ - "image": to_tensor(image).to(torch.float32), - "junctions": to_tensor(junctions).to(torch.float32)[0, ...], - "junction_map": to_tensor(junction_map).to(torch.int), - "line_map": to_tensor(line_map).to(torch.int32)[0, ...], - "heatmap": to_tensor(heatmap).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) - }) + outputs.update( + { + "image": to_tensor(image).to(torch.float32), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32), + } + ) if desc_training: - outputs.update({ - "line_points": to_tensor( - line_points).to(torch.float32)[0], - "line_indices": torch.tensor(line_indices, - dtype=torch.int) - }) + outputs.update( + { + "line_points": to_tensor(line_points).to(torch.float32)[0], + "line_indices": torch.tensor(line_indices, dtype=torch.int), + } + ) else: - outputs.update({ - "image": image, - "junctions": junctions.astype(np.float32), - "junction_map": junction_map.astype(np.int32), - "line_map": line_map.astype(np.int32), - "heatmap": heatmap.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) - }) + outputs.update( + { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32), + } + ) if desc_training: - outputs.update({ - "line_points": line_points.astype(np.float32), - "line_indices": line_indices.astype(int) - }) - + outputs.update( + { + "line_points": line_points.astype(np.float32), + "line_indices": line_indices.astype(int), + } + ) + return outputs - - def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.): - """ Train preprocessing for paired data for the exported labels - for descriptor training. """ + + def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.0): + """Train preprocessing for paired data for the exported labels + for descriptor training.""" outputs = {} # Define the random crop for scaling if necessary @@ -732,36 +774,49 @@ class WireframeDataset(Dataset): h_crop = np.random.randint(H_scale - H) if W_scale > W: w_crop = np.random.randint(W_scale - W) - + # Sample ref homography first homo_config = self.config["augmentation"]["homographic"]["params"] image_shape = self.config["preprocessing"]["resize"] - ref_H, ref_scale = homoaug.sample_homography(image_shape, - **homo_config) + ref_H, ref_scale = homoaug.sample_homography(image_shape, **homo_config) # Data for target view (All augmentation) target_data = self.train_preprocessing_exported( - data, numpy=numpy, desc_training=True, H1=None, H2=ref_H, - scale=scale, h_crop=h_crop, w_crop=w_crop) + data, + numpy=numpy, + desc_training=True, + H1=None, + H2=ref_H, + scale=scale, + h_crop=h_crop, + w_crop=w_crop, + ) # Data for reference view (No homographical augmentation) ref_data = self.train_preprocessing_exported( - data, numpy=numpy, desc_training=True, H1=ref_H, - H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(), - scale=scale, h_crop=h_crop, w_crop=w_crop) + data, + numpy=numpy, + desc_training=True, + H1=ref_H, + H1_scale=ref_scale, + H2=target_data["homography_mat"].numpy(), + scale=scale, + h_crop=h_crop, + w_crop=w_crop, + ) # Spread ref data for key, val in ref_data.items(): outputs["ref_" + key] = val - + # Spread target data for key, val in target_data.items(): outputs["target_" + key] = val - + return outputs def test_preprocessing(self, data, numpy=False): - """ Test preprocessing for GT data. """ + """Test preprocessing for GT data.""" data = copy.deepcopy(data) # Fetch the corresponding entries image = data["image"] @@ -771,31 +826,35 @@ class WireframeDataset(Dataset): image_size = image.shape[:2] # Convert junctions to pixel coordinates (from 128x128) junctions[:, 0] *= image_size[0] / 128 - junctions[:, 1] *= image_size[1] / 128 + junctions[:, 1] *= image_size[1] / 128 # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # In HW format - junctions = (junctions * np.array( - self.config['preprocessing']['resize'], np.float) - / np.array(size_old, np.float)) - + junctions = ( + junctions + * np.array(self.config["preprocessing"]["resize"], np.float) + / np.array(size_old, np.float) + ) + # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Still need to normalize image image_transform = photoaug.normalize_image() image = image_transform(image) - + # Convert to positive line map and negative line map (our format) num_junctions = junctions.shape[0] line_map_pos = self.convert_line_map(line_pos, num_junctions) @@ -819,13 +878,11 @@ class WireframeDataset(Dataset): "image": to_tensor(image), "junctions": to_tensor(junctions).to(torch.float32)[0, ...], "junction_map": to_tensor(junction_map).to(torch.int), - "line_map_pos": to_tensor( - line_map_pos).to(torch.int32)[0, ...], - "line_map_neg": to_tensor( - line_map_neg).to(torch.int32)[0, ...], + "line_map_pos": to_tensor(line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor(line_map_neg).to(torch.int32)[0, ...], "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) + "valid_mask": to_tensor(valid_mask).to(torch.int32), } else: return { @@ -836,26 +893,28 @@ class WireframeDataset(Dataset): "line_map_neg": line_map_neg.astype(np.int32), "heatmap_pos": heatmap_pos.astype(np.int32), "heatmap_neg": heatmap_neg.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) + "valid_mask": valid_mask.astype(np.int32), } - - def test_preprocessing_exported(self, data, numpy=False, scale=1.): - """ Test preprocessing for the exported labels. """ + + def test_preprocessing_exported(self, data, numpy=False, scale=1.0): + """Test preprocessing for the exported labels.""" data = copy.deepcopy(data) # Fetch the corresponding entries image = data["image"] junctions = data["junctions"] - line_map = data["line_map"] + line_map = data["line_map"] image_size = image.shape[:2] # Resize the image before photometric and homographical augmentations - if not(list(image_size) == self.config["preprocessing"]["resize"]): + if not (list(image_size) == self.config["preprocessing"]["resize"]): # Resize the image and the point location. - size_old = list(image.shape)[:2] # Only H and W dimensions + size_old = list(image.shape)[:2] # Only H and W dimensions image = cv2.resize( - image, tuple(self.config['preprocessing']['resize'][::-1]), - interpolation=cv2.INTER_LINEAR) + image, + tuple(self.config["preprocessing"]["resize"][::-1]), + interpolation=cv2.INTER_LINEAR, + ) image = np.array(image, dtype=np.uint8) # # In HW format @@ -865,7 +924,7 @@ class WireframeDataset(Dataset): # Optionally convert the image to grayscale if self.config["gray_scale"]: - image = (color.rgb2gray(image) * 255.).astype(np.uint8) + image = (color.rgb2gray(image) * 255.0).astype(np.uint8) # Still need to normalize image image_transform = photoaug.normalize_image() @@ -875,7 +934,7 @@ class WireframeDataset(Dataset): junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) image_size = image.shape[:2] heatmap = get_line_heatmap(junctions_xy, line_map, image_size) - + # Declare default valid mask (all ones) valid_mask = np.ones(image_size) @@ -890,7 +949,7 @@ class WireframeDataset(Dataset): "junction_map": to_tensor(junction_map).to(torch.int), "line_map": to_tensor(line_map).to(torch.int32)[0, ...], "heatmap": to_tensor(heatmap).to(torch.int32), - "valid_mask": to_tensor(valid_mask).to(torch.int32) + "valid_mask": to_tensor(valid_mask).to(torch.int32), } else: outputs = { @@ -899,20 +958,20 @@ class WireframeDataset(Dataset): "junction_map": junction_map.astype(np.int32), "line_map": line_map.astype(np.int32), "heatmap": heatmap.astype(np.int32), - "valid_mask": valid_mask.astype(np.int32) + "valid_mask": valid_mask.astype(np.int32), } - + return outputs def __len__(self): return self.dataset_length def get_data_from_key(self, file_key): - """ Get data from file_key. """ + """Get data from file_key.""" # Check key exists if not file_key in self.filename_dataset.keys(): raise ValueError("[Error] the specified key is not in the dataset.") - + # Get the data paths data_path = self.filename_dataset[file_key] # Read in the image and npz labels (but haven't applied any transform) @@ -923,12 +982,12 @@ class WireframeDataset(Dataset): data = self.train_preprocessing(data, numpy=True) else: data = self.test_preprocessing(data, numpy=True) - + # Add file key to the output data["file_key"] = file_key - + return data - + def __getitem__(self, idx): """Return data file_key: str, keys used to retrieve data from the filename dataset. @@ -951,30 +1010,27 @@ class WireframeDataset(Dataset): if not self.gt_source == "official": with h5py.File(self.gt_source, "r") as f: exported_label = parse_h5_data(f[file_key]) - + data["junctions"] = exported_label["junctions"] data["line_map"] = exported_label["line_map"] - + # Perform transform and augmentation return_type = self.config.get("return_type", "single") - if (self.mode == "train" - or self.config["add_augmentation_to_all_splits"]): + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: # Perform random scaling first if self.config["augmentation"]["random_scaling"]["enable"]: scale_range = self.config["augmentation"]["random_scaling"]["range"] # Decide the scaling scale = np.random.uniform(min(scale_range), max(scale_range)) else: - scale = 1. + scale = 1.0 if self.gt_source == "official": data = self.train_preprocessing(data) else: if return_type == "paired_desc": - data = self.preprocessing_exported_paired_desc( - data, scale=scale) + data = self.preprocessing_exported_paired_desc(data, scale=scale) else: - data = self.train_preprocessing_exported(data, - scale=scale) + data = self.train_preprocessing_exported(data, scale=scale) else: if self.gt_source == "official": data = self.test_preprocessing(data) @@ -982,17 +1038,17 @@ class WireframeDataset(Dataset): data = self.preprocessing_exported_paired_desc(data) else: data = self.test_preprocessing_exported(data) - + # Add file key to the output data["file_key"] = file_key - + return data - + ######################## ## Some other methods ## ######################## def _check_dataset_cache(self): - """ Check if dataset cache exists. """ + """Check if dataset cache exists.""" cache_file_path = os.path.join(self.cache_path, self.cache_name) if os.path.exists(cache_file_path): return True diff --git a/third_party/SOLD2/sold2/experiment.py b/third_party/SOLD2/sold2/experiment.py index 3bf4db1c9f148b9e33c6d7d0ba973375cd770a14..0a2d5c0dc359cec13304813ac7732c5968d70a80 100644 --- a/third_party/SOLD2/sold2/experiment.py +++ b/third_party/SOLD2/sold2/experiment.py @@ -19,7 +19,7 @@ torch.backends.cudnn.benchmark = True def load_config(config_path): - """ Load configurations from a given yaml file. """ + """Load configurations from a given yaml file.""" # Check file exists if not os.path.exists(config_path): raise ValueError("[Error] The provided config path is not valid.") @@ -32,7 +32,7 @@ def load_config(config_path): def update_config(path, model_cfg=None, dataset_cfg=None): - """ Update configuration file from the resume path. """ + """Update configuration file from the resume path.""" # Check we need to update or completely override. model_cfg = {} if model_cfg is None else model_cfg dataset_cfg = {} if dataset_cfg is None else dataset_cfg @@ -57,23 +57,23 @@ def update_config(path, model_cfg=None, dataset_cfg=None): def record_config(model_cfg, dataset_cfg, output_path): - """ Record dataset config to the log path. """ + """Record dataset config to the log path.""" # Record model config with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: - yaml.safe_dump(model_cfg, f) - + yaml.safe_dump(model_cfg, f) + # Record dataset config with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: - yaml.safe_dump(dataset_cfg, f) - + yaml.safe_dump(dataset_cfg, f) + def train(args, dataset_cfg, model_cfg, output_path): - """ Training function. """ + """Training function.""" # Update model config from the resume path (only in resume mode) if args.resume: if os.path.realpath(output_path) != os.path.realpath(args.resume_path): record_config(model_cfg, dataset_cfg, output_path) - + # First time, then write the config file to the output path else: record_config(model_cfg, dataset_cfg, output_path) @@ -82,23 +82,32 @@ def train(args, dataset_cfg, model_cfg, output_path): train_net(args, dataset_cfg, model_cfg, output_path) -def export(args, dataset_cfg, model_cfg, output_path, - export_dataset_mode=None, device=torch.device("cuda")): - """ Export function. """ +def export( + args, + dataset_cfg, + model_cfg, + output_path, + export_dataset_mode=None, + device=torch.device("cuda"), +): + """Export function.""" # Choose between normal predictions export or homography adaptation if dataset_cfg.get("homography_adaptation") is not None: print("[Info] Export predictions with homography adaptation.") - export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, - export_dataset_mode, device) + export_homograpy_adaptation( + args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device + ) else: print("[Info] Export predictions normally.") - export_predictions(args, dataset_cfg, model_cfg, output_path, - export_dataset_mode) + export_predictions( + args, dataset_cfg, model_cfg, output_path, export_dataset_mode + ) -def main(args, dataset_cfg, model_cfg, export_dataset_mode=None, - device=torch.device("cuda")): - """ Main function. """ +def main( + args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda") +): + """Main function.""" # Make the output path output_path = os.path.join(cfg.EXP_PATH, args.exp_name) @@ -113,7 +122,14 @@ def main(args, dataset_cfg, model_cfg, export_dataset_mode=None, output_path = os.path.join(cfg.export_dataroot, args.exp_name) print("[Info] Export mode") print("\t Output path: %s" % output_path) - export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device) + export( + args, + dataset_cfg, + model_cfg, + output_path, + export_dataset_mode, + device=device, + ) else: raise ValueError("[Error]: Unknown mode: " + args.mode) @@ -126,28 +142,43 @@ def set_random_seed(seed): if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() - parser.add_argument("--mode", type=str, default="train", - help="'train' or 'export'.") - parser.add_argument("--dataset_config", type=str, default=None, - help="Path to the dataset config.") - parser.add_argument("--model_config", type=str, default=None, - help="Path to the model config.") - parser.add_argument("--exp_name", type=str, default="exp", - help="Experiment name.") - parser.add_argument("--resume", action="store_true", default=False, - help="Load a previously trained model.") - parser.add_argument("--pretrained", action="store_true", default=False, - help="Start training from a pre-trained model.") - parser.add_argument("--resume_path", default=None, - help="Path from which to resume training.") - parser.add_argument("--pretrained_path", default=None, - help="Path to the pre-trained model.") - parser.add_argument("--checkpoint_name", default=None, - help="Name of the checkpoint to use.") - parser.add_argument("--export_dataset_mode", default=None, - help="'train' or 'test'.") - parser.add_argument("--export_batch_size", default=4, type=int, - help="Export batch size.") + parser.add_argument( + "--mode", type=str, default="train", help="'train' or 'export'." + ) + parser.add_argument( + "--dataset_config", type=str, default=None, help="Path to the dataset config." + ) + parser.add_argument( + "--model_config", type=str, default=None, help="Path to the model config." + ) + parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") + parser.add_argument( + "--resume", + action="store_true", + default=False, + help="Load a previously trained model.", + ) + parser.add_argument( + "--pretrained", + action="store_true", + default=False, + help="Start training from a pre-trained model.", + ) + parser.add_argument( + "--resume_path", default=None, help="Path from which to resume training." + ) + parser.add_argument( + "--pretrained_path", default=None, help="Path to the pre-trained model." + ) + parser.add_argument( + "--checkpoint_name", default=None, help="Name of the checkpoint to use." + ) + parser.add_argument( + "--export_dataset_mode", default=None, help="'train' or 'test'." + ) + parser.add_argument( + "--export_batch_size", default=4, type=int, help="Export batch size." + ) args = parser.parse_args() @@ -159,28 +190,29 @@ if __name__ == "__main__": device = torch.device("cpu") # Check if dataset config and model config is given. - if (((args.dataset_config is None) or (args.model_config is None)) - and (not args.resume) and (args.mode == "train")): + if ( + ((args.dataset_config is None) or (args.model_config is None)) + and (not args.resume) + and (args.mode == "train") + ): raise ValueError( - "[Error] The dataset config and model config should be given in non-resume mode") + "[Error] The dataset config and model config should be given in non-resume mode" + ) # If resume, check if the resume path has been given if args.resume and (args.resume_path is None): - raise ValueError( - "[Error] Missing resume path.") + raise ValueError("[Error] Missing resume path.") # [Training] Load the config file. if args.mode == "train" and (not args.resume): # Check the pretrained checkpoint_path exists if args.pretrained: checkpoint_folder = args.resume_path - checkpoint_path = os.path.join(args.pretrained_path, - args.checkpoint_name) + checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name) if not os.path.exists(checkpoint_path): - raise ValueError("[Error] Missing checkpoint: " - + checkpoint_path) + raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) dataset_cfg = load_config(args.dataset_config) - model_cfg = load_config(args.model_config) + model_cfg = load_config(args.model_config) # [resume Training, Test, Export] Load the config file. elif (args.mode == "train" and args.resume) or (args.mode == "export"): @@ -195,33 +227,35 @@ if __name__ == "__main__": print("[Info] No model config provided. Loading from checkpoint folder.") model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") if not os.path.exists(model_cfg_path): - raise ValueError( - "[Error] Missing model config in checkpoint path.") + raise ValueError("[Error] Missing model config in checkpoint path.") model_cfg = load_config(model_cfg_path) else: model_cfg = load_config(args.model_config) - + # Load dataset_cfg from checkpoint folder if not provided if args.dataset_config is None: print("[Info] No dataset config provided. Loading from checkpoint folder.") - dataset_cfg_path = os.path.join(checkpoint_folder, - "dataset_cfg.yaml") + dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml") if not os.path.exists(dataset_cfg_path): - raise ValueError( - "[Error] Missing dataset config in checkpoint path.") + raise ValueError("[Error] Missing dataset config in checkpoint path.") dataset_cfg = load_config(dataset_cfg_path) else: dataset_cfg = load_config(args.dataset_config) - + # Check the --export_dataset_mode flag if (args.mode == "export") and (args.export_dataset_mode is None): raise ValueError("[Error] Empty --export_dataset_mode flag.") else: raise ValueError("[Error] Unknown mode: " + args.mode) - + # Set the random seed seed = dataset_cfg.get("random_seed", 0) set_random_seed(seed) - main(args, dataset_cfg, model_cfg, - export_dataset_mode=args.export_dataset_mode, device=device) + main( + args, + dataset_cfg, + model_cfg, + export_dataset_mode=args.export_dataset_mode, + device=device, + ) diff --git a/third_party/SOLD2/sold2/export.py b/third_party/SOLD2/sold2/export.py index 19683d982c6d7fd429b27868b620fd20562d1aa7..ec5bf2dcb1c51999c80b6d1ff170c238883e34a0 100644 --- a/third_party/SOLD2/sold2/export.py +++ b/third_party/SOLD2/sold2/export.py @@ -17,7 +17,7 @@ from .dataset.transforms.homographic_transforms import sample_homography def restore_weights(model, state_dict): - """ Restore weights in compatible mode. """ + """Restore weights in compatible mode.""" # Try to directly load state dict try: model.load_state_dict(state_dict) @@ -38,15 +38,14 @@ def restore_weights(model, state_dict): def get_padded_filename(num_pad, idx): - """ Get the filename padded with 0. """ + """Get the filename padded with 0.""" file_len = len("%d" % (idx)) filename = "0" * (num_pad - file_len) + "%d" % (idx) return filename -def export_predictions(args, dataset_cfg, model_cfg, output_path, - export_dataset_mode): - """ Export predictions. """ +def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode): + """Export predictions.""" # Get the test configuration test_cfg = model_cfg["test"] @@ -54,10 +53,14 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path, print("\t Initializing dataset and dataloader") batch_size = 4 export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) - export_loader = DataLoader(export_dataset, batch_size=batch_size, - num_workers=test_cfg.get("num_workers", 4), - shuffle=False, pin_memory=False, - collate_fn=collate_fn) + export_loader = DataLoader( + export_dataset, + batch_size=batch_size, + num_workers=test_cfg.get("num_workers", 4), + shuffle=False, + pin_memory=False, + collate_fn=collate_fn, + ) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint @@ -87,11 +90,18 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path, # Convert predictions junc_np = convert_junc_predictions( - outputs["junctions"], model_cfg["grid_size"], - model_cfg["detection_thresh"], 300) + outputs["junctions"], + model_cfg["grid_size"], + model_cfg["detection_thresh"], + 300, + ) junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1) - heatmap_np = softmax(outputs["heatmap"].detach(), - dim=1).cpu().numpy().transpose(0, 2, 3, 1) + heatmap_np = ( + softmax(outputs["heatmap"].detach(), dim=1) + .cpu() + .numpy() + .transpose(0, 2, 3, 1) + ) heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1) @@ -99,15 +109,22 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path, current_batch_size = input_images.shape[0] for batch_idx in range(current_batch_size): output_data = { - "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "image": input_images.cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], "junc_gt": junc_map_np[batch_idx], "junc_pred": junc_np["junc_pred"][batch_idx], - "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32), + "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype( + np.float32 + ), "heatmap_gt": heatmap_gt_np[batch_idx], "heatmap_pred": heatmap_np[batch_idx], "valid_mask": valid_mask_np[batch_idx], - "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32), - "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32) + "junc_points": data["junctions"][batch_idx] + .numpy()[0] + .round() + .astype(np.int32), + "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32), } # Save data to h5 dataset @@ -117,19 +134,18 @@ def export_predictions(args, dataset_cfg, model_cfg, output_path, # Store data for key, output_data in output_data.items(): - f_group.create_dataset(key, data=output_data, - compression="gzip") + f_group.create_dataset(key, data=output_data, compression="gzip") filename_idx += 1 -def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, - export_dataset_mode, device): - """ Export homography adaptation results. """ +def export_homograpy_adaptation( + args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device +): + """Export homography adaptation results.""" # Check if the export_dataset_mode is supported supported_modes = ["train", "test"] if not export_dataset_mode in supported_modes: - raise ValueError( - "[Error] The specified export_dataset_mode is not supported.") + raise ValueError("[Error] The specified export_dataset_mode is not supported.") # Get the test configuration test_cfg = model_cfg["test"] @@ -137,66 +153,87 @@ def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, # Get the homography adaptation configurations homography_cfg = dataset_cfg.get("homography_adaptation", None) if homography_cfg is None: - raise ValueError( - "[Error] Empty homography_adaptation entry in config.") + raise ValueError("[Error] Empty homography_adaptation entry in config.") # Create the dataset and dataloader based on the export_dataset_mode print("\t Initializing dataset and dataloader") batch_size = args.export_batch_size export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) - export_loader = DataLoader(export_dataset, batch_size=batch_size, - num_workers=test_cfg.get("num_workers", 4), - shuffle=False, pin_memory=False, - collate_fn=collate_fn) + export_loader = DataLoader( + export_dataset, + batch_size=batch_size, + num_workers=test_cfg.get("num_workers", 4), + shuffle=False, + pin_memory=False, + collate_fn=collate_fn, + ) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint model = get_model(model_cfg, mode="test") - checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, - device) + checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.to(device).eval() print("\t Successfully initialized model") # Start the export process - print("[Info] Start exporting predictions") + print("[Info] Start exporting predictions") output_dataset_path = output_path + ".h5" with h5py.File(output_dataset_path, "w", libver="latest") as f: - f.swmr_mode=True + f.swmr_mode = True for _, data in enumerate(tqdm(export_loader, ascii=True)): input_images = data["image"].to(device) file_keys = data["file_key"] batch_size = input_images.shape[0] - + # Run the homograpy adaptation - outputs = homography_adaptation(input_images, model, - model_cfg["grid_size"], - homography_cfg) + outputs = homography_adaptation( + input_images, model, model_cfg["grid_size"], homography_cfg + ) # Save the entries for batch_idx in range(batch_size): # Get the save key save_key = file_keys[batch_idx] output_data = { - "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], - "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx] + "image": input_images.cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "junc_prob_mean": outputs["junc_probs_mean"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "junc_prob_max": outputs["junc_probs_max"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "junc_count": outputs["junc_counts"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "heatmap_prob_mean": outputs["heatmap_probs_mean"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "heatmap_prob_max": outputs["heatmap_probs_max"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], + "heatmap_cout": outputs["heatmap_counts"] + .cpu() + .numpy() + .transpose(0, 2, 3, 1)[batch_idx], } # Create group and write data f_group = f.create_group(save_key) for key, output_data in output_data.items(): - f_group.create_dataset(key, data=output_data, - compression="gzip") + f_group.create_dataset(key, data=output_data, compression="gzip") def homography_adaptation(input_images, model, grid_size, homography_cfg): - """ The homography adaptation process. + """The homography adaptation process. Arguments: input_images: The images to be evaluated. model: The pytorch model in evaluation mode. @@ -222,121 +259,140 @@ def homography_adaptation(input_images, model, grid_size, homography_cfg): for idx in range(num_iter): if idx <= num_iter // 5: # Ensure that 20% of the homographies have no artifact - H_mat_lst = [sample_homography( - [H,W], **homography_cfg_no_artifacts)[0][None] - for _ in range(batch_size)] + H_mat_lst = [ + sample_homography([H, W], **homography_cfg_no_artifacts)[0][None] + for _ in range(batch_size) + ] else: - H_mat_lst = [sample_homography( - [H,W], **homography_cfg["homographies"])[0][None] - for _ in range(batch_size)] + H_mat_lst = [ + sample_homography([H, W], **homography_cfg["homographies"])[0][None] + for _ in range(batch_size) + ] H_mats = np.concatenate(H_mat_lst, axis=0) H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device) H_inv_tensor = torch.inverse(H_tensor) # Perform the homography warp - images_warped = warp_perspective(input_images, H_tensor, (H, W), - flags="bilinear") - + images_warped = warp_perspective( + input_images, H_tensor, (H, W), flags="bilinear" + ) + # Warp the mask masks_junc_warped = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), - H_tensor, (H, W), flags="nearest") + H_tensor, + (H, W), + flags="nearest", + ) masks_heatmap_warped = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), - H_tensor, (H, W), flags="nearest") + H_tensor, + (H, W), + flags="nearest", + ) # Run the network forward pass with torch.no_grad(): outputs = model(images_warped) - + # Unwarp and mask the junction prediction - junc_prob_warped = pixel_shuffle(softmax( - outputs["junctions"], dim=1)[:, :-1, :, :], grid_size) - junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor, - (H, W), flags="bilinear") + junc_prob_warped = pixel_shuffle( + softmax(outputs["junctions"], dim=1)[:, :-1, :, :], grid_size + ) + junc_prob = warp_perspective( + junc_prob_warped, H_inv_tensor, (H, W), flags="bilinear" + ) # Create the out of boundary mask out_boundary_mask = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), - H_inv_tensor, (H, W), flags="nearest") + H_inv_tensor, + (H, W), + flags="nearest", + ) out_boundary_mask = adjust_border(out_boundary_mask, device, margin) junc_prob = junc_prob * out_boundary_mask - junc_count = warp_perspective(masks_junc_warped * out_boundary_mask, - H_inv_tensor, (H, W), flags="nearest") + junc_count = warp_perspective( + masks_junc_warped * out_boundary_mask, H_inv_tensor, (H, W), flags="nearest" + ) # Unwarp the mask and heatmap prediction # Always fetch only one channel if outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here - heatmap_prob_warped = softmax(outputs["heatmap"], - dim=1)[:, 1:, :, :] + heatmap_prob_warped = softmax(outputs["heatmap"], dim=1)[:, 1:, :, :] else: heatmap_prob_warped = torch.sigmoid(outputs["heatmap"]) - + heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped - heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor, - (H, W), flags="bilinear") - heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor, - (H, W), flags="nearest") + heatmap_prob = warp_perspective( + heatmap_prob_warped, H_inv_tensor, (H, W), flags="bilinear" + ) + heatmap_count = warp_perspective( + masks_heatmap_warped, H_inv_tensor, (H, W), flags="nearest" + ) # Record the results - junc_probs[:, idx:idx+1, :, :] = junc_prob - heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob + junc_probs[:, idx : idx + 1, :, :] = junc_prob + heatmap_probs[:, idx : idx + 1, :, :] = heatmap_prob junc_counts += junc_count heatmap_counts += heatmap_count # Perform the accumulation operation if homography_cfg["min_counts"] > 0: min_counts = homography_cfg["min_counts"] - junc_count_mask = (junc_counts < min_counts) - heatmap_count_mask = (heatmap_counts < min_counts) + junc_count_mask = junc_counts < min_counts + heatmap_count_mask = heatmap_counts < min_counts junc_counts[junc_count_mask] = 0 heatmap_counts[heatmap_count_mask] = 0 else: junc_count_mask = np.zeros_like(junc_counts, dtype=bool) heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool) - + # Compute the mean accumulation junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts - junc_probs_mean[junc_count_mask] = 0. - heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True) - / heatmap_counts) - heatmap_probs_mean[heatmap_count_mask] = 0. + junc_probs_mean[junc_count_mask] = 0.0 + heatmap_probs_mean = torch.sum(heatmap_probs, dim=1, keepdim=True) / heatmap_counts + heatmap_probs_mean[heatmap_count_mask] = 0.0 # Compute the max accumulation junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0] - junc_probs_max[junc_count_mask] = 0. + junc_probs_max[junc_count_mask] = 0.0 heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0] - heatmap_probs_max[heatmap_count_mask] = 0. + heatmap_probs_max[heatmap_count_mask] = 0.0 - return {"junc_probs_mean": junc_probs_mean, - "junc_probs_max": junc_probs_max, - "junc_counts": junc_counts, - "heatmap_probs_mean": heatmap_probs_mean, - "heatmap_probs_max": heatmap_probs_max, - "heatmap_counts": heatmap_counts} + return { + "junc_probs_mean": junc_probs_mean, + "junc_probs_max": junc_probs_max, + "junc_counts": junc_counts, + "heatmap_probs_mean": heatmap_probs_mean, + "heatmap_probs_max": heatmap_probs_max, + "heatmap_counts": heatmap_counts, + } def adjust_border(input_masks, device, margin=3): - """ Adjust the border of the counts and valid_mask. """ + """Adjust the border of the counts and valid_mask.""" # Convert the mask to numpy array dtype = input_masks.dtype input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1) - erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (margin*2, margin*2)) + erosion_kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (margin * 2, margin * 2) + ) batch_size = input_masks.shape[0] - + output_mask_lst = [] # Erode all the masks for i in range(batch_size): output_mask = cv2.erode(input_masks[i, ...], erosion_kernel) output_mask_lst.append( - torch.tensor(output_mask, dtype=dtype, device=device)[None]) - + torch.tensor(output_mask, dtype=dtype, device=device)[None] + ) + # Concat back along the batch dimension. output_masks = torch.cat(output_mask_lst, dim=0) return output_masks.unsqueeze(dim=1) diff --git a/third_party/SOLD2/sold2/export_line_features.py b/third_party/SOLD2/sold2/export_line_features.py index 4cbde860a446d758dff254ea5320ca13bb79e6b7..6df203c6ad62a559a1617744b200df283b9bb9a7 100644 --- a/third_party/SOLD2/sold2/export_line_features.py +++ b/third_party/SOLD2/sold2/export_line_features.py @@ -12,24 +12,29 @@ from .experiment import load_config from .model.line_matcher import LineMatcher -def export_descriptors(images_list, ckpt_path, config, device, extension, - output_folder, multiscale=False): +def export_descriptors( + images_list, ckpt_path, config, device, extension, output_folder, multiscale=False +): # Extract the image paths - with open(images_list, 'r') as f: + with open(images_list, "r") as f: image_files = f.readlines() - image_files = [path.strip('\n') for path in image_files] + image_files = [path.strip("\n") for path in image_files] # Initialize the line matcher line_matcher = LineMatcher( - config["model_cfg"], ckpt_path, device, config["line_detector_cfg"], - config["line_matcher_cfg"], multiscale) + config["model_cfg"], + ckpt_path, + device, + config["line_detector_cfg"], + config["line_matcher_cfg"], + multiscale, + ) print("\t Successfully initialized model") # Run the inference on each image and write the output on disk for img_path in tqdm(image_files): img = cv2.imread(img_path, 0) - img = torch.tensor(img[None, None] / 255., dtype=torch.float, - device=device) + img = torch.tensor(img[None, None] / 255.0, dtype=torch.float, device=device) # Run the line detection and description ref_detection = line_matcher.line_detection(img) @@ -39,21 +44,29 @@ def export_descriptors(images_list, ckpt_path, config, device, extension, # Write the output on disk img_name = os.path.splitext(os.path.basename(img_path))[0] output_file = os.path.join(output_folder, img_name + extension) - np.savez_compressed(output_file, line_seg=ref_line_seg, - descriptors=ref_descriptors) + np.savez_compressed( + output_file, line_seg=ref_line_seg, descriptors=ref_descriptors + ) if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() - parser.add_argument("--img_list", type=str, required=True, - help="List of input images in a text file.") - parser.add_argument("--output_folder", type=str, required=True, - help="Path to the output folder.") - parser.add_argument("--config", type=str, - default="config/export_line_features.yaml") - parser.add_argument("--checkpoint_path", type=str, - default="pretrained_models/sold2_wireframe.tar") + parser.add_argument( + "--img_list", + type=str, + required=True, + help="List of input images in a text file.", + ) + parser.add_argument( + "--output_folder", type=str, required=True, help="Path to the output folder." + ) + parser.add_argument( + "--config", type=str, default="config/export_line_features.yaml" + ) + parser.add_argument( + "--checkpoint_path", type=str, default="pretrained_models/sold2_wireframe.tar" + ) parser.add_argument("--multiscale", action="store_true", default=False) parser.add_argument("--extension", type=str, default=None) args = parser.parse_args() @@ -67,8 +80,15 @@ if __name__ == "__main__": # Get the model config, extension and checkpoint path config = load_config(args.config) ckpt_path = os.path.abspath(args.checkpoint_path) - extension = 'sold2' if args.extension is None else args.extension + extension = "sold2" if args.extension is None else args.extension extension = "." + extension - export_descriptors(args.img_list, ckpt_path, config, device, extension, - args.output_folder, args.multiscale) + export_descriptors( + args.img_list, + ckpt_path, + config, + device, + extension, + args.output_folder, + args.multiscale, + ) diff --git a/third_party/SOLD2/sold2/misc/geometry_utils.py b/third_party/SOLD2/sold2/misc/geometry_utils.py index 50f0478062cd19ebac812bff62b6c3a3d5f124c2..024430a07b9b094d2eca6e4e9e14edd5105ad1c5 100644 --- a/third_party/SOLD2/sold2/misc/geometry_utils.py +++ b/third_party/SOLD2/sold2/misc/geometry_utils.py @@ -7,8 +7,9 @@ import torch # Warp a list of points using a homography def warp_points(points, homography): # Convert to homogeneous and in xy format - new_points = np.concatenate([points[..., [1, 0]], - np.ones_like(points[..., :1])], axis=-1) + new_points = np.concatenate( + [points[..., [1, 0]], np.ones_like(points[..., :1])], axis=-1 + ) # Warp new_points = (homography @ new_points.T).T # Convert back to inhomogeneous and hw format @@ -18,10 +19,12 @@ def warp_points(points, homography): # Mask out the points that are outside of img_size def mask_points(points, img_size): - mask = ((points[..., 0] >= 0) - & (points[..., 0] < img_size[0]) - & (points[..., 1] >= 0) - & (points[..., 1] < img_size[1])) + mask = ( + (points[..., 0] >= 0) + & (points[..., 0] < img_size[0]) + & (points[..., 1] >= 0) + & (points[..., 1] < img_size[1]) + ) return mask @@ -30,8 +33,12 @@ def mask_points(points, img_size): def keypoints_to_grid(keypoints, img_size): n_points = keypoints.size()[-2] device = keypoints.device - grid_points = keypoints.float() * 2. / torch.tensor( - img_size, dtype=torch.float, device=device) - 1. + grid_points = ( + keypoints.float() + * 2.0 + / torch.tensor(img_size, dtype=torch.float, device=device) + - 1.0 + ) grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2) return grid_points @@ -44,8 +51,9 @@ def get_dist_mask(kp0, kp1, valid_mask, dist_thresh): dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1) dist_mask = torch.min(dist_mask0, dist_mask1) dist_mask = dist_mask <= dist_thresh - dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points, - b_size * n_points) + dist_mask = dist_mask.repeat(1, 1, b_size).reshape( + b_size * n_points, b_size * n_points + ) dist_mask = dist_mask[valid_mask, :][:, valid_mask] return dist_mask @@ -75,7 +83,8 @@ def mask_lines(lines, valid_mask): def get_common_line_mask(line_indices, valid_mask): b_size, n_points = line_indices.shape common_mask = line_indices[:, :, None] == line_indices[:, None, :] - common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points, - b_size * n_points) + common_mask = common_mask.repeat(1, 1, b_size).reshape( + b_size * n_points, b_size * n_points + ) common_mask = common_mask[valid_mask, :][:, valid_mask] return common_mask diff --git a/third_party/SOLD2/sold2/misc/train_utils.py b/third_party/SOLD2/sold2/misc/train_utils.py index d5ada35eea660df1f78b9f20d9bf7ed726eaee2c..99113247351ceef152f308e793234a952df78166 100644 --- a/third_party/SOLD2/sold2/misc/train_utils.py +++ b/third_party/SOLD2/sold2/misc/train_utils.py @@ -10,7 +10,7 @@ import torch ## image utils ## ################# def convert_image(input_tensor, axis): - """ Convert single channel images to 3-channel images. """ + """Convert single channel images to 3-channel images.""" image_lst = [input_tensor for _ in range(3)] outputs = np.concatenate(image_lst, axis) return outputs @@ -19,29 +19,32 @@ def convert_image(input_tensor, axis): ###################### ## checkpoint utils ## ###################### -def get_latest_checkpoint(checkpoint_root, checkpoint_name, - device=torch.device("cuda")): - """ Get the latest checkpoint or by filename. """ +def get_latest_checkpoint( + checkpoint_root, checkpoint_name, device=torch.device("cuda") +): + """Get the latest checkpoint or by filename.""" # Load specific checkpoint if checkpoint_name is not None: checkpoint = torch.load( - os.path.join(checkpoint_root, checkpoint_name), - map_location=device) + os.path.join(checkpoint_root, checkpoint_name), map_location=device + ) # Load the latest checkpoint else: - lastest_checkpoint = sorted(os.listdir(os.path.join( - checkpoint_root, "*.tar")))[-1] - checkpoint = torch.load(os.path.join( - checkpoint_root, lastest_checkpoint), map_location=device) + lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[ + -1 + ] + checkpoint = torch.load( + os.path.join(checkpoint_root, lastest_checkpoint), map_location=device + ) return checkpoint def remove_old_checkpoints(checkpoint_root, max_ckpt=15): - """ Remove the outdated checkpoints. """ + """Remove the outdated checkpoints.""" # Get sorted list of checkpoints checkpoint_list = sorted( - [_ for _ in os.listdir(os.path.join(checkpoint_root)) - if _.endswith(".tar")]) + [_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")] + ) # Get the checkpoints to be removed if len(checkpoint_list) > max_ckpt: @@ -55,7 +58,7 @@ def remove_old_checkpoints(checkpoint_root, max_ckpt=15): def adapt_checkpoint(state_dict): new_state_dict = {} for k, v in state_dict.items(): - if k.startswith('module.'): + if k.startswith("module."): new_state_dict[k[7:]] = v else: new_state_dict[k] = v @@ -66,9 +69,9 @@ def adapt_checkpoint(state_dict): ## HDF5 utils ## ################ def parse_h5_data(h5_data): - """ Parse h5 dataset. """ + """Parse h5 dataset.""" output_data = {} for key in h5_data.keys(): output_data[key] = np.array(h5_data[key]) - + return output_data diff --git a/third_party/SOLD2/sold2/misc/visualize_util.py b/third_party/SOLD2/sold2/misc/visualize_util.py index 4aa46877f79724221b7caa423de6916acdc021f8..2d1aa38bb992302fe504bc166a3fa113e5365337 100644 --- a/third_party/SOLD2/sold2/misc/visualize_util.py +++ b/third_party/SOLD2/sold2/misc/visualize_util.py @@ -20,15 +20,17 @@ def plot_junctions(input_image, junctions, junc_size=3, color=None): if image.dtype == np.uint8: pass # A float type image ranging from 0~1 - elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: - image = (image * 255.).astype(np.uint8) + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: + image = (image * 255.0).astype(np.uint8) # A float type image ranging from 0.~255. - elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: image = image.astype(np.uint8) else: - raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + raise ValueError( + "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." + ) - # Check whether the image is single channel + # Check whether the image is single channel if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): # Squeeze to H*W first image = image.squeeze() @@ -46,30 +48,38 @@ def plot_junctions(input_image, junctions, junc_size=3, color=None): junctions = junctions.T else: raise ValueError("[Error] At least one of the two dims should be 2.") - + # Round and convert junctions to int (and check the boundary) H, W = image.shape[:2] junctions = (np.round(junctions)).astype(np.int) - junctions[junctions < 0] = 0 - junctions[junctions[:, 0] >= H, 0] = H-1 # (first dim) max bounded by H-1 - junctions[junctions[:, 1] >= W, 1] = W-1 # (second dim) max bounded by W-1 + junctions[junctions < 0] = 0 + junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1 + junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1 # Iterate through all the junctions num_junc = junctions.shape[0] if color is None: - color = (0, 255., 0) + color = (0, 255.0, 0) for idx in range(num_junc): # Fetch one junction junc = junctions[idx, :] - cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, - color=color, thickness=3) - + cv2.circle( + image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3 + ) + return image # Plot line segements given junctions and line adjecent map -def plot_line_segments(input_image, junctions, line_map, junc_size=3, - color=(0, 255., 0), line_width=1, plot_survived_junc=True): +def plot_line_segments( + input_image, + junctions, + line_map, + junc_size=3, + color=(0, 255.0, 0), + line_width=1, + plot_survived_junc=True, +): """ input_image: can be 0~1 float or 0~255 uint8. junctions: Nx2 or 2xN np array. @@ -85,15 +95,17 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, if image.dtype == np.uint8: pass # A float type image ranging from 0~1 - elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: - image = (image * 255.).astype(np.uint8) + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: + image = (image * 255.0).astype(np.uint8) # A float type image ranging from 0.~255. - elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: image = image.astype(np.uint8) else: - raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + raise ValueError( + "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." + ) - # Check whether the image is single channel + # Check whether the image is single channel if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): # Squeeze to H*W first image = image.squeeze() @@ -111,7 +123,7 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, junctions = junctions.T else: raise ValueError("[Error] At least one of the two dims should be 2.") - + # line_map dimension should be 2 if not len(line_map.shape) == 2: raise ValueError("[Error] line_map should be 2-dim array.") @@ -122,8 +134,10 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, raise ValueError("[Error] color should have type list or tuple.") else: if len(color) != 3: - raise ValueError("[Error] color should be a list or tuple with length 3.") - + raise ValueError( + "[Error] color should be a list or tuple with length 3." + ) + # Make a copy of the line_map line_map_tmp = copy.copy(line_map) @@ -136,14 +150,17 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, # record the line segment else: for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: - p1 = np.flip(junctions[idx, :]) # Convert to xy format - p2 = np.flip(junctions[idx2, :]) # Convert to xy format - segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0) - + p1 = np.flip(junctions[idx, :]) # Convert to xy format + p2 = np.flip(junctions[idx2, :]) # Convert to xy format + segments = np.concatenate( + (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), + axis=0, + ) + # Update line_map line_map_tmp[idx, idx2] = 0 line_map_tmp[idx2, idx] = 0 - + # Draw segment pairs for idx in range(segments.shape[0]): seg = np.round(segments[idx, :]).astype(np.int) @@ -151,8 +168,14 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, if color != "random": color = tuple(color) else: - color = tuple(np.random.rand(3,)) - cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width) + color = tuple( + np.random.rand( + 3, + ) + ) + cv2.line( + image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width + ) # Also draw the junctions if not plot_survived_junc: @@ -160,45 +183,63 @@ def plot_line_segments(input_image, junctions, line_map, junc_size=3, for idx in range(num_junc): # Fetch one junction junc = junctions[idx, :] - cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, - color=(0, 255., 0), thickness=3) + cv2.circle( + image, + tuple(np.flip(junc)), + radius=junc_size, + color=(0, 255.0, 0), + thickness=3, + ) # Only plot the junctions which are part of a line segment else: for idx in range(segments.shape[0]): - seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format. - cv2.circle(image, tuple(seg[:2]), radius=junc_size, - color=(0, 255., 0), thickness=3) - cv2.circle(image, tuple(seg[2:]), radius=junc_size, - color=(0, 255., 0), thickness=3) - + seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format. + cv2.circle( + image, + tuple(seg[:2]), + radius=junc_size, + color=(0, 255.0, 0), + thickness=3, + ) + cv2.circle( + image, + tuple(seg[2:]), + radius=junc_size, + color=(0, 255.0, 0), + thickness=3, + ) + return image # Plot line segments given Nx4 or Nx2x2 line segments -def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, - color=(0, 255., 0), line_width=1): +def plot_line_segments_from_segments( + input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1 +): # Create image copy image = copy.copy(input_image) # Make sure the image is converted to 255 uint8 if image.dtype == np.uint8: pass # A float type image ranging from 0~1 - elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: - image = (image * 255.).astype(np.uint8) + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: + image = (image * 255.0).astype(np.uint8) # A float type image ranging from 0.~255. - elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: image = image.astype(np.uint8) else: - raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + raise ValueError( + "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." + ) - # Check whether the image is single channel + # Check whether the image is single channel if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): # Squeeze to H*W first image = image.squeeze() # Stack to channle 3 image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) - + # Check the if line_segments are in (1) Nx4, or (2) Nx2x2. H, W, _ = image.shape # (1) Nx4 format @@ -207,18 +248,20 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, line_segments = line_segments.astype(np.int32) # Clip H dimension - line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1) - line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1) + line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1) + line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1) # Clip W dimension - line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1) - line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1) + line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1) + line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1) # Convert to Nx2x2 format line_segments = np.concatenate( - [np.expand_dims(line_segments[:, :2], axis=1), - np.expand_dims(line_segments[:, 2:], axis=1)], - axis=1 + [ + np.expand_dims(line_segments[:, :2], axis=1), + np.expand_dims(line_segments[:, 2:], axis=1), + ], + axis=1, ) # (2) Nx2x2 format @@ -227,11 +270,13 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, line_segments = line_segments.astype(np.int32) # Clip H dimension - line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1) - line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1) + line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1) + line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1) else: - raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.") + raise ValueError( + "[Error] line_segments should be either Nx4 or Nx2x2 in HW format." + ) # Draw segment pairs (all segments should be in HW format) image = image.copy() @@ -241,21 +286,41 @@ def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, if color != "random": color = tuple(color) else: - color = tuple(np.random.rand(3,)) - cv2.line(image, tuple(np.flip(seg[0, :])), - tuple(np.flip(seg[1, :])), - color=color, thickness=line_width) + color = tuple( + np.random.rand( + 3, + ) + ) + cv2.line( + image, + tuple(np.flip(seg[0, :])), + tuple(np.flip(seg[1, :])), + color=color, + thickness=line_width, + ) # Also draw the junctions - cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3) - cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3) - + cv2.circle( + image, + tuple(np.flip(seg[0, :])), + radius=junc_size, + color=(0, 255.0, 0), + thickness=3, + ) + cv2.circle( + image, + tuple(np.flip(seg[1, :])), + radius=junc_size, + color=(0, 255.0, 0), + thickness=3, + ) + return image # Additional functions to visualize multiple images at the same time, # e.g. for line matching -def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5): +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5): """Plot a set of images horizontally. Args: imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). @@ -265,7 +330,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5): n = len(imgs) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n - figsize = (size*n, size*3/4) if size is not None else None + figsize = (size * n, size * 3 / 4) if size is not None else None fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) if n == 1: ax = [ax] @@ -281,7 +346,7 @@ def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5): fig.tight_layout(pad=pad) -def plot_keypoints(kpts, colors='lime', ps=4): +def plot_keypoints(kpts, colors="lime", ps=4): """Plot keypoints for existing images. Args: kpts: list of ndarrays of size (N, 2). @@ -295,7 +360,7 @@ def plot_keypoints(kpts, colors='lime', ps=4): a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) -def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): """Plot matches for a pair of existing images. Args: kpts0, kpts1: corresponding keypoints of size (N, 2). @@ -322,11 +387,18 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) - fig.lines += [matplotlib.lines.Line2D( - (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), - zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, - alpha=a) - for i in range(len(kpts0))] + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=color[i], + linewidth=lw, + alpha=a, + ) + for i in range(len(kpts0)) + ] # freeze the axes to prevent the transform to change ax0.autoscale(enable=False) @@ -337,8 +409,9 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2) -def plot_lines(lines, line_colors='orange', point_colors='cyan', - ps=4, lw=2, indices=(0, 1)): +def plot_lines( + lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1) +): """Plot lines and endpoints for existing images. Args: lines: list of ndarrays of size (N, 2, 2). @@ -361,16 +434,19 @@ def plot_lines(lines, line_colors='orange', point_colors='cyan', # Plot the lines and junctions for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): for i in range(len(l)): - line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), - (l[i, 0, 1], l[i, 1, 1]), - zorder=1, c=lc, linewidth=lw) + line = matplotlib.lines.Line2D( + (l[i, 0, 0], l[i, 1, 0]), + (l[i, 0, 1], l[i, 1, 1]), + zorder=1, + c=lc, + linewidth=lw, + ) a.add_line(line) pts = l.reshape(-1, 2) - a.scatter(pts[:, 0], pts[:, 1], - c=pc, s=ps, linewidths=0, zorder=2) + a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2) -def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.): +def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0): """Plot matches for a pair of existing images, parametrized by their middle point. Args: kpts0, kpts1: corresponding middle points of the lines of size (N, 2). @@ -396,19 +472,25 @@ def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.): transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) - fig.lines += [matplotlib.lines.Line2D( - (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), - zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, - alpha=a) - for i in range(len(kpts0))] + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=color[i], + linewidth=lw, + alpha=a, + ) + for i in range(len(kpts0)) + ] # freeze the axes to prevent the transform to change ax0.autoscale(enable=False) ax1.autoscale(enable=False) -def plot_color_line_matches(lines, correct_matches=None, - lw=2, indices=(0, 1)): +def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)): """Plot line matches for existing images with multiple colors. Args: lines: list of ndarrays of size (N, 2, 2). @@ -417,7 +499,7 @@ def plot_color_line_matches(lines, correct_matches=None, indices: indices of the images to draw the matches on. """ n_lines = len(lines[0]) - colors = sns.color_palette('husl', n_colors=n_lines) + colors = sns.color_palette("husl", n_colors=n_lines) np.random.shuffle(colors) alphas = np.ones(n_lines) # If correct_matches is not None, display wrong matches with a low alpha @@ -436,15 +518,21 @@ def plot_color_line_matches(lines, correct_matches=None, transFigure = fig.transFigure.inverted() endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) - fig.lines += [matplotlib.lines.Line2D( - (endpoint0[i, 0], endpoint1[i, 0]), - (endpoint0[i, 1], endpoint1[i, 1]), - zorder=1, transform=fig.transFigure, c=colors[i], - alpha=alphas[i], linewidth=lw) for i in range(n_lines)] - - -def plot_color_lines(lines, correct_matches, wrong_matches, - lw=2, indices=(0, 1)): + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=colors[i], + alpha=alphas[i], + linewidth=lw, + ) + for i in range(n_lines) + ] + + +def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)): """Plot line matches for existing images with multiple colors: green for correct matches, red for wrong ones, and blue for the rest. Args: @@ -476,15 +564,21 @@ def plot_color_lines(lines, correct_matches, wrong_matches, transFigure = fig.transFigure.inverted() endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) - fig.lines += [matplotlib.lines.Line2D( - (endpoint0[i, 0], endpoint1[i, 0]), - (endpoint0[i, 1], endpoint1[i, 1]), - zorder=1, transform=fig.transFigure, c=c[i], - linewidth=lw) for i in range(len(l))] + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=c[i], + linewidth=lw, + ) + for i in range(len(l)) + ] def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): - """ Plot line matches for existing images with multiple colors and + """Plot line matches for existing images with multiple colors and highlight the actually matched subsegments. Args: lines: list of ndarrays of size (N, 2, 2). @@ -493,8 +587,9 @@ def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): indices: indices of the images to draw the matches on. """ n_lines = len(lines[0]) - colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7, - gamma=1.3, hue=1, n_colors=n_lines) + colors = sns.cubehelix_palette( + start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines + ) fig = plt.gcf() ax = fig.axes @@ -510,17 +605,31 @@ def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): # Draw full line endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) - fig.lines += [matplotlib.lines.Line2D( - (endpoint0[i, 0], endpoint1[i, 0]), - (endpoint0[i, 1], endpoint1[i, 1]), - zorder=1, transform=fig.transFigure, c='red', - alpha=0.7, linewidth=lw) for i in range(n_lines)] + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c="red", + alpha=0.7, + linewidth=lw, + ) + for i in range(n_lines) + ] # Draw matched subsegment endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1])) - fig.lines += [matplotlib.lines.Line2D( - (endpoint0[i, 0], endpoint1[i, 0]), - (endpoint0[i, 1], endpoint1[i, 1]), - zorder=1, transform=fig.transFigure, c=colors[i], - alpha=1, linewidth=lw) for i in range(n_lines)] \ No newline at end of file + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=colors[i], + alpha=1, + linewidth=lw, + ) + for i in range(n_lines) + ] diff --git a/third_party/SOLD2/sold2/model/line_detection.py b/third_party/SOLD2/sold2/model/line_detection.py index 0c186337b0ce2072ddd5246408c538dac2cf325f..8ff379a8de3ff5d54dc807b397f947ea8f361ef9 100644 --- a/third_party/SOLD2/sold2/model/line_detection.py +++ b/third_party/SOLD2/sold2/model/line_detection.py @@ -7,14 +7,25 @@ import torch class LineSegmentDetectionModule(object): - """ Module extracting line segments from junctions and line heatmaps. """ + """Module extracting line segments from junctions and line heatmaps.""" + def __init__( - self, detect_thresh, num_samples=64, sampling_method="local_max", - inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2, - max_local_patch_radius=3, lambda_radius=2., - use_candidate_suppression=False, nms_dist_tolerance=3., - use_heatmap_refinement=False, heatmap_refine_cfg=None, - use_junction_refinement=False, junction_refine_cfg=None): + self, + detect_thresh, + num_samples=64, + sampling_method="local_max", + inlier_thresh=0.0, + heatmap_low_thresh=0.15, + heatmap_high_thresh=0.2, + max_local_patch_radius=3, + lambda_radius=2.0, + use_candidate_suppression=False, + nms_dist_tolerance=3.0, + use_heatmap_refinement=False, + heatmap_refine_cfg=None, + use_junction_refinement=False, + junction_refine_cfg=None, + ): """ Parameters: detect_thresh: The probability threshold for mean activation (0. ~ 1.) @@ -41,7 +52,7 @@ class LineSegmentDetectionModule(object): self.inlier_thresh = inlier_thresh self.local_patch_radius = max_local_patch_radius self.lambda_radius = lambda_radius - + # Detecting junctions on the boundary parameters self.low_thresh = heatmap_low_thresh self.high_thresh = heatmap_high_thresh @@ -65,56 +76,61 @@ class LineSegmentDetectionModule(object): self.junction_refine_cfg = junction_refine_cfg if self.use_junction_refinement and self.junction_refine_cfg is None: raise ValueError("[Error] Missing junction refinement config.") - + def convert_inputs(self, inputs, device): - """ Convert inputs to desired torch tensor. """ + """Convert inputs to desired torch tensor.""" if isinstance(inputs, np.ndarray): outputs = torch.tensor(inputs, dtype=torch.float32, device=device) elif isinstance(inputs, torch.Tensor): outputs = inputs.to(torch.float32).to(device) else: raise ValueError( - "[Error] Inputs must either be torch tensor or numpy ndarray.") - + "[Error] Inputs must either be torch tensor or numpy ndarray." + ) + return outputs - + def detect(self, junctions, heatmap, device=torch.device("cpu")): - """ Main function performing line segment detection. """ + """Main function performing line segment detection.""" # Convert inputs to torch tensor junctions = self.convert_inputs(junctions, device=device) heatmap = self.convert_inputs(heatmap, device=device) - + # Perform the heatmap refinement if self.use_heatmap_refinement: if self.heatmap_refine_cfg["mode"] == "global": heatmap = self.refine_heatmap( - heatmap, + heatmap, self.heatmap_refine_cfg["ratio"], - self.heatmap_refine_cfg["valid_thresh"] + self.heatmap_refine_cfg["valid_thresh"], ) elif self.heatmap_refine_cfg["mode"] == "local": heatmap = self.refine_heatmap_local( - heatmap, + heatmap, self.heatmap_refine_cfg["num_blocks"], self.heatmap_refine_cfg["overlap_ratio"], self.heatmap_refine_cfg["ratio"], - self.heatmap_refine_cfg["valid_thresh"] + self.heatmap_refine_cfg["valid_thresh"], ) - + # Initialize empty line map num_junctions = junctions.shape[0] - line_map_pred = torch.zeros([num_junctions, num_junctions], - device=device, dtype=torch.int32) - + line_map_pred = torch.zeros( + [num_junctions, num_junctions], device=device, dtype=torch.int32 + ) + # Stop if there are not enough junctions if num_junctions < 2: return line_map_pred, junctions, heatmap # Generate the candidate map - candidate_map = torch.triu(torch.ones( - [num_junctions, num_junctions], device=device, dtype=torch.int32), - diagonal=1) - + candidate_map = torch.triu( + torch.ones( + [num_junctions, num_junctions], device=device, dtype=torch.int32 + ), + diagonal=1, + ) + # Fetch the image boundary if len(heatmap.shape) > 2: H, W, _ = heatmap.shape @@ -123,39 +139,47 @@ class LineSegmentDetectionModule(object): # Optionally perform candidate filtering if self.use_candidate_suppression: - candidate_map = self.candidate_suppression(junctions, - candidate_map) + candidate_map = self.candidate_suppression(junctions, candidate_map) # Fetch the candidates candidate_index_map = torch.where(candidate_map) - candidate_index_map = torch.cat([candidate_index_map[0][..., None], - candidate_index_map[1][..., None]], - dim=-1) - + candidate_index_map = torch.cat( + [candidate_index_map[0][..., None], candidate_index_map[1][..., None]], + dim=-1, + ) + # Get the corresponding start and end junctions candidate_junc_start = junctions[candidate_index_map[:, 0], :] candidate_junc_end = junctions[candidate_index_map[:, 1], :] # Get the sampling locations (N x 64) sampler = self.torch_sampler.to(device)[None, ...] - cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \ - candidate_junc_end[:, 0:1] * (1 - sampler) - cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \ - candidate_junc_end[:, 1:2] * (1 - sampler) - + cand_samples_h = candidate_junc_start[:, 0:1] * sampler + candidate_junc_end[ + :, 0:1 + ] * (1 - sampler) + cand_samples_w = candidate_junc_start[:, 1:2] * sampler + candidate_junc_end[ + :, 1:2 + ] * (1 - sampler) + # Clip to image boundary - cand_h = torch.clamp(cand_samples_h, min=0, max=H-1) - cand_w = torch.clamp(cand_samples_w, min=0, max=W-1) - + cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) + cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) + # Local maximum search if self.sampling_method == "local_max": # Compute normalized segment lengths - segments_length = torch.sqrt(torch.sum( - (candidate_junc_start.to(torch.float32) - - candidate_junc_end.to(torch.float32)) ** 2, dim=-1)) - normalized_seg_length = (segments_length - / (((H ** 2) + (W ** 2)) ** 0.5)) - + segments_length = torch.sqrt( + torch.sum( + ( + candidate_junc_start.to(torch.float32) + - candidate_junc_end.to(torch.float32) + ) + ** 2, + dim=-1, + ) + ) + normalized_seg_length = segments_length / (((H**2) + (W**2)) ** 0.5) + # Perform local max search num_cand = cand_h.shape[0] group_size = 10000 @@ -163,85 +187,88 @@ class LineSegmentDetectionModule(object): num_iter = math.ceil(num_cand / group_size) sampled_feat_lst = [] for iter_idx in range(num_iter): - if not iter_idx == num_iter-1: - cand_h_ = cand_h[iter_idx * group_size: - (iter_idx+1) * group_size, :] - cand_w_ = cand_w[iter_idx * group_size: - (iter_idx+1) * group_size, :] + if not iter_idx == num_iter - 1: + cand_h_ = cand_h[ + iter_idx * group_size : (iter_idx + 1) * group_size, : + ] + cand_w_ = cand_w[ + iter_idx * group_size : (iter_idx + 1) * group_size, : + ] normalized_seg_length_ = normalized_seg_length[ - iter_idx * group_size: (iter_idx+1) * group_size] + iter_idx * group_size : (iter_idx + 1) * group_size + ] else: - cand_h_ = cand_h[iter_idx * group_size:, :] - cand_w_ = cand_w[iter_idx * group_size:, :] + cand_h_ = cand_h[iter_idx * group_size :, :] + cand_w_ = cand_w[iter_idx * group_size :, :] normalized_seg_length_ = normalized_seg_length[ - iter_idx * group_size:] + iter_idx * group_size : + ] sampled_feat_ = self.detect_local_max( - heatmap, cand_h_, cand_w_, H, W, - normalized_seg_length_, device) + heatmap, cand_h_, cand_w_, H, W, normalized_seg_length_, device + ) sampled_feat_lst.append(sampled_feat_) sampled_feat = torch.cat(sampled_feat_lst, dim=0) else: sampled_feat = self.detect_local_max( - heatmap, cand_h, cand_w, H, W, - normalized_seg_length, device) + heatmap, cand_h, cand_w, H, W, normalized_seg_length, device + ) # Bilinear sampling elif self.sampling_method == "bilinear": # Perform bilinear sampling - sampled_feat = self.detect_bilinear( - heatmap, cand_h, cand_w, H, W, device) + sampled_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) else: raise ValueError("[Error] Unknown sampling method.") - + # [Simple threshold detection] # detection_results is a mask over all candidates - detection_results = (torch.mean(sampled_feat, dim=-1) - > self.detect_thresh) - + detection_results = torch.mean(sampled_feat, dim=-1) > self.detect_thresh + # [Inlier threshold detection] - if self.inlier_thresh > 0.: - inlier_ratio = torch.sum( - sampled_feat > self.detect_thresh, - dim=-1).to(torch.float32) / self.num_samples + if self.inlier_thresh > 0.0: + inlier_ratio = ( + torch.sum(sampled_feat > self.detect_thresh, dim=-1).to(torch.float32) + / self.num_samples + ) detection_results_inlier = inlier_ratio >= self.inlier_thresh detection_results = detection_results * detection_results_inlier # Convert detection results back to line_map_pred detected_junc_indexes = candidate_index_map[detection_results, :] - line_map_pred[detected_junc_indexes[:, 0], - detected_junc_indexes[:, 1]] = 1 - line_map_pred[detected_junc_indexes[:, 1], - detected_junc_indexes[:, 0]] = 1 - + line_map_pred[detected_junc_indexes[:, 0], detected_junc_indexes[:, 1]] = 1 + line_map_pred[detected_junc_indexes[:, 1], detected_junc_indexes[:, 0]] = 1 + # Perform junction refinement if self.use_junction_refinement and len(detected_junc_indexes) > 0: junctions, line_map_pred = self.refine_junction_perturb( - junctions, line_map_pred, heatmap, H, W, device) + junctions, line_map_pred, heatmap, H, W, device + ) return line_map_pred, junctions, heatmap - + def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): - """ Global heatmap refinement method. """ + """Global heatmap refinement method.""" # Grab the top 10% values heatmap_values = heatmap[heatmap > valid_thresh] sorted_values = torch.sort(heatmap_values, descending=True)[0] top10_len = math.ceil(sorted_values.shape[0] * ratio) max20 = torch.mean(sorted_values[:top10_len]) - heatmap = torch.clamp(heatmap / max20, min=0., max=1.) + heatmap = torch.clamp(heatmap / max20, min=0.0, max=1.0) return heatmap - - def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5, - ratio=0.2, valid_thresh=2e-3): - """ Local heatmap refinement method. """ + + def refine_heatmap_local( + self, heatmap, num_blocks=5, overlap_ratio=0.5, ratio=0.2, valid_thresh=2e-3 + ): + """Local heatmap refinement method.""" # Get the shape of the heatmap H, W = heatmap.shape increase_ratio = 1 - overlap_ratio h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) - count_map = torch.zeros(heatmap.shape, dtype=torch.int, - device=heatmap.device) - heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float, - device=heatmap.device) + count_map = torch.zeros(heatmap.shape, dtype=torch.int, device=heatmap.device) + heatmap_output = torch.zeros( + heatmap.shape, dtype=torch.float, device=heatmap.device + ) # Iterate through each block for h_idx in range(num_blocks): for w_idx in range(num_blocks): @@ -254,25 +281,29 @@ class LineSegmentDetectionModule(object): subheatmap = heatmap[h_start:h_end, w_start:w_end] if subheatmap.max() > valid_thresh: subheatmap = self.refine_heatmap( - subheatmap, ratio, valid_thresh=valid_thresh) - + subheatmap, ratio, valid_thresh=valid_thresh + ) + # Aggregate it to the final heatmap heatmap_output[h_start:h_end, w_start:w_end] += subheatmap count_map[h_start:h_end, w_start:w_end] += 1 - heatmap_output = torch.clamp(heatmap_output / count_map, - max=1., min=0.) + heatmap_output = torch.clamp(heatmap_output / count_map, max=1.0, min=0.0) return heatmap_output def candidate_suppression(self, junctions, candidate_map): - """ Suppress overlapping long lines in the candidate segments. """ + """Suppress overlapping long lines in the candidate segments.""" # Define the distance tolerance dist_tolerance = self.nms_dist_tolerance # Compute distance between junction pairs # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map - line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1) - - junctions[None, ...]) ** 2, dim=-1) ** 0.5 + line_dist_map = ( + torch.sum( + (torch.unsqueeze(junctions, dim=1) - junctions[None, ...]) ** 2, dim=-1 + ) + ** 0.5 + ) # Fetch all the "detected lines" seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) @@ -285,20 +316,23 @@ class LineSegmentDetectionModule(object): line_dists = line_dist_map[start_point_idxs, end_point_idxs] # Check whether they are on the line - dir_vecs = ((end_points - start_points) - / torch.norm(end_points - start_points, - dim=-1)[..., None]) + dir_vecs = (end_points - start_points) / torch.norm( + end_points - start_points, dim=-1 + )[..., None] # Get the orthogonal distance cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) cand_vecs_norm = torch.norm(cand_vecs, dim=-1) # Check whether they are projected directly onto the segment - proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) - / line_dists[..., None, None]) + proj = ( + torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) + / line_dists[..., None, None] + ) # proj is num_segs x num_junction x 1 - proj_mask = (proj >=0) * (proj <= 1) + proj_mask = (proj >= 0) * (proj <= 1) cand_angles = torch.acos( - torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) - / cand_vecs_norm[..., None]) + torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) + / cand_vecs_norm[..., None] + ) cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) junc_dist_mask = cand_dists <= dist_tolerance junc_mask = junc_dist_mask * proj_mask @@ -306,21 +340,21 @@ class LineSegmentDetectionModule(object): # Minus starting points num_segs = start_point_idxs.shape[0] junc_counts = torch.sum(junc_mask, dim=[1, 2]) - junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), - start_point_idxs].to(torch.int) - junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), - end_point_idxs].to(torch.int) - + junc_counts -= junc_mask[..., 0][ + torch.arange(0, num_segs), start_point_idxs + ].to(torch.int) + junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), end_point_idxs].to( + torch.int + ) + # Get the invalid candidate mask final_mask = junc_counts > 0 - candidate_map[start_point_idxs[final_mask], - end_point_idxs[final_mask]] = 0 - + candidate_map[start_point_idxs[final_mask], end_point_idxs[final_mask]] = 0 + return candidate_map - - def refine_junction_perturb(self, junctions, line_map_pred, - heatmap, H, W, device): - """ Refine the line endpoints in a similar way as in LSD. """ + + def refine_junction_perturb(self, junctions, line_map_pred, heatmap, H, W, device): + """Refine the line endpoints in a similar way as in LSD.""" # Get the config junction_refine_cfg = self.junction_refine_cfg @@ -330,14 +364,23 @@ class LineSegmentDetectionModule(object): side_perturbs = (num_perturbs - 1) // 2 # Fetch the 2D perturb mat perturb_vec = torch.arange( - start=-perturb_interval*side_perturbs, - end=perturb_interval*(side_perturbs+1), - step=perturb_interval, device=device) + start=-perturb_interval * side_perturbs, + end=perturb_interval * (side_perturbs + 1), + step=perturb_interval, + device=device, + ) w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( - perturb_vec, perturb_vec, perturb_vec, perturb_vec) - perturb_tensor = torch.cat([ - w1_grid[..., None], h1_grid[..., None], - w2_grid[..., None], h2_grid[..., None]], dim=-1) + perturb_vec, perturb_vec, perturb_vec, perturb_vec + ) + perturb_tensor = torch.cat( + [ + w1_grid[..., None], + h1_grid[..., None], + w2_grid[..., None], + h2_grid[..., None], + ], + dim=-1, + ) perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) # Fetch the junctions and line_map @@ -351,16 +394,20 @@ class LineSegmentDetectionModule(object): start_points = junctions[start_point_idxs, :] end_points = junctions[end_point_idxs, :] - line_segments = torch.cat([start_points.unsqueeze(dim=1), - end_points.unsqueeze(dim=1)], dim=1) + line_segments = torch.cat( + [start_points.unsqueeze(dim=1), end_points.unsqueeze(dim=1)], dim=1 + ) - line_segment_candidates = (line_segments.unsqueeze(dim=1) - + perturb_tensor_flat[None, ...]) + line_segment_candidates = ( + line_segments.unsqueeze(dim=1) + perturb_tensor_flat[None, ...] + ) # Clip the boundaries line_segment_candidates[..., 0] = torch.clamp( - line_segment_candidates[..., 0], min=0, max=H - 1) + line_segment_candidates[..., 0], min=0, max=H - 1 + ) line_segment_candidates[..., 1] = torch.clamp( - line_segment_candidates[..., 1], min=0, max=W - 1) + line_segment_candidates[..., 1], min=0, max=W - 1 + ) # Iterate through all the segments refined_segment_lst = [] @@ -373,36 +420,37 @@ class LineSegmentDetectionModule(object): # Get the sampling locations (N x 64) sampler = self.torch_sampler.to(device)[None, ...] - cand_samples_h = (candidate_junc_start[:, 0:1] * sampler + - candidate_junc_end[:, 0:1] * (1 - sampler)) - cand_samples_w = (candidate_junc_start[:, 1:2] * sampler + - candidate_junc_end[:, 1:2] * (1 - sampler)) - + cand_samples_h = candidate_junc_start[ + :, 0:1 + ] * sampler + candidate_junc_end[:, 0:1] * (1 - sampler) + cand_samples_w = candidate_junc_start[ + :, 1:2 + ] * sampler + candidate_junc_end[:, 1:2] * (1 - sampler) + # Clip to image boundary cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) # Perform bilinear sampling - segment_feat = self.detect_bilinear( - heatmap, cand_h, cand_w, H, W, device) + segment_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) segment_results = torch.mean(segment_feat, dim=-1) max_idx = torch.argmax(segment_results) refined_segment_lst.append(segment[max_idx, ...][None, ...]) - + # Concatenate back to segments refined_segments = torch.cat(refined_segment_lst, dim=0) # Convert back to junctions and line_map junctions_new = torch.cat( - [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0) + [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0 + ) junctions_new = torch.unique(junctions_new, dim=0) - line_map_new = self.segments_to_line_map(junctions_new, - refined_segments) + line_map_new = self.segments_to_line_map(junctions_new, refined_segments) return junctions_new, line_map_new - + def segments_to_line_map(self, junctions, segments): - """ Convert the list of segments to line map. """ + """Convert the list of segments to line map.""" # Create empty line map device = junctions.device num_junctions = junctions.shape[0] @@ -416,10 +464,8 @@ class LineSegmentDetectionModule(object): junction2 = seg[1, :] # Get index - idx_junction1 = torch.where( - (junctions == junction1).sum(axis=1) == 2)[0] - idx_junction2 = torch.where( - (junctions == junction2).sum(axis=1) == 2)[0] + idx_junction1 = torch.where((junctions == junction1).sum(axis=1) == 2)[0] + idx_junction2 = torch.where((junctions == junction2).sum(axis=1) == 2)[0] # label the corresponding entries line_map[idx_junction1, idx_junction2] = 1 @@ -428,7 +474,7 @@ class LineSegmentDetectionModule(object): return line_map def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): - """ Detection by bilinear sampling. """ + """Detection by bilinear sampling.""" # Get the floor and ceiling locations cand_h_floor = torch.floor(cand_h).to(torch.long) cand_h_ceil = torch.ceil(cand_h).to(torch.long) @@ -437,63 +483,83 @@ class LineSegmentDetectionModule(object): # Perform the bilinear sampling cand_samples_feat = ( - heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h) - * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil] - * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) + - heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor) - * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil] - * (cand_h - cand_h_floor) * (cand_w - cand_w_floor)) - + heatmap[cand_h_floor, cand_w_floor] + * (cand_h_ceil - cand_h) + * (cand_w_ceil - cand_w) + + heatmap[cand_h_floor, cand_w_ceil] + * (cand_h_ceil - cand_h) + * (cand_w - cand_w_floor) + + heatmap[cand_h_ceil, cand_w_floor] + * (cand_h - cand_h_floor) + * (cand_w_ceil - cand_w) + + heatmap[cand_h_ceil, cand_w_ceil] + * (cand_h - cand_h_floor) + * (cand_w - cand_w_floor) + ) + return cand_samples_feat - - def detect_local_max(self, heatmap, cand_h, cand_w, H, W, - normalized_seg_length, device): - """ Detection by local maximum search. """ + + def detect_local_max( + self, heatmap, cand_h, cand_w, H, W, normalized_seg_length, device + ): + """Detection by local maximum search.""" # Compute the distance threshold - dist_thresh = (0.5 * (2 ** 0.5) - + self.lambda_radius * normalized_seg_length) + dist_thresh = 0.5 * (2**0.5) + self.lambda_radius * normalized_seg_length # Make it N x 64 - dist_thresh = torch.repeat_interleave(dist_thresh[..., None], - self.num_samples, dim=-1) - + dist_thresh = torch.repeat_interleave( + dist_thresh[..., None], self.num_samples, dim=-1 + ) + # Compute the candidate points - cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], - dim=-1) - cand_points_round = torch.round(cand_points) # N x 64 x 2 - + cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], dim=-1) + cand_points_round = torch.round(cand_points) # N x 64 x 2 + # Construct local patches 9x9 = 81 - patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), - int(2 * self.local_patch_radius + 1)], - device=device) + patch_mask = torch.zeros( + [ + int(2 * self.local_patch_radius + 1), + int(2 * self.local_patch_radius + 1), + ], + device=device, + ) patch_center = torch.tensor( - [[self.local_patch_radius, self.local_patch_radius]], - device=device, dtype=torch.float32) + [[self.local_patch_radius, self.local_patch_radius]], + device=device, + dtype=torch.float32, + ) H_patch_points, W_patch_points = torch.where(patch_mask >= 0) - patch_points = torch.cat([H_patch_points[..., None], - W_patch_points[..., None]], dim=-1) + patch_points = torch.cat( + [H_patch_points[..., None], W_patch_points[..., None]], dim=-1 + ) # Fetch the circle region - patch_center_dist = torch.sqrt(torch.sum( - (patch_points - patch_center) ** 2, dim=-1)) - patch_points = (patch_points[patch_center_dist - <= self.local_patch_radius, :]) + patch_center_dist = torch.sqrt( + torch.sum((patch_points - patch_center) ** 2, dim=-1) + ) + patch_points = patch_points[patch_center_dist <= self.local_patch_radius, :] # Shift [0, 0] to the center patch_points = patch_points - self.local_patch_radius - + # Construct local patch mask - patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2) - + patch_points[None, None, ...]) - patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2) - - patch_points_shifted) ** 2, - dim=-1)) + patch_points_shifted = ( + torch.unsqueeze(cand_points_round, dim=2) + patch_points[None, None, ...] + ) + patch_dist = torch.sqrt( + torch.sum( + (torch.unsqueeze(cand_points, dim=2) - patch_points_shifted) ** 2, + dim=-1, + ) + ) patch_dist_mask = patch_dist < dist_thresh[..., None] - + # Get all points => num_points_center x num_patch_points x 2 - points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, - max=H - 1).to(torch.long) - points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, - max=W - 1).to(torch.long) + points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, max=H - 1).to( + torch.long + ) + points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, max=W - 1).to( + torch.long + ) points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) - + # Sample the feature (N x 64 x 81) sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] # Filtering using the valid mask @@ -502,5 +568,5 @@ class LineSegmentDetectionModule(object): sampled_feat_lmax = torch.empty(0, 64) else: sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) - + return sampled_feat_lmax diff --git a/third_party/SOLD2/sold2/model/line_detector.py b/third_party/SOLD2/sold2/model/line_detector.py index 2f3d059e130178c482e8e569171ef9e0370424c7..33429f8bc48d21d223efaf83ab6a8f1375b359ec 100644 --- a/third_party/SOLD2/sold2/model/line_detector.py +++ b/third_party/SOLD2/sold2/model/line_detector.py @@ -14,7 +14,7 @@ from ..misc.train_utils import adapt_checkpoint def line_map_to_segments(junctions, line_map): - """ Convert a line map to a Nx2x2 list of segments. """ + """Convert a line map to a Nx2x2 list of segments.""" line_map_tmp = line_map.copy() output_segments = np.zeros([0, 2, 2]) @@ -27,22 +27,23 @@ def line_map_to_segments(junctions, line_map): for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: p1 = junctions[idx, :] # HW format p2 = junctions[idx2, :] - single_seg = np.concatenate([p1[None, ...], p2[None, ...]], - axis=0) + single_seg = np.concatenate([p1[None, ...], p2[None, ...]], axis=0) output_segments = np.concatenate( - (output_segments, single_seg[None, ...]), axis=0) - + (output_segments, single_seg[None, ...]), axis=0 + ) + # Update line_map line_map_tmp[idx, idx2] = 0 line_map_tmp[idx2, idx] = 0 - + return output_segments class LineDetector(object): - def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, - junc_detect_thresh=None): - """ SOLD² line detector taking raw images as input. + def __init__( + self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None + ): + """SOLD² line detector taking raw images as input. Parameters: model_cfg: config for CNN model ckpt_path: path to the weights @@ -51,7 +52,7 @@ class LineDetector(object): # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device - + # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) @@ -65,20 +66,21 @@ class LineDetector(object): if junc_detect_thresh is not None: self.junc_detect_thresh = junc_detect_thresh else: - self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65) + self.junc_detect_thresh = model_cfg.get("detection_thresh", 1 / 65) self.max_num_junctions = model_cfg.get("max_num_junctions", 300) # Initialize the line detector self.line_detector_cfg = line_detector_cfg self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) - - def __call__(self, input_image, valid_mask=None, - return_heatmap=False, profile=False): + + def __call__( + self, input_image, valid_mask=None, return_heatmap=False, profile=False + ): # Now we restrict input_image to 4D torch tensor - if ((not len(input_image.shape) == 4) - or (not isinstance(input_image, torch.Tensor))): - raise ValueError( - "[Error] the input image should be a 4D torch tensor.") + if (not len(input_image.shape) == 4) or ( + not isinstance(input_image, torch.Tensor) + ): + raise ValueError("[Error] the input image should be a 4D torch tensor.") # Move the input to corresponding device input_image = input_image.to(self.device) @@ -89,15 +91,18 @@ class LineDetector(object): net_outputs = self.model(input_image) junc_np = convert_junc_predictions( - net_outputs["junctions"], self.grid_size, - self.junc_detect_thresh, self.max_num_junctions) + net_outputs["junctions"], + self.grid_size, + self.junc_detect_thresh, + self.max_num_junctions, + ) if valid_mask is None: junctions = np.where(junc_np["junc_pred_nms"].squeeze()) else: - junctions = np.where(junc_np["junc_pred_nms"].squeeze() - * valid_mask) + junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask) junctions = np.concatenate( - [junctions[0][..., None], junctions[1][..., None]], axis=-1) + [junctions[0][..., None], junctions[1][..., None]], axis=-1 + ) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here @@ -108,7 +113,8 @@ class LineDetector(object): # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( - junctions, heatmap, device=self.device) + junctions, heatmap, device=self.device + ) heatmap = heatmap.cpu().numpy() if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() @@ -123,5 +129,5 @@ class LineDetector(object): outputs["heatmap"] = heatmap if profile: outputs["time"] = end_time - start_time - + return outputs diff --git a/third_party/SOLD2/sold2/model/line_matcher.py b/third_party/SOLD2/sold2/model/line_matcher.py index bc5a003573c91313e2295c75871edcb1c113662a..458a5e3141c0ad27c0ba665dbd72d5ce0c1c9a86 100644 --- a/third_party/SOLD2/sold2/model/line_matcher.py +++ b/third_party/SOLD2/sold2/model/line_matcher.py @@ -19,14 +19,23 @@ from .line_detector import line_map_to_segments class LineMatcher(object): - """ Full line matcher including line detection and matching - with the Needleman-Wunsch algorithm. """ - def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, - line_matcher_cfg, multiscale=False, scales=[1., 2.]): + """Full line matcher including line detection and matching + with the Needleman-Wunsch algorithm.""" + + def __init__( + self, + model_cfg, + ckpt_path, + device, + line_detector_cfg, + line_matcher_cfg, + multiscale=False, + scales=[1.0, 2.0], + ): # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device - + # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) @@ -46,23 +55,22 @@ class LineMatcher(object): # Initialize the line matcher self.line_matcher = WunschLineMatcher(**line_matcher_cfg) - + # Print some debug messages for key, val in line_detector_cfg.items(): print(f"[Debug] {key}: {val}") # print("[Debug] detect_thresh: %f" % (line_detector_cfg["detect_thresh"])) # print("[Debug] num_samples: %d" % (line_detector_cfg["num_samples"])) - - # Perform line detection and descriptor inference on a single image - def line_detection(self, input_image, valid_mask=None, - desc_only=False, profile=False): + def line_detection( + self, input_image, valid_mask=None, desc_only=False, profile=False + ): # Restrict input_image to 4D torch tensor - if ((not len(input_image.shape) == 4) - or (not isinstance(input_image, torch.Tensor))): - raise ValueError( - "[Error] the input image should be a 4D torch tensor") + if (not len(input_image.shape) == 4) or ( + not isinstance(input_image, torch.Tensor) + ): + raise ValueError("[Error] the input image should be a 4D torch tensor") # Move the input to corresponding device input_image = input_image.to(self.device) @@ -76,29 +84,40 @@ class LineMatcher(object): if not desc_only: junc_np = convert_junc_predictions( - net_outputs["junctions"], self.grid_size, - self.junc_detect_thresh, self.max_num_junctions) + net_outputs["junctions"], + self.grid_size, + self.junc_detect_thresh, + self.max_num_junctions, + ) if valid_mask is None: junctions = np.where(junc_np["junc_pred_nms"].squeeze()) else: - junctions = np.where( - junc_np["junc_pred_nms"].squeeze() * valid_mask) - junctions = np.concatenate([junctions[0][..., None], - junctions[1][..., None]], axis=-1) + junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask) + junctions = np.concatenate( + [junctions[0][..., None], junctions[1][..., None]], axis=-1 + ) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here - heatmap = softmax( - net_outputs["heatmap"], - dim=1)[:, 1:, :, :].cpu().numpy().transpose(0, 2, 3, 1) + heatmap = ( + softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] + .cpu() + .numpy() + .transpose(0, 2, 3, 1) + ) else: - heatmap = torch.sigmoid( - net_outputs["heatmap"]).cpu().numpy().transpose(0, 2, 3, 1) + heatmap = ( + torch.sigmoid(net_outputs["heatmap"]) + .cpu() + .numpy() + .transpose(0, 2, 3, 1) + ) heatmap = heatmap[0, :, :, 0] # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( - junctions, heatmap, device=self.device) + junctions, heatmap, device=self.device + ) if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() if isinstance(junctions, torch.Tensor): @@ -115,7 +134,9 @@ class LineMatcher(object): line_segments_inlier = [] for inlier_idx in range(num_inlier_thresh): line_map_tmp = line_map[detect_idx, inlier_idx, :, :] - line_segments_tmp = line_map_to_segments(junctions, line_map_tmp) + line_segments_tmp = line_map_to_segments( + junctions, line_map_tmp + ) line_segments_inlier.append(line_segments_tmp) line_segments.append(line_segments_inlier) else: @@ -127,18 +148,24 @@ class LineMatcher(object): if profile: outputs["time"] = end_time - start_time - + return outputs # Perform line detection and descriptor inference at multiple scales - def multiscale_line_detection(self, input_image, valid_mask=None, - desc_only=False, profile=False, - scales=[1., 2.], aggregation='mean'): + def multiscale_line_detection( + self, + input_image, + valid_mask=None, + desc_only=False, + profile=False, + scales=[1.0, 2.0], + aggregation="mean", + ): # Restrict input_image to 4D torch tensor - if ((not len(input_image.shape) == 4) - or (not isinstance(input_image, torch.Tensor))): - raise ValueError( - "[Error] the input image should be a 4D torch tensor") + if (not len(input_image.shape) == 4) or ( + not isinstance(input_image, torch.Tensor) + ): + raise ValueError("[Error] the input image should be a 4D torch tensor") # Move the input to corresponding device input_image = input_image.to(self.device) @@ -150,34 +177,39 @@ class LineMatcher(object): junctions, heatmaps, descriptors = [], [], [] for s in scales: # Resize the image - resized_img = F.interpolate(input_image, scale_factor=s, - mode='bilinear') + resized_img = F.interpolate(input_image, scale_factor=s, mode="bilinear") # Forward of the CNN backbone with torch.no_grad(): net_outputs = self.model(resized_img) - descriptors.append(F.interpolate( - net_outputs["descriptors"], size=desc_size, mode="bilinear")) + descriptors.append( + F.interpolate( + net_outputs["descriptors"], size=desc_size, mode="bilinear" + ) + ) if not desc_only: junc_prob = convert_junc_predictions( - net_outputs["junctions"], self.grid_size)["junc_pred"] - junctions.append(cv2.resize(junc_prob.squeeze(), - (img_size[1], img_size[0]), - interpolation=cv2.INTER_LINEAR)) + net_outputs["junctions"], self.grid_size + )["junc_pred"] + junctions.append( + cv2.resize( + junc_prob.squeeze(), + (img_size[1], img_size[0]), + interpolation=cv2.INTER_LINEAR, + ) + ) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here - heatmap = softmax(net_outputs["heatmap"], - dim=1)[:, 1:, :, :] + heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] else: heatmap = torch.sigmoid(net_outputs["heatmap"]) - heatmaps.append(F.interpolate(heatmap, size=img_size, - mode="bilinear")) + heatmaps.append(F.interpolate(heatmap, size=img_size, mode="bilinear")) # Aggregate the results - if aggregation == 'mean': + if aggregation == "mean": # Aggregation through the mean activation descriptors = torch.stack(descriptors, dim=0).mean(0) else: @@ -186,7 +218,7 @@ class LineMatcher(object): outputs = {"descriptor": descriptors} if not desc_only: - if aggregation == 'mean': + if aggregation == "mean": junctions = np.stack(junctions, axis=0).mean(0)[None] heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :] heatmap = heatmap.cpu().numpy() @@ -197,18 +229,23 @@ class LineMatcher(object): # Extract junctions junc_pred_nms = super_nms( - junctions[..., None], self.grid_size, - self.junc_detect_thresh, self.max_num_junctions) + junctions[..., None], + self.grid_size, + self.junc_detect_thresh, + self.max_num_junctions, + ) if valid_mask is None: junctions = np.where(junc_pred_nms.squeeze()) else: junctions = np.where(junc_pred_nms.squeeze() * valid_mask) - junctions = np.concatenate([junctions[0][..., None], - junctions[1][..., None]], axis=-1) + junctions = np.concatenate( + [junctions[0][..., None], junctions[1][..., None]], axis=-1 + ) # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( - junctions, heatmap, device=self.device) + junctions, heatmap, device=self.device + ) if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() if isinstance(junctions, torch.Tensor): @@ -226,7 +263,8 @@ class LineMatcher(object): for inlier_idx in range(num_inlier_thresh): line_map_tmp = line_map[detect_idx, inlier_idx, :, :] line_segments_tmp = line_map_to_segments( - junctions, line_map_tmp) + junctions, line_map_tmp + ) line_segments_inlier.append(line_segments_tmp) line_segments.append(line_segments_inlier) else: @@ -238,25 +276,25 @@ class LineMatcher(object): if profile: outputs["time"] = end_time - start_time - + return outputs - + def __call__(self, images, valid_masks=[None, None], profile=False): # Line detection and descriptor inference on both images if self.multiscale: forward_outputs = [ self.multiscale_line_detection( - images[0], valid_masks[0], profile=profile, - scales=self.scales), + images[0], valid_masks[0], profile=profile, scales=self.scales + ), self.multiscale_line_detection( - images[1], valid_masks[1], profile=profile, - scales=self.scales)] + images[1], valid_masks[1], profile=profile, scales=self.scales + ), + ] else: forward_outputs = [ - self.line_detection(images[0], valid_masks[0], - profile=profile), - self.line_detection(images[1], valid_masks[1], - profile=profile)] + self.line_detection(images[0], valid_masks[0], profile=profile), + self.line_detection(images[1], valid_masks[1], profile=profile), + ] line_seg1 = forward_outputs[0]["line_segments"] line_seg2 = forward_outputs[1]["line_segments"] desc1 = forward_outputs[0]["descriptor"] @@ -264,16 +302,15 @@ class LineMatcher(object): # Match the lines in both images start_time = time.time() - matches = self.line_matcher.forward(line_seg1, line_seg2, - desc1, desc2) + matches = self.line_matcher.forward(line_seg1, line_seg2, desc1, desc2) end_time = time.time() - outputs = {"line_segments": [line_seg1, line_seg2], - "matches": matches} + outputs = {"line_segments": [line_seg1, line_seg2], "matches": matches} if profile: - outputs["line_detection_time"] = (forward_outputs[0]["time"] - + forward_outputs[1]["time"]) + outputs["line_detection_time"] = ( + forward_outputs[0]["time"] + forward_outputs[1]["time"] + ) outputs["line_matching_time"] = end_time - start_time - + return outputs diff --git a/third_party/SOLD2/sold2/model/line_matching.py b/third_party/SOLD2/sold2/model/line_matching.py index 89b71879e3104f9a8b52c1cf5e534cd124fe83b2..bfceb5a161732c3f7f4cf97e988d5e369a4c25fa 100644 --- a/third_party/SOLD2/sold2/model/line_matching.py +++ b/third_party/SOLD2/sold2/model/line_matching.py @@ -10,11 +10,19 @@ from ..misc.geometry_utils import keypoints_to_grid class WunschLineMatcher(object): - """ Class matching two sets of line segments - with the Needleman-Wunsch algorithm. """ - def __init__(self, cross_check=True, num_samples=10, min_dist_pts=8, - top_k_candidates=10, grid_size=8, sampling="regular", - line_score=False): + """Class matching two sets of line segments + with the Needleman-Wunsch algorithm.""" + + def __init__( + self, + cross_check=True, + num_samples=10, + min_dist_pts=8, + top_k_candidates=10, + grid_size=8, + sampling="regular", + line_score=False, + ): self.cross_check = cross_check self.num_samples = num_samples self.min_dist_pts = min_dist_pts @@ -27,13 +35,11 @@ class WunschLineMatcher(object): def forward(self, line_seg1, line_seg2, desc1, desc2): """ - Find the best matches between two sets of line segments - and their corresponding descriptors. + Find the best matches between two sets of line segments + and their corresponding descriptors. """ - img_size1 = (desc1.shape[2] * self.grid_size, - desc1.shape[3] * self.grid_size) - img_size2 = (desc2.shape[2] * self.grid_size, - desc2.shape[3] * self.grid_size) + img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) + img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) device = desc1.device # Default case when an image has no lines @@ -48,13 +54,17 @@ class WunschLineMatcher(object): line_points2, valid_points2 = self.sample_line_points(line_seg2) else: line_points1, valid_points1 = self.sample_salient_points( - line_seg1, desc1, img_size1, self.sampling_mode) + line_seg1, desc1, img_size1, self.sampling_mode + ) line_points2, valid_points2 = self.sample_salient_points( - line_seg2, desc2, img_size2, self.sampling_mode) - line_points1 = torch.tensor(line_points1.reshape(-1, 2), - dtype=torch.float, device=device) - line_points2 = torch.tensor(line_points2.reshape(-1, 2), - dtype=torch.float, device=device) + line_seg2, desc2, img_size2, self.sampling_mode + ) + line_points1 = torch.tensor( + line_points1.reshape(-1, 2), dtype=torch.float, device=device + ) + line_points2 = torch.tensor( + line_points2.reshape(-1, 2), dtype=torch.float, device=device + ) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) @@ -67,8 +77,9 @@ class WunschLineMatcher(object): scores = desc1.t() @ desc2 scores[~valid_points1.flatten()] = -1 scores[:, ~valid_points2.flatten()] = -1 - scores = scores.reshape(len(line_seg1), self.num_samples, - len(line_seg2), self.num_samples) + scores = scores.reshape( + len(line_seg1), self.num_samples, len(line_seg2), self.num_samples + ) scores = scores.permute(0, 2, 1, 3) # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) @@ -77,16 +88,15 @@ class WunschLineMatcher(object): # [Optionally] filter matches with mutual nearest neighbor filtering if self.cross_check: - matches2 = self.filter_and_match_lines( - scores.permute(1, 0, 3, 2)) + matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2)) mutual = matches2[matches] == np.arange(len(line_seg1)) matches[~mutual] = -1 return matches def d2_net_saliency_score(self, desc): - """ Compute the D2-Net saliency score - on a 3D or 4D descriptor. """ + """Compute the D2-Net saliency score + on a 3D or 4D descriptor.""" is_3d = len(desc.shape) == 3 b_size = len(desc) feat = F.relu(desc) @@ -94,11 +104,9 @@ class WunschLineMatcher(object): # Compute the soft local max exp = torch.exp(feat) if is_3d: - sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, - padding=1) + sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, padding=1) else: - sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, - padding=1) + sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, padding=1) soft_local_max = exp / sum_exp # Compute the depth-wise maximum @@ -116,7 +124,7 @@ class WunschLineMatcher(object): return score def asl_feat_saliency_score(self, desc): - """ Compute the ASLFeat saliency score on a 3D or 4D descriptor. """ + """Compute the ASLFeat saliency score on a 3D or 4D descriptor.""" is_3d = len(desc.shape) == 3 b_size = len(desc) @@ -141,8 +149,7 @@ class WunschLineMatcher(object): score = score / normalization return score - def sample_salient_points(self, line_seg, desc, img_size, - saliency_type='d2_net'): + def sample_salient_points(self, line_seg, desc, img_size, saliency_type="d2_net"): """ Sample the most salient points along each line segments, with a minimal distance between each point. Pad the remaining points. @@ -167,8 +174,9 @@ class WunschLineMatcher(object): line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) # The number of samples depends on the length of the line - num_samples_lst = np.clip(line_lengths // self.min_dist_pts, - 2, self.num_samples) + num_samples_lst = np.clip( + line_lengths // self.min_dist_pts, 2, self.num_samples + ) line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) valid_points = np.empty((num_lines, self.num_samples), dtype=bool) @@ -182,17 +190,19 @@ class WunschLineMatcher(object): cur_num_lines = len(cur_line_seg) if cur_num_lines == 0: continue - line_points_x = np.linspace(cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 0], - sample_rate, axis=-1) - line_points_y = np.linspace(cur_line_seg[:, 0, 1], - cur_line_seg[:, 1, 1], - sample_rate, axis=-1) - cur_line_points = np.stack([line_points_x, line_points_y], - axis=-1).reshape(-1, 2) + line_points_x = np.linspace( + cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], sample_rate, axis=-1 + ) + line_points_y = np.linspace( + cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], sample_rate, axis=-1 + ) + cur_line_points = np.stack([line_points_x, line_points_y], axis=-1).reshape( + -1, 2 + ) # cur_line_points is of shape (n_cur_lines * sample_rate, 2) - cur_line_points = torch.tensor(cur_line_points, dtype=torch.float, - device=device) + cur_line_points = torch.tensor( + cur_line_points, dtype=torch.float, device=device + ) grid_points = keypoints_to_grid(cur_line_points, img_size) if self.line_score: @@ -206,25 +216,26 @@ class WunschLineMatcher(object): else: scores = self.asl_feat_saliency_score(line_desc) else: - scores = F.grid_sample(score.unsqueeze(1), - grid_points).squeeze() + scores = F.grid_sample(score.unsqueeze(1), grid_points).squeeze() # Take the most salient point in n distinct regions scores = scores.reshape(-1, n, n_samples_per_region) best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy() - cur_line_points = cur_line_points.reshape(-1, n, - n_samples_per_region, 2) + cur_line_points = cur_line_points.reshape(-1, n, n_samples_per_region, 2) cur_line_points = np.take_along_axis( - cur_line_points, best[..., None], axis=2)[:, :, 0] + cur_line_points, best[..., None], axis=2 + )[:, :, 0] # Pad - cur_valid_points = np.ones((cur_num_lines, self.num_samples), - dtype=bool) + cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool) cur_valid_points[:, n:] = False - cur_line_points = np.concatenate([ - cur_line_points, - np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], - axis=1) + cur_line_points = np.concatenate( + [ + cur_line_points, + np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float), + ], + axis=1, + ) line_points[cur_mask] = cur_line_points valid_points[cur_mask] = cur_valid_points @@ -246,31 +257,34 @@ class WunschLineMatcher(object): # Sample the points separated by at least min_dist_pts along each line # The number of samples depends on the length of the line - num_samples_lst = np.clip(line_lengths // self.min_dist_pts, - 2, self.num_samples) + num_samples_lst = np.clip( + line_lengths // self.min_dist_pts, 2, self.num_samples + ) line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) valid_points = np.empty((num_lines, self.num_samples), dtype=bool) for n in np.arange(2, self.num_samples + 1): # Consider all lines where we can fit up to n points cur_mask = num_samples_lst == n cur_line_seg = line_seg[cur_mask] - line_points_x = np.linspace(cur_line_seg[:, 0, 0], - cur_line_seg[:, 1, 0], - n, axis=-1) - line_points_y = np.linspace(cur_line_seg[:, 0, 1], - cur_line_seg[:, 1, 1], - n, axis=-1) + line_points_x = np.linspace( + cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1 + ) + line_points_y = np.linspace( + cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1 + ) cur_line_points = np.stack([line_points_x, line_points_y], axis=-1) # Pad cur_num_lines = len(cur_line_seg) - cur_valid_points = np.ones((cur_num_lines, self.num_samples), - dtype=bool) + cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool) cur_valid_points[:, n:] = False - cur_line_points = np.concatenate([ - cur_line_points, - np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], - axis=1) + cur_line_points = np.concatenate( + [ + cur_line_points, + np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float), + ], + axis=1, + ) line_points[cur_mask] = cur_line_points valid_points[cur_mask] = cur_valid_points @@ -290,23 +304,18 @@ class WunschLineMatcher(object): # Pre-filter the pairs and keep the top k best candidate lines line_scores1 = scores.max(3)[0] valid_scores1 = line_scores1 != -1 - line_scores1 = ((line_scores1 * valid_scores1).sum(2) - / valid_scores1.sum(2)) + line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2) line_scores2 = scores.max(2)[0] valid_scores2 = line_scores2 != -1 - line_scores2 = ((line_scores2 * valid_scores2).sum(2) - / valid_scores2.sum(2)) + line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2) line_scores = (line_scores1 + line_scores2) / 2 - topk_lines = torch.argsort(line_scores, - dim=1)[:, -self.top_k_candidates:] + topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :] scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy() # topk_lines.shape = (n_lines1, top_k_candidates) - top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], - axis=1) + top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], axis=1) # Consider the reversed line segments as well - top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], - axis=1) + top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], axis=1) # Compute the line distance matrix with Needleman-Wunsch algo and # retrieve the closest line neighbor @@ -339,30 +348,33 @@ class WunschLineMatcher(object): for j in range(m): nw_grid[:, i + 1, j + 1] = np.maximum( np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]), - nw_grid[:, i, j] + nw_scores[:, i, j]) + nw_grid[:, i, j] + nw_scores[:, i, j], + ) return nw_grid[:, -1, -1] def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2): """ - Compute the OPPOSITE of the NW score for pairs of line segments - and their corresponding descriptors. + Compute the OPPOSITE of the NW score for pairs of line segments + and their corresponding descriptors. """ num_lines = len(line_seg1) - assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score." - img_size1 = (desc1.shape[2] * self.grid_size, - desc1.shape[3] * self.grid_size) - img_size2 = (desc2.shape[2] * self.grid_size, - desc2.shape[3] * self.grid_size) + assert num_lines == len( + line_seg2 + ), "The same number of lines is required in pairwise score." + img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) + img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) device = desc1.device # Sample points regularly along each line line_points1, valid_points1 = self.sample_line_points(line_seg1) line_points2, valid_points2 = self.sample_line_points(line_seg2) - line_points1 = torch.tensor(line_points1.reshape(-1, 2), - dtype=torch.float, device=device) - line_points2 = torch.tensor(line_points2.reshape(-1, 2), - dtype=torch.float, device=device) + line_points1 = torch.tensor( + line_points1.reshape(-1, 2), dtype=torch.float, device=device + ) + line_points2 = torch.tensor( + line_points2.reshape(-1, 2), dtype=torch.float, device=device + ) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) @@ -374,9 +386,8 @@ class WunschLineMatcher(object): # Compute the distance between line points for every pair of lines # Assign a score of -1 for unvalid points - scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy() - scores = scores.reshape(num_lines * self.num_samples, - self.num_samples) + scores = torch.einsum("dns,dnt->nst", desc1, desc2).cpu().numpy() + scores = scores.reshape(num_lines * self.num_samples, self.num_samples) scores[~valid_points1.flatten()] = -1 scores = scores.reshape(num_lines, self.num_samples, self.num_samples) scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1) diff --git a/third_party/SOLD2/sold2/model/loss.py b/third_party/SOLD2/sold2/model/loss.py index aaad3c67f3fd59db308869901f8a56623901e318..c1d2bfd232958fc19a4a775fe561dd5089079bff 100644 --- a/third_party/SOLD2/sold2/model/loss.py +++ b/third_party/SOLD2/sold2/model/loss.py @@ -7,17 +7,16 @@ import torch.nn as nn import torch.nn.functional as F from kornia.geometry import warp_perspective -from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask, - get_common_line_mask) +from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask def get_loss_and_weights(model_cfg, device=torch.device("cuda")): - """ Get loss functions and either static or dynamic weighting. """ + """Get loss functions and either static or dynamic weighting.""" # Get the global weighting policy w_policy = model_cfg.get("weighting_policy", "static") if not w_policy in ["static", "dynamic"]: raise ValueError("[Error] Not supported weighting policy.") - + loss_func = {} loss_weight = {} # Get junction loss function and weight @@ -27,14 +26,16 @@ def get_loss_and_weights(model_cfg, device=torch.device("cuda")): # Get heatmap loss function and weight w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight( - model_cfg, w_policy, device) + model_cfg, w_policy, device + ) loss_func["heatmap_loss"] = heatmap_loss_func.to(device) loss_weight["w_heatmap"] = w_heatmap # [Optionally] get descriptor loss function and weight if model_cfg.get("descriptor_loss_func", None) is not None: w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight( - model_cfg, w_policy) + model_cfg, w_policy + ) loss_func["descriptor_loss"] = descriptor_loss_func.to(device) loss_weight["w_desc"] = w_descriptor @@ -42,26 +43,26 @@ def get_loss_and_weights(model_cfg, device=torch.device("cuda")): def get_junction_loss_and_weight(model_cfg, global_w_policy): - """ Get the junction loss function and weight. """ + """Get the junction loss function and weight.""" junction_loss_cfg = model_cfg.get("junction_loss_cfg", {}) - + # Get the junction loss weight w_policy = junction_loss_cfg.get("policy", global_w_policy) if w_policy == "static": w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32) elif w_policy == "dynamic": w_junc = nn.Parameter( - torch.tensor(model_cfg["w_junc"], dtype=torch.float32), - requires_grad=True) + torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True + ) else: - raise ValueError( - "[Error] Unknown weighting policy for junction loss weight.") + raise ValueError("[Error] Unknown weighting policy for junction loss weight.") # Get the junction loss function junc_loss_name = model_cfg.get("junction_loss_func", "superpoint") if junc_loss_name == "superpoint": - junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"], - model_cfg["keep_border_valid"]) + junc_loss_func = JunctionDetectionLoss( + model_cfg["grid_size"], model_cfg["keep_border_valid"] + ) else: raise ValueError("[Error] Not supported junction loss function.") @@ -69,7 +70,7 @@ def get_junction_loss_and_weight(model_cfg, global_w_policy): def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): - """ Get the heatmap loss function and weight. """ + """Get the heatmap loss function and weight.""" heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {}) # Get the heatmap loss weight @@ -78,19 +79,20 @@ def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32) elif w_policy == "dynamic": w_heatmap = nn.Parameter( - torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), - requires_grad=True) + torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), + requires_grad=True, + ) else: - raise ValueError( - "[Error] Unknown weighting policy for junction loss weight.") + raise ValueError("[Error] Unknown weighting policy for junction loss weight.") # Get the corresponding heatmap loss based on the config heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy") if heatmap_loss_name == "cross_entropy": # Get the heatmap class weight (always static) - heatmap_class_w = model_cfg.get("w_heatmap_class", 1.) - class_weight = torch.tensor( - np.array([1., heatmap_class_w])).to(torch.float).to(device) + heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0) + class_weight = ( + torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device) + ) heatmap_loss_func = HeatmapLoss(class_weight=class_weight) else: raise ValueError("[Error] Not supported heatmap loss function.") @@ -99,28 +101,28 @@ def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): def get_descriptor_loss_and_weight(model_cfg, global_w_policy): - """ Get the descriptor loss function and weight. """ + """Get the descriptor loss function and weight.""" descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {}) - + # Get the descriptor loss weight w_policy = descriptor_loss_cfg.get("policy", global_w_policy) if w_policy == "static": w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32) elif w_policy == "dynamic": - w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"], - dtype=torch.float32), requires_grad=True) + w_descriptor = nn.Parameter( + torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True + ) else: - raise ValueError( - "[Error] Unknown weighting policy for descriptor loss weight.") + raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.") # Get the descriptor loss function - descriptor_loss_name = model_cfg.get("descriptor_loss_func", - "regular_sampling") + descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling") if descriptor_loss_name == "regular_sampling": descriptor_loss_func = TripletDescriptorLoss( descriptor_loss_cfg["grid_size"], descriptor_loss_cfg["dist_threshold"], - descriptor_loss_cfg["margin"]) + descriptor_loss_cfg["margin"], + ) else: raise ValueError("[Error] Not supported descriptor loss function.") @@ -128,79 +130,88 @@ def get_descriptor_loss_and_weight(model_cfg, global_w_policy): def space_to_depth(input_tensor, grid_size): - """ PixelUnshuffle for pytorch. """ + """PixelUnshuffle for pytorch.""" N, C, H, W = input_tensor.size() # (N, C, H//bs, bs, W//bs, bs) x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size) # (N, bs, bs, C, H//bs, W//bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, C*bs^2, H//bs, W//bs) - x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size) + x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size) return x -def junction_detection_loss(junction_map, junc_predictions, valid_mask=None, - grid_size=8, keep_border=True): - """ Junction detection loss. """ +def junction_detection_loss( + junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True +): + """Junction detection loss.""" # Convert junc_map to channel tensor junc_map = space_to_depth(junction_map, grid_size) map_shape = junc_map.shape[-2:] batch_size = junc_map.shape[0] - dust_bin_label = torch.ones( - [batch_size, 1, map_shape[0], - map_shape[1]]).to(junc_map.device).to(torch.int) - junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1) + dust_bin_label = ( + torch.ones([batch_size, 1, map_shape[0], map_shape[1]]) + .to(junc_map.device) + .to(torch.int) + ) + junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1) labels = torch.argmax( - junc_map.to(torch.float) + - torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device), - dim=1) + junc_map.to(torch.float) + + torch.distributions.Uniform(0, 0.1) + .sample(junc_map.shape) + .to(junc_map.device), + dim=1, + ) # Also convert the valid mask to channel tensor - valid_mask = (torch.ones(junction_map.shape) if valid_mask is None - else valid_mask) + valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask valid_mask = space_to_depth(valid_mask, grid_size) - + # Compute junction loss on the border patch or not if keep_border: - valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), - dim=1, keepdim=True) > 0 + valid_mask = ( + torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0 + ) else: - valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), - dim=1, keepdim=True) >= grid_size * grid_size + valid_mask = ( + torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) + >= grid_size * grid_size + ) # Compute the classification loss loss_func = nn.CrossEntropyLoss(reduction="none") # The loss still need NCHW format - loss = loss_func(input=junc_predictions, - target=labels.to(torch.long)) - + loss = loss_func(input=junc_predictions, target=labels.to(torch.long)) + # Weighted sum by the valid mask - loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float), - dim=1), dim=[0, 1, 2]) - loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), - dim=1)) + loss_ = torch.sum( + loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2] + ) + loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1)) return loss_final -def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, - class_weight=None): - """ Heatmap prediction loss. """ +def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None): + """Heatmap prediction loss.""" # Compute the classification loss on each pixel if class_weight is None: loss_func = nn.CrossEntropyLoss(reduction="none") else: loss_func = nn.CrossEntropyLoss(class_weight, reduction="none") - loss = loss_func(input=heatmap_pred, - target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)) + loss = loss_func( + input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1) + ) # Weighted sum by the valid mask # Sum over H and W - loss_spatial_sum = torch.sum(loss * torch.squeeze( - valid_mask.to(torch.float), dim=1), dim=[1, 2]) - valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32), - dim=1), dim=[1, 2]) + loss_spatial_sum = torch.sum( + loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2] + ) + valid_spatial_sum = torch.sum( + torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2] + ) # Mean to single scalar over batch dimension loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum) @@ -208,19 +219,22 @@ def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class JunctionDetectionLoss(nn.Module): - """ Junction detection loss. """ + """Junction detection loss.""" + def __init__(self, grid_size, keep_border): super(JunctionDetectionLoss, self).__init__() self.grid_size = grid_size self.keep_border = keep_border def forward(self, prediction, target, valid_mask=None): - return junction_detection_loss(target, prediction, valid_mask, - self.grid_size, self.keep_border) + return junction_detection_loss( + target, prediction, valid_mask, self.grid_size, self.keep_border + ) class HeatmapLoss(nn.Module): - """ Heatmap prediction loss. """ + """Heatmap prediction loss.""" + def __init__(self, class_weight): super(HeatmapLoss, self).__init__() self.class_weight = class_weight @@ -230,7 +244,8 @@ class HeatmapLoss(nn.Module): class RegularizationLoss(nn.Module): - """ Module for regularization loss. """ + """Module for regularization loss.""" + def __init__(self): super(RegularizationLoss, self).__init__() self.name = "regularization_loss" @@ -242,14 +257,23 @@ class RegularizationLoss(nn.Module): for _, val in loss_weights.items(): if isinstance(val, nn.Parameter): loss += val - + return loss -def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, - epoch, grid_size=8, dist_threshold=8, - init_dist_threshold=64, margin=1): - """ Regular triplet loss for descriptor learning. """ +def triplet_loss( + desc_pred1, + desc_pred2, + points1, + points2, + line_indices, + epoch, + grid_size=8, + dist_threshold=8, + init_dist_threshold=64, + margin=1, +): + """Regular triplet loss for descriptor learning.""" b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * grid_size, Wc * grid_size) device = desc_pred1.device @@ -259,12 +283,11 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: - return torch.tensor(0., dtype=torch.float, device=device) + return torch.tensor(0.0, dtype=torch.float, device=device) # Check which keypoints are too close to be matched # dist_threshold is decreased at each epoch for easier training - dist_threshold = max(dist_threshold, - 2 * init_dist_threshold // (epoch + 1)) + dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1)) dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) # Additionally ban negative mining along the same line @@ -276,11 +299,17 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors - desc1 = F.grid_sample(desc_pred1, grid1).permute( - 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc1 = ( + F.grid_sample(desc_pred1, grid1) + .permute(0, 2, 3, 1) + .reshape(b_size * n_points, -1)[valid_points] + ) desc1 = F.normalize(desc1, dim=1) - desc2 = F.grid_sample(desc_pred2, grid2).permute( - 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc2 = ( + F.grid_sample(desc_pred2, grid2) + .permute(0, 2, 3, 1) + .reshape(b_size * n_points, -1)[valid_points] + ) desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) @@ -288,20 +317,23 @@ def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, pos_dist = torch.diag(desc_dists) # Negative distance loss - max_dist = torch.tensor(4., dtype=torch.float, device=device) + max_dist = torch.tensor(4.0, dtype=torch.float, device=device) desc_dists[ torch.arange(n_correct_points, dtype=torch.long), - torch.arange(n_correct_points, dtype=torch.long)] = max_dist + torch.arange(n_correct_points, dtype=torch.long), + ] = max_dist desc_dists[dist_mask] = max_dist - neg_dist = torch.min(torch.min(desc_dists, dim=1)[0], - torch.min(desc_dists, dim=0)[0]) + neg_dist = torch.min( + torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0] + ) triplet_loss = F.relu(margin + pos_dist - neg_dist) return triplet_loss, grid1, grid2, valid_points class TripletDescriptorLoss(nn.Module): - """ Triplet descriptor loss. """ + """Triplet descriptor loss.""" + def __init__(self, grid_size, dist_threshold, margin): super(TripletDescriptorLoss, self).__init__() self.grid_size = grid_size @@ -309,23 +341,35 @@ class TripletDescriptorLoss(nn.Module): self.dist_threshold = dist_threshold self.margin = margin - def forward(self, desc_pred1, desc_pred2, points1, - points2, line_indices, epoch): - return self.descriptor_loss(desc_pred1, desc_pred2, points1, - points2, line_indices, epoch) + def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch): + return self.descriptor_loss( + desc_pred1, desc_pred2, points1, points2, line_indices, epoch + ) # The descriptor loss based on regularly sampled points along the lines - def descriptor_loss(self, desc_pred1, desc_pred2, points1, - points2, line_indices, epoch): - return torch.mean(triplet_loss( - desc_pred1, desc_pred2, points1, points2, line_indices, epoch, - self.grid_size, self.dist_threshold, self.init_dist_threshold, - self.margin)[0]) + def descriptor_loss( + self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch + ): + return torch.mean( + triplet_loss( + desc_pred1, + desc_pred2, + points1, + points2, + line_indices, + epoch, + self.grid_size, + self.dist_threshold, + self.init_dist_threshold, + self.margin, + )[0] + ) class TotalLoss(nn.Module): - """ Total loss summing junction, heatma, descriptor - and regularization losses. """ + """Total loss summing junction, heatma, descriptor + and regularization losses.""" + def __init__(self, loss_funcs, loss_weights, weighting_policy): super(TotalLoss, self).__init__() # Whether we need to compute the descriptor loss @@ -338,23 +382,26 @@ class TotalLoss(nn.Module): # Always add regularization loss (it will return zero if not used) self.loss_funcs["reg_loss"] = RegularizationLoss().cuda() - def forward(self, junc_pred, junc_target, heatmap_pred, - heatmap_target, valid_mask=None): - """ Detection only loss. """ + def forward( + self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None + ): + """Detection only loss.""" # Compute the junction loss - junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, - valid_mask) + junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask) # Compute the heatmap loss heatmap_loss = self.loss_funcs["heatmap_loss"]( - heatmap_pred, heatmap_target, valid_mask) + heatmap_pred, heatmap_target, valid_mask + ) # Compute the total loss. if self.weighting_policy == "dynamic": reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) - total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \ - heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \ - reg_loss - + total_loss = ( + junc_loss * torch.exp(-self.loss_weights["w_junc"]) + + heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + + reg_loss + ) + return { "total_loss": total_loss, "junc_loss": junc_loss, @@ -363,32 +410,47 @@ class TotalLoss(nn.Module): "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(), "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(), } - + elif self.weighting_policy == "static": - total_loss = junc_loss * self.loss_weights["w_junc"] + \ - heatmap_loss * self.loss_weights["w_heatmap"] - + total_loss = ( + junc_loss * self.loss_weights["w_junc"] + + heatmap_loss * self.loss_weights["w_heatmap"] + ) + return { "total_loss": total_loss, "junc_loss": junc_loss, - "heatmap_loss": heatmap_loss + "heatmap_loss": heatmap_loss, } else: raise ValueError("[Error] Unknown weighting policy.") - - def forward_descriptors(self, - junc_map_pred1, junc_map_pred2, junc_map_target1, - junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1, - heatmap_target2, line_points1, line_points2, line_indices, - desc_pred1, desc_pred2, epoch, valid_mask1=None, - valid_mask2=None): - """ Loss for detection + description. """ + + def forward_descriptors( + self, + junc_map_pred1, + junc_map_pred2, + junc_map_target1, + junc_map_target2, + heatmap_pred1, + heatmap_pred2, + heatmap_target1, + heatmap_target2, + line_points1, + line_points2, + line_indices, + desc_pred1, + desc_pred2, + epoch, + valid_mask1=None, + valid_mask2=None, + ): + """Loss for detection + description.""" # Compute junction loss junc_loss = self.loss_funcs["junc_loss"]( - torch.cat([junc_map_pred1, junc_map_pred2], dim=0), + torch.cat([junc_map_pred1, junc_map_pred2], dim=0), torch.cat([junc_map_target1, junc_map_target2], dim=0), - torch.cat([valid_mask1, valid_mask2], dim=0) + torch.cat([valid_mask1, valid_mask2], dim=0), ) # Get junction loss weight (dynamic or not) if isinstance(self.loss_weights["w_junc"], nn.Parameter): @@ -398,9 +460,9 @@ class TotalLoss(nn.Module): # Compute heatmap loss heatmap_loss = self.loss_funcs["heatmap_loss"]( - torch.cat([heatmap_pred1, heatmap_pred2], dim=0), + torch.cat([heatmap_pred1, heatmap_pred2], dim=0), torch.cat([heatmap_target1, heatmap_target2], dim=0), - torch.cat([valid_mask1, valid_mask2], dim=0) + torch.cat([valid_mask1, valid_mask2], dim=0), ) # Get heatmap loss weight (dynamic or not) if isinstance(self.loss_weights["w_heatmap"], nn.Parameter): @@ -410,8 +472,8 @@ class TotalLoss(nn.Module): # Compute the descriptor loss descriptor_loss = self.loss_funcs["descriptor_loss"]( - desc_pred1, desc_pred2, line_points1, - line_points2, line_indices, epoch) + desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch + ) # Get descriptor loss weight (dynamic or not) if isinstance(self.loss_weights["w_desc"], nn.Parameter): w_descriptor = torch.exp(-self.loss_weights["w_desc"]) @@ -419,27 +481,27 @@ class TotalLoss(nn.Module): w_descriptor = self.loss_weights["w_desc"] # Update the total loss - total_loss = (junc_loss * w_junc - + heatmap_loss * w_heatmap - + descriptor_loss * w_descriptor) + total_loss = ( + junc_loss * w_junc + + heatmap_loss * w_heatmap + + descriptor_loss * w_descriptor + ) outputs = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, - "w_junc": w_junc.item() \ - if isinstance(w_junc, nn.Parameter) else w_junc, - "w_heatmap": w_heatmap.item() \ - if isinstance(w_heatmap, nn.Parameter) else w_heatmap, + "w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc, + "w_heatmap": w_heatmap.item() + if isinstance(w_heatmap, nn.Parameter) + else w_heatmap, "descriptor_loss": descriptor_loss, - "w_desc": w_descriptor.item() \ - if isinstance(w_descriptor, nn.Parameter) else w_descriptor + "w_desc": w_descriptor.item() + if isinstance(w_descriptor, nn.Parameter) + else w_descriptor, } - + # Compute the regularization loss reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) total_loss += reg_loss - outputs.update({ - "reg_loss": reg_loss, - "total_loss": total_loss - }) + outputs.update({"reg_loss": reg_loss, "total_loss": total_loss}) return outputs diff --git a/third_party/SOLD2/sold2/model/lr_scheduler.py b/third_party/SOLD2/sold2/model/lr_scheduler.py index 3faa4f68a67564719008a932b40c16c5e908949f..fa3f5903c92a61f01eaa8aed95fb2261212f3762 100644 --- a/third_party/SOLD2/sold2/model/lr_scheduler.py +++ b/third_party/SOLD2/sold2/model/lr_scheduler.py @@ -5,18 +5,17 @@ import torch def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer): - """ Get the learning rate scheduler according to the config. """ + """Get the learning rate scheduler according to the config.""" # If no lr_decay is specified => return None if (lr_decay == False) or (lr_decay_cfg is None): schduler = None # Exponential decay elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"): schduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer, - gamma=lr_decay_cfg["gamma"] + optimizer, gamma=lr_decay_cfg["gamma"] ) # Unknown policy else: raise ValueError("[Error] Unknow learning rate decay policy!") - return schduler \ No newline at end of file + return schduler diff --git a/third_party/SOLD2/sold2/model/metrics.py b/third_party/SOLD2/sold2/model/metrics.py index 0894a7207ee4afa344cb332c605c715b14db73a4..668daaf99acb9bbb80d7ca2746926f9d79d55cf0 100644 --- a/third_party/SOLD2/sold2/model/metrics.py +++ b/third_party/SOLD2/sold2/model/metrics.py @@ -10,15 +10,26 @@ from ..misc.geometry_utils import keypoints_to_grid class Metrics(object): - """ Metric evaluation calculator. """ - def __init__(self, detection_thresh, prob_thresh, grid_size, - junc_metric_lst=None, heatmap_metric_lst=None, - pr_metric_lst=None, desc_metric_lst=None): + """Metric evaluation calculator.""" + + def __init__( + self, + detection_thresh, + prob_thresh, + grid_size, + junc_metric_lst=None, + heatmap_metric_lst=None, + pr_metric_lst=None, + desc_metric_lst=None, + ): # List supported metrics - self.supported_junc_metrics = ["junc_precision", "junc_precision_nms", - "junc_recall", "junc_recall_nms"] - self.supported_heatmap_metrics = ["heatmap_precision", - "heatmap_recall"] + self.supported_junc_metrics = [ + "junc_precision", + "junc_precision_nms", + "junc_recall", + "junc_recall_nms", + ] + self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] self.supported_desc_metrics = ["matching_score"] @@ -38,14 +49,13 @@ class Metrics(object): # For the descriptors, the default None assumes no desc metric at all if desc_metric_lst is None: self.desc_metric_lst = [] - elif desc_metric_lst == 'all': + elif desc_metric_lst == "all": self.desc_metric_lst = self.supported_desc_metrics else: self.desc_metric_lst = desc_metric_lst if not self._check_metrics(): - raise ValueError( - "[Error] Some elements in the metric_lst are invalid.") + raise ValueError("[Error] Some elements in the metric_lst are invalid.") # Metric mapping table self.metric_table = { @@ -57,18 +67,29 @@ class Metrics(object): "heatmap_recall": heatmap_recall(prob_thresh), "junc_pr": junction_pr(), "junc_nms_pr": junction_pr(), - "matching_score": matching_score(grid_size) + "matching_score": matching_score(grid_size), } # Initialize the results self.metric_results = {} for key in self.metric_table.keys(): - self.metric_results[key] = 0. - - def evaluate(self, junc_pred, junc_pred_nms, junc_gt, heatmap_pred, - heatmap_gt, valid_mask, line_points1=None, line_points2=None, - desc_pred1=None, desc_pred2=None, valid_points=None): - """ Perform evaluation. """ + self.metric_results[key] = 0.0 + + def evaluate( + self, + junc_pred, + junc_pred_nms, + junc_gt, + heatmap_pred, + heatmap_gt, + valid_mask, + line_points1=None, + line_points2=None, + desc_pred1=None, + desc_pred2=None, + valid_points=None, + ): + """Perform evaluation.""" for metric in self.junc_metric_lst: # If nms metrics then use nms to compute it. if "nms" in metric: @@ -77,27 +98,31 @@ class Metrics(object): else: junc_pred_input = junc_pred self.metric_results[metric] = self.metric_table[metric]( - junc_pred_input, junc_gt, valid_mask) + junc_pred_input, junc_gt, valid_mask + ) for metric in self.heatmap_metric_lst: self.metric_results[metric] = self.metric_table[metric]( - heatmap_pred, heatmap_gt, valid_mask) + heatmap_pred, heatmap_gt, valid_mask + ) for metric in self.pr_metric_lst: if "nms" in metric: self.metric_results[metric] = self.metric_table[metric]( - junc_pred_nms, junc_gt, valid_mask) + junc_pred_nms, junc_gt, valid_mask + ) else: self.metric_results[metric] = self.metric_table[metric]( - junc_pred, junc_gt, valid_mask) + junc_pred, junc_gt, valid_mask + ) for metric in self.desc_metric_lst: self.metric_results[metric] = self.metric_table[metric]( - line_points1, line_points2, desc_pred1, - desc_pred2, valid_points) + line_points1, line_points2, desc_pred1, desc_pred2, valid_points + ) def _check_metrics(self): - """ Check if all input metrics are valid. """ + """Check if all input metrics are valid.""" flag = True for metric in self.junc_metric_lst: if not metric in self.supported_junc_metrics: @@ -116,19 +141,31 @@ class Metrics(object): class AverageMeter(object): - def __init__(self, junc_metric_lst=None, heatmap_metric_lst=None, - is_training=True, desc_metric_lst=None): + def __init__( + self, + junc_metric_lst=None, + heatmap_metric_lst=None, + is_training=True, + desc_metric_lst=None, + ): # List supported metrics - self.supported_junc_metrics = ["junc_precision", "junc_precision_nms", - "junc_recall", "junc_recall_nms"] - self.supported_heatmap_metrics = ["heatmap_precision", - "heatmap_recall"] + self.supported_junc_metrics = [ + "junc_precision", + "junc_precision_nms", + "junc_recall", + "junc_recall_nms", + ] + self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] self.supported_desc_metrics = ["matching_score"] # Record loss in training mode # if is_training: self.supported_loss = [ - "junc_loss", "heatmap_loss", "descriptor_loss", "total_loss"] + "junc_loss", + "heatmap_loss", + "descriptor_loss", + "total_loss", + ] self.is_training = is_training @@ -144,21 +181,23 @@ class AverageMeter(object): # For the descriptors, the default None assumes no desc metric at all if desc_metric_lst is None: self.desc_metric_lst = [] - elif desc_metric_lst == 'all': + elif desc_metric_lst == "all": self.desc_metric_lst = self.supported_desc_metrics else: self.desc_metric_lst = desc_metric_lst if not self._check_metrics(): - raise ValueError( - "[Error] Some elements in the metric_lst are invalid.") + raise ValueError("[Error] Some elements in the metric_lst are invalid.") # Initialize the results self.metric_results = {} - for key in (self.supported_junc_metrics - + self.supported_heatmap_metrics - + self.supported_loss + self.supported_desc_metrics): - self.metric_results[key] = 0. + for key in ( + self.supported_junc_metrics + + self.supported_heatmap_metrics + + self.supported_loss + + self.supported_desc_metrics + ): + self.metric_results[key] = 0.0 for key in self.supported_pr_metrics: zero_lst = [0 for _ in range(50)] self.metric_results[key] = { @@ -167,7 +206,7 @@ class AverageMeter(object): "fp": zero_lst, "fn": zero_lst, "precision": zero_lst, - "recall": zero_lst + "recall": zero_lst, } # Initialize total count @@ -176,18 +215,18 @@ class AverageMeter(object): def update(self, metrics, loss_dict=None, num_samples=1): # loss should be given in the training mode if self.is_training and (loss_dict is None): - raise ValueError( - "[Error] loss info should be given in the training mode.") + raise ValueError("[Error] loss info should be given in the training mode.") # update total counts self.count += num_samples # update all the metrics - for met in (self.supported_junc_metrics - + self.supported_heatmap_metrics - + self.supported_desc_metrics): - self.metric_results[met] += (num_samples - * metrics.metric_results[met]) + for met in ( + self.supported_junc_metrics + + self.supported_heatmap_metrics + + self.supported_desc_metrics + ): + self.metric_results[met] += num_samples * metrics.metric_results[met] # Update all the losses for loss in loss_dict.keys(): @@ -200,8 +239,8 @@ class AverageMeter(object): # Update each interval for idx in range(len(self.metric_results[pr_met][key])): self.metric_results[pr_met][key][idx] += ( - num_samples - * metrics.metric_results[pr_met][key][idx]) + num_samples * metrics.metric_results[pr_met][key][idx] + ) def average(self): results = {} @@ -217,21 +256,22 @@ class AverageMeter(object): "fp": self.metric_results[met]["fp"], "fn": self.metric_results[met]["fn"], "precision": [], - "recall": [] + "recall": [], } for idx in range(len(self.metric_results[met]["precision"])): met_results["precision"].append( - self.metric_results[met]["precision"][idx] - / self.count) + self.metric_results[met]["precision"][idx] / self.count + ) met_results["recall"].append( - self.metric_results[met]["recall"][idx] / self.count) + self.metric_results[met]["recall"][idx] / self.count + ) results[met] = met_results return results def _check_metrics(self): - """ Check if all input metrics are valid. """ + """Check if all input metrics are valid.""" flag = True for metric in self.junc_metric_lst: if not metric in self.supported_junc_metrics: @@ -250,7 +290,8 @@ class AverageMeter(object): class junction_precision(object): - """ Junction precision. """ + """Junction precision.""" + def __init__(self, detection_thresh): self.detection_thresh = detection_thresh @@ -262,8 +303,7 @@ class junction_precision(object): # Deal with the corner case of the prediction if np.sum(junc_pred) > 0: - precision = (np.sum(junc_pred * junc_gt.squeeze()) - / np.sum(junc_pred)) + precision = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_pred) else: precision = 0 @@ -271,7 +311,8 @@ class junction_precision(object): class junction_recall(object): - """ Junction recall. """ + """Junction recall.""" + def __init__(self, detection_thresh): self.detection_thresh = detection_thresh @@ -291,7 +332,8 @@ class junction_recall(object): class junction_pr(object): - """ Junction precision-recall info. """ + """Junction precision-recall info.""" + def __init__(self, num_threshold=50): self.max = 0.4 step = self.max / num_threshold @@ -316,12 +358,21 @@ class junction_pr(object): # Compute tp, fp, tn, fn junc_gt = junc_gt.squeeze() tp = np.sum(junc_pred * junc_gt) - tn = np.sum((junc_pred == 0).astype(np.float) - * (junc_gt == 0).astype(np.float) * valid_mask) - fp = np.sum((junc_pred == 1).astype(np.float) - * (junc_gt == 0).astype(np.float) * valid_mask) - fn = np.sum((junc_pred == 0).astype(np.float) - * (junc_gt == 1).astype(np.float) * valid_mask) + tn = np.sum( + (junc_pred == 0).astype(np.float) + * (junc_gt == 0).astype(np.float) + * valid_mask + ) + fp = np.sum( + (junc_pred == 1).astype(np.float) + * (junc_gt == 0).astype(np.float) + * valid_mask + ) + fn = np.sum( + (junc_pred == 0).astype(np.float) + * (junc_gt == 1).astype(np.float) + * valid_mask + ) tp_lst.append(tp) tn_lst.append(tn) @@ -336,12 +387,13 @@ class junction_pr(object): "fp": np.array(fp_lst), "fn": np.array(fn_lst), "precision": np.array(precision_lst), - "recall": np.array(recall_lst) + "recall": np.array(recall_lst), } class heatmap_precision(object): - """ Heatmap precision. """ + """Heatmap precision.""" + def __init__(self, prob_thresh): self.prob_thresh = prob_thresh @@ -352,16 +404,18 @@ class heatmap_precision(object): # Deal with the corner case of the prediction if np.sum(heatmap_pred) > 0: - precision = (np.sum(heatmap_pred * heatmap_gt.squeeze()) - / np.sum(heatmap_pred)) + precision = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum( + heatmap_pred + ) else: - precision = 0. + precision = 0.0 return precision class heatmap_recall(object): - """ Heatmap recall. """ + """Heatmap recall.""" + def __init__(self, prob_thresh): self.prob_thresh = prob_thresh @@ -372,21 +426,20 @@ class heatmap_recall(object): # Deal with the corner case of the ground truth if np.sum(heatmap_gt) > 0: - recall = (np.sum(heatmap_pred * heatmap_gt.squeeze()) - / np.sum(heatmap_gt)) + recall = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(heatmap_gt) else: - recall = 0. + recall = 0.0 return recall class matching_score(object): - """ Descriptors matching score. """ + """Descriptors matching score.""" + def __init__(self, grid_size): self.grid_size = grid_size - def __call__(self, points1, points2, desc_pred1, - desc_pred2, line_indices): + def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices): b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * self.grid_size, Wc * self.grid_size) device = desc_pred1.device @@ -396,32 +449,37 @@ class matching_score(object): valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: - return torch.tensor(0., dtype=torch.float, device=device) + return torch.tensor(0.0, dtype=torch.float, device=device) # Convert the keypoints to a grid suitable for interpolation grid1 = keypoints_to_grid(points1, img_size) grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors - desc1 = F.grid_sample(desc_pred1, grid1).permute( - 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc1 = ( + F.grid_sample(desc_pred1, grid1) + .permute(0, 2, 3, 1) + .reshape(b_size * n_points, -1)[valid_points] + ) desc1 = F.normalize(desc1, dim=1) - desc2 = F.grid_sample(desc_pred2, grid2).permute( - 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc2 = ( + F.grid_sample(desc_pred2, grid2) + .permute(0, 2, 3, 1) + .reshape(b_size * n_points, -1)[valid_points] + ) desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) # Compute percentage of correct matches matches0 = torch.min(desc_dists, dim=1)[1] matches1 = torch.min(desc_dists, dim=0)[1] - matching_score = (matches1[matches0] - == torch.arange(len(matches0)).to(device)) + matching_score = matches1[matches0] == torch.arange(len(matches0)).to(device) matching_score = matching_score.float().mean() return matching_score def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): - """ Non-maximum suppression adapted from SuperPoint. """ + """Non-maximum suppression adapted from SuperPoint.""" # Iterate through batch dimension im_h = prob_predictions.shape[1] im_w = prob_predictions.shape[2] @@ -430,17 +488,19 @@ def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): # print(i) prob_pred = prob_predictions[i, ...] # Filter the points using prob_thresh - coord = np.where(prob_pred >= prob_thresh) # HW format - points = np.concatenate((coord[0][..., None], coord[1][..., None]), - axis=1) # HW format + coord = np.where(prob_pred >= prob_thresh) # HW format + points = np.concatenate( + (coord[0][..., None], coord[1][..., None]), axis=1 + ) # HW format # Get the probability score prob_score = prob_pred[points[:, 0], points[:, 1]] # Perform super nms # Modify the in_points to xy format (instead of HW format) - in_points = np.concatenate((coord[1][..., None], coord[0][..., None], - prob_score), axis=1).T + in_points = np.concatenate( + (coord[1][..., None], coord[0][..., None], prob_score), axis=1 + ).T keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh) # Remember to flip outputs back to HW format keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T) @@ -454,8 +514,9 @@ def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): # Re-compose the probability map output_map = np.zeros([im_h, im_w]) - output_map[keep_points[:, 0].astype(np.int), - keep_points[:, 1].astype(np.int)] = keep_score.squeeze() + output_map[ + keep_points[:, 0].astype(np.int), keep_points[:, 1].astype(np.int) + ] = keep_score.squeeze() output_lst.append(output_map[None, ...]) @@ -506,14 +567,14 @@ def nms_fast(in_corners, H, W, dist_thresh): inds[rcorners[1, i], rcorners[0, i]] = i # Pad the border of the grid, so that we can NMS points near the border. pad = dist_thresh - grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant') + grid = np.pad(grid, ((pad, pad), (pad, pad)), mode="constant") # Iterate through points, highest to lowest conf, suppress neighborhood. count = 0 for i, rc in enumerate(rcorners.T): # Account for top and left padding. pt = (rc[0] + pad, rc[1] + pad) if grid[pt[1], pt[0]] == 1: # If not yet suppressed. - grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0 + grid[pt[1] - pad : pt[1] + pad + 1, pt[0] - pad : pt[0] + pad + 1] = 0 grid[pt[1], pt[0]] = -1 count += 1 # Get all surviving -1's and return sorted array of remaining corners. diff --git a/third_party/SOLD2/sold2/model/model_util.py b/third_party/SOLD2/sold2/model/model_util.py index f70d80da40a72c207edfcfc1509e820846f0b731..037239e45d50123c7d679e36df5c6b0de314fa8b 100644 --- a/third_party/SOLD2/sold2/model/model_util.py +++ b/third_party/SOLD2/sold2/model/model_util.py @@ -9,7 +9,7 @@ from .nets.descriptor_decoder import SuperpointDescriptor def get_model(model_cfg=None, loss_weights=None, mode="train"): - """ Get model based on the model configuration. """ + """Get model based on the model configuration.""" # Check dataset config is given if model_cfg is None: raise ValueError("[Error] The model config is required!") @@ -18,26 +18,27 @@ def get_model(model_cfg=None, loss_weights=None, mode="train"): print("\n\n\t--------Initializing model----------") supported_arch = ["simple"] if not model_cfg["model_architecture"] in supported_arch: - raise ValueError( - "[Error] The model architecture is not in supported arch!") + raise ValueError("[Error] The model architecture is not in supported arch!") if model_cfg["model_architecture"] == "simple": model = SOLD2Net(model_cfg) else: - raise ValueError( - "[Error] The model architecture is not in supported arch!") + raise ValueError("[Error] The model architecture is not in supported arch!") # Optionally register loss weights to the model if mode == "train": if loss_weights is not None: for param_name, param in loss_weights.items(): if isinstance(param, nn.Parameter): - print("\t [Debug] Adding %s with value %f to model" - % (param_name, param.item())) + print( + "\t [Debug] Adding %s with value %f to model" + % (param_name, param.item()) + ) model.register_parameter(param_name, param) else: raise ValueError( - "[Error] the loss weights can not be None in dynamic weighting mode during training.") + "[Error] the loss weights can not be None in dynamic weighting mode during training." + ) # Display some summary info. print("\tModel architecture: %s" % model_cfg["model_architecture"]) @@ -50,7 +51,8 @@ def get_model(model_cfg=None, loss_weights=None, mode="train"): class SOLD2Net(nn.Module): - """ Full network for SOLD². """ + """Full network for SOLD².""" + def __init__(self, model_cfg): super(SOLD2Net, self).__init__() self.name = model_cfg["model_name"] @@ -65,8 +67,7 @@ class SOLD2Net(nn.Module): self.junction_decoder = self.get_junction_decoder() # List supported heatmap decoder options - self.supported_heatmap_decoder = ["pixel_shuffle", - "pixel_shuffle_single"] + self.supported_heatmap_decoder = ["pixel_shuffle", "pixel_shuffle_single"] self.heatmap_decoder = self.get_heatmap_decoder() # List supported descriptor decoder options @@ -96,10 +97,9 @@ class SOLD2Net(nn.Module): return outputs def get_backbone(self): - """ Retrieve the backbone encoder network. """ + """Retrieve the backbone encoder network.""" if not self.cfg["backbone"] in self.supported_backbone: - raise ValueError( - "[Error] The backbone selection is not supported.") + raise ValueError("[Error] The backbone selection is not supported.") # lcnn backbone (stacked hourglass) if self.cfg["backbone"] == "lcnn": @@ -113,79 +113,73 @@ class SOLD2Net(nn.Module): feat_channel = 128 else: - raise ValueError( - "[Error] The backbone selection is not supported.") + raise ValueError("[Error] The backbone selection is not supported.") return backbone, feat_channel def get_junction_decoder(self): - """ Get the junction decoder. """ - if (not self.cfg["junction_decoder"] - in self.supported_junction_decoder): - raise ValueError( - "[Error] The junction decoder selection is not supported.") + """Get the junction decoder.""" + if not self.cfg["junction_decoder"] in self.supported_junction_decoder: + raise ValueError("[Error] The junction decoder selection is not supported.") # superpoint decoder if self.cfg["junction_decoder"] == "superpoint_decoder": - decoder = SuperpointDecoder(self.feat_channel, - self.cfg["backbone"]) + decoder = SuperpointDecoder(self.feat_channel, self.cfg["backbone"]) else: - raise ValueError( - "[Error] The junction decoder selection is not supported.") + raise ValueError("[Error] The junction decoder selection is not supported.") return decoder def get_heatmap_decoder(self): - """ Get the heatmap decoder. """ + """Get the heatmap decoder.""" if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: - raise ValueError( - "[Error] The heatmap decoder selection is not supported.") + raise ValueError("[Error] The heatmap decoder selection is not supported.") # Pixel_shuffle decoder if self.cfg["heatmap_decoder"] == "pixel_shuffle": if self.cfg["backbone"] == "lcnn": - decoder = PixelShuffleDecoder(self.feat_channel, - num_upsample=2) + decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=2) elif self.cfg["backbone"] == "superpoint": - decoder = PixelShuffleDecoder(self.feat_channel, - num_upsample=3) + decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=3) else: raise ValueError("[Error] Unknown backbone option.") # Pixel_shuffle decoder with single channel output elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": if self.cfg["backbone"] == "lcnn": decoder = PixelShuffleDecoder( - self.feat_channel, num_upsample=2, output_channel=1) + self.feat_channel, num_upsample=2, output_channel=1 + ) elif self.cfg["backbone"] == "superpoint": decoder = PixelShuffleDecoder( - self.feat_channel, num_upsample=3, output_channel=1) + self.feat_channel, num_upsample=3, output_channel=1 + ) else: raise ValueError("[Error] Unknown backbone option.") else: - raise ValueError( - "[Error] The heatmap decoder selection is not supported.") + raise ValueError("[Error] The heatmap decoder selection is not supported.") return decoder def get_descriptor_decoder(self): - """ Get the descriptor decoder. """ - if (not self.cfg["descriptor_decoder"] - in self.supported_descriptor_decoder): + """Get the descriptor decoder.""" + if not self.cfg["descriptor_decoder"] in self.supported_descriptor_decoder: raise ValueError( - "[Error] The descriptor decoder selection is not supported.") + "[Error] The descriptor decoder selection is not supported." + ) # SuperPoint descriptor if self.cfg["descriptor_decoder"] == "superpoint_descriptor": decoder = SuperpointDescriptor(self.feat_channel) else: raise ValueError( - "[Error] The descriptor decoder selection is not supported.") + "[Error] The descriptor decoder selection is not supported." + ) return decoder def weight_init(m): - """ Weight initialization function. """ + """Weight initialization function.""" # Conv2D if isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight.data) diff --git a/third_party/SOLD2/sold2/model/nets/backbone.py b/third_party/SOLD2/sold2/model/nets/backbone.py index 71f260aef108c77d54319cab7bc082c3c51112e7..26b5a1366223b9148bc110ec28917cc1f81b5cbf 100644 --- a/third_party/SOLD2/sold2/model/nets/backbone.py +++ b/third_party/SOLD2/sold2/model/nets/backbone.py @@ -5,49 +5,46 @@ from .lcnn_hourglass import MultitaskHead, hg class HourglassBackbone(nn.Module): - """ Hourglass backbone. """ - def __init__(self, input_channel=1, depth=4, num_stacks=2, - num_blocks=1, num_classes=5): + """Hourglass backbone.""" + + def __init__( + self, input_channel=1, depth=4, num_stacks=2, num_blocks=1, num_classes=5 + ): super(HourglassBackbone, self).__init__() self.head = MultitaskHead - self.net = hg(**{ - "head": self.head, - "depth": depth, - "num_stacks": num_stacks, - "num_blocks": num_blocks, - "num_classes": num_classes, - "input_channels": input_channel - }) + self.net = hg( + **{ + "head": self.head, + "depth": depth, + "num_stacks": num_stacks, + "num_blocks": num_blocks, + "num_classes": num_classes, + "input_channels": input_channel, + } + ) def forward(self, input_images): return self.net(input_images)[1] class SuperpointBackbone(nn.Module): - """ SuperPoint backbone. """ + """SuperPoint backbone.""" + def __init__(self): super(SuperpointBackbone, self).__init__() self.relu = torch.nn.ReLU(inplace=True) self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4 = 64, 64, 128, 128 # Shared Encoder. - self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, - stride=1, padding=1) - self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, - stride=1, padding=1) - self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, - stride=1, padding=1) - self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, - stride=1, padding=1) - self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, - stride=1, padding=1) - self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, - stride=1, padding=1) - self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, - stride=1, padding=1) - self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, - stride=1, padding=1) - + self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + def forward(self, input_images): # Shared Encoder. x = self.relu(self.conv1a(input_images)) diff --git a/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py b/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py index 6ed4306fad764efab2c22ede9cae253c9b17d6c2..449bac37e6b0e6ff7802c0dbcea92f4829786578 100644 --- a/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py +++ b/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py @@ -3,17 +3,18 @@ import torch.nn as nn class SuperpointDescriptor(nn.Module): - """ Descriptor decoder based on the SuperPoint arcihtecture. """ + """Descriptor decoder based on the SuperPoint arcihtecture.""" + def __init__(self, input_feat_dim=128): super(SuperpointDescriptor, self).__init__() self.relu = torch.nn.ReLU(inplace=True) - self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, - stride=1, padding=1) - self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1, - stride=1, padding=0) + self.convPa = torch.nn.Conv2d( + input_feat_dim, 256, kernel_size=3, stride=1, padding=1 + ) + self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0) def forward(self, input_features): feat = self.relu(self.convPa(input_features)) semi = self.convPb(feat) - return semi \ No newline at end of file + return semi diff --git a/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py b/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py index bd5157ca740c8c7e25f2183b2a3c1fefa813deca..11828426a2852fb3e9ee3e6a3310ca89cbcd4d78 100644 --- a/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py +++ b/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py @@ -2,7 +2,8 @@ import torch.nn as nn class PixelShuffleDecoder(nn.Module): - """ Pixel shuffle decoder. """ + """Pixel shuffle decoder.""" + def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2): super(PixelShuffleDecoder, self).__init__() # Get channel parameters @@ -10,35 +11,46 @@ class PixelShuffleDecoder(nn.Module): # Define the pixel shuffle self.pixshuffle = nn.PixelShuffle(2) - + # Process the feature self.conv_block_lst = [] # The input block self.conv_block_lst.append( nn.Sequential( - nn.Conv2d(input_feat_dim, self.channel_conf[0], - kernel_size=3, stride=1, padding=1), + nn.Conv2d( + input_feat_dim, + self.channel_conf[0], + kernel_size=3, + stride=1, + padding=1, + ), nn.BatchNorm2d(self.channel_conf[0]), - nn.ReLU(inplace=True) - )) + nn.ReLU(inplace=True), + ) + ) # Intermediate block for channel in self.channel_conf[1:-1]: self.conv_block_lst.append( nn.Sequential( - nn.Conv2d(channel, channel, kernel_size=3, - stride=1, padding=1), + nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(channel), - nn.ReLU(inplace=True) - )) - + nn.ReLU(inplace=True), + ) + ) + # Output block self.conv_block_lst.append( - nn.Conv2d(self.channel_conf[-1], output_channel, - kernel_size=1, stride=1, padding=0) + nn.Conv2d( + self.channel_conf[-1], + output_channel, + kernel_size=1, + stride=1, + padding=0, + ) ) self.conv_block_lst = nn.ModuleList(self.conv_block_lst) - + # Get num of channels based on number of upsampling. def get_channel_conf(self, num_upsample): if num_upsample == 2: @@ -52,7 +64,7 @@ class PixelShuffleDecoder(nn.Module): for block in self.conv_block_lst[:-1]: out = block(out) out = self.pixshuffle(out) - + # Output layer out = self.conv_block_lst[-1](out) diff --git a/third_party/SOLD2/sold2/model/nets/junction_decoder.py b/third_party/SOLD2/sold2/model/nets/junction_decoder.py index d2bb649518896501c784940028a772d688c2b3a7..ea90a6b6821d994461dee83f85a6d2851d78e055 100644 --- a/third_party/SOLD2/sold2/model/nets/junction_decoder.py +++ b/third_party/SOLD2/sold2/model/nets/junction_decoder.py @@ -3,25 +3,27 @@ import torch.nn as nn class SuperpointDecoder(nn.Module): - """ Junction decoder based on the SuperPoint architecture. """ + """Junction decoder based on the SuperPoint architecture.""" + def __init__(self, input_feat_dim=128, backbone_name="lcnn"): super(SuperpointDecoder, self).__init__() self.relu = torch.nn.ReLU(inplace=True) # Perform strided convolution when using lcnn backbone. if backbone_name == "lcnn": - self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, - stride=2, padding=1) + self.convPa = torch.nn.Conv2d( + input_feat_dim, 256, kernel_size=3, stride=2, padding=1 + ) elif backbone_name == "superpoint": - self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, - stride=1, padding=1) + self.convPa = torch.nn.Conv2d( + input_feat_dim, 256, kernel_size=3, stride=1, padding=1 + ) else: raise ValueError("[Error] Unknown backbone option.") - - self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, - stride=1, padding=0) + + self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0) def forward(self, input_features): feat = self.relu(self.convPa(input_features)) semi = self.convPb(feat) - return semi \ No newline at end of file + return semi diff --git a/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py b/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py index a9dc78eef34e7ee146166b1b66c10070799d63f3..c25594d9dda28624337546fd8fec27e1c59b452f 100644 --- a/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py +++ b/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py @@ -39,8 +39,7 @@ class Bottleneck2D(nn.Module): self.bn1 = nn.BatchNorm2d(inplanes) self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) self.bn2 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1) self.bn3 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1) self.relu = nn.ReLU(inplace=True) @@ -116,15 +115,17 @@ class Hourglass(nn.Module): class HourglassNet(nn.Module): """Hourglass model from Newell et al ECCV 2016""" - def __init__(self, block, head, depth, num_stacks, num_blocks, - num_classes, input_channels): + def __init__( + self, block, head, depth, num_stacks, num_blocks, num_classes, input_channels + ): super(HourglassNet, self).__init__() self.inplanes = 64 self.num_feats = 128 self.num_stacks = num_stacks - self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, - stride=2, padding=3) + self.conv1 = nn.Conv2d( + input_channels, self.inplanes, kernel_size=7, stride=2, padding=3 + ) self.bn1 = nn.BatchNorm2d(self.inplanes) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_residual(block, self.inplanes, 1) @@ -215,12 +216,11 @@ class HourglassNet(nn.Module): def hg(**kwargs): model = HourglassNet( Bottleneck2D, - head=kwargs.get("head", - lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)), + head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)), depth=kwargs["depth"], num_stacks=kwargs["num_stacks"], num_blocks=kwargs["num_blocks"], num_classes=kwargs["num_classes"], - input_channels=kwargs["input_channels"] + input_channels=kwargs["input_channels"], ) return model diff --git a/third_party/SOLD2/sold2/postprocess/convert_homography_results.py b/third_party/SOLD2/sold2/postprocess/convert_homography_results.py index 352eebbde00f6d8a9c20517dccd7024fd0758ffd..61045777bde0190e872c1c3983f1172ef36d8f1c 100644 --- a/third_party/SOLD2/sold2/postprocess/convert_homography_results.py +++ b/third_party/SOLD2/sold2/postprocess/convert_homography_results.py @@ -2,6 +2,7 @@ Convert the aggregation results from the homography adaptation to GT labels. """ import sys + sys.path.append("../") import os import yaml @@ -17,9 +18,10 @@ from model.metrics import super_nms from misc.train_utils import parse_h5_data -def convert_raw_exported_predictions(input_data, grid_size=8, - detect_thresh=1/65, topk=300): - """ Convert the exported junctions and heatmaps predictions +def convert_raw_exported_predictions( + input_data, grid_size=8, detect_thresh=1 / 65, topk=300 +): + """Convert the exported junctions and heatmaps predictions to a standard format. Arguments: input_data: the raw data (dict) decoded from the hdf5 dataset @@ -31,28 +33,29 @@ def convert_raw_exported_predictions(input_data, grid_size=8, # Check the input_data is from (1) single prediction, # or (2) homography adaptation. # Homography adaptation raw predictions - if (("junc_prob_mean" in input_data.keys()) - and ("heatmap_prob_mean" in input_data.keys())): + if ("junc_prob_mean" in input_data.keys()) and ( + "heatmap_prob_mean" in input_data.keys() + ): # Get the junction predictions and convert if to Nx2 format junc_prob = input_data["junc_prob_mean"] junc_pred_np = junc_prob[None, ...] - junc_pred_np_nms = super_nms(junc_pred_np, grid_size, - detect_thresh, topk) + junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) junctions = np.where(junc_pred_np_nms.squeeze()) - junc_points_pred = np.concatenate([junctions[0][..., None], - junctions[1][..., None]], axis=-1) + junc_points_pred = np.concatenate( + [junctions[0][..., None], junctions[1][..., None]], axis=-1 + ) # Get the heatmap predictions heatmap_pred = input_data["heatmap_prob_mean"].squeeze() valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32) - + # Single predictions else: # Get the junction point predictions and convert to Nx2 format junc_points_pred = np.where(input_data["junc_pred_nms"]) junc_points_pred = np.concatenate( - [junc_points_pred[0][..., None], - junc_points_pred[1][..., None]], axis=-1) + [junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1 + ) # Get the heatmap predictions heatmap_pred = input_data["heatmap_pred"] @@ -61,34 +64,29 @@ def convert_raw_exported_predictions(input_data, grid_size=8, return { "junctions_pred": junc_points_pred, "heatmap_pred": heatmap_pred, - "valid_mask": valid_mask + "valid_mask": valid_mask, } if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("input_dataset", type=str, - help="Name of the exported dataset.") - parser.add_argument("output_dataset", type=str, - help="Name of the output dataset.") - parser.add_argument("config", type=str, - help="Path to the model config.") - args = parser.parse_args() - + parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.") + parser.add_argument("output_dataset", type=str, help="Name of the output dataset.") + parser.add_argument("config", type=str, help="Path to the model config.") + args = parser.parse_args() + # Define the path to the input exported dataset - exported_dataset_path = os.path.join(cfg.export_dataroot, - args.input_dataset) + exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset) if not os.path.exists(exported_dataset_path): raise ValueError("Missing input dataset: " + exported_dataset_path) exported_dataset = h5py.File(exported_dataset_path, "r") # Define the output path for the results - output_dataset_path = os.path.join(cfg.export_dataroot, - args.output_dataset) + output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset) device = torch.device("cuda") nms_device = torch.device("cuda") - + # Read the config file if not os.path.exists(args.config): raise ValueError("Missing config file: " + args.config) @@ -96,41 +94,43 @@ if __name__ == "__main__": config = yaml.safe_load(f) model_cfg = config["model_cfg"] line_detector_cfg = config["line_detector_cfg"] - + # Initialize the line detection module line_detector = LineSegmentDetectionModule(**line_detector_cfg) # Iterate through all the dataset keys with h5py.File(output_dataset_path, "w") as output_dataset: - for idx, output_key in enumerate(tqdm(list(exported_dataset.keys()), - ascii=True)): + for idx, output_key in enumerate( + tqdm(list(exported_dataset.keys()), ascii=True) + ): # Get the data data = parse_h5_data(exported_dataset[output_key]) # Preprocess the data converted_data = convert_raw_exported_predictions( - data, grid_size=model_cfg["grid_size"], - detect_thresh=model_cfg["detection_thresh"]) + data, + grid_size=model_cfg["grid_size"], + detect_thresh=model_cfg["detection_thresh"], + ) junctions_pred_raw = converted_data["junctions_pred"] heatmap_pred = converted_data["heatmap_pred"] valid_mask = converted_data["valid_mask"] line_map_pred, junctions_pred, heatmap_pred = line_detector.detect( - junctions_pred_raw, heatmap_pred, device=device) + junctions_pred_raw, heatmap_pred, device=device + ) if isinstance(line_map_pred, torch.Tensor): line_map_pred = line_map_pred.cpu().numpy() if isinstance(junctions_pred, torch.Tensor): junctions_pred = junctions_pred.cpu().numpy() if isinstance(heatmap_pred, torch.Tensor): heatmap_pred = heatmap_pred.cpu().numpy() - - output_data = {"junctions": junctions_pred, - "line_map": line_map_pred} + + output_data = {"junctions": junctions_pred, "line_map": line_map_pred} # Record it to the h5 dataset f_group = output_dataset.create_group(output_key) # Store data for key, output_data in output_data.items(): - f_group.create_dataset(key, data=output_data, - compression="gzip") + f_group.create_dataset(key, data=output_data, compression="gzip") diff --git a/third_party/SOLD2/sold2/train.py b/third_party/SOLD2/sold2/train.py index 2064e00e6d192f9202f011c3626d6f53c4fe6270..148c9b23464d975f1efc03ea459c82d4a0759b05 100644 --- a/third_party/SOLD2/sold2/train.py +++ b/third_party/SOLD2/sold2/train.py @@ -15,12 +15,15 @@ from .model.model_util import get_model from .model.loss import TotalLoss, get_loss_and_weights from .model.metrics import AverageMeter, Metrics, super_nms from .model.lr_scheduler import get_lr_scheduler -from .misc.train_utils import (convert_image, get_latest_checkpoint, - remove_old_checkpoints) +from .misc.train_utils import ( + convert_image, + get_latest_checkpoint, + remove_old_checkpoints, +) def customized_collate_fn(batch): - """ Customized collate_fn. """ + """Customized collate_fn.""" batch_keys = ["image", "junction_map", "heatmap", "valid_mask"] list_keys = ["junctions", "line_map"] @@ -34,14 +37,14 @@ def customized_collate_fn(batch): def restore_weights(model, state_dict, strict=True): - """ Restore weights in compatible mode. """ + """Restore weights in compatible mode.""" # Try to directly load state dict try: model.load_state_dict(state_dict, strict=strict) # Deal with some version compatibility issue (catch version incompatible) except: err = model.load_state_dict(state_dict, strict=False) - + # missing keys are those in model but not in state_dict missing_keys = err.missing_keys # Unexpected keys are those in state_dict but not in model @@ -53,12 +56,12 @@ def restore_weights(model, state_dict, strict=True): dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] model_dict[key] = state_dict[dict_keys[idx]] model.load_state_dict(model_dict) - + return model def train_net(args, dataset_cfg, model_cfg, output_path): - """ Main training function. """ + """Main training function.""" # Add some version compatibility check if model_cfg.get("weighting_policy") is None: # Default to static @@ -74,44 +77,50 @@ def train_net(args, dataset_cfg, model_cfg, output_path): test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) # Create the dataloader - train_loader = DataLoader(train_dataset, - batch_size=train_cfg["batch_size"], - num_workers=8, - shuffle=True, pin_memory=True, - collate_fn=train_collate_fn) - test_loader = DataLoader(test_dataset, - batch_size=test_cfg.get("batch_size", 1), - num_workers=test_cfg.get("num_workers", 1), - shuffle=False, pin_memory=False, - collate_fn=test_collate_fn) + train_loader = DataLoader( + train_dataset, + batch_size=train_cfg["batch_size"], + num_workers=8, + shuffle=True, + pin_memory=True, + collate_fn=train_collate_fn, + ) + test_loader = DataLoader( + test_dataset, + batch_size=test_cfg.get("batch_size", 1), + num_workers=test_cfg.get("num_workers", 1), + shuffle=False, + pin_memory=False, + collate_fn=test_collate_fn, + ) print("\t Successfully intialized dataloaders.") - # Get the loss function and weight first loss_funcs, loss_weights = get_loss_and_weights(model_cfg) # If resume. if args.resume: # Create model and load the state dict - checkpoint = get_latest_checkpoint(args.resume_path, - args.checkpoint_name) + checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) model = get_model(model_cfg, loss_weights) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.cuda() optimizer = torch.optim.Adam( - [{"params": model.parameters(), - "initial_lr": model_cfg["learning_rate"]}], - model_cfg["learning_rate"], - amsgrad=True) + [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], + model_cfg["learning_rate"], + amsgrad=True, + ) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler( lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), - optimizer=optimizer) + optimizer=optimizer, + ) # If we start to use learning rate scheduler from the middle - if ((scheduler is not None) - and (checkpoint.get("scheduler_state_dict", None) is not None)): + if (scheduler is not None) and ( + checkpoint.get("scheduler_state_dict", None) is not None + ): scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = checkpoint["epoch"] + 1 # Initialize all the components. @@ -121,40 +130,45 @@ def train_net(args, dataset_cfg, model_cfg, output_path): # Optionally get the pretrained wieghts if args.pretrained: print("\t [Debug] Loading pretrained weights...") - checkpoint = get_latest_checkpoint(args.pretrained_path, - args.checkpoint_name) + checkpoint = get_latest_checkpoint( + args.pretrained_path, args.checkpoint_name + ) # If auto weighting restore from non-auto weighting - model = restore_weights(model, checkpoint["model_state_dict"], - strict=False) + model = restore_weights(model, checkpoint["model_state_dict"], strict=False) print("\t [Debug] Finished loading pretrained weights!") - + model = model.cuda() optimizer = torch.optim.Adam( - [{"params": model.parameters(), - "initial_lr": model_cfg["learning_rate"]}], - model_cfg["learning_rate"], - amsgrad=True) + [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], + model_cfg["learning_rate"], + amsgrad=True, + ) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler( lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), - optimizer=optimizer) + optimizer=optimizer, + ) start_epoch = 0 - + print("\t Successfully initialized model") # Define the total loss policy = model_cfg.get("weighting_policy", "static") loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() if "descriptor_decoder" in model_cfg: - metric_func = Metrics(model_cfg["detection_thresh"], - model_cfg["prob_thresh"], - model_cfg["descriptor_loss_cfg"]["grid_size"], - desc_metric_lst='all') + metric_func = Metrics( + model_cfg["detection_thresh"], + model_cfg["prob_thresh"], + model_cfg["descriptor_loss_cfg"]["grid_size"], + desc_metric_lst="all", + ) else: - metric_func = Metrics(model_cfg["detection_thresh"], - model_cfg["prob_thresh"], - model_cfg["grid_size"]) + metric_func = Metrics( + model_cfg["detection_thresh"], + model_cfg["prob_thresh"], + model_cfg["grid_size"], + ) # Define the summary writer logdir = os.path.join(output_path, "log") @@ -176,7 +190,8 @@ def train_net(args, dataset_cfg, model_cfg, output_path): metric_func=metric_func, train_loader=train_loader, writer=writer, - epoch=epoch) + epoch=epoch, + ) # Do the validation print("\n\n================== Validation ==================") @@ -187,21 +202,22 @@ def train_net(args, dataset_cfg, model_cfg, output_path): metric_func=metric_func, val_loader=test_loader, writer=writer, - epoch=epoch) + epoch=epoch, + ) # Update the scheduler if scheduler is not None: scheduler.step() # Save checkpoints - file_name = os.path.join(output_path, - "checkpoint-epoch%03d-end.tar"%(epoch)) + file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch)) print("[Info] Saving checkpoint %s ..." % file_name) save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), - "model_cfg": model_cfg} + "model_cfg": model_cfg, + } if scheduler is not None: save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) torch.save(save_dict, file_name) @@ -210,16 +226,17 @@ def train_net(args, dataset_cfg, model_cfg, output_path): remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15)) -def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, - train_loader, writer, epoch): - """ Train for one epoch. """ +def train_single_epoch( + model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch +): + """Train for one epoch.""" # Switch the model to training mode model.train() # Initialize the average meter compute_descriptors = loss_func.compute_descriptors if compute_descriptors: - average_meter = AverageMeter(is_training=True, desc_metric_lst='all') + average_meter = AverageMeter(is_training=True, desc_metric_lst="all") else: average_meter = AverageMeter(is_training=True) @@ -244,11 +261,23 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, # Compute losses losses = loss_func.forward_descriptors( - outputs["junctions"], outputs2["junctions"], - junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"], - heatmap, heatmap2, line_points, line_points2, - line_indices, outputs['descriptors'], outputs2['descriptors'], - epoch, valid_mask, valid_mask2) + outputs["junctions"], + outputs2["junctions"], + junc_map, + junc_map2, + outputs["heatmap"], + outputs2["heatmap"], + heatmap, + heatmap2, + line_points, + line_points2, + line_indices, + outputs["descriptors"], + outputs2["descriptors"], + epoch, + valid_mask, + valid_mask2, + ) else: junc_map = data["junction_map"].cuda() heatmap = data["heatmap"].cuda() @@ -260,58 +289,74 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, # Compute losses losses = loss_func( - outputs["junctions"], junc_map, - outputs["heatmap"], heatmap, - valid_mask) - + outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask + ) + total_loss = losses["total_loss"] # Update the model optimizer.zero_grad() - total_loss.backward() + total_loss.backward() optimizer.step() # Compute the global step global_step = epoch * len(train_loader) + idx ############## Measure the metric error ######################### # Only do this when needed - if (((idx % model_cfg["disp_freq"]) == 0) - or ((idx % model_cfg["summary_freq"]) == 0)): + if ((idx % model_cfg["disp_freq"]) == 0) or ( + (idx % model_cfg["summary_freq"]) == 0 + ): junc_np = convert_junc_predictions( - outputs["junctions"], model_cfg["grid_size"], - model_cfg["detection_thresh"], 300) + outputs["junctions"], + model_cfg["grid_size"], + model_cfg["detection_thresh"], + 300, + ) junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) # Always fetch only one channel (compatible with L1, L2, and CE) if outputs["heatmap"].shape[1] == 2: - heatmap_np = softmax(outputs["heatmap"].detach(), - dim=1).cpu().numpy() + heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy() heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:] else: heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) - + heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) # Evaluate metric results if compute_descriptors: metric_func.evaluate( - junc_np["junc_pred"], junc_np["junc_pred_nms"], - junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, - line_points, line_points2, outputs["descriptors"], - outputs2["descriptors"], line_indices) + junc_np["junc_pred"], + junc_np["junc_pred_nms"], + junc_map_np, + heatmap_np, + heatmap_gt_np, + valid_mask_np, + line_points, + line_points2, + outputs["descriptors"], + outputs2["descriptors"], + line_indices, + ) else: metric_func.evaluate( - junc_np["junc_pred"], junc_np["junc_pred_nms"], - junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np) + junc_np["junc_pred"], + junc_np["junc_pred_nms"], + junc_map_np, + heatmap_np, + heatmap_gt_np, + valid_mask_np, + ) # Update average meter junc_loss = losses["junc_loss"].item() heatmap_loss = losses["heatmap_loss"].item() loss_dict = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, - "total_loss": total_loss.item()} + "total_loss": total_loss.item(), + } if compute_descriptors: descriptor_loss = losses["descriptor_loss"].item() loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() @@ -323,34 +368,75 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, results = metric_func.metric_results average = average_meter.average() # Get gpu memory usage in GB - gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) + gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) if compute_descriptors: - print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" - % (epoch, model_cfg["epochs"], idx, len(train_loader), - total_loss.item(), average["total_loss"], junc_loss, - average["junc_loss"], heatmap_loss, - average["heatmap_loss"], descriptor_loss, - average["descriptor_loss"], gpu_mem_usage)) + print( + "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" + % ( + epoch, + model_cfg["epochs"], + idx, + len(train_loader), + total_loss.item(), + average["total_loss"], + junc_loss, + average["junc_loss"], + heatmap_loss, + average["heatmap_loss"], + descriptor_loss, + average["descriptor_loss"], + gpu_mem_usage, + ) + ) else: - print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" - % (epoch, model_cfg["epochs"], idx, len(train_loader), - total_loss.item(), average["total_loss"], - junc_loss, average["junc_loss"], heatmap_loss, - average["heatmap_loss"], gpu_mem_usage)) - print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" - % (results["junc_precision"], average["junc_precision"], - results["junc_recall"], average["junc_recall"])) - print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" - % (results["junc_precision_nms"], - average["junc_precision_nms"], - results["junc_recall_nms"], average["junc_recall_nms"])) - print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" - %(results["heatmap_precision"], + print( + "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" + % ( + epoch, + model_cfg["epochs"], + idx, + len(train_loader), + total_loss.item(), + average["total_loss"], + junc_loss, + average["junc_loss"], + heatmap_loss, + average["heatmap_loss"], + gpu_mem_usage, + ) + ) + print( + "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["junc_precision"], + average["junc_precision"], + results["junc_recall"], + average["junc_recall"], + ) + ) + print( + "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["junc_precision_nms"], + average["junc_precision_nms"], + results["junc_recall_nms"], + average["junc_recall_nms"], + ) + ) + print( + "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["heatmap_precision"], average["heatmap_precision"], - results["heatmap_recall"], average["heatmap_recall"])) + results["heatmap_recall"], + average["heatmap_recall"], + ) + ) if compute_descriptors: - print("\t Descriptors matching score=%.4f (%.4f)" - %(results["matching_score"], average["matching_score"])) + print( + "\t Descriptors matching score=%.4f (%.4f)" + % (results["matching_score"], average["matching_score"]) + ) # Record summaries if (idx % model_cfg["summary_freq"]) == 0: @@ -362,7 +448,8 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, "heatmap_loss": heatmap_loss, "total_loss": total_loss.detach().cpu().numpy(), "metrics": results, - "average": average} + "average": average, + } # Add descriptor terms if compute_descriptors: scalar_summaries["descriptor_loss"] = descriptor_loss @@ -374,10 +461,13 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, scalar_summaries["reg_loss"] = losses["reg_loss"].item() num_images = 3 - junc_pred_binary = (junc_np["junc_pred"][:num_images, ...] - > model_cfg["detection_thresh"]) - junc_pred_nms_binary = (junc_np["junc_pred_nms"][:num_images, ...] - > model_cfg["detection_thresh"]) + junc_pred_binary = ( + junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"] + ) + junc_pred_nms_binary = ( + junc_np["junc_pred_nms"][:num_images, ...] + > model_cfg["detection_thresh"] + ) image_summaries = { "image": input_images.cpu().numpy()[:num_images, ...], "valid_mask": valid_mask_np[:num_images, ...], @@ -386,22 +476,23 @@ def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, "junc_map_gt": junc_map_np[:num_images, ...], "junc_prob_map": junc_np["junc_prob"][:num_images, ...], "heatmap_pred": heatmap_np[:num_images, ...], - "heatmap_gt": heatmap_gt_np[:num_images, ...]} + "heatmap_gt": heatmap_gt_np[:num_images, ...], + } # Record the training summary record_train_summaries( - writer, global_step, scalars=scalar_summaries, - images=image_summaries) + writer, global_step, scalars=scalar_summaries, images=image_summaries + ) def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch): - """ Validation. """ + """Validation.""" # Switch the model to eval mode model.eval() # Initialize the average meter compute_descriptors = loss_func.compute_descriptors if compute_descriptors: - average_meter = AverageMeter(is_training=True, desc_metric_lst='all') + average_meter = AverageMeter(is_training=True, desc_metric_lst="all") else: average_meter = AverageMeter(is_training=True) @@ -427,11 +518,23 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch # Compute losses losses = loss_func.forward_descriptors( - outputs["junctions"], outputs2["junctions"], - junc_map, junc_map2, outputs["heatmap"], - outputs2["heatmap"], heatmap, heatmap2, line_points, - line_points2, line_indices, outputs['descriptors'], - outputs2['descriptors'], epoch, valid_mask, valid_mask2) + outputs["junctions"], + outputs2["junctions"], + junc_map, + junc_map2, + outputs["heatmap"], + outputs2["heatmap"], + heatmap, + heatmap2, + line_points, + line_points2, + line_indices, + outputs["descriptors"], + outputs2["descriptors"], + epoch, + valid_mask, + valid_mask2, + ) else: junc_map = data["junction_map"].cuda() heatmap = data["heatmap"].cuda() @@ -444,47 +547,70 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch # Compute losses losses = loss_func( - outputs["junctions"], junc_map, - outputs["heatmap"], heatmap, - valid_mask) + outputs["junctions"], + junc_map, + outputs["heatmap"], + heatmap, + valid_mask, + ) total_loss = losses["total_loss"] ############## Measure the metric error ######################### junc_np = convert_junc_predictions( - outputs["junctions"], model_cfg["grid_size"], - model_cfg["detection_thresh"], 300) + outputs["junctions"], + model_cfg["grid_size"], + model_cfg["detection_thresh"], + 300, + ) junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) # Always fetch only one channel (compatible with L1, L2, and CE) if outputs["heatmap"].shape[1] == 2: - heatmap_np = softmax(outputs["heatmap"].detach(), - dim=1).cpu().numpy().transpose(0, 2, 3, 1) + heatmap_np = ( + softmax(outputs["heatmap"].detach(), dim=1) + .cpu() + .numpy() + .transpose(0, 2, 3, 1) + ) heatmap_np = heatmap_np[:, :, :, 1:] else: heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) - heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) # Evaluate metric results if compute_descriptors: metric_func.evaluate( - junc_np["junc_pred"], junc_np["junc_pred_nms"], - junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, - line_points, line_points2, outputs["descriptors"], - outputs2["descriptors"], line_indices) + junc_np["junc_pred"], + junc_np["junc_pred_nms"], + junc_map_np, + heatmap_np, + heatmap_gt_np, + valid_mask_np, + line_points, + line_points2, + outputs["descriptors"], + outputs2["descriptors"], + line_indices, + ) else: metric_func.evaluate( - junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, - heatmap_np, heatmap_gt_np, valid_mask_np) + junc_np["junc_pred"], + junc_np["junc_pred_nms"], + junc_map_np, + heatmap_np, + heatmap_gt_np, + valid_mask_np, + ) # Update average meter junc_loss = losses["junc_loss"].item() heatmap_loss = losses["heatmap_loss"].item() loss_dict = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, - "total_loss": total_loss.item()} + "total_loss": total_loss.item(), + } if compute_descriptors: descriptor_loss = losses["descriptor_loss"].item() loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() @@ -495,32 +621,67 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch results = metric_func.metric_results average = average_meter.average() if compute_descriptors: - print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" - % (idx, len(val_loader), - total_loss.item(), average["total_loss"], - junc_loss, average["junc_loss"], - heatmap_loss, average["heatmap_loss"], - descriptor_loss, average["descriptor_loss"])) + print( + "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" + % ( + idx, + len(val_loader), + total_loss.item(), + average["total_loss"], + junc_loss, + average["junc_loss"], + heatmap_loss, + average["heatmap_loss"], + descriptor_loss, + average["descriptor_loss"], + ) + ) else: - print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" - % (idx, len(val_loader), - total_loss.item(), average["total_loss"], - junc_loss, average["junc_loss"], - heatmap_loss, average["heatmap_loss"])) - print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" - % (results["junc_precision"], average["junc_precision"], - results["junc_recall"], average["junc_recall"])) - print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" - % (results["junc_precision_nms"], - average["junc_precision_nms"], - results["junc_recall_nms"], average["junc_recall_nms"])) - print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" - % (results["heatmap_precision"], - average["heatmap_precision"], - results["heatmap_recall"], average["heatmap_recall"])) + print( + "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" + % ( + idx, + len(val_loader), + total_loss.item(), + average["total_loss"], + junc_loss, + average["junc_loss"], + heatmap_loss, + average["heatmap_loss"], + ) + ) + print( + "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["junc_precision"], + average["junc_precision"], + results["junc_recall"], + average["junc_recall"], + ) + ) + print( + "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["junc_precision_nms"], + average["junc_precision_nms"], + results["junc_recall_nms"], + average["junc_recall_nms"], + ) + ) + print( + "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % ( + results["heatmap_precision"], + average["heatmap_precision"], + results["heatmap_recall"], + average["heatmap_recall"], + ) + ) if compute_descriptors: - print("\t Descriptors matching score=%.4f (%.4f)" - %(results["matching_score"], average["matching_score"])) + print( + "\t Descriptors matching score=%.4f (%.4f)" + % (results["matching_score"], average["matching_score"]) + ) # Record summaries average = average_meter.average() @@ -529,143 +690,182 @@ def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch record_test_summaries(writer, epoch, scalar_summaries) -def convert_junc_predictions(predictions, grid_size, - detect_thresh=1/65, topk=300): - """ Convert torch predictions to numpy arrays for evaluation. """ +def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300): + """Convert torch predictions to numpy arrays for evaluation.""" # Convert to probability outputs first junc_prob = softmax(predictions.detach(), dim=1).cpu() junc_pred = junc_prob[:, :-1, :, :] junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1] junc_prob_np = np.sum(junc_prob_np, axis=-1) - junc_pred_np = pixel_shuffle( - junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) + junc_pred_np = ( + pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) + ) junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) junc_pred_np = junc_pred_np.squeeze(-1) - return {"junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms, - "junc_prob": junc_prob_np} + return { + "junc_pred": junc_pred_np, + "junc_pred_nms": junc_pred_np_nms, + "junc_prob": junc_prob_np, + } def record_train_summaries(writer, global_step, scalars, images): - """ Record training summaries. """ + """Record training summaries.""" # Record the scalar summaries results = scalars["metrics"] average = scalars["average"] # GPU memory part # Get gpu memory usage in GB - gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) + gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step) # Loss part - writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], - global_step) - writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], - global_step) - writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], - global_step) + writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step) + writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step) + writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step) # Add regularization loss if "reg_loss" in scalars.keys(): - writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], - global_step) + writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step) # Add descriptor loss if "descriptor_loss" in scalars.keys(): key = "descriptor_loss" - writer.add_scalar("Train_loss/%s"%(key), scalars[key], global_step) - writer.add_scalar("Train_loss_average/%s"%(key), average[key], - global_step) - + writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step) + writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step) + # Record weighting for key in scalars.keys(): if "w_" in key: - writer.add_scalar("Train_weight/%s"%(key), scalars[key], - global_step) - + writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step) + # Smoothed loss - writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], - global_step) - writer.add_scalar("Train_loss_average/heatmap_loss", - average["heatmap_loss"], global_step) - writer.add_scalar("Train_loss_average/total_loss", average["total_loss"], - global_step) + writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step) + writer.add_scalar( + "Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step + ) + writer.add_scalar( + "Train_loss_average/total_loss", average["total_loss"], global_step + ) # Add smoothed descriptor loss if "descriptor_loss" in average.keys(): - writer.add_scalar("Train_loss_average/descriptor_loss", - average["descriptor_loss"], global_step) + writer.add_scalar( + "Train_loss_average/descriptor_loss", + average["descriptor_loss"], + global_step, + ) # Metrics part - writer.add_scalar("Train_metrics/junc_precision", - results["junc_precision"], global_step) - writer.add_scalar("Train_metrics/junc_precision_nms", - results["junc_precision_nms"], global_step) - writer.add_scalar("Train_metrics/junc_recall", - results["junc_recall"], global_step) - writer.add_scalar("Train_metrics/junc_recall_nms", - results["junc_recall_nms"], global_step) - writer.add_scalar("Train_metrics/heatmap_precision", - results["heatmap_precision"], global_step) - writer.add_scalar("Train_metrics/heatmap_recall", - results["heatmap_recall"], global_step) + writer.add_scalar( + "Train_metrics/junc_precision", results["junc_precision"], global_step + ) + writer.add_scalar( + "Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step + ) + writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step) + writer.add_scalar( + "Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step + ) + writer.add_scalar( + "Train_metrics/heatmap_precision", results["heatmap_precision"], global_step + ) + writer.add_scalar( + "Train_metrics/heatmap_recall", results["heatmap_recall"], global_step + ) # Add descriptor metric if "matching_score" in results.keys(): - writer.add_scalar("Train_metrics/matching_score", - results["matching_score"], global_step) + writer.add_scalar( + "Train_metrics/matching_score", results["matching_score"], global_step + ) # Average part - writer.add_scalar("Train_metrics_average/junc_precision", - average["junc_precision"], global_step) - writer.add_scalar("Train_metrics_average/junc_precision_nms", - average["junc_precision_nms"], global_step) - writer.add_scalar("Train_metrics_average/junc_recall", - average["junc_recall"], global_step) - writer.add_scalar("Train_metrics_average/junc_recall_nms", - average["junc_recall_nms"], global_step) - writer.add_scalar("Train_metrics_average/heatmap_precision", - average["heatmap_precision"], global_step) - writer.add_scalar("Train_metrics_average/heatmap_recall", - average["heatmap_recall"], global_step) + writer.add_scalar( + "Train_metrics_average/junc_precision", average["junc_precision"], global_step + ) + writer.add_scalar( + "Train_metrics_average/junc_precision_nms", + average["junc_precision_nms"], + global_step, + ) + writer.add_scalar( + "Train_metrics_average/junc_recall", average["junc_recall"], global_step + ) + writer.add_scalar( + "Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step + ) + writer.add_scalar( + "Train_metrics_average/heatmap_precision", + average["heatmap_precision"], + global_step, + ) + writer.add_scalar( + "Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step + ) # Add smoothed descriptor metric if "matching_score" in average.keys(): - writer.add_scalar("Train_metrics_average/matching_score", - average["matching_score"], global_step) + writer.add_scalar( + "Train_metrics_average/matching_score", + average["matching_score"], + global_step, + ) # Record the image summary # Image part image_tensor = convert_image(images["image"], 1) valid_masks = convert_image(images["valid_mask"], -1) - writer.add_images("Train/images", image_tensor, global_step, - dataformats="NCHW") - writer.add_images("Train/valid_map", valid_masks, global_step, - dataformats="NHWC") + writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW") + writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC") # Heatmap part - writer.add_images("Train/heatmap_gt", - convert_image(images["heatmap_gt"], -1), global_step, - dataformats="NHWC") - writer.add_images("Train/heatmap_pred", - convert_image(images["heatmap_pred"], -1), global_step, - dataformats="NHWC") + writer.add_images( + "Train/heatmap_gt", + convert_image(images["heatmap_gt"], -1), + global_step, + dataformats="NHWC", + ) + writer.add_images( + "Train/heatmap_pred", + convert_image(images["heatmap_pred"], -1), + global_step, + dataformats="NHWC", + ) # Junction prediction part junc_plots = plot_junction_detection( - image_tensor, images["junc_map_pred"], - images["junc_map_pred_nms"], images["junc_map_gt"]) - writer.add_images("Train/junc_gt", junc_plots["junc_gt_plot"] / 255., - global_step, dataformats="NHWC") - writer.add_images("Train/junc_pred", junc_plots["junc_pred_plot"] / 255., - global_step, dataformats="NHWC") - writer.add_images("Train/junc_pred_nms", - junc_plots["junc_pred_nms_plot"] / 255., global_step, - dataformats="NHWC") + image_tensor, + images["junc_map_pred"], + images["junc_map_pred_nms"], + images["junc_map_gt"], + ) + writer.add_images( + "Train/junc_gt", + junc_plots["junc_gt_plot"] / 255.0, + global_step, + dataformats="NHWC", + ) + writer.add_images( + "Train/junc_pred", + junc_plots["junc_pred_plot"] / 255.0, + global_step, + dataformats="NHWC", + ) + writer.add_images( + "Train/junc_pred_nms", + junc_plots["junc_pred_nms_plot"] / 255.0, + global_step, + dataformats="NHWC", + ) writer.add_images( "Train/junc_prob_map", convert_image(images["junc_prob_map"][..., None], axis=-1), - global_step, dataformats="NHWC") + global_step, + dataformats="NHWC", + ) def record_test_summaries(writer, epoch, scalars): - """ Record testing summaries. """ + """Record testing summaries.""" average = scalars["average"] # Average loss @@ -675,30 +875,30 @@ def record_test_summaries(writer, epoch, scalars): # Add descriptor loss if "descriptor_loss" in average.keys(): key = "descriptor_loss" - writer.add_scalar("Val_loss/%s"%(key), average[key], epoch) + writer.add_scalar("Val_loss/%s" % (key), average[key], epoch) # Average metrics - writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], - epoch) - writer.add_scalar("Val_metrics/junc_precision_nms", - average["junc_precision_nms"], epoch) - writer.add_scalar("Val_metrics/junc_recall", - average["junc_recall"], epoch) - writer.add_scalar("Val_metrics/junc_recall_nms", - average["junc_recall_nms"], epoch) - writer.add_scalar("Val_metrics/heatmap_precision", - average["heatmap_precision"], epoch) - writer.add_scalar("Val_metrics/heatmap_recall", - average["heatmap_recall"], epoch) + writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch) + writer.add_scalar( + "Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch + ) + writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch) + writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch) + writer.add_scalar( + "Val_metrics/heatmap_precision", average["heatmap_precision"], epoch + ) + writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch) # Add descriptor metric if "matching_score" in average.keys(): - writer.add_scalar("Val_metrics/matching_score", - average["matching_score"], epoch) + writer.add_scalar( + "Val_metrics/matching_score", average["matching_score"], epoch + ) -def plot_junction_detection(image_tensor, junc_pred_tensor, - junc_pred_nms_tensor, junc_gt_tensor): - """ Plot the junction points on images. """ +def plot_junction_detection( + image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor +): + """Plot the junction points on images.""" # Get the batch_size batch_size = image_tensor.shape[0] @@ -708,45 +908,61 @@ def plot_junction_detection(image_tensor, junc_pred_tensor, junc_gt_lst = [] for i in range(batch_size): # Convert image to 255 uint8 - image = (image_tensor[i, :, :, :] - * 255.).astype(np.uint8).transpose(1,2,0) + image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0) # Plot groundtruth onto image junc_gt = junc_gt_tensor[i, ...] coord_gt = np.where(junc_gt.squeeze() > 0) - points_gt = np.concatenate((coord_gt[0][..., None], - coord_gt[1][..., None]), - axis=1) + points_gt = np.concatenate( + (coord_gt[0][..., None], coord_gt[1][..., None]), axis=1 + ) plot_gt = image.copy() for id in range(points_gt.shape[0]): - cv2.circle(plot_gt, tuple(np.flip(points_gt[id, :])), 3, - color=(255, 0, 0), thickness=2) + cv2.circle( + plot_gt, + tuple(np.flip(points_gt[id, :])), + 3, + color=(255, 0, 0), + thickness=2, + ) junc_gt_lst.append(plot_gt[None, ...]) # Plot junc_pred junc_pred = junc_pred_tensor[i, ...] coord_pred = np.where(junc_pred > 0) - points_pred = np.concatenate((coord_pred[0][..., None], - coord_pred[1][..., None]), - axis=1) + points_pred = np.concatenate( + (coord_pred[0][..., None], coord_pred[1][..., None]), axis=1 + ) plot_pred = image.copy() for id in range(points_pred.shape[0]): - cv2.circle(plot_pred, tuple(np.flip(points_pred[id, :])), 3, - color=(0, 255, 0), thickness=2) + cv2.circle( + plot_pred, + tuple(np.flip(points_pred[id, :])), + 3, + color=(0, 255, 0), + thickness=2, + ) junc_pred_lst.append(plot_pred[None, ...]) # Plot junc_pred_nms junc_pred_nms = junc_pred_nms_tensor[i, ...] coord_pred_nms = np.where(junc_pred_nms > 0) - points_pred_nms = np.concatenate((coord_pred_nms[0][..., None], - coord_pred_nms[1][..., None]), - axis=1) + points_pred_nms = np.concatenate( + (coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1 + ) plot_pred_nms = image.copy() for id in range(points_pred_nms.shape[0]): - cv2.circle(plot_pred_nms, tuple(np.flip(points_pred_nms[id, :])), - 3, color=(0, 255, 0), thickness=2) + cv2.circle( + plot_pred_nms, + tuple(np.flip(points_pred_nms[id, :])), + 3, + color=(0, 255, 0), + thickness=2, + ) junc_pred_nms_lst.append(plot_pred_nms[None, ...]) - return {"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), - "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), - "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0)} + return { + "junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), + "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), + "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0), + } diff --git a/third_party/SuperGluePretrainedNetwork/demo_superglue.py b/third_party/SuperGluePretrainedNetwork/demo_superglue.py index 32d4ad3c7df1b7da141c4c6aa51f871a7d756aaf..c639efd7481052b842c640d4aa23aaf18e0eb449 100644 --- a/third_party/SuperGluePretrainedNetwork/demo_superglue.py +++ b/third_party/SuperGluePretrainedNetwork/demo_superglue.py @@ -51,69 +51,110 @@ import matplotlib.cm as cm import torch from models.matching import Matching -from models.utils import (AverageTimer, VideoStreamer, - make_matching_plot_fast, frame2tensor) +from models.utils import ( + AverageTimer, + VideoStreamer, + make_matching_plot_fast, + frame2tensor, +) torch.set_grad_enabled(False) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='SuperGlue demo', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="SuperGlue demo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - '--input', type=str, default='0', - help='ID of a USB webcam, URL of an IP camera, ' - 'or path to an image directory or movie file') + "--input", + type=str, + default="0", + help="ID of a USB webcam, URL of an IP camera, " + "or path to an image directory or movie file", + ) parser.add_argument( - '--output_dir', type=str, default=None, - help='Directory where to write output frames (If None, no output)') + "--output_dir", + type=str, + default=None, + help="Directory where to write output frames (If None, no output)", + ) parser.add_argument( - '--image_glob', type=str, nargs='+', default=['*.png', '*.jpg', '*.jpeg'], - help='Glob if a directory of images is specified') + "--image_glob", + type=str, + nargs="+", + default=["*.png", "*.jpg", "*.jpeg"], + help="Glob if a directory of images is specified", + ) parser.add_argument( - '--skip', type=int, default=1, - help='Images to skip if input is a movie or directory') + "--skip", + type=int, + default=1, + help="Images to skip if input is a movie or directory", + ) parser.add_argument( - '--max_length', type=int, default=1000000, - help='Maximum length if input is a movie or directory') + "--max_length", + type=int, + default=1000000, + help="Maximum length if input is a movie or directory", + ) parser.add_argument( - '--resize', type=int, nargs='+', default=[640, 480], - help='Resize the input image before running inference. If two numbers, ' - 'resize to the exact dimensions, if one number, resize the max ' - 'dimension, if -1, do not resize') + "--resize", + type=int, + nargs="+", + default=[640, 480], + help="Resize the input image before running inference. If two numbers, " + "resize to the exact dimensions, if one number, resize the max " + "dimension, if -1, do not resize", + ) parser.add_argument( - '--superglue', choices={'indoor', 'outdoor'}, default='indoor', - help='SuperGlue weights') + "--superglue", + choices={"indoor", "outdoor"}, + default="indoor", + help="SuperGlue weights", + ) parser.add_argument( - '--max_keypoints', type=int, default=-1, - help='Maximum number of keypoints detected by Superpoint' - ' (\'-1\' keeps all keypoints)') + "--max_keypoints", + type=int, + default=-1, + help="Maximum number of keypoints detected by Superpoint" + " ('-1' keeps all keypoints)", + ) parser.add_argument( - '--keypoint_threshold', type=float, default=0.005, - help='SuperPoint keypoint detector confidence threshold') + "--keypoint_threshold", + type=float, + default=0.005, + help="SuperPoint keypoint detector confidence threshold", + ) parser.add_argument( - '--nms_radius', type=int, default=4, - help='SuperPoint Non Maximum Suppression (NMS) radius' - ' (Must be positive)') + "--nms_radius", + type=int, + default=4, + help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)", + ) parser.add_argument( - '--sinkhorn_iterations', type=int, default=20, - help='Number of Sinkhorn iterations performed by SuperGlue') + "--sinkhorn_iterations", + type=int, + default=20, + help="Number of Sinkhorn iterations performed by SuperGlue", + ) parser.add_argument( - '--match_threshold', type=float, default=0.2, - help='SuperGlue match threshold') + "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold" + ) parser.add_argument( - '--show_keypoints', action='store_true', - help='Show the detected keypoints') + "--show_keypoints", action="store_true", help="Show the detected keypoints" + ) parser.add_argument( - '--no_display', action='store_true', - help='Do not display images to screen. Useful if running remotely') + "--no_display", + action="store_true", + help="Do not display images to screen. Useful if running remotely", + ) parser.add_argument( - '--force_cpu', action='store_true', - help='Force pytorch to run in CPU mode.') + "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode." + ) opt = parser.parse_args() print(opt) @@ -121,138 +162,160 @@ if __name__ == '__main__': if len(opt.resize) == 2 and opt.resize[1] == -1: opt.resize = opt.resize[0:1] if len(opt.resize) == 2: - print('Will resize to {}x{} (WxH)'.format( - opt.resize[0], opt.resize[1])) + print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1])) elif len(opt.resize) == 1 and opt.resize[0] > 0: - print('Will resize max dimension to {}'.format(opt.resize[0])) + print("Will resize max dimension to {}".format(opt.resize[0])) elif len(opt.resize) == 1: - print('Will not resize images') + print("Will not resize images") else: - raise ValueError('Cannot specify more than two integers for --resize') + raise ValueError("Cannot specify more than two integers for --resize") - device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' - print('Running inference on device \"{}\"'.format(device)) + device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu" + print('Running inference on device "{}"'.format(device)) config = { - 'superpoint': { - 'nms_radius': opt.nms_radius, - 'keypoint_threshold': opt.keypoint_threshold, - 'max_keypoints': opt.max_keypoints + "superpoint": { + "nms_radius": opt.nms_radius, + "keypoint_threshold": opt.keypoint_threshold, + "max_keypoints": opt.max_keypoints, + }, + "superglue": { + "weights": opt.superglue, + "sinkhorn_iterations": opt.sinkhorn_iterations, + "match_threshold": opt.match_threshold, }, - 'superglue': { - 'weights': opt.superglue, - 'sinkhorn_iterations': opt.sinkhorn_iterations, - 'match_threshold': opt.match_threshold, - } } matching = Matching(config).eval().to(device) - keys = ['keypoints', 'scores', 'descriptors'] + keys = ["keypoints", "scores", "descriptors"] - vs = VideoStreamer(opt.input, opt.resize, opt.skip, - opt.image_glob, opt.max_length) + vs = VideoStreamer(opt.input, opt.resize, opt.skip, opt.image_glob, opt.max_length) frame, ret = vs.next_frame() - assert ret, 'Error when reading the first frame (try different --input?)' + assert ret, "Error when reading the first frame (try different --input?)" frame_tensor = frame2tensor(frame, device) - last_data = matching.superpoint({'image': frame_tensor}) - last_data = {k+'0': last_data[k] for k in keys} - last_data['image0'] = frame_tensor + last_data = matching.superpoint({"image": frame_tensor}) + last_data = {k + "0": last_data[k] for k in keys} + last_data["image0"] = frame_tensor last_frame = frame last_image_id = 0 if opt.output_dir is not None: - print('==> Will write outputs to {}'.format(opt.output_dir)) + print("==> Will write outputs to {}".format(opt.output_dir)) Path(opt.output_dir).mkdir(exist_ok=True) # Create a window to display the demo. if not opt.no_display: - cv2.namedWindow('SuperGlue matches', cv2.WINDOW_NORMAL) - cv2.resizeWindow('SuperGlue matches', 640*2, 480) + cv2.namedWindow("SuperGlue matches", cv2.WINDOW_NORMAL) + cv2.resizeWindow("SuperGlue matches", 640 * 2, 480) else: - print('Skipping visualization, will not show a GUI.') + print("Skipping visualization, will not show a GUI.") # Print the keyboard help menu. - print('==> Keyboard control:\n' - '\tn: select the current frame as the anchor\n' - '\te/r: increase/decrease the keypoint confidence threshold\n' - '\td/f: increase/decrease the match filtering threshold\n' - '\tk: toggle the visualization of keypoints\n' - '\tq: quit') + print( + "==> Keyboard control:\n" + "\tn: select the current frame as the anchor\n" + "\te/r: increase/decrease the keypoint confidence threshold\n" + "\td/f: increase/decrease the match filtering threshold\n" + "\tk: toggle the visualization of keypoints\n" + "\tq: quit" + ) timer = AverageTimer() while True: frame, ret = vs.next_frame() if not ret: - print('Finished demo_superglue.py') + print("Finished demo_superglue.py") break - timer.update('data') + timer.update("data") stem0, stem1 = last_image_id, vs.i - 1 frame_tensor = frame2tensor(frame, device) - pred = matching({**last_data, 'image1': frame_tensor}) - kpts0 = last_data['keypoints0'][0].cpu().numpy() - kpts1 = pred['keypoints1'][0].cpu().numpy() - matches = pred['matches0'][0].cpu().numpy() - confidence = pred['matching_scores0'][0].cpu().numpy() - timer.update('forward') + pred = matching({**last_data, "image1": frame_tensor}) + kpts0 = last_data["keypoints0"][0].cpu().numpy() + kpts1 = pred["keypoints1"][0].cpu().numpy() + matches = pred["matches0"][0].cpu().numpy() + confidence = pred["matching_scores0"][0].cpu().numpy() + timer.update("forward") valid = matches > -1 mkpts0 = kpts0[valid] mkpts1 = kpts1[matches[valid]] color = cm.jet(confidence[valid]) text = [ - 'SuperGlue', - 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), - 'Matches: {}'.format(len(mkpts0)) + "SuperGlue", + "Keypoints: {}:{}".format(len(kpts0), len(kpts1)), + "Matches: {}".format(len(mkpts0)), ] - k_thresh = matching.superpoint.config['keypoint_threshold'] - m_thresh = matching.superglue.config['match_threshold'] + k_thresh = matching.superpoint.config["keypoint_threshold"] + m_thresh = matching.superglue.config["match_threshold"] small_text = [ - 'Keypoint Threshold: {:.4f}'.format(k_thresh), - 'Match Threshold: {:.2f}'.format(m_thresh), - 'Image Pair: {:06}:{:06}'.format(stem0, stem1), + "Keypoint Threshold: {:.4f}".format(k_thresh), + "Match Threshold: {:.2f}".format(m_thresh), + "Image Pair: {:06}:{:06}".format(stem0, stem1), ] out = make_matching_plot_fast( - last_frame, frame, kpts0, kpts1, mkpts0, mkpts1, color, text, - path=None, show_keypoints=opt.show_keypoints, small_text=small_text) + last_frame, + frame, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + path=None, + show_keypoints=opt.show_keypoints, + small_text=small_text, + ) if not opt.no_display: - cv2.imshow('SuperGlue matches', out) + cv2.imshow("SuperGlue matches", out) key = chr(cv2.waitKey(1) & 0xFF) - if key == 'q': + if key == "q": vs.cleanup() - print('Exiting (via q) demo_superglue.py') + print("Exiting (via q) demo_superglue.py") break - elif key == 'n': # set the current frame as anchor - last_data = {k+'0': pred[k+'1'] for k in keys} - last_data['image0'] = frame_tensor + elif key == "n": # set the current frame as anchor + last_data = {k + "0": pred[k + "1"] for k in keys} + last_data["image0"] = frame_tensor last_frame = frame - last_image_id = (vs.i - 1) - elif key in ['e', 'r']: + last_image_id = vs.i - 1 + elif key in ["e", "r"]: # Increase/decrease keypoint threshold by 10% each keypress. - d = 0.1 * (-1 if key == 'e' else 1) - matching.superpoint.config['keypoint_threshold'] = min(max( - 0.0001, matching.superpoint.config['keypoint_threshold']*(1+d)), 1) - print('\nChanged the keypoint threshold to {:.4f}'.format( - matching.superpoint.config['keypoint_threshold'])) - elif key in ['d', 'f']: + d = 0.1 * (-1 if key == "e" else 1) + matching.superpoint.config["keypoint_threshold"] = min( + max( + 0.0001, + matching.superpoint.config["keypoint_threshold"] * (1 + d), + ), + 1, + ) + print( + "\nChanged the keypoint threshold to {:.4f}".format( + matching.superpoint.config["keypoint_threshold"] + ) + ) + elif key in ["d", "f"]: # Increase/decrease match threshold by 0.05 each keypress. - d = 0.05 * (-1 if key == 'd' else 1) - matching.superglue.config['match_threshold'] = min(max( - 0.05, matching.superglue.config['match_threshold']+d), .95) - print('\nChanged the match threshold to {:.2f}'.format( - matching.superglue.config['match_threshold'])) - elif key == 'k': + d = 0.05 * (-1 if key == "d" else 1) + matching.superglue.config["match_threshold"] = min( + max(0.05, matching.superglue.config["match_threshold"] + d), 0.95 + ) + print( + "\nChanged the match threshold to {:.2f}".format( + matching.superglue.config["match_threshold"] + ) + ) + elif key == "k": opt.show_keypoints = not opt.show_keypoints - timer.update('viz') + timer.update("viz") timer.print() if opt.output_dir is not None: - #stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1) - stem = 'matches_{:06}_{:06}'.format(stem0, stem1) - out_file = str(Path(opt.output_dir, stem + '.png')) - print('\nWriting image to {}'.format(out_file)) + # stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1) + stem = "matches_{:06}_{:06}".format(stem0, stem1) + out_file = str(Path(opt.output_dir, stem + ".png")) + print("\nWriting image to {}".format(out_file)) cv2.imwrite(out_file, out) cv2.destroyAllWindows() diff --git a/third_party/SuperGluePretrainedNetwork/match_pairs.py b/third_party/SuperGluePretrainedNetwork/match_pairs.py index 7079687cf69fd71d810ec80442548ad2a7b869e0..9dcbcadd3ca8efc053cf4ea33c825ff75728bef1 100644 --- a/third_party/SuperGluePretrainedNetwork/match_pairs.py +++ b/third_party/SuperGluePretrainedNetwork/match_pairs.py @@ -53,118 +53,176 @@ import torch from models.matching import Matching -from models.utils import (compute_pose_error, compute_epipolar_error, - estimate_pose, make_matching_plot, - error_colormap, AverageTimer, pose_auc, read_image, - rotate_intrinsics, rotate_pose_inplane, - scale_intrinsics) +from models.utils import ( + compute_pose_error, + compute_epipolar_error, + estimate_pose, + make_matching_plot, + error_colormap, + AverageTimer, + pose_auc, + read_image, + rotate_intrinsics, + rotate_pose_inplane, + scale_intrinsics, +) torch.set_grad_enabled(False) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='Image pair matching and pose evaluation with SuperGlue', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="Image pair matching and pose evaluation with SuperGlue", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - '--input_pairs', type=str, default='assets/scannet_sample_pairs_with_gt.txt', - help='Path to the list of image pairs') + "--input_pairs", + type=str, + default="assets/scannet_sample_pairs_with_gt.txt", + help="Path to the list of image pairs", + ) parser.add_argument( - '--input_dir', type=str, default='assets/scannet_sample_images/', - help='Path to the directory that contains the images') + "--input_dir", + type=str, + default="assets/scannet_sample_images/", + help="Path to the directory that contains the images", + ) parser.add_argument( - '--output_dir', type=str, default='dump_match_pairs/', - help='Path to the directory in which the .npz results and optionally,' - 'the visualization images are written') + "--output_dir", + type=str, + default="dump_match_pairs/", + help="Path to the directory in which the .npz results and optionally," + "the visualization images are written", + ) parser.add_argument( - '--max_length', type=int, default=-1, - help='Maximum number of pairs to evaluate') + "--max_length", type=int, default=-1, help="Maximum number of pairs to evaluate" + ) parser.add_argument( - '--resize', type=int, nargs='+', default=[640, 480], - help='Resize the input image before running inference. If two numbers, ' - 'resize to the exact dimensions, if one number, resize the max ' - 'dimension, if -1, do not resize') + "--resize", + type=int, + nargs="+", + default=[640, 480], + help="Resize the input image before running inference. If two numbers, " + "resize to the exact dimensions, if one number, resize the max " + "dimension, if -1, do not resize", + ) parser.add_argument( - '--resize_float', action='store_true', - help='Resize the image after casting uint8 to float') + "--resize_float", + action="store_true", + help="Resize the image after casting uint8 to float", + ) parser.add_argument( - '--superglue', choices={'indoor', 'outdoor'}, default='indoor', - help='SuperGlue weights') + "--superglue", + choices={"indoor", "outdoor"}, + default="indoor", + help="SuperGlue weights", + ) parser.add_argument( - '--max_keypoints', type=int, default=1024, - help='Maximum number of keypoints detected by Superpoint' - ' (\'-1\' keeps all keypoints)') + "--max_keypoints", + type=int, + default=1024, + help="Maximum number of keypoints detected by Superpoint" + " ('-1' keeps all keypoints)", + ) parser.add_argument( - '--keypoint_threshold', type=float, default=0.005, - help='SuperPoint keypoint detector confidence threshold') + "--keypoint_threshold", + type=float, + default=0.005, + help="SuperPoint keypoint detector confidence threshold", + ) parser.add_argument( - '--nms_radius', type=int, default=4, - help='SuperPoint Non Maximum Suppression (NMS) radius' - ' (Must be positive)') + "--nms_radius", + type=int, + default=4, + help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)", + ) parser.add_argument( - '--sinkhorn_iterations', type=int, default=20, - help='Number of Sinkhorn iterations performed by SuperGlue') + "--sinkhorn_iterations", + type=int, + default=20, + help="Number of Sinkhorn iterations performed by SuperGlue", + ) parser.add_argument( - '--match_threshold', type=float, default=0.2, - help='SuperGlue match threshold') + "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold" + ) parser.add_argument( - '--viz', action='store_true', - help='Visualize the matches and dump the plots') + "--viz", action="store_true", help="Visualize the matches and dump the plots" + ) parser.add_argument( - '--eval', action='store_true', - help='Perform the evaluation' - ' (requires ground truth pose and intrinsics)') + "--eval", + action="store_true", + help="Perform the evaluation" " (requires ground truth pose and intrinsics)", + ) parser.add_argument( - '--fast_viz', action='store_true', - help='Use faster image visualization with OpenCV instead of Matplotlib') + "--fast_viz", + action="store_true", + help="Use faster image visualization with OpenCV instead of Matplotlib", + ) parser.add_argument( - '--cache', action='store_true', - help='Skip the pair if output .npz files are already found') + "--cache", + action="store_true", + help="Skip the pair if output .npz files are already found", + ) parser.add_argument( - '--show_keypoints', action='store_true', - help='Plot the keypoints in addition to the matches') + "--show_keypoints", + action="store_true", + help="Plot the keypoints in addition to the matches", + ) parser.add_argument( - '--viz_extension', type=str, default='png', choices=['png', 'pdf'], - help='Visualization file extension. Use pdf for highest-quality.') + "--viz_extension", + type=str, + default="png", + choices=["png", "pdf"], + help="Visualization file extension. Use pdf for highest-quality.", + ) parser.add_argument( - '--opencv_display', action='store_true', - help='Visualize via OpenCV before saving output images') + "--opencv_display", + action="store_true", + help="Visualize via OpenCV before saving output images", + ) parser.add_argument( - '--shuffle', action='store_true', - help='Shuffle ordering of pairs before processing') + "--shuffle", + action="store_true", + help="Shuffle ordering of pairs before processing", + ) parser.add_argument( - '--force_cpu', action='store_true', - help='Force pytorch to run in CPU mode.') + "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode." + ) opt = parser.parse_args() print(opt) - assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display' - assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz' - assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz' - assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz' + assert not ( + opt.opencv_display and not opt.viz + ), "Must use --viz with --opencv_display" + assert not ( + opt.opencv_display and not opt.fast_viz + ), "Cannot use --opencv_display without --fast_viz" + assert not (opt.fast_viz and not opt.viz), "Must use --viz with --fast_viz" + assert not ( + opt.fast_viz and opt.viz_extension == "pdf" + ), "Cannot use pdf extension with --fast_viz" if len(opt.resize) == 2 and opt.resize[1] == -1: opt.resize = opt.resize[0:1] if len(opt.resize) == 2: - print('Will resize to {}x{} (WxH)'.format( - opt.resize[0], opt.resize[1])) + print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1])) elif len(opt.resize) == 1 and opt.resize[0] > 0: - print('Will resize max dimension to {}'.format(opt.resize[0])) + print("Will resize max dimension to {}".format(opt.resize[0])) elif len(opt.resize) == 1: - print('Will not resize images') + print("Will not resize images") else: - raise ValueError('Cannot specify more than two integers for --resize') + raise ValueError("Cannot specify more than two integers for --resize") - with open(opt.input_pairs, 'r') as f: + with open(opt.input_pairs, "r") as f: pairs = [l.split() for l in f.readlines()] if opt.max_length > -1: - pairs = pairs[0:np.min([len(pairs), opt.max_length])] + pairs = pairs[0 : np.min([len(pairs), opt.max_length])] if opt.shuffle: random.Random(0).shuffle(pairs) @@ -172,48 +230,50 @@ if __name__ == '__main__': if opt.eval: if not all([len(p) == 38 for p in pairs]): raise ValueError( - 'All pairs should have ground truth info for evaluation.' - 'File \"{}\" needs 38 valid entries per row'.format(opt.input_pairs)) + "All pairs should have ground truth info for evaluation." + 'File "{}" needs 38 valid entries per row'.format(opt.input_pairs) + ) # Load the SuperPoint and SuperGlue models. - device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' - print('Running inference on device \"{}\"'.format(device)) + device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu" + print('Running inference on device "{}"'.format(device)) config = { - 'superpoint': { - 'nms_radius': opt.nms_radius, - 'keypoint_threshold': opt.keypoint_threshold, - 'max_keypoints': opt.max_keypoints + "superpoint": { + "nms_radius": opt.nms_radius, + "keypoint_threshold": opt.keypoint_threshold, + "max_keypoints": opt.max_keypoints, + }, + "superglue": { + "weights": opt.superglue, + "sinkhorn_iterations": opt.sinkhorn_iterations, + "match_threshold": opt.match_threshold, }, - 'superglue': { - 'weights': opt.superglue, - 'sinkhorn_iterations': opt.sinkhorn_iterations, - 'match_threshold': opt.match_threshold, - } } matching = Matching(config).eval().to(device) # Create the output directories if they do not exist already. input_dir = Path(opt.input_dir) - print('Looking for data in directory \"{}\"'.format(input_dir)) + print('Looking for data in directory "{}"'.format(input_dir)) output_dir = Path(opt.output_dir) output_dir.mkdir(exist_ok=True, parents=True) - print('Will write matches to directory \"{}\"'.format(output_dir)) + print('Will write matches to directory "{}"'.format(output_dir)) if opt.eval: - print('Will write evaluation results', - 'to directory \"{}\"'.format(output_dir)) + print("Will write evaluation results", 'to directory "{}"'.format(output_dir)) if opt.viz: - print('Will write visualization images to', - 'directory \"{}\"'.format(output_dir)) + print("Will write visualization images to", 'directory "{}"'.format(output_dir)) timer = AverageTimer(newline=True) for i, pair in enumerate(pairs): name0, name1 = pair[:2] stem0, stem1 = Path(name0).stem, Path(name1).stem - matches_path = output_dir / '{}_{}_matches.npz'.format(stem0, stem1) - eval_path = output_dir / '{}_{}_evaluation.npz'.format(stem0, stem1) - viz_path = output_dir / '{}_{}_matches.{}'.format(stem0, stem1, opt.viz_extension) - viz_eval_path = output_dir / \ - '{}_{}_evaluation.{}'.format(stem0, stem1, opt.viz_extension) + matches_path = output_dir / "{}_{}_matches.npz".format(stem0, stem1) + eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) + viz_path = output_dir / "{}_{}_matches.{}".format( + stem0, stem1, opt.viz_extension + ) + viz_eval_path = output_dir / "{}_{}_evaluation.{}".format( + stem0, stem1, opt.viz_extension + ) # Handle --cache logic. do_match = True @@ -225,31 +285,30 @@ if __name__ == '__main__': try: results = np.load(matches_path) except: - raise IOError('Cannot load matches .npz file: %s' % - matches_path) + raise IOError("Cannot load matches .npz file: %s" % matches_path) - kpts0, kpts1 = results['keypoints0'], results['keypoints1'] - matches, conf = results['matches'], results['match_confidence'] + kpts0, kpts1 = results["keypoints0"], results["keypoints1"] + matches, conf = results["matches"], results["match_confidence"] do_match = False if opt.eval and eval_path.exists(): try: results = np.load(eval_path) except: - raise IOError('Cannot load eval .npz file: %s' % eval_path) - err_R, err_t = results['error_R'], results['error_t'] - precision = results['precision'] - matching_score = results['matching_score'] - num_correct = results['num_correct'] - epi_errs = results['epipolar_errors'] + raise IOError("Cannot load eval .npz file: %s" % eval_path) + err_R, err_t = results["error_R"], results["error_t"] + precision = results["precision"] + matching_score = results["matching_score"] + num_correct = results["num_correct"] + epi_errs = results["epipolar_errors"] do_eval = False if opt.viz and viz_path.exists(): do_viz = False if opt.viz and opt.eval and viz_eval_path.exists(): do_viz_eval = False - timer.update('load_cache') + timer.update("load_cache") if not (do_match or do_eval or do_viz or do_viz_eval): - timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs))) + timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) continue # If a rotation integer is provided (e.g. from EXIF data), use it: @@ -260,26 +319,35 @@ if __name__ == '__main__': # Load the image pair. image0, inp0, scales0 = read_image( - input_dir / name0, device, opt.resize, rot0, opt.resize_float) + input_dir / name0, device, opt.resize, rot0, opt.resize_float + ) image1, inp1, scales1 = read_image( - input_dir / name1, device, opt.resize, rot1, opt.resize_float) + input_dir / name1, device, opt.resize, rot1, opt.resize_float + ) if image0 is None or image1 is None: - print('Problem reading image pair: {} {}'.format( - input_dir/name0, input_dir/name1)) + print( + "Problem reading image pair: {} {}".format( + input_dir / name0, input_dir / name1 + ) + ) exit(1) - timer.update('load_image') + timer.update("load_image") if do_match: # Perform the matching. - pred = matching({'image0': inp0, 'image1': inp1}) + pred = matching({"image0": inp0, "image1": inp1}) pred = {k: v[0].cpu().numpy() for k, v in pred.items()} - kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] - matches, conf = pred['matches0'], pred['matching_scores0'] - timer.update('matcher') + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + matches, conf = pred["matches0"], pred["matching_scores0"] + timer.update("matcher") # Write the matches to disk. - out_matches = {'keypoints0': kpts0, 'keypoints1': kpts1, - 'matches': matches, 'match_confidence': conf} + out_matches = { + "keypoints0": kpts0, + "keypoints1": kpts1, + "matches": matches, + "match_confidence": conf, + } np.savez(str(matches_path), **out_matches) # Keep the matching keypoints. @@ -290,7 +358,7 @@ if __name__ == '__main__': if do_eval: # Estimate the pose and compute the pose error. - assert len(pair) == 38, 'Pair does not have ground truth info' + assert len(pair) == 38, "Pair does not have ground truth info" K0 = np.array(pair[4:13]).astype(float).reshape(3, 3) K1 = np.array(pair[13:22]).astype(float).reshape(3, 3) T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4) @@ -318,7 +386,7 @@ if __name__ == '__main__': precision = np.mean(correct) if len(correct) > 0 else 0 matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0 - thresh = 1. # In pixels relative to resized image size. + thresh = 1.0 # In pixels relative to resized image size. ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh) if ret is None: err_t, err_R = np.inf, np.inf @@ -327,77 +395,103 @@ if __name__ == '__main__': err_t, err_R = compute_pose_error(T_0to1, R, t) # Write the evaluation results to disk. - out_eval = {'error_t': err_t, - 'error_R': err_R, - 'precision': precision, - 'matching_score': matching_score, - 'num_correct': num_correct, - 'epipolar_errors': epi_errs} + out_eval = { + "error_t": err_t, + "error_R": err_R, + "precision": precision, + "matching_score": matching_score, + "num_correct": num_correct, + "epipolar_errors": epi_errs, + } np.savez(str(eval_path), **out_eval) - timer.update('eval') + timer.update("eval") if do_viz: # Visualize the matches. color = cm.jet(mconf) text = [ - 'SuperGlue', - 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), - 'Matches: {}'.format(len(mkpts0)), + "SuperGlue", + "Keypoints: {}:{}".format(len(kpts0), len(kpts1)), + "Matches: {}".format(len(mkpts0)), ] if rot0 != 0 or rot1 != 0: - text.append('Rotation: {}:{}'.format(rot0, rot1)) + text.append("Rotation: {}:{}".format(rot0, rot1)) # Display extra parameter info. - k_thresh = matching.superpoint.config['keypoint_threshold'] - m_thresh = matching.superglue.config['match_threshold'] + k_thresh = matching.superpoint.config["keypoint_threshold"] + m_thresh = matching.superglue.config["match_threshold"] small_text = [ - 'Keypoint Threshold: {:.4f}'.format(k_thresh), - 'Match Threshold: {:.2f}'.format(m_thresh), - 'Image Pair: {}:{}'.format(stem0, stem1), + "Keypoint Threshold: {:.4f}".format(k_thresh), + "Match Threshold: {:.2f}".format(m_thresh), + "Image Pair: {}:{}".format(stem0, stem1), ] make_matching_plot( - image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, - text, viz_path, opt.show_keypoints, - opt.fast_viz, opt.opencv_display, 'Matches', small_text) - - timer.update('viz_match') + image0, + image1, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + viz_path, + opt.show_keypoints, + opt.fast_viz, + opt.opencv_display, + "Matches", + small_text, + ) + + timer.update("viz_match") if do_viz_eval: # Visualize the evaluation results for the image pair. color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1) color = error_colormap(1 - color) - deg, delta = ' deg', 'Delta ' + deg, delta = " deg", "Delta " if not opt.fast_viz: - deg, delta = '°', '$\\Delta$' - e_t = 'FAIL' if np.isinf(err_t) else '{:.1f}{}'.format(err_t, deg) - e_R = 'FAIL' if np.isinf(err_R) else '{:.1f}{}'.format(err_R, deg) + deg, delta = "°", "$\\Delta$" + e_t = "FAIL" if np.isinf(err_t) else "{:.1f}{}".format(err_t, deg) + e_R = "FAIL" if np.isinf(err_R) else "{:.1f}{}".format(err_R, deg) text = [ - 'SuperGlue', - '{}R: {}'.format(delta, e_R), '{}t: {}'.format(delta, e_t), - 'inliers: {}/{}'.format(num_correct, (matches > -1).sum()), + "SuperGlue", + "{}R: {}".format(delta, e_R), + "{}t: {}".format(delta, e_t), + "inliers: {}/{}".format(num_correct, (matches > -1).sum()), ] if rot0 != 0 or rot1 != 0: - text.append('Rotation: {}:{}'.format(rot0, rot1)) + text.append("Rotation: {}:{}".format(rot0, rot1)) # Display extra parameter info (only works with --fast_viz). - k_thresh = matching.superpoint.config['keypoint_threshold'] - m_thresh = matching.superglue.config['match_threshold'] + k_thresh = matching.superpoint.config["keypoint_threshold"] + m_thresh = matching.superglue.config["match_threshold"] small_text = [ - 'Keypoint Threshold: {:.4f}'.format(k_thresh), - 'Match Threshold: {:.2f}'.format(m_thresh), - 'Image Pair: {}:{}'.format(stem0, stem1), + "Keypoint Threshold: {:.4f}".format(k_thresh), + "Match Threshold: {:.2f}".format(m_thresh), + "Image Pair: {}:{}".format(stem0, stem1), ] make_matching_plot( - image0, image1, kpts0, kpts1, mkpts0, - mkpts1, color, text, viz_eval_path, - opt.show_keypoints, opt.fast_viz, - opt.opencv_display, 'Relative Pose', small_text) - - timer.update('viz_eval') - - timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs))) + image0, + image1, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + viz_eval_path, + opt.show_keypoints, + opt.fast_viz, + opt.opencv_display, + "Relative Pose", + small_text, + ) + + timer.update("viz_eval") + + timer.print("Finished pair {:5} of {:5}".format(i, len(pairs))) if opt.eval: # Collate the results into a final table and print to terminal. @@ -407,19 +501,21 @@ if __name__ == '__main__': for pair in pairs: name0, name1 = pair[:2] stem0, stem1 = Path(name0).stem, Path(name1).stem - eval_path = output_dir / \ - '{}_{}_evaluation.npz'.format(stem0, stem1) + eval_path = output_dir / "{}_{}_evaluation.npz".format(stem0, stem1) results = np.load(eval_path) - pose_error = np.maximum(results['error_t'], results['error_R']) + pose_error = np.maximum(results["error_t"], results["error_R"]) pose_errors.append(pose_error) - precisions.append(results['precision']) - matching_scores.append(results['matching_score']) + precisions.append(results["precision"]) + matching_scores.append(results["matching_score"]) thresholds = [5, 10, 20] aucs = pose_auc(pose_errors, thresholds) - aucs = [100.*yy for yy in aucs] - prec = 100.*np.mean(precisions) - ms = 100.*np.mean(matching_scores) - print('Evaluation Results (mean over {} pairs):'.format(len(pairs))) - print('AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t') - print('{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t'.format( - aucs[0], aucs[1], aucs[2], prec, ms)) + aucs = [100.0 * yy for yy in aucs] + prec = 100.0 * np.mean(precisions) + ms = 100.0 * np.mean(matching_scores) + print("Evaluation Results (mean over {} pairs):".format(len(pairs))) + print("AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t") + print( + "{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t".format( + aucs[0], aucs[1], aucs[2], prec, ms + ) + ) diff --git a/third_party/SuperGluePretrainedNetwork/models/matching.py b/third_party/SuperGluePretrainedNetwork/models/matching.py index 5d174208d146373230a8a68dd1420fc59c180633..c5c0eda3337d021464eb6283e57b7412c08afb03 100644 --- a/third_party/SuperGluePretrainedNetwork/models/matching.py +++ b/third_party/SuperGluePretrainedNetwork/models/matching.py @@ -47,14 +47,15 @@ from .superglue import SuperGlue class Matching(torch.nn.Module): - """ Image Matching Frontend (SuperPoint + SuperGlue) """ + """Image Matching Frontend (SuperPoint + SuperGlue)""" + def __init__(self, config={}): super().__init__() - self.superpoint = SuperPoint(config.get('superpoint', {})) - self.superglue = SuperGlue(config.get('superglue', {})) + self.superpoint = SuperPoint(config.get("superpoint", {})) + self.superglue = SuperGlue(config.get("superglue", {})) def forward(self, data): - """ Run SuperPoint (optionally) and SuperGlue + """Run SuperPoint (optionally) and SuperGlue SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input Args: data: dictionary with minimal keys: ['image0', 'image1'] @@ -62,12 +63,12 @@ class Matching(torch.nn.Module): pred = {} # Extract SuperPoint (keypoints, scores, descriptors) if not provided - if 'keypoints0' not in data: - pred0 = self.superpoint({'image': data['image0']}) - pred = {**pred, **{k+'0': v for k, v in pred0.items()}} - if 'keypoints1' not in data: - pred1 = self.superpoint({'image': data['image1']}) - pred = {**pred, **{k+'1': v for k, v in pred1.items()}} + if "keypoints0" not in data: + pred0 = self.superpoint({"image": data["image0"]}) + pred = {**pred, **{k + "0": v for k, v in pred0.items()}} + if "keypoints1" not in data: + pred1 = self.superpoint({"image": data["image1"]}) + pred = {**pred, **{k + "1": v for k, v in pred1.items()}} # Batch all features # We should either have i) one image per batch, or diff --git a/third_party/SuperGluePretrainedNetwork/models/superglue.py b/third_party/SuperGluePretrainedNetwork/models/superglue.py index 5316234dee9be9cdc083e3b4bebe97a6e51e587d..70156e07b83614b1dfb36207ea96b4b79a6ddbb9 100644 --- a/third_party/SuperGluePretrainedNetwork/models/superglue.py +++ b/third_party/SuperGluePretrainedNetwork/models/superglue.py @@ -49,13 +49,12 @@ from torch import nn def MLP(channels: List[int], do_bn: bool = True) -> nn.Module: - """ Multi-layer perceptron """ + """Multi-layer perceptron""" n = len(channels) layers = [] for i in range(1, n): - layers.append( - nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) - if i < (n-1): + layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): if do_bn: layers.append(nn.BatchNorm1d(channels[i])) layers.append(nn.ReLU()) @@ -63,17 +62,18 @@ def MLP(channels: List[int], do_bn: bool = True) -> nn.Module: def normalize_keypoints(kpts, image_shape): - """ Normalize keypoints locations based on image image_shape""" + """Normalize keypoints locations based on image image_shape""" _, _, height, width = image_shape one = kpts.new_tensor(1) - size = torch.stack([one*width, one*height])[None] + size = torch.stack([one * width, one * height])[None] center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (kpts - center[:, None, :]) / scaling[:, None, :] class KeypointEncoder(nn.Module): - """ Joint encoding of visual appearance and location using MLPs""" + """Joint encoding of visual appearance and location using MLPs""" + def __init__(self, feature_dim: int, layers: List[int]) -> None: super().__init__() self.encoder = MLP([3] + layers + [feature_dim]) @@ -84,15 +84,18 @@ class KeypointEncoder(nn.Module): return self.encoder(torch.cat(inputs, dim=1)) -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: +def attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: dim = query.shape[1] - scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 + scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5 prob = torch.nn.functional.softmax(scores, dim=-1) - return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob + return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob class MultiHeadedAttention(nn.Module): - """ Multi-head attention to increase model expressivitiy """ + """Multi-head attention to increase model expressivitiy""" + def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 @@ -101,19 +104,23 @@ class MultiHeadedAttention(nn.Module): self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) - def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: batch_dim = query.size(0) - query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) - for l, x in zip(self.proj, (query, key, value))] + query, key, value = [ + l(x).view(batch_dim, self.dim, self.num_heads, -1) + for l, x in zip(self.proj, (query, key, value)) + ] x, _ = attention(query, key, value) - return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) + return self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) class AttentionalPropagation(nn.Module): def __init__(self, feature_dim: int, num_heads: int): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) - self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) + self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim]) nn.init.constant_(self.mlp[-1].bias, 0.0) def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor: @@ -124,14 +131,16 @@ class AttentionalPropagation(nn.Module): class AttentionalGNN(nn.Module): def __init__(self, feature_dim: int, layer_names: List[str]) -> None: super().__init__() - self.layers = nn.ModuleList([ - AttentionalPropagation(feature_dim, 4) - for _ in range(len(layer_names))]) + self.layers = nn.ModuleList( + [AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))] + ) self.names = layer_names - def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: + def forward( + self, desc0: torch.Tensor, desc1: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: for layer, name in zip(self.layers, self.names): - if name == 'cross': + if name == "cross": src0, src1 = desc1, desc0 else: # if name == 'self': src0, src1 = desc0, desc1 @@ -140,8 +149,10 @@ class AttentionalGNN(nn.Module): return desc0, desc1 -def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor: - """ Perform Sinkhorn Normalization in Log-space for stability""" +def log_sinkhorn_iterations( + Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int +) -> torch.Tensor: + """Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) @@ -149,20 +160,23 @@ def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch return Z + u.unsqueeze(2) + v.unsqueeze(1) -def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: - """ Perform Differentiable Optimal Transport in Log-space for stability""" +def log_optimal_transport( + scores: torch.Tensor, alpha: torch.Tensor, iters: int +) -> torch.Tensor: + """Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) - ms, ns = (m*one).to(scores), (n*one).to(scores) + ms, ns = (m * one).to(scores), (n * one).to(scores) bins0 = alpha.expand(b, m, 1) bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) - couplings = torch.cat([torch.cat([scores, bins0], -1), - torch.cat([bins1, alpha], -1)], 1) + couplings = torch.cat( + [torch.cat([scores, bins0], -1), torch.cat([bins1, alpha], -1)], 1 + ) - norm = - (ms + ns).log() + norm = -(ms + ns).log() log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) @@ -194,13 +208,14 @@ class SuperGlue(nn.Module): Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 """ + default_config = { - 'descriptor_dim': 256, - 'weights': 'indoor', - 'keypoint_encoder': [32, 64, 128, 256], - 'GNN_layers': ['self', 'cross'] * 9, - 'sinkhorn_iterations': 100, - 'match_threshold': 0.2, + "descriptor_dim": 256, + "weights": "indoor", + "keypoint_encoder": [32, 64, 128, 256], + "GNN_layers": ["self", "cross"] * 9, + "sinkhorn_iterations": 100, + "match_threshold": 0.2, } def __init__(self, config): @@ -208,46 +223,51 @@ class SuperGlue(nn.Module): self.config = {**self.default_config, **config} self.kenc = KeypointEncoder( - self.config['descriptor_dim'], self.config['keypoint_encoder']) + self.config["descriptor_dim"], self.config["keypoint_encoder"] + ) self.gnn = AttentionalGNN( - feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers']) + feature_dim=self.config["descriptor_dim"], + layer_names=self.config["GNN_layers"], + ) self.final_proj = nn.Conv1d( - self.config['descriptor_dim'], self.config['descriptor_dim'], - kernel_size=1, bias=True) + self.config["descriptor_dim"], + self.config["descriptor_dim"], + kernel_size=1, + bias=True, + ) - bin_score = torch.nn.Parameter(torch.tensor(1.)) - self.register_parameter('bin_score', bin_score) + bin_score = torch.nn.Parameter(torch.tensor(1.0)) + self.register_parameter("bin_score", bin_score) - assert self.config['weights'] in ['indoor', 'outdoor'] + assert self.config["weights"] in ["indoor", "outdoor"] path = Path(__file__).parent - path = path / 'weights/superglue_{}.pth'.format(self.config['weights']) + path = path / "weights/superglue_{}.pth".format(self.config["weights"]) self.load_state_dict(torch.load(str(path))) - print('Loaded SuperGlue model (\"{}\" weights)'.format( - self.config['weights'])) + print('Loaded SuperGlue model ("{}" weights)'.format(self.config["weights"])) def forward(self, data): """Run SuperGlue on a pair of keypoints and descriptors""" - desc0, desc1 = data['descriptors0'], data['descriptors1'] - kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + desc0, desc1 = data["descriptors0"], data["descriptors1"] + kpts0, kpts1 = data["keypoints0"], data["keypoints1"] if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] return { - 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int), - 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int), - 'matching_scores0': kpts0.new_zeros(shape0), - 'matching_scores1': kpts1.new_zeros(shape1), + "matches0": kpts0.new_full(shape0, -1, dtype=torch.int), + "matches1": kpts1.new_full(shape1, -1, dtype=torch.int), + "matching_scores0": kpts0.new_zeros(shape0), + "matching_scores1": kpts1.new_zeros(shape1), } # Keypoint normalization. - kpts0 = normalize_keypoints(kpts0, data['image0'].shape) - kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + kpts0 = normalize_keypoints(kpts0, data["image0"].shape) + kpts1 = normalize_keypoints(kpts1, data["image1"].shape) # Keypoint MLP encoder. - desc0 = desc0 + self.kenc(kpts0, data['scores0']) - desc1 = desc1 + self.kenc(kpts1, data['scores1']) + desc0 = desc0 + self.kenc(kpts0, data["scores0"]) + desc1 = desc1 + self.kenc(kpts1, data["scores1"]) # Multi-layer Transformer network. desc0, desc1 = self.gnn(desc0, desc1) @@ -256,13 +276,13 @@ class SuperGlue(nn.Module): mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) # Compute matching descriptor distance. - scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) - scores = scores / self.config['descriptor_dim']**.5 + scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1) + scores = scores / self.config["descriptor_dim"] ** 0.5 # Run the optimal transport. scores = log_optimal_transport( - scores, self.bin_score, - iters=self.config['sinkhorn_iterations']) + scores, self.bin_score, iters=self.config["sinkhorn_iterations"] + ) # Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) @@ -272,13 +292,13 @@ class SuperGlue(nn.Module): zero = scores.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) - valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > self.config["match_threshold"]) valid1 = mutual1 & valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) return { - 'matches0': indices0, # use -1 for invalid match - 'matches1': indices1, # use -1 for invalid match - 'matching_scores0': mscores0, - 'matching_scores1': mscores1, + "matches0": indices0, # use -1 for invalid match + "matches1": indices1, # use -1 for invalid match + "matching_scores0": mscores0, + "matching_scores1": mscores1, } diff --git a/third_party/SuperGluePretrainedNetwork/models/superpoint.py b/third_party/SuperGluePretrainedNetwork/models/superpoint.py index b837d938f755850180ddc168e957742e874adacd..ab9712eed30ea30f1578cabb97c0c8f2fbed8c7c 100644 --- a/third_party/SuperGluePretrainedNetwork/models/superpoint.py +++ b/third_party/SuperGluePretrainedNetwork/models/superpoint.py @@ -44,13 +44,15 @@ from pathlib import Path import torch from torch import nn + def simple_nms(scores, nms_radius: int): - """ Fast Non-maximum suppression to remove nearby points """ - assert(nms_radius >= 0) + """Fast Non-maximum suppression to remove nearby points""" + assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool2d( - x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) @@ -63,7 +65,7 @@ def simple_nms(scores, nms_radius: int): def remove_borders(keypoints, scores, border: int, height: int, width: int): - """ Removes keypoints too close to the border """ + """Removes keypoints too close to the border""" mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) mask = mask_h & mask_w @@ -78,17 +80,20 @@ def top_k_keypoints(keypoints, scores, k: int): def sample_descriptors(keypoints, descriptors, s: int = 8): - """ Interpolate descriptors at keypoint locations """ + """Interpolate descriptors at keypoint locations""" b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 - keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], - ).to(keypoints)[None] - keypoints = keypoints*2 - 1 # normalize to (-1, 1) - args = {'align_corners': True} if torch.__version__ >= '1.3' else {} + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} descriptors = torch.nn.functional.grid_sample( - descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) descriptors = torch.nn.functional.normalize( - descriptors.reshape(b, c, -1), p=2, dim=1) + descriptors.reshape(b, c, -1), p=2, dim=1 + ) return descriptors @@ -100,12 +105,13 @@ class SuperPoint(nn.Module): Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 """ + default_config = { - 'descriptor_dim': 256, - 'nms_radius': 4, - 'keypoint_threshold': 0.005, - 'max_keypoints': -1, - 'remove_borders': 4, + "descriptor_dim": 256, + "nms_radius": 4, + "keypoint_threshold": 0.005, + "max_keypoints": -1, + "remove_borders": 4, } def __init__(self, config): @@ -130,22 +136,22 @@ class SuperPoint(nn.Module): self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, self.config['descriptor_dim'], - kernel_size=1, stride=1, padding=0) + c5, self.config["descriptor_dim"], kernel_size=1, stride=1, padding=0 + ) - path = Path(__file__).parent / 'weights/superpoint_v1.pth' + path = Path(__file__).parent / "weights/superpoint_v1.pth" self.load_state_dict(torch.load(str(path))) - mk = self.config['max_keypoints'] + mk = self.config["max_keypoints"] if mk == 0 or mk < -1: - raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + raise ValueError('"max_keypoints" must be positive or "-1"') - print('Loaded SuperPoint model') + print("Loaded SuperPoint model") def forward(self, data): - """ Compute keypoints, scores, descriptors for image """ + """Compute keypoints, scores, descriptors for image""" # Shared Encoder - x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1a(data["image"])) x = self.relu(self.conv1b(x)) x = self.pool(x) x = self.relu(self.conv2a(x)) @@ -163,25 +169,35 @@ class SuperPoint(nn.Module): scores = torch.nn.functional.softmax(scores, 1)[:, :-1] b, _, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) - scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) - scores = simple_nms(scores, self.config['nms_radius']) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.config["nms_radius"]) # Extract keypoints keypoints = [ - torch.nonzero(s > self.config['keypoint_threshold']) - for s in scores] + torch.nonzero(s > self.config["keypoint_threshold"]) for s in scores + ] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders - keypoints, scores = list(zip(*[ - remove_borders(k, s, self.config['remove_borders'], h*8, w*8) - for k, s in zip(keypoints, scores)])) + keypoints, scores = list( + zip( + *[ + remove_borders(k, s, self.config["remove_borders"], h * 8, w * 8) + for k, s in zip(keypoints, scores) + ] + ) + ) # Keep the k keypoints with highest score - if self.config['max_keypoints'] >= 0: - keypoints, scores = list(zip(*[ - top_k_keypoints(k, s, self.config['max_keypoints']) - for k, s in zip(keypoints, scores)])) + if self.config["max_keypoints"] >= 0: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, self.config["max_keypoints"]) + for k, s in zip(keypoints, scores) + ] + ) + ) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] @@ -192,11 +208,13 @@ class SuperPoint(nn.Module): descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) # Extract descriptors - descriptors = [sample_descriptors(k[None], d[None], 8)[0] - for k, d in zip(keypoints, descriptors)] + descriptors = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors) + ] return { - 'keypoints': keypoints, - 'scores': scores, - 'descriptors': descriptors, + "keypoints": keypoints, + "scores": scores, + "descriptors": descriptors, } diff --git a/third_party/SuperGluePretrainedNetwork/models/utils.py b/third_party/SuperGluePretrainedNetwork/models/utils.py index 1206244aa2a004d9f653782de798bfef9e5e726b..d302ff84cf316f3dad016f1f23bbb54518566d2e 100644 --- a/third_party/SuperGluePretrainedNetwork/models/utils.py +++ b/third_party/SuperGluePretrainedNetwork/models/utils.py @@ -51,11 +51,12 @@ import cv2 import torch import matplotlib.pyplot as plt import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") class AverageTimer: - """ Class to help manage printing simple timing of code execution. """ + """Class to help manage printing simple timing of code execution.""" def __init__(self, smoothing=0.3, newline=False): self.smoothing = smoothing @@ -71,7 +72,7 @@ class AverageTimer: for name in self.will_print: self.will_print[name] = False - def update(self, name='default'): + def update(self, name="default"): now = time.time() dt = now - self.last_time if name in self.times: @@ -80,29 +81,30 @@ class AverageTimer: self.will_print[name] = True self.last_time = now - def print(self, text='Timer'): - total = 0. - print('[{}]'.format(text), end=' ') + def print(self, text="Timer"): + total = 0.0 + print("[{}]".format(text), end=" ") for key in self.times: val = self.times[key] if self.will_print[key]: - print('%s=%.3f' % (key, val), end=' ') + print("%s=%.3f" % (key, val), end=" ") total += val - print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') + print("total=%.3f sec {%.1f FPS}" % (total, 1.0 / total), end=" ") if self.newline: print(flush=True) else: - print(end='\r', flush=True) + print(end="\r", flush=True) self.reset() class VideoStreamer: - """ Class to help process image streams. Four types of possible inputs:" - 1.) USB Webcam. - 2.) An IP camera - 3.) A directory of images (files in directory matching 'image_glob'). - 4.) A video file, such as an .mp4 or .avi file. + """Class to help process image streams. Four types of possible inputs:" + 1.) USB Webcam. + 2.) An IP camera + 3.) A directory of images (files in directory matching 'image_glob'). + 4.) A video file, such as an .mp4 or .avi file. """ + def __init__(self, basedir, resize, skip, image_glob, max_length=1000000): self._ip_grabbed = False self._ip_running = False @@ -119,45 +121,45 @@ class VideoStreamer: self.skip = skip self.max_length = max_length if isinstance(basedir, int) or basedir.isdigit(): - print('==> Processing USB webcam input: {}'.format(basedir)) + print("==> Processing USB webcam input: {}".format(basedir)) self.cap = cv2.VideoCapture(int(basedir)) self.listing = range(0, self.max_length) - elif basedir.startswith(('http', 'rtsp')): - print('==> Processing IP camera input: {}'.format(basedir)) + elif basedir.startswith(("http", "rtsp")): + print("==> Processing IP camera input: {}".format(basedir)) self.cap = cv2.VideoCapture(basedir) self.start_ip_camera_thread() self._ip_camera = True self.listing = range(0, self.max_length) elif Path(basedir).is_dir(): - print('==> Processing image directory input: {}'.format(basedir)) + print("==> Processing image directory input: {}".format(basedir)) self.listing = list(Path(basedir).glob(image_glob[0])) for j in range(1, len(image_glob)): image_path = list(Path(basedir).glob(image_glob[j])) self.listing = self.listing + image_path self.listing.sort() - self.listing = self.listing[::self.skip] + self.listing = self.listing[:: self.skip] self.max_length = np.min([self.max_length, len(self.listing)]) if self.max_length == 0: - raise IOError('No images found (maybe bad \'image_glob\' ?)') - self.listing = self.listing[:self.max_length] + raise IOError("No images found (maybe bad 'image_glob' ?)") + self.listing = self.listing[: self.max_length] self.camera = False elif Path(basedir).exists(): - print('==> Processing video input: {}'.format(basedir)) + print("==> Processing video input: {}".format(basedir)) self.cap = cv2.VideoCapture(basedir) self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.listing = range(0, num_frames) - self.listing = self.listing[::self.skip] + self.listing = self.listing[:: self.skip] self.video_file = True self.max_length = np.min([self.max_length, len(self.listing)]) - self.listing = self.listing[:self.max_length] + self.listing = self.listing[: self.max_length] else: - raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + raise ValueError('VideoStreamer input "{}" not recognized.'.format(basedir)) if self.camera and not self.cap.isOpened(): - raise IOError('Could not read camera') + raise IOError("Could not read camera") def load_image(self, impath): - """ Read image as grayscale and resize to img_size. + """Read image as grayscale and resize to img_size. Inputs impath: Path to input image. Returns @@ -165,15 +167,14 @@ class VideoStreamer: """ grayim = cv2.imread(impath, 0) if grayim is None: - raise Exception('Error reading image %s' % impath) + raise Exception("Error reading image %s" % impath) w, h = grayim.shape[1], grayim.shape[0] w_new, h_new = process_resize(w, h, self.resize) - grayim = cv2.resize( - grayim, (w_new, h_new), interpolation=self.interp) + grayim = cv2.resize(grayim, (w_new, h_new), interpolation=self.interp) return grayim def next_frame(self): - """ Return the next frame, and increment internal counter. + """Return the next frame, and increment internal counter. Returns image: Next H x W image. status: True or False depending whether image was loaded. @@ -184,9 +185,9 @@ class VideoStreamer: if self.camera: if self._ip_camera: - #Wait for first image, making sure we haven't exited + # Wait for first image, making sure we haven't exited while self._ip_grabbed is False and self._ip_exited is False: - time.sleep(.001) + time.sleep(0.001) ret, image = self._ip_grabbed, self._ip_image.copy() if ret is False: @@ -194,15 +195,14 @@ class VideoStreamer: else: ret, image = self.cap.read() if ret is False: - print('VideoStreamer: Cannot get image from camera') + print("VideoStreamer: Cannot get image from camera") return (None, False) w, h = image.shape[1], image.shape[0] if self.video_file: self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i]) w_new, h_new = process_resize(w, h, self.resize) - image = cv2.resize(image, (w_new, h_new), - interpolation=self.interp) + image = cv2.resize(image, (w_new, h_new), interpolation=self.interp) image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: image_file = str(self.listing[self.i]) @@ -229,19 +229,20 @@ class VideoStreamer: self._ip_image = img self._ip_grabbed = ret self._ip_index += 1 - #print('IPCAMERA THREAD got frame {}'.format(self._ip_index)) - + # print('IPCAMERA THREAD got frame {}'.format(self._ip_index)) def cleanup(self): self._ip_running = False + # --- PREPROCESSING --- + def process_resize(w, h, resize): - assert(len(resize) > 0 and len(resize) <= 2) + assert len(resize) > 0 and len(resize) <= 2 if len(resize) == 1 and resize[0] > -1: scale = resize[0] / max(h, w) - w_new, h_new = int(round(w*scale)), int(round(h*scale)) + w_new, h_new = int(round(w * scale)), int(round(h * scale)) elif len(resize) == 1 and resize[0] == -1: w_new, h_new = w, h else: # len(resize) == 2: @@ -249,15 +250,15 @@ def process_resize(w, h, resize): # Issue warning if resolution is too small or too large. if max(w_new, h_new) < 160: - print('Warning: input resolution is very small, results may vary') + print("Warning: input resolution is very small, results may vary") elif max(w_new, h_new) > 2000: - print('Warning: input resolution is very large, results may vary') + print("Warning: input resolution is very large, results may vary") return w_new, h_new def frame2tensor(frame, device): - return torch.from_numpy(frame/255.).float()[None, None].to(device) + return torch.from_numpy(frame / 255.0).float()[None, None].to(device) def read_image(path, device, resize, rotation, resize_float): @@ -269,9 +270,9 @@ def read_image(path, device, resize, rotation, resize_float): scales = (float(w) / float(w_new), float(h) / float(h_new)) if resize_float: - image = cv2.resize(image.astype('float32'), (w_new, h_new)) + image = cv2.resize(image.astype("float32"), (w_new, h_new)) else: - image = cv2.resize(image, (w_new, h_new)).astype('float32') + image = cv2.resize(image, (w_new, h_new)).astype("float32") if rotation != 0: image = np.rot90(image, k=rotation) @@ -296,16 +297,15 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] E, mask = cv2.findEssentialMat( - kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, - method=cv2.RANSAC) + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC + ) assert E is not None best_num_inliers = 0 ret = None for _E in np.split(E, len(E) / 3): - n, R, t, _ = cv2.recoverPose( - _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) if n > best_num_inliers: best_num_inliers = n ret = (R, t[:, 0], mask.ravel() > 0) @@ -315,36 +315,42 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): def rotate_intrinsics(K, image_shape, rot): """image_shape is the shape of the image after rotation""" assert rot <= 3 - h, w = image_shape[:2][::-1 if (rot % 2) else 1] + h, w = image_shape[:2][:: -1 if (rot % 2) else 1] fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] rot = rot % 4 if rot == 1: - return np.array([[fy, 0., cy], - [0., fx, w-1-cx], - [0., 0., 1.]], dtype=K.dtype) + return np.array( + [[fy, 0.0, cy], [0.0, fx, w - 1 - cx], [0.0, 0.0, 1.0]], dtype=K.dtype + ) elif rot == 2: - return np.array([[fx, 0., w-1-cx], - [0., fy, h-1-cy], - [0., 0., 1.]], dtype=K.dtype) + return np.array( + [[fx, 0.0, w - 1 - cx], [0.0, fy, h - 1 - cy], [0.0, 0.0, 1.0]], + dtype=K.dtype, + ) else: # if rot == 3: - return np.array([[fy, 0., h-1-cy], - [0., fx, cx], - [0., 0., 1.]], dtype=K.dtype) + return np.array( + [[fy, 0.0, h - 1 - cy], [0.0, fx, cx], [0.0, 0.0, 1.0]], dtype=K.dtype + ) def rotate_pose_inplane(i_T_w, rot): rotation_matrices = [ - np.array([[np.cos(r), -np.sin(r), 0., 0.], - [np.sin(r), np.cos(r), 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]], dtype=np.float32) + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] ] return np.dot(rotation_matrices[rot], i_T_w) def scale_intrinsics(K, scales): - scales = np.diag([1./scales[0], 1./scales[1], 1.]) + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) return np.dot(scales, K) @@ -359,24 +365,22 @@ def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1): kpts1 = to_homogeneous(kpts1) t0, t1, t2 = T_0to1[:3, 3] - t_skew = np.array([ - [0, -t2, t1], - [t2, 0, -t0], - [-t1, t0, 0] - ]) + t_skew = np.array([[0, -t2, t1], [t2, 0, -t0], [-t1, t0, 0]]) E = t_skew @ T_0to1[:3, :3] Ep0 = kpts0 @ E.T # N x 3 p1Ep0 = np.sum(kpts1 * Ep0, -1) # N Etp1 = kpts1 @ E # N x 3 - d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) - + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) + d = p1Ep0**2 * ( + 1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2) + + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2) + ) return d def angle_error_mat(R1, R2): cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 - cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds return np.rad2deg(np.abs(np.arccos(cos))) @@ -398,27 +402,27 @@ def pose_auc(errors, thresholds): sort_idx = np.argsort(errors) errors = np.array(errors.copy())[sort_idx] recall = (np.arange(len(errors)) + 1) / len(errors) - errors = np.r_[0., errors] - recall = np.r_[0., recall] + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] aucs = [] for t in thresholds: last_index = np.searchsorted(errors, t) - r = np.r_[recall[:last_index], recall[last_index-1]] + r = np.r_[recall[:last_index], recall[last_index - 1]] e = np.r_[errors[:last_index], t] - aucs.append(np.trapz(r, x=e)/t) + aucs.append(np.trapz(r, x=e) / t) return aucs # --- VISUALIZATION --- -def plot_image_pair(imgs, dpi=100, size=6, pad=.5): +def plot_image_pair(imgs, dpi=100, size=6, pad=0.5): n = len(imgs) - assert n == 2, 'number of images must be two' - figsize = (size*n, size*3/4) if size is not None else None + assert n == 2, "number of images must be two" + figsize = (size * n, size * 3 / 4) if size is not None else None _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) for i in range(n): - ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255) + ax[i].imshow(imgs[i], cmap=plt.get_cmap("gray"), vmin=0, vmax=255) ax[i].get_yaxis().set_ticks([]) ax[i].get_xaxis().set_ticks([]) for spine in ax[i].spines.values(): # remove frame @@ -426,7 +430,7 @@ def plot_image_pair(imgs, dpi=100, size=6, pad=.5): plt.tight_layout(pad=pad) -def plot_keypoints(kpts0, kpts1, color='w', ps=2): +def plot_keypoints(kpts0, kpts1, color="w", ps=2): ax = plt.gcf().axes ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) @@ -441,59 +445,116 @@ def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4): fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0)) fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1)) - fig.lines = [matplotlib.lines.Line2D( - (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1, - transform=fig.transFigure, c=color[i], linewidth=lw) - for i in range(len(kpts0))] + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=color[i], + linewidth=lw, + ) + for i in range(len(kpts0)) + ] ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) -def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1, - color, text, path, show_keypoints=False, - fast_viz=False, opencv_display=False, - opencv_title='matches', small_text=[]): +def make_matching_plot( + image0, + image1, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + path, + show_keypoints=False, + fast_viz=False, + opencv_display=False, + opencv_title="matches", + small_text=[], +): if fast_viz: - make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1, - color, text, path, show_keypoints, 10, - opencv_display, opencv_title, small_text) + make_matching_plot_fast( + image0, + image1, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + path, + show_keypoints, + 10, + opencv_display, + opencv_title, + small_text, + ) return plot_image_pair([image0, image1]) if show_keypoints: - plot_keypoints(kpts0, kpts1, color='k', ps=4) - plot_keypoints(kpts0, kpts1, color='w', ps=2) + plot_keypoints(kpts0, kpts1, color="k", ps=4) + plot_keypoints(kpts0, kpts1, color="w", ps=2) plot_matches(mkpts0, mkpts1, color) fig = plt.gcf() - txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w' + txt_color = "k" if image0[:100, :150].mean() > 200 else "w" fig.text( - 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, - fontsize=15, va='top', ha='left', color=txt_color) - - txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w' + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + + txt_color = "k" if image0[-100:, :150].mean() > 200 else "w" fig.text( - 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes, - fontsize=5, va='bottom', ha='left', color=txt_color) - - plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + 0.01, + 0.01, + "\n".join(small_text), + transform=fig.axes[0].transAxes, + fontsize=5, + va="bottom", + ha="left", + color=txt_color, + ) + + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) plt.close() -def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, - mkpts1, color, text, path=None, - show_keypoints=False, margin=10, - opencv_display=False, opencv_title='', - small_text=[]): +def make_matching_plot_fast( + image0, + image1, + kpts0, + kpts1, + mkpts0, + mkpts1, + color, + text, + path=None, + show_keypoints=False, + margin=10, + opencv_display=False, + opencv_title="", + small_text=[], +): H0, W0 = image0.shape H1, W1 = image1.shape H, W = max(H0, H1), W0 + W1 + margin - out = 255*np.ones((H, W), np.uint8) + out = 255 * np.ones((H, W), np.uint8) out[:H0, :W0] = image0 - out[:H1, W0+margin:] = image1 - out = np.stack([out]*3, -1) + out[:H1, W0 + margin :] = image1 + out = np.stack([out] * 3, -1) if show_keypoints: kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) @@ -503,42 +564,77 @@ def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA) cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA) for x, y in kpts1: - cv2.circle(out, (x + margin + W0, y), 2, black, -1, - lineType=cv2.LINE_AA) - cv2.circle(out, (x + margin + W0, y), 1, white, -1, - lineType=cv2.LINE_AA) + cv2.circle(out, (x + margin + W0, y), 2, black, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x + margin + W0, y), 1, white, -1, lineType=cv2.LINE_AA) mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) - color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] + color = (np.array(color[:, :3]) * 255).astype(int)[:, ::-1] for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): c = c.tolist() - cv2.line(out, (x0, y0), (x1 + margin + W0, y1), - color=c, thickness=1, lineType=cv2.LINE_AA) + cv2.line( + out, + (x0, y0), + (x1 + margin + W0, y1), + color=c, + thickness=1, + lineType=cv2.LINE_AA, + ) # display line end-points as circles cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) - cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, - lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, lineType=cv2.LINE_AA) # Scale factor for consistent visualization across scales. - sc = min(H / 640., 2.0) + sc = min(H / 640.0, 2.0) # Big text. Ht = int(30 * sc) # text height txt_color_fg = (255, 255, 255) txt_color_bg = (0, 0, 0) for i, t in enumerate(text): - cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, - 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) - cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, - 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) + cv2.putText( + out, + t, + (int(8 * sc), Ht * (i + 1)), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_bg, + 2, + cv2.LINE_AA, + ) + cv2.putText( + out, + t, + (int(8 * sc), Ht * (i + 1)), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_fg, + 1, + cv2.LINE_AA, + ) # Small text. Ht = int(18 * sc) # text height for i, t in enumerate(reversed(small_text)): - cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, - 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) - cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, - 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) + cv2.putText( + out, + t, + (int(8 * sc), int(H - Ht * (i + 0.6))), + cv2.FONT_HERSHEY_DUPLEX, + 0.5 * sc, + txt_color_bg, + 2, + cv2.LINE_AA, + ) + cv2.putText( + out, + t, + (int(8 * sc), int(H - Ht * (i + 0.6))), + cv2.FONT_HERSHEY_DUPLEX, + 0.5 * sc, + txt_color_fg, + 1, + cv2.LINE_AA, + ) if path is not None: cv2.imwrite(str(path), out) @@ -552,4 +648,5 @@ def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, def error_colormap(x): return np.clip( - np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1) + np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1 + ) diff --git a/third_party/TopicFM/configs/data/base.py b/third_party/TopicFM/configs/data/base.py index 6cab7e67019a6fee2657c1a28609c8aca5b2a1d8..1897a84393e186cc46f34fe856243756e8393a2a 100644 --- a/third_party/TopicFM/configs/data/base.py +++ b/third_party/TopicFM/configs/data/base.py @@ -4,6 +4,7 @@ Setups in data configs will override all existed setups! """ from yacs.config import CfgNode as CN + _CN = CN() _CN.DATASET = CN() _CN.TRAINER = CN() diff --git a/third_party/TopicFM/configs/data/megadepth_trainval.py b/third_party/TopicFM/configs/data/megadepth_trainval.py index 215b5c34cc41d36aa4444a58ca0cb69afbc11952..7b7b0a77e26bbf6e7b7ceb2cd54f8c2e3b709db4 100644 --- a/third_party/TopicFM/configs/data/megadepth_trainval.py +++ b/third_party/TopicFM/configs/data/megadepth_trainval.py @@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 TEST_BASE_PATH = "data/megadepth/index" cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" -cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" -cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" -cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val +cfg.DATASET.VAL_NPZ_ROOT = ( + cfg.DATASET.TEST_NPZ_ROOT +) = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = ( + cfg.DATASET.TEST_LIST_PATH +) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val # 368 scenes in total for MegaDepth # (with difficulty balanced (further split each scene to 3 sub-scenes)) diff --git a/third_party/TopicFM/configs/model/outdoor/model_ds.py b/third_party/TopicFM/configs/model/outdoor/model_ds.py index 2c090edbfbdcd66cea225c39af6f62da8feb50b9..e0c234e8b3c932656052aa58836ed2b158344fb5 100644 --- a/third_party/TopicFM/configs/model/outdoor/model_ds.py +++ b/third_party/TopicFM/configs/model/outdoor/model_ds.py @@ -1,6 +1,6 @@ from src.config.default import _CN as cfg -cfg.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.MODEL.MATCH_COARSE.MATCH_TYPE = "dual_softmax" cfg.MODEL.COARSE.N_SAMPLES = 8 cfg.TRAINER.CANONICAL_LR = 1e-2 diff --git a/third_party/TopicFM/flop_counter.py b/third_party/TopicFM/flop_counter.py index ea87fa0139897434ca52b369450aa82203311181..915f703bd76146e54a3f2f9e819a7b1b85f2d700 100644 --- a/third_party/TopicFM/flop_counter.py +++ b/third_party/TopicFM/flop_counter.py @@ -27,7 +27,7 @@ def coarse_model_flops(coarse_model, config, inputs): return flops.total() / 1e9 -if __name__ == '__main__': +if __name__ == "__main__": path_img0 = "assets/scannet_sample_images/scene0711_00_frame-001680.jpg" path_img1 = "assets/scannet_sample_images/scene0711_00_frame-001995.jpg" img0, img1 = read_scannet_gray(path_img0), read_scannet_gray(path_img1) @@ -35,21 +35,48 @@ if __name__ == '__main__': # LoFTR loftr_conf = dict(default_cfg) - feat_c0, loftr_featnet_flops0 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img0) - feat_c1, loftr_featnet_flops1 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img1) - print("FLOPs of feature extraction in LoFTR: {} GFLOPs".format((loftr_featnet_flops0 + loftr_featnet_flops1)/2)) - feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') - feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') - loftr_coarse_model_flops = coarse_model_flops(LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1)) - print("FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(loftr_coarse_model_flops)) + feat_c0, loftr_featnet_flops0 = feat_net_flops( + loftr_featnet, loftr_conf["resnetfpn"], img0 + ) + feat_c1, loftr_featnet_flops1 = feat_net_flops( + loftr_featnet, loftr_conf["resnetfpn"], img1 + ) + print( + "FLOPs of feature extraction in LoFTR: {} GFLOPs".format( + (loftr_featnet_flops0 + loftr_featnet_flops1) / 2 + ) + ) + feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c") + feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c") + loftr_coarse_model_flops = coarse_model_flops( + LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1) + ) + print( + "FLOPs of coarse matching model in LoFTR: {} GFLOPs".format( + loftr_coarse_model_flops + ) + ) # TopicFM topicfm_conf = get_model_cfg() - feat_c0, topicfm_featnet_flops0 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img0) - feat_c1, topicfm_featnet_flops1 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img1) - print("FLOPs of feature extraction in TopicFM: {} GFLOPs".format((topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2)) - feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') - feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') - topicfm_coarse_model_flops = coarse_model_flops(TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1)) - print("FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(topicfm_coarse_model_flops)) - + feat_c0, topicfm_featnet_flops0 = feat_net_flops( + topicfm_featnet, topicfm_conf["fpn"], img0 + ) + feat_c1, topicfm_featnet_flops1 = feat_net_flops( + topicfm_featnet, topicfm_conf["fpn"], img1 + ) + print( + "FLOPs of feature extraction in TopicFM: {} GFLOPs".format( + (topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2 + ) + ) + feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c") + feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c") + topicfm_coarse_model_flops = coarse_model_flops( + TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1) + ) + print( + "FLOPs of coarse matching model in TopicFM: {} GFLOPs".format( + topicfm_coarse_model_flops + ) + ) diff --git a/third_party/TopicFM/src/__init__.py b/third_party/TopicFM/src/__init__.py index 30caef94f911f99e0c12510d8181b3c1537daf1a..aa7ba68e1b8fa7c7854ca49680c07d54d468d83e 100644 --- a/third_party/TopicFM/src/__init__.py +++ b/third_party/TopicFM/src/__init__.py @@ -1,11 +1,13 @@ from yacs.config import CfgNode from .config.default import _CN + def lower_config(yacs_cfg): if not isinstance(yacs_cfg, CfgNode): return yacs_cfg return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + def get_model_cfg(): cfg = lower_config(lower_config(_CN)) - return cfg["model"] \ No newline at end of file + return cfg["model"] diff --git a/third_party/TopicFM/src/config/default.py b/third_party/TopicFM/src/config/default.py index 591558b3f358cdce0e9e72e94acba702b2a4e896..a252b1a13952480b5c22e50d6b90432f5a328112 100644 --- a/third_party/TopicFM/src/config/default.py +++ b/third_party/TopicFM/src/config/default.py @@ -1,9 +1,10 @@ from yacs.config import CfgNode as CN + _CN = CN() ############## ↓ MODEL Pipeline ↓ ############## _CN.MODEL = CN() -_CN.MODEL.BACKBONE_TYPE = 'FPN' +_CN.MODEL.BACKBONE_TYPE = "FPN" _CN.MODEL.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] _CN.MODEL.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd _CN.MODEL.FINE_CONCAT_COARSE_FEAT = False @@ -18,8 +19,8 @@ _CN.MODEL.COARSE = CN() _CN.MODEL.COARSE.D_MODEL = 256 _CN.MODEL.COARSE.D_FFN = 256 _CN.MODEL.COARSE.NHEAD = 8 -_CN.MODEL.COARSE.LAYER_NAMES = ['seed', 'seed', 'seed', 'seed', 'seed'] -_CN.MODEL.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.MODEL.COARSE.LAYER_NAMES = ["seed", "seed", "seed", "seed", "seed"] +_CN.MODEL.COARSE.ATTENTION = "linear" # options: ['linear', 'full'] _CN.MODEL.COARSE.TEMP_BUG_FIX = True _CN.MODEL.COARSE.N_TOPICS = 100 _CN.MODEL.COARSE.N_SAMPLES = 6 @@ -29,7 +30,7 @@ _CN.MODEL.COARSE.N_TOPIC_TRANSFORMERS = 1 _CN.MODEL.MATCH_COARSE = CN() _CN.MODEL.MATCH_COARSE.THR = 0.2 _CN.MODEL.MATCH_COARSE.BORDER_RM = 2 -_CN.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +_CN.MODEL.MATCH_COARSE.MATCH_TYPE = "dual_softmax" _CN.MODEL.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 _CN.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory _CN.MODEL.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock @@ -40,8 +41,8 @@ _CN.MODEL.FINE = CN() _CN.MODEL.FINE.D_MODEL = 128 _CN.MODEL.FINE.D_FFN = 128 _CN.MODEL.FINE.NHEAD = 4 -_CN.MODEL.FINE.LAYER_NAMES = ['cross'] * 1 -_CN.MODEL.FINE.ATTENTION = 'linear' +_CN.MODEL.FINE.LAYER_NAMES = ["cross"] * 1 +_CN.MODEL.FINE.ATTENTION = "linear" _CN.MODEL.FINE.N_TOPICS = 1 # 5. MODEL Losses @@ -57,7 +58,7 @@ _CN.MODEL.LOSS.NEG_WEIGHT = 1.0 # use `_CN.MODEL.MATCH_COARSE.MATCH_TYPE` # -- # fine-level -_CN.MODEL.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.MODEL.LOSS.FINE_TYPE = "l2_with_std" # ['l2_with_std', 'l2'] _CN.MODEL.LOSS.FINE_WEIGHT = 1.0 _CN.MODEL.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) @@ -75,25 +76,33 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None _CN.DATASET.VAL_DATA_ROOT = None _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) _CN.DATASET.VAL_NPZ_ROOT = None -_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_LIST_PATH = ( + None # None if val data from all scenes are bundled into a single npz file +) _CN.DATASET.VAL_INTRINSIC_PATH = None # testing _CN.DATASET.TEST_DATA_SOURCE = None _CN.DATASET.TEST_DATA_ROOT = None _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) _CN.DATASET.TEST_NPZ_ROOT = None -_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_LIST_PATH = ( + None # None if test data from all scenes are bundled into a single npz file +) _CN.DATASET.TEST_INTRINSIC_PATH = None _CN.DATASET.TEST_IMGSIZE = None # 2. dataset config # general options -_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = ( + 0.4 # discard data with overlap_score < min_overlap_score +) _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] # MegaDepth options -_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_RESIZE = ( + 640 # resize the longer side, zero-pad bottom-right to square. +) _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 _CN.DATASET.MGDPT_DF = 8 @@ -109,17 +118,17 @@ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning # optimizer _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime -_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam _CN.TRAINER.ADAMW_DECAY = 0.01 # step-based warm-up -_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] -_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0.0 _CN.TRAINER.WARMUP_STEP = 4800 # learning rate scheduler -_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] -_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.SCHEDULER = "MultiStepLR" # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step] _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR _CN.TRAINER.MSLR_GAMMA = 0.5 _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing @@ -127,25 +136,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' in # plotting related _CN.TRAINER.ENABLE_PLOTTING = True -_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting -_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] -_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic" # geometric metrics and pose solver -_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) -_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] -_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.EPI_ERR_THR = ( + 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +) +_CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, DEGENSAC, MAGSAC] _CN.TRAINER.RANSAC_PIXEL_THR = 0.5 _CN.TRAINER.RANSAC_CONF = 0.99999 _CN.TRAINER.RANSAC_MAX_ITERS = 10000 _CN.TRAINER.USE_MAGSACPP = False # data sampler for train_dataloader -_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +_CN.TRAINER.DATA_SAMPLER = ( + "scene_balance" # options: ['scene_balance', 'random', 'normal'] +) # 'scene_balance' config _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 -_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not -_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = ( + True # whether sample each scene with replacement or not +) +_CN.TRAINER.SB_SUBSET_SHUFFLE = ( + True # after sampling from scenes, whether shuffle within the epoch or not +) _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data # 'random' config _CN.TRAINER.RDM_REPLACEMENT = True diff --git a/third_party/TopicFM/src/datasets/aachen.py b/third_party/TopicFM/src/datasets/aachen.py index ebfeee4dbfbd78770976ec027ceee8ef333a4574..71f2dd18855f3536a5159e7f420044d6536d960b 100644 --- a/third_party/TopicFM/src/datasets/aachen.py +++ b/third_party/TopicFM/src/datasets/aachen.py @@ -9,7 +9,7 @@ class AachenDataset(Dataset): self.img_path = img_path self.img_resize = img_resize self.down_factor = down_factor - with open(match_list_path, 'r') as f: + with open(match_list_path, "r") as f: self.raw_pairs = f.readlines() print("number of matching pairs: ", len(self.raw_pairs)) @@ -18,12 +18,20 @@ class AachenDataset(Dataset): def __getitem__(self, idx): raw_pair = self.raw_pairs[idx] - image_name0, image_name1 = raw_pair.strip('\n').split(' ') + image_name0, image_name1 = raw_pair.strip("\n").split(" ") path_img0 = os.path.join(self.img_path, image_name0) path_img1 = os.path.join(self.img_path, image_name1) - img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor) - img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor) - return {"image0": img0, "image1": img1, - "scale0": scale0, "scale1": scale1, - "pair_names": (image_name0, image_name1), - "dataset_name": "AachenDayNight"} \ No newline at end of file + img0, scale0 = read_img_gray( + path_img0, resize=self.img_resize, down_factor=self.down_factor + ) + img1, scale1 = read_img_gray( + path_img1, resize=self.img_resize, down_factor=self.down_factor + ) + return { + "image0": img0, + "image1": img1, + "scale0": scale0, + "scale1": scale1, + "pair_names": (image_name0, image_name1), + "dataset_name": "AachenDayNight", + } diff --git a/third_party/TopicFM/src/datasets/custom_dataloader.py b/third_party/TopicFM/src/datasets/custom_dataloader.py index 46d55d4f4d56d2c96cd42b6597834f945a5eb20d..eb3bd7a083baf5d0a1e8a9a21b97a08dcc22f163 100644 --- a/third_party/TopicFM/src/datasets/custom_dataloader.py +++ b/third_party/TopicFM/src/datasets/custom_dataloader.py @@ -28,99 +28,124 @@ class TestDataLoader(DataLoader): # 2. dataset config # general options - self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_test = ( + config.DATASET.MIN_OVERLAP_SCORE_TEST + ) # 0.4, omit data with overlap_score < min_overlap_score # MegaDepth options - if dataset_name == 'megadepth': + if dataset_name == "megadepth": self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800 self.mgdpt_img_pad = True self.mgdpt_depth_pad = True self.mgdpt_df = 8 self.coarse_scale = 0.125 - if dataset_name == 'scannet': + if dataset_name == "scannet": self.img_resize = config.DATASET.TEST_IMGSIZE - if (dataset_name == 'megadepth') or (dataset_name == 'scannet'): + if (dataset_name == "megadepth") or (dataset_name == "scannet"): test_dataset = self._setup_dataset( self.test_data_root, self.test_npz_root, self.test_list_path, self.test_intrinsic_path, - mode='test', + mode="test", min_overlap_score=self.min_overlap_score_test, - pose_dir=self.test_pose_root) - elif dataset_name == 'aachen_v1.1': - test_dataset = AachenDataset(self.test_data_root, self.test_list_path, - img_resize=config.DATASET.TEST_IMGSIZE) - elif dataset_name == 'inloc': - test_dataset = InLocDataset(self.test_data_root, self.test_list_path, - img_resize=config.DATASET.TEST_IMGSIZE) + pose_dir=self.test_pose_root, + ) + elif dataset_name == "aachen_v1.1": + test_dataset = AachenDataset( + self.test_data_root, + self.test_list_path, + img_resize=config.DATASET.TEST_IMGSIZE, + ) + elif dataset_name == "inloc": + test_dataset = InLocDataset( + self.test_data_root, + self.test_list_path, + img_resize=config.DATASET.TEST_IMGSIZE, + ) else: raise "unknown dataset" self.test_loader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': 4, - 'pin_memory': True + "batch_size": 1, + "shuffle": False, + "num_workers": 4, + "pin_memory": True, } # sampler = Seq(self.test_dataset, shuffle=False) super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params) - def _setup_dataset(self, - data_root, - split_npz_root, - scene_list_path, - intri_path, - mode='train', - min_overlap_score=0., - pose_dir=None): - """ Setup train / val / test set""" - with open(scene_list_path, 'r') as f: + def _setup_dataset( + self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode="train", + min_overlap_score=0.0, + pose_dir=None, + ): + """Setup train / val / test set""" + with open(scene_list_path, "r") as f: npz_names = [name.split()[0] for name in f.readlines()] local_npz_names = npz_names - return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path, - mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + return self._build_concat_dataset( + data_root, + local_npz_names, + split_npz_root, + intri_path, + mode=mode, + min_overlap_score=min_overlap_score, + pose_dir=pose_dir, + ) def _build_concat_dataset( - self, - data_root, - npz_names, - npz_dir, - intrinsic_path, - mode, - min_overlap_score=0., - pose_dir=None + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0.0, + pose_dir=None, ): datasets = [] # augment_fn = self.augment_fn if mode == 'train' else None data_source = self.test_data_source - if str(data_source).lower() == 'megadepth': - npz_names = [f'{n}.npz' for n in npz_names] + if str(data_source).lower() == "megadepth": + npz_names = [f"{n}.npz" for n in npz_names] for npz_name in tqdm(npz_names): # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. npz_path = osp.join(npz_dir, npz_name) - if data_source == 'ScanNet': + if data_source == "ScanNet": datasets.append( - ScanNetDataset(data_root, - npz_path, - intrinsic_path, - mode=mode, img_resize=self.img_resize, - min_overlap_score=min_overlap_score, - pose_dir=pose_dir)) - elif data_source == 'MegaDepth': + ScanNetDataset( + data_root, + npz_path, + intrinsic_path, + mode=mode, + img_resize=self.img_resize, + min_overlap_score=min_overlap_score, + pose_dir=pose_dir, + ) + ) + elif data_source == "MegaDepth": datasets.append( - MegaDepthDataset(data_root, - npz_path, - mode=mode, - min_overlap_score=min_overlap_score, - img_resize=self.mgdpt_img_resize, - df=self.mgdpt_df, - img_padding=self.mgdpt_img_pad, - depth_padding=self.mgdpt_depth_pad, - coarse_scale=self.coarse_scale)) + MegaDepthDataset( + data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + coarse_scale=self.coarse_scale, + ) + ) else: raise NotImplementedError() return ConcatDataset(datasets) diff --git a/third_party/TopicFM/src/datasets/inloc.py b/third_party/TopicFM/src/datasets/inloc.py index 5421099d11b4dbbea8c09568c493d844d5c6a1b0..dc176761b7626aafd90e9674c5d85ff6e95f537c 100644 --- a/third_party/TopicFM/src/datasets/inloc.py +++ b/third_party/TopicFM/src/datasets/inloc.py @@ -9,7 +9,7 @@ class InLocDataset(Dataset): self.img_path = img_path self.img_resize = img_resize self.down_factor = down_factor - with open(match_list_path, 'r') as f: + with open(match_list_path, "r") as f: self.raw_pairs = f.readlines() print("number of matching pairs: ", len(self.raw_pairs)) @@ -18,12 +18,20 @@ class InLocDataset(Dataset): def __getitem__(self, idx): raw_pair = self.raw_pairs[idx] - image_name0, image_name1 = raw_pair.strip('\n').split(' ') + image_name0, image_name1 = raw_pair.strip("\n").split(" ") path_img0 = os.path.join(self.img_path, image_name0) path_img1 = os.path.join(self.img_path, image_name1) - img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor) - img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor) - return {"image0": img0, "image1": img1, - "scale0": scale0, "scale1": scale1, - "pair_names": (image_name0, image_name1), - "dataset_name": "InLoc"} \ No newline at end of file + img0, scale0 = read_img_gray( + path_img0, resize=self.img_resize, down_factor=self.down_factor + ) + img1, scale1 = read_img_gray( + path_img1, resize=self.img_resize, down_factor=self.down_factor + ) + return { + "image0": img0, + "image1": img1, + "scale0": scale0, + "scale1": scale1, + "pair_names": (image_name0, image_name1), + "dataset_name": "InLoc", + } diff --git a/third_party/TopicFM/src/datasets/megadepth.py b/third_party/TopicFM/src/datasets/megadepth.py index e92768e72e373c2a8ebeaf1158f9710fb1bfb5f1..77516327ebed8ca4ea8be9692a7077d94f03ee5b 100644 --- a/third_party/TopicFM/src/datasets/megadepth.py +++ b/third_party/TopicFM/src/datasets/megadepth.py @@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth class MegaDepthDataset(Dataset): - def __init__(self, - root_dir, - npz_path, - mode='train', - min_overlap_score=0.4, - img_resize=None, - df=None, - img_padding=False, - depth_padding=False, - augment_fn=None, - **kwargs): + def __init__( + self, + root_dir, + npz_path, + mode="train", + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + **kwargs + ): """ Manage one scene(npz_path) of MegaDepth dataset. - + Args: root_dir (str): megadepth root directory that has `phoenix`. npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. @@ -38,30 +40,38 @@ class MegaDepthDataset(Dataset): super().__init__() self.root_dir = root_dir self.mode = mode - self.scene_id = npz_path.split('.')[0] + self.scene_id = npz_path.split(".")[0] # prepare scene_info and pair_info - if mode == 'test' and min_overlap_score != 0: - logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + if mode == "test" and min_overlap_score != 0: + logger.warning( + "You are using `min_overlap_score`!=0 in test mode. Set to 0." + ) min_overlap_score = 0 self.scene_info = np.load(npz_path, allow_pickle=True) - self.pair_infos = self.scene_info['pair_infos'].copy() - del self.scene_info['pair_infos'] - self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + self.pair_infos = self.scene_info["pair_infos"].copy() + del self.scene_info["pair_infos"] + self.pair_infos = [ + pair_info + for pair_info in self.pair_infos + if pair_info[1] > min_overlap_score + ] # parameters for image resizing, padding and depthmap padding - if mode == 'train': + if mode == "train": assert img_resize is not None and img_padding and depth_padding self.img_resize = img_resize - if mode == 'val': + if mode == "val": self.img_resize = 864 self.df = df self.img_padding = img_padding - self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + self.depth_max_size = ( + 2000 if depth_padding else None + ) # the upperbound of depthmaps size in megadepth. # for training LoFTR - self.augment_fn = augment_fn if mode == 'train' else None - self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + self.augment_fn = augment_fn if mode == "train" else None + self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125) def __len__(self): return len(self.pair_infos) @@ -70,60 +80,77 @@ class MegaDepthDataset(Dataset): (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] # read grayscale image and mask. (1, h, w) and (h, w) - img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) - img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) - + img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1]) + # TODO: Support augmentation & handle seeds for each worker correctly. image0, mask0, scale0 = read_megadepth_gray( - img_name0, self.img_resize, self.df, self.img_padding, None) - # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + img_name0, self.img_resize, self.df, self.img_padding, None + ) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) image1, mask1, scale1 = read_megadepth_gray( - img_name1, self.img_resize, self.df, self.img_padding, None) - # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + img_name1, self.img_resize, self.df, self.img_padding, None + ) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # read depth. shape: (h, w) - if self.mode in ['train', 'val']: + if self.mode in ["train", "val"]: depth0 = read_megadepth_depth( - osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]), + pad_to=self.depth_max_size, + ) depth1 = read_megadepth_depth( - osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]), + pad_to=self.depth_max_size, + ) else: depth0 = depth1 = torch.tensor([]) # read intrinsics of original size - K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) - K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + K_0 = torch.tensor( + self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float + ).reshape(3, 3) + K_1 = torch.tensor( + self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float + ).reshape(3, 3) # read and compute relative poses - T0 = self.scene_info['poses'][idx0] - T1 = self.scene_info['poses'][idx1] - T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T0 = self.scene_info["poses"][idx0] + T1 = self.scene_info["poses"][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) T_1to0 = T_0to1.inverse() data = { - 'image0': image0, # (1, h, w) - 'depth0': depth0, # (h, w) - 'image1': image1, - 'depth1': depth1, - 'T_0to1': T_0to1, # (4, 4) - 'T_1to0': T_1to0, - 'K0': K_0, # (3, 3) - 'K1': K_1, - 'scale0': scale0, # [scale_w, scale_h] - 'scale1': scale1, - 'dataset_name': 'MegaDepth', - 'scene_id': self.scene_id, - 'pair_id': idx, - 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + "image0": image0, # (1, h, w) + "depth0": depth0, # (h, w) + "image1": image1, + "depth1": depth1, + "T_0to1": T_0to1, # (4, 4) + "T_1to0": T_1to0, + "K0": K_0, # (3, 3) + "K1": K_1, + "scale0": scale0, # [scale_w, scale_h] + "scale1": scale1, + "dataset_name": "MegaDepth", + "scene_id": self.scene_id, + "pair_id": idx, + "pair_names": ( + self.scene_info["image_paths"][idx0], + self.scene_info["image_paths"][idx1], + ), } # for LoFTR training if mask0 is not None: # img_padding is True if self.coarse_scale: - [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), - scale_factor=self.coarse_scale, - mode='nearest', - recompute_scale_factor=False)[0].bool() - data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + [ts_mask_0, ts_mask_1] = F.interpolate( + torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode="nearest", + recompute_scale_factor=False, + )[0].bool() + data.update({"mask0": ts_mask_0, "mask1": ts_mask_1}) return data diff --git a/third_party/TopicFM/src/datasets/sampler.py b/third_party/TopicFM/src/datasets/sampler.py index 81b6f435645632a013476f9a665a0861ab7fcb61..131111c4cf69cd8770058dfac2be717aa183978e 100644 --- a/third_party/TopicFM/src/datasets/sampler.py +++ b/third_party/TopicFM/src/datasets/sampler.py @@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset class RandomConcatSampler(Sampler): - """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. - + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. Args: shuffle (bool): shuffle the random sampled indices across all sub-datsets. @@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler): TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 """ - def __init__(self, - data_source: ConcatDataset, - n_samples_per_subset: int, - subset_replacement: bool=True, - shuffle: bool=True, - repeat: int=1, - seed: int=None): + + def __init__( + self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool = True, + shuffle: bool = True, + repeat: int = 1, + seed: int = None, + ): if not isinstance(data_source, ConcatDataset): raise TypeError("data_source should be torch.utils.data.ConcatDataset") - + self.data_source = data_source self.n_subset = len(self.data_source.datasets) self.n_samples_per_subset = n_samples_per_subset @@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler): self.shuffle = shuffle self.generator = torch.manual_seed(seed) assert self.repeat >= 1 - + def __len__(self): return self.n_samples - + def __iter__(self): indices = [] # sample from each sub-dataset for d_idx in range(self.n_subset): - low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1] high = self.data_source.cumulative_sizes[d_idx] if self.subset_replacement: - rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), - generator=self.generator, dtype=torch.int64) + rand_tensor = torch.randint( + low, + high, + (self.n_samples_per_subset,), + generator=self.generator, + dtype=torch.int64, + ) else: # sample without replacement len_subset = len(self.data_source.datasets[d_idx]) rand_tensor = torch.randperm(len_subset, generator=self.generator) + low if len_subset >= self.n_samples_per_subset: - rand_tensor = rand_tensor[:self.n_samples_per_subset] - else: # padding with replacement - rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), - generator=self.generator, dtype=torch.int64) + rand_tensor = rand_tensor[: self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint( + low, + high, + (self.n_samples_per_subset - len_subset,), + generator=self.generator, + dtype=torch.int64, + ) rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) indices.append(rand_tensor) indices = torch.cat(indices) @@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler): _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] repeat_indices = map(_choice, repeat_indices) indices = torch.cat([indices, *repeat_indices], 0) - + assert indices.shape[0] == self.n_samples return iter(indices.tolist()) diff --git a/third_party/TopicFM/src/datasets/scannet.py b/third_party/TopicFM/src/datasets/scannet.py index fb5dab7b150a3c6f54eb07b0459bbf3e9ba58fbf..b955c4fa1609625be2c6c1a0ed6665109908bba0 100644 --- a/third_party/TopicFM/src/datasets/scannet.py +++ b/third_party/TopicFM/src/datasets/scannet.py @@ -10,20 +10,22 @@ from src.utils.dataset import ( read_scannet_gray, read_scannet_depth, read_scannet_pose, - read_scannet_intrinsic + read_scannet_intrinsic, ) class ScanNetDataset(utils.data.Dataset): - def __init__(self, - root_dir, - npz_path, - intrinsic_path, - mode='train', - min_overlap_score=0.4, - augment_fn=None, - pose_dir=None, - **kwargs): + def __init__( + self, + root_dir, + npz_path, + intrinsic_path, + mode="train", + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + **kwargs, + ): """Manage one scene of ScanNet Dataset. Args: root_dir (str): ScanNet root directory that contains scene folders. @@ -38,78 +40,88 @@ class ScanNetDataset(utils.data.Dataset): self.root_dir = root_dir self.pose_dir = pose_dir if pose_dir is not None else root_dir self.mode = mode - self.img_resize = (640, 480) if 'img_resize' not in kwargs else kwargs['img_resize'] + self.img_resize = ( + (640, 480) if "img_resize" not in kwargs else kwargs["img_resize"] + ) # prepare data_names, intrinsics and extrinsics(T) with np.load(npz_path) as data: - self.data_names = data['name'] - if 'score' in data.keys() and mode not in ['val' or 'test']: - kept_mask = data['score'] > min_overlap_score + self.data_names = data["name"] + if "score" in data.keys() and mode not in ["val" or "test"]: + kept_mask = data["score"] > min_overlap_score self.data_names = self.data_names[kept_mask] self.intrinsics = dict(np.load(intrinsic_path)) # for training LoFTR - self.augment_fn = augment_fn if mode == 'train' else None + self.augment_fn = augment_fn if mode == "train" else None def __len__(self): return len(self.data_names) def _read_abs_pose(self, scene_name, name): - pth = osp.join(self.pose_dir, - scene_name, - 'pose', f'{name}.txt') + pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt") return read_scannet_pose(pth) def _compute_rel_pose(self, scene_name, name0, name1): pose0 = self._read_abs_pose(scene_name, name0) pose1 = self._read_abs_pose(scene_name, name1) - + return np.matmul(pose1, inv(pose0)) # (4, 4) def __getitem__(self, idx): data_name = self.data_names[idx] scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name - scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" # read the grayscale image which will be resized to (1, 480, 640) - img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') - img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') - + img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg") + img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg") + # TODO: Support augmentation & handle seeds for each worker correctly. image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None) - # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None) - # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # read the depthmap which is stored as (480, 640) - if self.mode in ['train', 'val']: - depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) - depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + if self.mode in ["train", "val"]: + depth0 = read_scannet_depth( + osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png") + ) + depth1 = read_scannet_depth( + osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png") + ) else: depth0 = depth1 = torch.tensor([]) # read the intrinsic of depthmap - K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + K_0 = K_1 = torch.tensor( + self.intrinsics[scene_name].copy(), dtype=torch.float + ).reshape(3, 3) # read and compute relative poses - T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), - dtype=torch.float32) + T_0to1 = torch.tensor( + self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32, + ) T_1to0 = T_0to1.inverse() data = { - 'image0': image0, # (1, h, w) - 'depth0': depth0, # (h, w) - 'image1': image1, - 'depth1': depth1, - 'T_0to1': T_0to1, # (4, 4) - 'T_1to0': T_1to0, - 'K0': K_0, # (3, 3) - 'K1': K_1, - 'dataset_name': 'ScanNet', - 'scene_id': scene_name, - 'pair_id': idx, - 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), - osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + "image0": image0, # (1, h, w) + "depth0": depth0, # (h, w) + "image1": image1, + "depth1": depth1, + "T_0to1": T_0to1, # (4, 4) + "T_1to0": T_1to0, + "K0": K_0, # (3, 3) + "K1": K_1, + "dataset_name": "ScanNet", + "scene_id": scene_name, + "pair_id": idx, + "pair_names": ( + osp.join(scene_name, "color", f"{stem_name_0}.jpg"), + osp.join(scene_name, "color", f"{stem_name_1}.jpg"), + ), } return data diff --git a/third_party/TopicFM/src/lightning_trainer/data.py b/third_party/TopicFM/src/lightning_trainer/data.py index 8deb713b6300e0e9e8a261e2230031174b452862..95f6a5eeecf39a993b86674242eacb7b42f8a566 100644 --- a/third_party/TopicFM/src/lightning_trainer/data.py +++ b/third_party/TopicFM/src/lightning_trainer/data.py @@ -16,7 +16,7 @@ from torch.utils.data import ( ConcatDataset, DistributedSampler, RandomSampler, - dataloader + dataloader, ) from src.utils.augment import build_augmentor @@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler class MultiSceneDataModule(pl.LightningDataModule): - """ + """ For distributed training, each training process is assgined only a part of the training scenes to reduce memory overhead. """ + def __init__(self, args, config): super().__init__() @@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule): # 2. dataset config # general options - self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_test = ( + config.DATASET.MIN_OVERLAP_SCORE_TEST + ) # 0.4, omit data with overlap_score < min_overlap_score self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN - self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + self.augment_fn = build_augmentor( + config.DATASET.AUGMENTATION_TYPE + ) # None, options: [None, 'dark', 'mobile'] # MegaDepth options self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 - self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True - self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True self.mgdpt_df = config.DATASET.MGDPT_DF # 8 self.coarse_scale = 1 / config.MODEL.RESOLUTION[0] # 0.125. for training loftr. # 3.loader parameters self.train_loader_params = { - 'batch_size': args.batch_size, - 'num_workers': args.num_workers, - 'pin_memory': getattr(args, 'pin_memory', True) + "batch_size": args.batch_size, + "num_workers": args.num_workers, + "pin_memory": getattr(args, "pin_memory", True), } self.val_loader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': args.num_workers, - 'pin_memory': getattr(args, 'pin_memory', True) + "batch_size": 1, + "shuffle": False, + "num_workers": args.num_workers, + "pin_memory": getattr(args, "pin_memory", True), } self.test_loader_params = { - 'batch_size': 1, - 'shuffle': False, - 'num_workers': args.num_workers, - 'pin_memory': True + "batch_size": 1, + "shuffle": False, + "num_workers": args.num_workers, + "pin_memory": True, } - + # 4. sampler self.data_sampler = config.TRAINER.DATA_SAMPLER self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE self.repeat = config.TRAINER.SB_REPEAT - + # (optional) RandomSampler for debugging # misc configurations - self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.parallel_load_data = getattr(args, "parallel_load_data", False) self.seed = config.TRAINER.SEED # 66 def setup(self, stage=None): @@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule): stage (str): 'fit' in training phase, and 'test' in testing phase. """ - assert stage in ['fit', 'test'], "stage must be either fit or test" + assert stage in ["fit", "test"], "stage must be either fit or test" try: self.world_size = dist.get_world_size() @@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule): self.rank = 0 logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") - if stage == 'fit': + if stage == "fit": self.train_dataset = self._setup_dataset( self.train_data_root, self.train_npz_root, self.train_list_path, self.train_intrinsic_path, - mode='train', + mode="train", min_overlap_score=self.min_overlap_score_train, - pose_dir=self.train_pose_root) + pose_dir=self.train_pose_root, + ) # setup multiple (optional) validation subsets if isinstance(self.val_list_path, (list, tuple)): self.val_dataset = [] if not isinstance(self.val_npz_root, (list, tuple)): - self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + self.val_npz_root = [ + self.val_npz_root for _ in range(len(self.val_list_path)) + ] for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): - self.val_dataset.append(self._setup_dataset( - self.val_data_root, - npz_root, - npz_list, - self.val_intrinsic_path, - mode='val', - min_overlap_score=self.min_overlap_score_test, - pose_dir=self.val_pose_root)) + self.val_dataset.append( + self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode="val", + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root, + ) + ) else: self.val_dataset = self._setup_dataset( self.val_data_root, self.val_npz_root, self.val_list_path, self.val_intrinsic_path, - mode='val', + mode="val", min_overlap_score=self.min_overlap_score_test, - pose_dir=self.val_pose_root) - logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + pose_dir=self.val_pose_root, + ) + logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!") else: # stage == 'test self.test_dataset = self._setup_dataset( self.test_data_root, self.test_npz_root, self.test_list_path, self.test_intrinsic_path, - mode='test', + mode="test", min_overlap_score=self.min_overlap_score_test, - pose_dir=self.test_pose_root) - logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + pose_dir=self.test_pose_root, + ) + logger.info(f"[rank:{self.rank}]: Test Dataset loaded!") - def _setup_dataset(self, - data_root, - split_npz_root, - scene_list_path, - intri_path, - mode='train', - min_overlap_score=0., - pose_dir=None): - """ Setup train / val / test set""" - with open(scene_list_path, 'r') as f: + def _setup_dataset( + self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode="train", + min_overlap_score=0.0, + pose_dir=None, + ): + """Setup train / val / test set""" + with open(scene_list_path, "r") as f: npz_names = [name.split()[0] for name in f.readlines()] - if mode == 'train': - local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + if mode == "train": + local_npz_names = get_local_split( + npz_names, self.world_size, self.rank, self.seed + ) else: local_npz_names = npz_names - logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') - - dataset_builder = self._build_concat_dataset_parallel \ - if self.parallel_load_data \ - else self._build_concat_dataset - return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, - mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.") + + dataset_builder = ( + self._build_concat_dataset_parallel + if self.parallel_load_data + else self._build_concat_dataset + ) + return dataset_builder( + data_root, + local_npz_names, + split_npz_root, + intri_path, + mode=mode, + min_overlap_score=min_overlap_score, + pose_dir=pose_dir, + ) def _build_concat_dataset( self, @@ -196,44 +222,56 @@ class MultiSceneDataModule(pl.LightningDataModule): npz_dir, intrinsic_path, mode, - min_overlap_score=0., - pose_dir=None + min_overlap_score=0.0, + pose_dir=None, ): datasets = [] - augment_fn = self.augment_fn if mode == 'train' else None - data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source - if str(data_source).lower() == 'megadepth': - npz_names = [f'{n}.npz' for n in npz_names] - for npz_name in tqdm(npz_names, - desc=f'[rank:{self.rank}] loading {mode} datasets', - disable=int(self.rank) != 0): + augment_fn = self.augment_fn if mode == "train" else None + data_source = ( + self.trainval_data_source + if mode in ["train", "val"] + else self.test_data_source + ) + if str(data_source).lower() == "megadepth": + npz_names = [f"{n}.npz" for n in npz_names] + for npz_name in tqdm( + npz_names, + desc=f"[rank:{self.rank}] loading {mode} datasets", + disable=int(self.rank) != 0, + ): # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. npz_path = osp.join(npz_dir, npz_name) - if data_source == 'ScanNet': + if data_source == "ScanNet": datasets.append( - ScanNetDataset(data_root, - npz_path, - intrinsic_path, - mode=mode, - min_overlap_score=min_overlap_score, - augment_fn=augment_fn, - pose_dir=pose_dir)) - elif data_source == 'MegaDepth': + ScanNetDataset( + data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir, + ) + ) + elif data_source == "MegaDepth": datasets.append( - MegaDepthDataset(data_root, - npz_path, - mode=mode, - min_overlap_score=min_overlap_score, - img_resize=self.mgdpt_img_resize, - df=self.mgdpt_df, - img_padding=self.mgdpt_img_pad, - depth_padding=self.mgdpt_depth_pad, - augment_fn=augment_fn, - coarse_scale=self.coarse_scale)) + MegaDepthDataset( + data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale, + ) + ) else: raise NotImplementedError() return ConcatDataset(datasets) - + def _build_concat_dataset_parallel( self, data_root, @@ -241,77 +279,118 @@ class MultiSceneDataModule(pl.LightningDataModule): npz_dir, intrinsic_path, mode, - min_overlap_score=0., + min_overlap_score=0.0, pose_dir=None, ): - augment_fn = self.augment_fn if mode == 'train' else None - data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source - if str(data_source).lower() == 'megadepth': - npz_names = [f'{n}.npz' for n in npz_names] - with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', - total=len(npz_names), disable=int(self.rank) != 0)): - if data_source == 'ScanNet': - datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( - delayed(lambda x: _build_dataset( - ScanNetDataset, - data_root, - osp.join(npz_dir, x), - intrinsic_path, - mode=mode, - min_overlap_score=min_overlap_score, - augment_fn=augment_fn, - pose_dir=pose_dir))(name) - for name in npz_names) - elif data_source == 'MegaDepth': + augment_fn = self.augment_fn if mode == "train" else None + data_source = ( + self.trainval_data_source + if mode in ["train", "val"] + else self.test_data_source + ) + if str(data_source).lower() == "megadepth": + npz_names = [f"{n}.npz" for n in npz_names] + with tqdm_joblib( + tqdm( + desc=f"[rank:{self.rank}] loading {mode} datasets", + total=len(npz_names), + disable=int(self.rank) != 0, + ) + ): + if data_source == "ScanNet": + datasets = Parallel( + n_jobs=math.floor( + len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size() + ) + )( + delayed( + lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir, + ) + )(name) + for name in npz_names + ) + elif data_source == "MegaDepth": # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. raise NotImplementedError() - datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( - delayed(lambda x: _build_dataset( - MegaDepthDataset, - data_root, - osp.join(npz_dir, x), - mode=mode, - min_overlap_score=min_overlap_score, - img_resize=self.mgdpt_img_resize, - df=self.mgdpt_df, - img_padding=self.mgdpt_img_pad, - depth_padding=self.mgdpt_depth_pad, - augment_fn=augment_fn, - coarse_scale=self.coarse_scale))(name) - for name in npz_names) + datasets = Parallel( + n_jobs=math.floor( + len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size() + ) + )( + delayed( + lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale, + ) + )(name) + for name in npz_names + ) else: - raise ValueError(f'Unknown dataset: {data_source}') + raise ValueError(f"Unknown dataset: {data_source}") return ConcatDataset(datasets) def train_dataloader(self): - """ Build training dataloader for ScanNet / MegaDepth. """ - assert self.data_sampler in ['scene_balance'] - logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') - if self.data_sampler == 'scene_balance': - sampler = RandomConcatSampler(self.train_dataset, - self.n_samples_per_subset, - self.subset_replacement, - self.shuffle, self.repeat, self.seed) + """Build training dataloader for ScanNet / MegaDepth.""" + assert self.data_sampler in ["scene_balance"] + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)." + ) + if self.data_sampler == "scene_balance": + sampler = RandomConcatSampler( + self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, + self.repeat, + self.seed, + ) else: sampler = None - dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + dataloader = DataLoader( + self.train_dataset, sampler=sampler, **self.train_loader_params + ) return dataloader - + def val_dataloader(self): - """ Build validation dataloader for ScanNet / MegaDepth. """ - logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + """Build validation dataloader for ScanNet / MegaDepth.""" + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init." + ) if not isinstance(self.val_dataset, abc.Sequence): sampler = DistributedSampler(self.val_dataset, shuffle=False) - return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + return DataLoader( + self.val_dataset, sampler=sampler, **self.val_loader_params + ) else: dataloaders = [] for dataset in self.val_dataset: sampler = DistributedSampler(dataset, shuffle=False) - dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + dataloaders.append( + DataLoader(dataset, sampler=sampler, **self.val_loader_params) + ) return dataloaders def test_dataloader(self, *args, **kwargs): - logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + logger.info( + f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init." + ) sampler = DistributedSampler(self.test_dataset, shuffle=False) return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) diff --git a/third_party/TopicFM/src/lightning_trainer/trainer.py b/third_party/TopicFM/src/lightning_trainer/trainer.py index acf51f66130be66b7d3294ca5c081a2df3856d96..cce4839b536eba974426309eca10415547479f50 100644 --- a/third_party/TopicFM/src/lightning_trainer/trainer.py +++ b/third_party/TopicFM/src/lightning_trainer/trainer.py @@ -1,4 +1,3 @@ - from collections import defaultdict import pprint from loguru import logger @@ -10,13 +9,16 @@ import pytorch_lightning as pl from matplotlib import pyplot as plt from src.models import TopicFM -from src.models.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.models.utils.supervision import ( + compute_supervision_coarse, + compute_supervision_fine, +) from src.losses.loss import TopicFMLoss from src.optimizers import build_optimizer, build_scheduler from src.utils.metrics import ( compute_symmetrical_epipolar_errors, compute_pose_errors, - aggregate_metrics + aggregate_metrics, ) from src.utils.plotting import make_matching_figures from src.utils.comm import gather, all_gather @@ -34,168 +36,225 @@ class PL_Trainer(pl.LightningModule): # Misc self.config = config # full config _config = lower_config(self.config) - self.model_cfg = lower_config(_config['model']) + self.model_cfg = lower_config(_config["model"]) self.profiler = profiler or PassThroughProfiler() - self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + self.n_vals_plot = max( + config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1 + ) # Matcher: TopicFM - self.matcher = TopicFM(config=_config['model']) + self.matcher = TopicFM(config=_config["model"]) self.loss = TopicFMLoss(_config) # Pretrained weights if pretrained_ckpt: - state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"] self.matcher.load_state_dict(state_dict, strict=True) - logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") - + logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint") + # Testing self.dump_dir = dump_dir - + def configure_optimizers(self): # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` optimizer = build_optimizer(self, self.config) scheduler = build_scheduler(self.config, optimizer) return [optimizer], [scheduler] - + def optimizer_step( - self, epoch, batch_idx, optimizer, optimizer_idx, - optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu, + using_native_amp, + using_lbfgs, + ): # learning rate warm up warmup_step = self.config.TRAINER.WARMUP_STEP if self.trainer.global_step < warmup_step: - if self.config.TRAINER.WARMUP_TYPE == 'linear': + if self.config.TRAINER.WARMUP_TYPE == "linear": base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR - lr = base_lr + \ - (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ - abs(self.config.TRAINER.TRUE_LR - base_lr) + lr = base_lr + ( + self.trainer.global_step / self.config.TRAINER.WARMUP_STEP + ) * abs(self.config.TRAINER.TRUE_LR - base_lr) for pg in optimizer.param_groups: - pg['lr'] = lr - elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pg["lr"] = lr + elif self.config.TRAINER.WARMUP_TYPE == "constant": pass else: - raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + raise ValueError( + f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}" + ) # update params optimizer.step(closure=optimizer_closure) optimizer.zero_grad() - + def _trainval_inference(self, batch): with self.profiler.profile("Compute coarse supervision"): compute_supervision_coarse(batch, self.config) - + with self.profiler.profile("TopicFM"): self.matcher(batch) - + with self.profiler.profile("Compute fine supervision"): compute_supervision_fine(batch, self.config) - + with self.profiler.profile("Compute losses"): self.loss(batch) - + def _compute_metrics(self, batch): with self.profiler.profile("Copmute metrics"): - compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match - compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + compute_symmetrical_epipolar_errors( + batch + ) # compute epi_errs for each match + compute_pose_errors( + batch, self.config + ) # compute R_errs, t_errs, pose_errs for each pair - rel_pair_names = list(zip(*batch['pair_names'])) - bs = batch['image0'].size(0) + rel_pair_names = list(zip(*batch["pair_names"])) + bs = batch["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler - 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], - 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], - 'R_errs': batch['R_errs'], - 't_errs': batch['t_errs'], - 'inliers': batch['inliers']} - ret_dict = {'metrics': metrics} + "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], + "epi_errs": [ + batch["epi_errs"][batch["m_bids"] == b].cpu().numpy() + for b in range(bs) + ], + "R_errs": batch["R_errs"], + "t_errs": batch["t_errs"], + "inliers": batch["inliers"], + } + ret_dict = {"metrics": metrics} return ret_dict, rel_pair_names - + def training_step(self, batch, batch_idx): self._trainval_inference(batch) - + # logging - if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + if ( + self.trainer.global_rank == 0 + and self.global_step % self.trainer.log_every_n_steps == 0 + ): # scalars - for k, v in batch['loss_scalars'].items(): - self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) + for k, v in batch["loss_scalars"].items(): + self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step) # figures if self.config.TRAINER.ENABLE_PLOTTING: - compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match - figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + compute_symmetrical_epipolar_errors( + batch + ) # compute epi_errs for each match + figures = make_matching_figures( + batch, self.config, self.config.TRAINER.PLOT_MODE + ) for k, v in figures.items(): - self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + self.logger.experiment.add_figure( + f"train_match/{k}", v, self.global_step + ) - return {'loss': batch['loss']} + return {"loss": batch["loss"]} def training_epoch_end(self, outputs): - avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() if self.trainer.global_rank == 0: self.logger.experiment.add_scalar( - 'train/avg_loss_on_epoch', avg_loss, - global_step=self.current_epoch) - + "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch + ) + def validation_step(self, batch, batch_idx): self._trainval_inference(batch) - + ret_dict, _ = self._compute_metrics(batch) - + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) figures = {self.config.TRAINER.PLOT_MODE: []} if batch_idx % val_plot_interval == 0: - figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) + figures = make_matching_figures( + batch, self.config, mode=self.config.TRAINER.PLOT_MODE + ) return { **ret_dict, - 'loss_scalars': batch['loss_scalars'], - 'figures': figures, + "loss_scalars": batch["loss_scalars"], + "figures": figures, } - + def validation_epoch_end(self, outputs): # handle multiple validation sets - multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_outputs = ( + [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + ) multi_val_metrics = defaultdict(list) - + for valset_idx, outputs in enumerate(multi_outputs): # since pl performs sanity_check at the very begining of the training cur_epoch = self.trainer.current_epoch - if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + if ( + not self.trainer.resume_from_checkpoint + and self.trainer.running_sanity_check + ): cur_epoch = -1 # 1. loss_scalars: dict of list, on cpu - _loss_scalars = [o['loss_scalars'] for o in outputs] - loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + _loss_scalars = [o["loss_scalars"] for o in outputs] + loss_scalars = { + k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) + for k in _loss_scalars[0] + } # 2. val metrics: dict of list, numpy - _metrics = [o['metrics'] for o in outputs] - metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} - # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 - val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + _metrics = [o["metrics"] for o in outputs] + metrics = { + k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) + for k in _metrics[0] + } + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics( + metrics, self.config.TRAINER.EPI_ERR_THR + ) for thr in [5, 10, 20]: - multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) - + multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"]) + # 3. figures - _figures = [o['figures'] for o in outputs] - figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + _figures = [o["figures"] for o in outputs] + figures = { + k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) + for k in _figures[0] + } # tensorboard records only on rank 0 if self.trainer.global_rank == 0: for k, v in loss_scalars.items(): mean_v = torch.stack(v).mean() - self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + self.logger.experiment.add_scalar( + f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch + ) for k, v in val_metrics_4tb.items(): - self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) - + self.logger.experiment.add_scalar( + f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch + ) + for k, v in figures.items(): if self.trainer.global_rank == 0: for plot_idx, fig in enumerate(v): self.logger.experiment.add_figure( - f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) - plt.close('all') + f"val_match_{valset_idx}/{k}/pair-{plot_idx}", + fig, + cur_epoch, + close=True, + ) + plt.close("all") for thr in [5, 10, 20]: # log on all ranks for ModelCheckpoint callback to work properly - self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + self.log( + f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"])) + ) # ckpt monitors on this def test_step(self, batch, batch_idx): with self.profiler.profile("TopicFM"): @@ -206,39 +265,46 @@ class PL_Trainer(pl.LightningModule): with self.profiler.profile("dump_results"): if self.dump_dir is not None: # dump results for further analysis - keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} - pair_names = list(zip(*batch['pair_names'])) - bs = batch['image0'].shape[0] + keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"} + pair_names = list(zip(*batch["pair_names"])) + bs = batch["image0"].shape[0] dumps = [] for b_id in range(bs): item = {} - mask = batch['m_bids'] == b_id - item['pair_names'] = pair_names[b_id] - item['identifier'] = '#'.join(rel_pair_names[b_id]) + mask = batch["m_bids"] == b_id + item["pair_names"] = pair_names[b_id] + item["identifier"] = "#".join(rel_pair_names[b_id]) for key in keys_to_save: item[key] = batch[key][mask].cpu().numpy() - for key in ['R_errs', 't_errs', 'inliers']: + for key in ["R_errs", "t_errs", "inliers"]: item[key] = batch[key][b_id] dumps.append(item) - ret_dict['dumps'] = dumps + ret_dict["dumps"] = dumps return ret_dict def test_epoch_end(self, outputs): # metrics: dict of list, numpy - _metrics = [o['metrics'] for o in outputs] - metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + _metrics = [o["metrics"] for o in outputs] + metrics = { + k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) + for k in _metrics[0] + } # [{key: [{...}, *#bs]}, *#batch] if self.dump_dir is not None: Path(self.dump_dir).mkdir(parents=True, exist_ok=True) - _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + _dumps = flattenList([o["dumps"] for o in outputs]) # [{...}, #bs*#batch] dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] - logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + logger.info( + f"Prediction and evaluation results will be saved to: {self.dump_dir}" + ) if self.trainer.global_rank == 0: print(self.profiler.summary()) - val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) - logger.info('\n' + pprint.pformat(val_metrics_4tb)) + val_metrics_4tb = aggregate_metrics( + metrics, self.config.TRAINER.EPI_ERR_THR + ) + logger.info("\n" + pprint.pformat(val_metrics_4tb)) if self.dump_dir is not None: - np.save(Path(self.dump_dir) / 'TopicFM_pred_eval', dumps) + np.save(Path(self.dump_dir) / "TopicFM_pred_eval", dumps) diff --git a/third_party/TopicFM/src/losses/loss.py b/third_party/TopicFM/src/losses/loss.py index 4be58498579c9fe649ed0ce2d42f230e59cef581..e386bb557285a290962477179e9a3a36b665368f 100644 --- a/third_party/TopicFM/src/losses/loss.py +++ b/third_party/TopicFM/src/losses/loss.py @@ -13,10 +13,10 @@ def sample_non_matches(pos_mask, match_ids=None, sampling_ratio=10): return ~pos_mask neg_mask = torch.zeros_like(pos_mask) - probs = torch.ones((HW - 1)//3, device=pos_mask.device) + probs = torch.ones((HW - 1) // 3, device=pos_mask.device) for _ in range(sampling_ratio): d = torch.multinomial(probs, len(j_ids), replacement=True) - sampled_j_ids = (j_ids + d*3 + 1) % HW + sampled_j_ids = (j_ids + d * 3 + 1) % HW neg_mask[b_ids, i_ids, sampled_j_ids] = True # neg_mask = neg_matrix == 1 else: @@ -29,18 +29,20 @@ class TopicFMLoss(nn.Module): def __init__(self, config): super().__init__() self.config = config # config under the global namespace - self.loss_config = config['model']['loss'] - self.match_type = self.config['model']['match_coarse']['match_type'] - + self.loss_config = config["model"]["loss"] + self.match_type = self.config["model"]["match_coarse"]["match_type"] + # coarse-level - self.correct_thr = self.loss_config['fine_correct_thr'] - self.c_pos_w = self.loss_config['pos_weight'] - self.c_neg_w = self.loss_config['neg_weight'] + self.correct_thr = self.loss_config["fine_correct_thr"] + self.c_pos_w = self.loss_config["pos_weight"] + self.c_neg_w = self.loss_config["neg_weight"] # fine-level - self.fine_type = self.loss_config['fine_type'] + self.fine_type = self.loss_config["fine_type"] - def compute_coarse_loss(self, conf, topic_mat, conf_gt, match_ids=None, weight=None): - """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + def compute_coarse_loss( + self, conf, topic_mat, conf_gt, match_ids=None, weight=None + ): + """Point-wise CE / Focal Loss with 0 / 1 confidence as gt. Args: conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) conf_gt (torch.Tensor): (N, HW0, HW1) @@ -53,30 +55,30 @@ class TopicFMLoss(nn.Module): if not pos_mask.any(): # assign a wrong gt pos_mask[0, 0, 0] = True if weight is not None: - weight[0, 0, 0] = 0. - c_pos_w = 0. + weight[0, 0, 0] = 0.0 + c_pos_w = 0.0 if not neg_mask.any(): neg_mask[0, 0, 0] = True if weight is not None: - weight[0, 0, 0] = 0. - c_neg_w = 0. + weight[0, 0, 0] = 0.0 + c_neg_w = 0.0 conf = torch.clamp(conf, 1e-6, 1 - 1e-6) - alpha = self.loss_config['focal_alpha'] + alpha = self.loss_config["focal_alpha"] loss = 0.0 if isinstance(topic_mat, torch.Tensor): pos_topic = topic_mat[pos_mask] - loss_pos_topic = - alpha * (pos_topic + 1e-6).log() + loss_pos_topic = -alpha * (pos_topic + 1e-6).log() neg_topic = topic_mat[neg_mask] - loss_neg_topic = - alpha * (1 - neg_topic + 1e-6).log() + loss_neg_topic = -alpha * (1 - neg_topic + 1e-6).log() if weight is not None: loss_pos_topic = loss_pos_topic * weight[pos_mask] loss_neg_topic = loss_neg_topic * weight[neg_mask] loss = loss_pos_topic.mean() + loss_neg_topic.mean() pos_conf = conf[pos_mask] - loss_pos = - alpha * pos_conf.log() + loss_pos = -alpha * pos_conf.log() # handle loss weights if weight is not None: # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, @@ -86,11 +88,11 @@ class TopicFMLoss(nn.Module): loss = loss + c_pos_w * loss_pos.mean() return loss - + def compute_fine_loss(self, expec_f, expec_f_gt): - if self.fine_type == 'l2_with_std': + if self.fine_type == "l2_with_std": return self._compute_fine_loss_l2_std(expec_f, expec_f_gt) - elif self.fine_type == 'l2': + elif self.fine_type == "l2": return self._compute_fine_loss_l2(expec_f, expec_f_gt) else: raise NotImplementedError() @@ -101,9 +103,13 @@ class TopicFMLoss(nn.Module): expec_f (torch.Tensor): [M, 2] expec_f_gt (torch.Tensor): [M, 2] """ - correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + correct_mask = ( + torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr + ) if correct_mask.sum() == 0: - if self.training: # this seldomly happen when training, since we pad prediction with gt + if ( + self.training + ): # this seldomly happen when training, since we pad prediction with gt logger.warning("assign a false supervision to avoid ddp deadlock") correct_mask[0] = True else: @@ -118,34 +124,45 @@ class TopicFMLoss(nn.Module): expec_f_gt (torch.Tensor): [M, 2] """ # correct_mask tells you which pair to compute fine-loss - correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + correct_mask = ( + torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr + ) # use std as weight that measures uncertainty std = expec_f[:, 2] - inverse_std = 1. / torch.clamp(std, min=1e-10) - weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std + inverse_std = 1.0 / torch.clamp(std, min=1e-10) + weight = ( + inverse_std / torch.mean(inverse_std) + ).detach() # avoid minizing loss through increase std # corner case: no correct coarse match found if not correct_mask.any(): - if self.training: # this seldomly happen during training, since we pad prediction with gt - # sometimes there is not coarse-level gt at all. + if ( + self.training + ): # this seldomly happen during training, since we pad prediction with gt + # sometimes there is not coarse-level gt at all. logger.warning("assign a false supervision to avoid ddp deadlock") correct_mask[0] = True - weight[0] = 0. + weight[0] = 0.0 else: return None # l2 loss with std - offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1) + offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum( + -1 + ) loss = (offset_l2 * weight[correct_mask]).mean() return loss - + @torch.no_grad() def compute_c_weight(self, data): - """ compute element-wise weights for computing coarse-level loss. """ - if 'mask0' in data: - c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float() + """compute element-wise weights for computing coarse-level loss.""" + if "mask0" in data: + c_weight = ( + data["mask0"].flatten(-2)[..., None] + * data["mask1"].flatten(-2)[:, None] + ).float() else: c_weight = None return c_weight @@ -163,20 +180,24 @@ class TopicFMLoss(nn.Module): c_weight = self.compute_c_weight(data) # 1. coarse-level loss - loss_c = self.compute_coarse_loss(data['conf_matrix'], data['topic_matrix'], - data['conf_matrix_gt'], match_ids=(data['spv_b_ids'], data['spv_i_ids'], data['spv_j_ids']), - weight=c_weight) - loss = loss_c * self.loss_config['coarse_weight'] + loss_c = self.compute_coarse_loss( + data["conf_matrix"], + data["topic_matrix"], + data["conf_matrix_gt"], + match_ids=(data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]), + weight=c_weight, + ) + loss = loss_c * self.loss_config["coarse_weight"] loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) # 2. fine-level loss - loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt']) + loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"]) if loss_f is not None: - loss += loss_f * self.loss_config['fine_weight'] - loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) + loss += loss_f * self.loss_config["fine_weight"] + loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) else: assert self.training is False - loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound + loss_scalars.update({"loss_f": torch.tensor(1.0)}) # 1 is the upper bound - loss_scalars.update({'loss': loss.clone().detach().cpu()}) + loss_scalars.update({"loss": loss.clone().detach().cpu()}) data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/third_party/TopicFM/src/models/backbone/__init__.py b/third_party/TopicFM/src/models/backbone/__init__.py index 53f98db4e910b46716bed7cfc6ebbf8c8bfad399..72a80de20ba3f6bc02454f4930b25d6b18f4b34f 100644 --- a/third_party/TopicFM/src/models/backbone/__init__.py +++ b/third_party/TopicFM/src/models/backbone/__init__.py @@ -2,4 +2,4 @@ from .fpn import FPN def build_backbone(config): - return FPN(config['fpn']) + return FPN(config["fpn"]) diff --git a/third_party/TopicFM/src/models/backbone/fpn.py b/third_party/TopicFM/src/models/backbone/fpn.py index 93cc475f57317f9dbb8132cdfe0297391972f9e2..7f38ec13f196793a00cacbaaa3eb7c0a5d8e9605 100644 --- a/third_party/TopicFM/src/models/backbone/fpn.py +++ b/third_party/TopicFM/src/models/backbone/fpn.py @@ -4,12 +4,16 @@ import torch.nn.functional as F def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution without padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False + ) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) class ConvBlock(nn.Module): @@ -22,7 +26,7 @@ class ConvBlock(nn.Module): def forward(self, x): y = self.conv(x) if self.bn: - y = self.bn(y) #F.layer_norm(y, y.shape[1:]) + y = self.bn(y) # F.layer_norm(y, y.shape[1:]) y = self.act(y) return y @@ -37,14 +41,16 @@ class FPN(nn.Module): super().__init__() # Config block = ConvBlock - initial_dim = config['initial_dim'] - block_dims = config['block_dims'] + initial_dim = config["initial_dim"] + block_dims = config["block_dims"] # Class Variable self.in_planes = initial_dim # Networks - self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + 1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = nn.BatchNorm2d(initial_dim) self.relu = nn.ReLU(inplace=True) @@ -72,7 +78,7 @@ class FPN(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -94,16 +100,22 @@ class FPN(nn.Module): x4 = self.layer4(x3) # 1/16 # FPN - x4_out_2x = F.interpolate(x4, scale_factor=2., mode='bilinear', align_corners=True) + x4_out_2x = F.interpolate( + x4, scale_factor=2.0, mode="bilinear", align_corners=True + ) x3_out = self.layer3_outconv(x3) - x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + x3_out = self.layer3_outconv2(x3_out + x4_out_2x) - x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out_2x = F.interpolate( + x3_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x2_out = self.layer2_outconv(x2) - x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + x2_out = self.layer2_outconv2(x2_out + x3_out_2x) - x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out_2x = F.interpolate( + x2_out, scale_factor=2.0, mode="bilinear", align_corners=True + ) x1_out = self.layer1_outconv(x1) - x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + x1_out = self.layer1_outconv2(x1_out + x2_out_2x) return [x3_out, x1_out] diff --git a/third_party/TopicFM/src/models/modules/fine_preprocess.py b/third_party/TopicFM/src/models/modules/fine_preprocess.py index 4c8d264c1895be8f4e124fc3982d4e0d3b876af3..4cdce2d327ebc88371769946a292824f834729a5 100644 --- a/third_party/TopicFM/src/models/modules/fine_preprocess.py +++ b/third_party/TopicFM/src/models/modules/fine_preprocess.py @@ -9,15 +9,15 @@ class FinePreprocess(nn.Module): super().__init__() self.config = config - self.cat_c_feat = config['fine_concat_coarse_feat'] - self.W = self.config['fine_window_size'] + self.cat_c_feat = config["fine_concat_coarse_feat"] + self.W = self.config["fine_window_size"] - d_model_c = self.config['coarse']['d_model'] - d_model_f = self.config['fine']['d_model'] + d_model_c = self.config["coarse"]["d_model"] + d_model_f = self.config["fine"]["d_model"] self.d_model_f = d_model_f if self.cat_c_feat: self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) - self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True) self._reset_parameters() @@ -28,32 +28,48 @@ class FinePreprocess(nn.Module): def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): W = self.W - stride = data['hw0_f'][0] // data['hw0_c'][0] + stride = data["hw0_f"][0] // data["hw0_c"][0] - data.update({'W': W}) - if data['b_ids'].shape[0] == 0: + data.update({"W": W}) + if data["b_ids"].shape[0] == 0: feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) return feat0, feat1 # 1. unfold(crop) all local windows - feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) - feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) - feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) - feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f0_unfold = F.unfold( + feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2 + ) + feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2) + feat_f1_unfold = F.unfold( + feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2 + ) + feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2) # 2. select only the predicted matches - feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] - feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]] # option: use coarse-level feature as context: concat and linear if self.cat_c_feat: - feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], - feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] - feat_cf_win = self.merge_feat(torch.cat([ - torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] - repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] - ], -1)) + feat_c_win = self.down_proj( + torch.cat( + [ + feat_c0[data["b_ids"], data["i_ids"]], + feat_c1[data["b_ids"], data["j_ids"]], + ], + 0, + ) + ) # [2n, c] + feat_cf_win = self.merge_feat( + torch.cat( + [ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, "n c -> n ww c", ww=W**2), # [2n, ww, cf] + ], + -1, + ) + ) feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) return feat_f0_unfold, feat_f1_unfold diff --git a/third_party/TopicFM/src/models/modules/linear_attention.py b/third_party/TopicFM/src/models/modules/linear_attention.py index af6cd825033e98b7be15cc694ce28110ef84cc93..57b86b3ba682da62f9ff65893aa0ccd6753d32af 100644 --- a/third_party/TopicFM/src/models/modules/linear_attention.py +++ b/third_party/TopicFM/src/models/modules/linear_attention.py @@ -18,7 +18,7 @@ class LinearAttention(Module): self.eps = eps def forward(self, queries, keys, values, q_mask=None, kv_mask=None): - """ Multi-Head linear attention proposed in "Transformers are RNNs" + """Multi-Head linear attention proposed in "Transformers are RNNs" Args: queries: [N, L, H, D] keys: [N, S, H, D] @@ -54,7 +54,7 @@ class FullAttention(Module): self.dropout = Dropout(attention_dropout) def forward(self, queries, keys, values, q_mask=None, kv_mask=None): - """ Multi-head scaled dot-product attention, a.k.a full attention. + """Multi-head scaled dot-product attention, a.k.a full attention. Args: queries: [N, L, H, D] keys: [N, S, H, D] @@ -68,10 +68,12 @@ class FullAttention(Module): # Compute the unnormalized attention and apply the masks QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) if kv_mask is not None: - QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9) + QK.masked_fill_( + ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9 + ) # Compute the attention and the weighted average - softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + softmax_temp = 1.0 / queries.size(3) ** 0.5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=2) if self.use_dropout: A = self.dropout(A) diff --git a/third_party/TopicFM/src/models/modules/transformer.py b/third_party/TopicFM/src/models/modules/transformer.py index 27ff8f6554844b1e14a7094fcbad40876f766db8..cef17ca689cd0f844c1d6bd6c0f987a3e0c3be59 100644 --- a/third_party/TopicFM/src/models/modules/transformer.py +++ b/third_party/TopicFM/src/models/modules/transformer.py @@ -8,10 +8,7 @@ from .linear_attention import LinearAttention, FullAttention class LoFTREncoderLayer(nn.Module): - def __init__(self, - d_model, - nhead, - attention='linear'): + def __init__(self, d_model, nhead, attention="linear"): super(LoFTREncoderLayer, self).__init__() self.dim = d_model // nhead @@ -21,14 +18,14 @@ class LoFTREncoderLayer(nn.Module): self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) - self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.attention = LinearAttention() if attention == "linear" else FullAttention() self.merge = nn.Linear(d_model, d_model, bias=False) # feed-forward network self.mlp = nn.Sequential( - nn.Linear(d_model*2, d_model*2, bias=False), + nn.Linear(d_model * 2, d_model * 2, bias=False), nn.GELU(), - nn.Linear(d_model*2, d_model, bias=False), + nn.Linear(d_model * 2, d_model, bias=False), ) # norm and dropout @@ -50,8 +47,10 @@ class LoFTREncoderLayer(nn.Module): query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) - message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] - message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.attention( + query, key, value, q_mask=x_mask, kv_mask=source_mask + ) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] message = self.norm1(message) # feed-forward network @@ -68,18 +67,33 @@ class TopicFormer(nn.Module): super(TopicFormer, self).__init__() self.config = config - self.d_model = config['d_model'] - self.nhead = config['nhead'] - self.layer_names = config['layer_names'] - encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) - self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) - - self.topic_transformers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2*config['n_topic_transformers'])]) if config['n_samples'] > 0 else None #nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) - self.n_iter_topic_transformer = config['n_topic_transformers'] + self.d_model = config["d_model"] + self.nhead = config["nhead"] + self.layer_names = config["layer_names"] + encoder_layer = LoFTREncoderLayer( + config["d_model"], config["nhead"], config["attention"] + ) + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))] + ) - self.seed_tokens = nn.Parameter(torch.randn(config['n_topics'], config['d_model'])) - self.register_parameter('seed_tokens', self.seed_tokens) - self.n_samples = config['n_samples'] + self.topic_transformers = ( + nn.ModuleList( + [ + copy.deepcopy(encoder_layer) + for _ in range(2 * config["n_topic_transformers"]) + ] + ) + if config["n_samples"] > 0 + else None + ) # nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) + self.n_iter_topic_transformer = config["n_topic_transformers"] + + self.seed_tokens = nn.Parameter( + torch.randn(config["n_topics"], config["d_model"]) + ) + self.register_parameter("seed_tokens", self.seed_tokens) + self.n_samples = config["n_samples"] self._reset_parameters() @@ -94,9 +108,9 @@ class TopicFormer(nn.Module): topics (torch.Tensor): [N, L+S, K] """ prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:] - topics0, topics1 = topics[:, :L], topics[:, L:] + topics0, topics1 = topics[:, :L], topics[:, L:] - theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K] + theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K] theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1) theta = F.normalize(theta0 * theta1, p=1, dim=-1) if self.n_samples == 0: @@ -106,18 +120,28 @@ class TopicFormer(nn.Module): sampled_values = torch.gather(theta, dim=-1, index=sampled_inds) else: sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1) - sampled_topics0 = torch.gather(topics0, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1)) - sampled_topics1 = torch.gather(topics1, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1)) + sampled_topics0 = torch.gather( + topics0, + dim=-1, + index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1), + ) + sampled_topics1 = torch.gather( + topics1, + dim=-1, + index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1), + ) return sampled_topics0, sampled_topics1 def reduce_feat(self, feat, topick, N, C): len_topic = topick.sum(dim=-1).int() max_len = len_topic.max().item() selected_ids = topick.bool() - resized_feat = torch.zeros((N, max_len, C), dtype=torch.float32, device=feat.device) + resized_feat = torch.zeros( + (N, max_len, C), dtype=torch.float32, device=feat.device + ) new_mask = torch.zeros_like(resized_feat[..., 0]).bool() for i in range(N): - new_mask[i, :len_topic[i]] = True + new_mask[i, : len_topic[i]] = True resized_feat[new_mask, :] = feat[selected_ids, :] return resized_feat, new_mask, selected_ids @@ -130,8 +154,16 @@ class TopicFormer(nn.Module): mask1 (torch.Tensor): [N, S] (optional) """ - assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" - N, L, S, C, K = feat0.shape[0], feat0.shape[1], feat1.shape[1], feat0.shape[2], self.config['n_topics'] + assert ( + self.d_model == feat0.shape[2] + ), "the feature number of src and transformer must be equal" + N, L, S, C, K = ( + feat0.shape[0], + feat0.shape[1], + feat1.shape[1], + feat0.shape[2], + self.config["n_topics"], + ) seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1) @@ -142,18 +174,20 @@ class TopicFormer(nn.Module): mask = None for layer, name in zip(self.layers, self.layer_names): - if name == 'seed': + if name == "seed": # seeds = layer(seeds, feat0, None, mask0) # seeds = layer(seeds, feat1, None, mask1) seeds = layer(seeds, feat, None, mask) - elif name == 'feat': + elif name == "feat": feat0 = layer(feat0, seeds, mask0, None) feat1 = layer(feat1, seeds, mask1, None) dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds) prob_topics = F.softmax(dmatrix, dim=-1) - feat_topics = torch.zeros_like(dmatrix).scatter_(-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0) + feat_topics = torch.zeros_like(dmatrix).scatter_( + -1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0 + ) if mask is not None: feat_topics = feat_topics * mask.unsqueeze(-1) @@ -163,35 +197,57 @@ class TopicFormer(nn.Module): logger.warning("topic distribution is highly sparse!") sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L) if sampled_topics is not None: - updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(feat1) + updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like( + feat1 + ) s_topics0, s_topics1 = sampled_topics for k in range(s_topics0.shape[-1]): - topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S] + topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S] if (topick0.sum() > 0) and (topick1.sum() > 0): - new_feat0, new_mask0, selected_ids0 = self.reduce_feat(feat0, topick0, N, C) - new_feat1, new_mask1, selected_ids1 = self.reduce_feat(feat1, topick1, N, C) + new_feat0, new_mask0, selected_ids0 = self.reduce_feat( + feat0, topick0, N, C + ) + new_feat1, new_mask1, selected_ids1 = self.reduce_feat( + feat1, topick1, N, C + ) for idt in range(self.n_iter_topic_transformer): - new_feat0 = self.topic_transformers[idt*2](new_feat0, new_feat0, new_mask0, new_mask0) - new_feat1 = self.topic_transformers[idt*2](new_feat1, new_feat1, new_mask1, new_mask1) - new_feat0 = self.topic_transformers[idt*2+1](new_feat0, new_feat1, new_mask0, new_mask1) - new_feat1 = self.topic_transformers[idt*2+1](new_feat1, new_feat0, new_mask1, new_mask0) + new_feat0 = self.topic_transformers[idt * 2]( + new_feat0, new_feat0, new_mask0, new_mask0 + ) + new_feat1 = self.topic_transformers[idt * 2]( + new_feat1, new_feat1, new_mask1, new_mask1 + ) + new_feat0 = self.topic_transformers[idt * 2 + 1]( + new_feat0, new_feat1, new_mask0, new_mask1 + ) + new_feat1 = self.topic_transformers[idt * 2 + 1]( + new_feat1, new_feat0, new_mask1, new_mask0 + ) updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :] updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :] feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0 feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1 - conf_matrix = torch.einsum("nlc,nsc->nls", feat0, feat1) / C**.5 #(C * temperature) + conf_matrix = ( + torch.einsum("nlc,nsc->nls", feat0, feat1) / C**0.5 + ) # (C * temperature) if self.training: - topic_matrix = torch.einsum("nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:]) - outlier_mask = torch.einsum("nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:]) + topic_matrix = torch.einsum( + "nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:] + ) + outlier_mask = torch.einsum( + "nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:] + ) else: topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]} outlier_mask = torch.ones_like(conf_matrix) if mask0 is not None: - outlier_mask = (outlier_mask * mask0[..., None] * mask1[:, None]) #.bool() + outlier_mask = outlier_mask * mask0[..., None] * mask1[:, None] # .bool() conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9) - conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(conf_matrix, 2) # * topic_matrix + conf_matrix = F.softmax(conf_matrix, 1) * F.softmax( + conf_matrix, 2 + ) # * topic_matrix return feat0, feat1, conf_matrix, topic_matrix @@ -203,11 +259,15 @@ class LocalFeatureTransformer(nn.Module): super(LocalFeatureTransformer, self).__init__() self.config = config - self.d_model = config['d_model'] - self.nhead = config['nhead'] - self.layer_names = config['layer_names'] - encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) - self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) #len(self.layer_names))]) + self.d_model = config["d_model"] + self.nhead = config["nhead"] + self.layer_names = config["layer_names"] + encoder_layer = LoFTREncoderLayer( + config["d_model"], config["nhead"], config["attention"] + ) + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(2)] + ) # len(self.layer_names))]) self._reset_parameters() def _reset_parameters(self): @@ -224,7 +284,9 @@ class LocalFeatureTransformer(nn.Module): mask1 (torch.Tensor): [N, S] (optional) """ - assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" + assert ( + self.d_model == feat0.shape[2] + ), "the feature number of src and transformer must be equal" feat0 = self.layers[0](feat0, feat1, mask0, mask1) feat1 = self.layers[1](feat1, feat0, mask1, mask0) diff --git a/third_party/TopicFM/src/models/topic_fm.py b/third_party/TopicFM/src/models/topic_fm.py index 95cd22f9b66d08760382fe4cd22c4df918cc9f68..2556bdbb489574e13a5e5af60be87c546473d406 100644 --- a/third_party/TopicFM/src/models/topic_fm.py +++ b/third_party/TopicFM/src/models/topic_fm.py @@ -17,14 +17,14 @@ class TopicFM(nn.Module): # Modules self.backbone = build_backbone(config) - self.loftr_coarse = TopicFormer(config['coarse']) - self.coarse_matching = CoarseMatching(config['match_coarse']) + self.loftr_coarse = TopicFormer(config["coarse"]) + self.coarse_matching = CoarseMatching(config["match_coarse"]) self.fine_preprocess = FinePreprocess(config) self.loftr_fine = LocalFeatureTransformer(config["fine"]) self.fine_matching = FineMatching() def forward(self, data): - """ + """ Update: data (dict): { 'image0': (torch.Tensor): (N, 1, H, W) @@ -34,46 +34,65 @@ class TopicFM(nn.Module): } """ # 1. Local Feature CNN - data.update({ - 'bs': data['image0'].size(0), - 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] - }) + data.update( + { + "bs": data["image0"].size(0), + "hw0_i": data["image0"].shape[2:], + "hw1_i": data["image1"].shape[2:], + } + ) - if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence - feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) - (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence + feats_c, feats_f = self.backbone( + torch.cat([data["image0"], data["image1"]], dim=0) + ) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split( + data["bs"] + ), feats_f.split(data["bs"]) else: # handle different input shapes - (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone( + data["image0"] + ), self.backbone(data["image1"]) - data.update({ - 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], - 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] - }) + data.update( + { + "hw0_c": feat_c0.shape[2:], + "hw1_c": feat_c1.shape[2:], + "hw0_f": feat_f0.shape[2:], + "hw1_f": feat_f1.shape[2:], + } + ) # 2. coarse-level loftr module - feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') - feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c") + feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c") mask_c0 = mask_c1 = None # mask is useful in training - if 'mask0' in data: - mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + if "mask0" in data: + mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2) - feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) - data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ###### + feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse( + feat_c0, feat_c1, mask_c0, mask_c1 + ) + data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ###### # 3. match coarse-level self.coarse_matching(data) # 4. fine-level refinement - feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data) + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess( + feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data + ) if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted - feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + feat_f0_unfold, feat_f1_unfold = self.loftr_fine( + feat_f0_unfold, feat_f1_unfold + ) # 5. match fine-level self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) def load_state_dict(self, state_dict, *args, **kwargs): for k in list(state_dict.keys()): - if k.startswith('matcher.'): - state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + if k.startswith("matcher."): + state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k) return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/third_party/TopicFM/src/models/utils/coarse_matching.py b/third_party/TopicFM/src/models/utils/coarse_matching.py index 75adbb5cc465220e759a044f96f86c08da2d7a50..0cd0ea3db496fe50f82bf7660696e96e26b23484 100644 --- a/third_party/TopicFM/src/models/utils/coarse_matching.py +++ b/third_party/TopicFM/src/models/utils/coarse_matching.py @@ -5,8 +5,9 @@ from einops.einops import rearrange INF = 1e9 + def mask_border(m, b: int, v): - """ Mask borders with value + """Mask borders with value Args: m (torch.Tensor): [N, H0, W0, H1, W1] b (int) @@ -37,22 +38,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1): h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): - m[b_idx, h0 - bd:] = v - m[b_idx, :, w0 - bd:] = v - m[b_idx, :, :, h1 - bd:] = v - m[b_idx, :, :, :, w1 - bd:] = v + m[b_idx, h0 - bd :] = v + m[b_idx, :, w0 - bd :] = v + m[b_idx, :, :, h1 - bd :] = v + m[b_idx, :, :, :, w1 - bd :] = v def compute_max_candidates(p_m0, p_m1): """Compute the max candidates of all pairs within a batch - + Args: p_m0, p_m1 (torch.Tensor): padded masks """ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] - max_cand = torch.sum( - torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) return max_cand @@ -61,26 +61,27 @@ class CoarseMatching(nn.Module): super().__init__() self.config = config # general config - self.thr = config['thr'] - self.border_rm = config['border_rm'] + self.thr = config["thr"] + self.border_rm = config["border_rm"] # -- # for trainig fine-level LoFTR - self.train_coarse_percent = config['train_coarse_percent'] - self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + self.train_coarse_percent = config["train_coarse_percent"] + self.train_pad_num_gt_min = config["train_pad_num_gt_min"] # we provide 2 options for differentiable matching - self.match_type = config['match_type'] - if self.match_type == 'dual_softmax': - self.temperature = config['dsmax_temperature'] - elif self.match_type == 'sinkhorn': + self.match_type = config["match_type"] + if self.match_type == "dual_softmax": + self.temperature = config["dsmax_temperature"] + elif self.match_type == "sinkhorn": try: from .superglue import log_optimal_transport except ImportError: raise ImportError("download superglue.py first!") self.log_optimal_transport = log_optimal_transport self.bin_score = nn.Parameter( - torch.tensor(config['skh_init_bin_score'], requires_grad=True)) - self.skh_iters = config['skh_iters'] - self.skh_prefilter = config['skh_prefilter'] + torch.tensor(config["skh_init_bin_score"], requires_grad=True) + ) + self.skh_iters = config["skh_iters"] + self.skh_prefilter = config["skh_prefilter"] else: raise NotImplementedError() @@ -99,7 +100,7 @@ class CoarseMatching(nn.Module): 'mconf' (torch.Tensor): [M]} NOTE: M' != M during training. """ - conf_matrix = data['conf_matrix'] + conf_matrix = data["conf_matrix"] # predict coarse matches from conf_matrix data.update(**self.get_coarse_match(conf_matrix, data)) @@ -121,28 +122,33 @@ class CoarseMatching(nn.Module): 'mconf' (torch.Tensor): [M]} """ axes_lengths = { - 'h0c': data['hw0_c'][0], - 'w0c': data['hw0_c'][1], - 'h1c': data['hw1_c'][0], - 'w1c': data['hw1_c'][1] + "h0c": data["hw0_c"][0], + "w0c": data["hw0_c"][1], + "h1c": data["hw1_c"][0], + "w1c": data["hw1_c"][1], } _device = conf_matrix.device # 1. confidence thresholding mask = conf_matrix > self.thr - mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', - **axes_lengths) - if 'mask0' not in data: + mask = rearrange( + mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths + ) + if "mask0" not in data: mask_border(mask, self.border_rm, False) else: - mask_border_with_padding(mask, self.border_rm, False, - data['mask0'], data['mask1']) - mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', - **axes_lengths) + mask_border_with_padding( + mask, self.border_rm, False, data["mask0"], data["mask1"] + ) + mask = rearrange( + mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths + ) # 2. mutual nearest - mask = mask \ - * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + mask = ( + mask + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + ) # 3. find all valid coarse matches # this only works when at most one `True` in each row @@ -157,16 +163,17 @@ class CoarseMatching(nn.Module): # NOTE: # The sampling is performed across all pairs in a batch without manually balancing # #samples for fine-level increases w.r.t. batch_size - if 'mask0' not in data: - num_candidates_max = mask.size(0) * max( - mask.size(1), mask.size(2)) + if "mask0" not in data: + num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2)) else: num_candidates_max = compute_max_candidates( - data['mask0'], data['mask1']) - num_matches_train = int(num_candidates_max * - self.train_coarse_percent) + data["mask0"], data["mask1"] + ) + num_matches_train = int(num_candidates_max * self.train_coarse_percent) num_matches_pred = len(b_ids) - assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + assert ( + self.train_pad_num_gt_min < num_matches_train + ), "min-num-gt-pad should be less than num-train-matches" # pred_indices is to select from prediction if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: @@ -174,44 +181,55 @@ class CoarseMatching(nn.Module): else: pred_indices = torch.randint( num_matches_pred, - (num_matches_train - self.train_pad_num_gt_min, ), - device=_device) + (num_matches_train - self.train_pad_num_gt_min,), + device=_device, + ) # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) gt_pad_indices = torch.randint( - len(data['spv_b_ids']), - (max(num_matches_train - num_matches_pred, - self.train_pad_num_gt_min), ), - device=_device) - mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + len(data["spv_b_ids"]), + (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),), + device=_device, + ) + mconf_gt = torch.zeros( + len(data["spv_b_ids"]), device=_device + ) # set conf of gt paddings to all zero b_ids, i_ids, j_ids, mconf = map( - lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], - dim=0), - *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], - [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0), + *zip( + [b_ids, data["spv_b_ids"]], + [i_ids, data["spv_i_ids"]], + [j_ids, data["spv_j_ids"]], + [mconf, mconf_gt], + ) + ) # These matches select patches that feed into fine-level network - coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids} # 4. Update with matches in original image resolution - scale = data['hw0_i'][0] / data['hw0_c'][0] - scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale - scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale - mkpts0_c = torch.stack( - [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], - dim=1) * scale0 - mkpts1_c = torch.stack( - [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], - dim=1) * scale1 + scale = data["hw0_i"][0] / data["hw0_c"][0] + scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale + scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale + mkpts0_c = ( + torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1) + * scale0 + ) + mkpts1_c = ( + torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1) + * scale1 + ) # These matches is the current prediction (for visualization) - coarse_matches.update({ - 'gt_mask': mconf == 0, - 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches - 'mkpts0_c': mkpts0_c[mconf != 0], - 'mkpts1_c': mkpts1_c[mconf != 0], - 'mconf': mconf[mconf != 0] - }) + coarse_matches.update( + { + "gt_mask": mconf == 0, + "m_bids": b_ids[mconf != 0], # mconf == 0 => gt matches + "mkpts0_c": mkpts0_c[mconf != 0], + "mkpts1_c": mkpts1_c[mconf != 0], + "mconf": mconf[mconf != 0], + } + ) return coarse_matches diff --git a/third_party/TopicFM/src/models/utils/fine_matching.py b/third_party/TopicFM/src/models/utils/fine_matching.py index 018f2fe475600b319998c263a97237ce135c3aaf..7156e3e1f22e2e341062565e5ad6baee41dd9bc6 100644 --- a/third_party/TopicFM/src/models/utils/fine_matching.py +++ b/third_party/TopicFM/src/models/utils/fine_matching.py @@ -27,39 +27,57 @@ class FineMatching(nn.Module): """ M, WW, C = feat_f0.shape W = int(math.sqrt(WW)) - scale = data['hw0_i'][0] / data['hw0_f'][0] + scale = data["hw0_i"][0] / data["hw0_f"][0] self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale # corner case: if no coarse matches found if M == 0: - assert self.training == False, "M is always >0, when training, see coarse_matching.py" + assert ( + self.training == False + ), "M is always >0, when training, see coarse_matching.py" # logger.warning('No matches found in coarse-level.') - data.update({ - 'expec_f': torch.empty(0, 3, device=feat_f0.device), - 'mkpts0_f': data['mkpts0_c'], - 'mkpts1_f': data['mkpts1_c'], - }) + data.update( + { + "expec_f": torch.empty(0, 3, device=feat_f0.device), + "mkpts0_f": data["mkpts0_c"], + "mkpts1_f": data["mkpts1_c"], + } + ) return - feat_f0_picked = feat_f0[:, WW//2, :] + feat_f0_picked = feat_f0[:, WW // 2, :] - sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) - softmax_temp = 1. / C**.5 + sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1) + softmax_temp = 1.0 / C**0.5 heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) - feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C] + feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C] heatmap = heatmap.view(-1, W, W) # compute coordinates from heatmap - coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] - grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[ + 0 + ] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape( + 1, -1, 2 + ) # [1, WW, 2] # compute std over - var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2 # [M, 2] - std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability - + var = ( + torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) + - coords1_normalized**2 + ) # [M, 2] + std = torch.sum( + torch.sqrt(torch.clamp(var, min=1e-10)), -1 + ) # [M] clamp needed for numerical stability + # for fine-level supervision - data.update({'expec_f': torch.cat([coords1_normalized, std.unsqueeze(1)], -1), - 'descriptors0': feat_f0_picked.detach(), 'descriptors1': feat_f1_picked.detach()}) + data.update( + { + "expec_f": torch.cat([coords1_normalized, std.unsqueeze(1)], -1), + "descriptors0": feat_f0_picked.detach(), + "descriptors1": feat_f1_picked.detach(), + } + ) # compute absolute kpt coords self.get_fine_match(coords1_normalized, data) @@ -70,11 +88,13 @@ class FineMatching(nn.Module): # mkpts0_f and mkpts1_f # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale - mkpts0_f = data['mkpts0_c'] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])] - scale1 = scale * data['scale1'][data['b_ids']] if 'scale1' in data else scale - mkpts1_f = data['mkpts1_c'] + (coords1_normed * (W // 2) * scale1)[:len(data['mconf'])] + mkpts0_f = data[ + "mkpts0_c" + ] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])] + scale1 = scale * data["scale1"][data["b_ids"]] if "scale1" in data else scale + mkpts1_f = ( + data["mkpts1_c"] + + (coords1_normed * (W // 2) * scale1)[: len(data["mconf"])] + ) - data.update({ - "mkpts0_f": mkpts0_f, - "mkpts1_f": mkpts1_f - }) + data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f}) diff --git a/third_party/TopicFM/src/models/utils/geometry.py b/third_party/TopicFM/src/models/utils/geometry.py index f95cdb65b48324c4f4ceb20231b1bed992b41116..6101f738f2b2b7ee014fcb53a4032391939ed8cd 100644 --- a/third_party/TopicFM/src/models/utils/geometry.py +++ b/third_party/TopicFM/src/models/utils/geometry.py @@ -3,10 +3,10 @@ import torch @torch.no_grad() def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): - """ Warp kpts0 from I0 to I1 with depth, K and Rt + """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). - + Args: kpts0 (torch.Tensor): [N, L, 2] - , depth0 (torch.Tensor): [N, H, W], @@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): # Sample depth, get calculable_mask on depth != 0 kpts0_depth = torch.stack( - [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + [ + depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] + for i in range(kpts0.shape[0]) + ], + dim=0, ) # (N, L) nonzero_mask = kpts0_depth != 0 # Unproject - kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) # Rigid Transform - w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] # Project w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) - w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth # Covisible Check h, w = depth1.shape[1:3] - covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ - (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) w_kpts0_long = w_kpts0.long() w_kpts0_long[~covisible_mask, :] = 0 w_kpts0_depth = torch.stack( - [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + [ + depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] + for i in range(w_kpts0_long.shape[0]) + ], + dim=0, ) # (N, L) - consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + consistent_mask = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() < 0.2 valid_mask = nonzero_mask * covisible_mask * consistent_mask return valid_mask, w_kpts0 diff --git a/third_party/TopicFM/src/models/utils/supervision.py b/third_party/TopicFM/src/models/utils/supervision.py index 1f1f0478fdcbe7f8ceffbc4aff4d507cec55bbd2..86f167e95439d588c998ca32b9296c3482484215 100644 --- a/third_party/TopicFM/src/models/utils/supervision.py +++ b/third_party/TopicFM/src/models/utils/supervision.py @@ -13,7 +13,7 @@ from .geometry import warp_kpts @torch.no_grad() def mask_pts_at_padded_regions(grid_pt, mask): """For megadepth dataset, zero-padding exists in images""" - mask = repeat(mask, 'n h w -> n (h w) c', c=2) + mask = repeat(mask, "n h w -> n (h w) c", c=2) grid_pt[~mask.bool()] = 0 return grid_pt @@ -30,37 +30,55 @@ def spvs_coarse(data, config): 'spv_w_pt0_i': [N, hw0, 2], in original image resolution 'spv_pt1_i': [N, hw1, 2], in original image resolution } - + NOTE: - for scannet dataset, there're 3 kinds of resolution {i, c, f} - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} """ # 1. misc - device = data['image0'].device - N, _, H0, W0 = data['image0'].shape - _, _, H1, W1 = data['image1'].shape - scale = config['MODEL']['RESOLUTION'][0] - scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale - scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + device = data["image0"].device + N, _, H0, W0 = data["image0"].shape + _, _, H1, W1 = data["image1"].shape + scale = config["MODEL"]["RESOLUTION"][0] + scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale + scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) # 2. warp grids # create kpts in meshgrid and resize them to image resolution - grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_c = ( + create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1) + ) # [N, hw, 2] grid_pt0_i = scale0 * grid_pt0_c - grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_c = ( + create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1) + ) grid_pt1_i = scale1 * grid_pt1_c # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt - if 'mask0' in data: - grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) - grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + if "mask0" in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"]) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"]) # warp kpts bi-directionally and resize them to coarse-level resolution # (no depth consistency check, since it leads to worse results experimentally) # (unhandled edge case: points with 0-depth will be warped to the left-up corner) - _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) - _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + _, w_pt0_i = warp_kpts( + grid_pt0_i, + data["depth0"], + data["depth1"], + data["T_0to1"], + data["K0"], + data["K1"], + ) + _, w_pt1_i = warp_kpts( + grid_pt1_i, + data["depth1"], + data["depth0"], + data["T_1to0"], + data["K1"], + data["K0"], + ) w_pt0_c = w_pt0_i / scale1 w_pt1_c = w_pt1_i / scale0 @@ -72,21 +90,26 @@ def spvs_coarse(data, config): # corner case: out of boundary def out_bound_mask(pt, w, h): - return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + return ( + (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + ) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 - loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) - correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + loop_back = torch.stack( + [nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0 + ) + correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1) correct_0to1[:, 0] = False # ignore the top-left corner # 4. construct a gt conf_matrix - conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device) b_ids, i_ids = torch.where(correct_0to1 != 0) j_ids = nearest_index1[b_ids, i_ids] conf_matrix_gt[b_ids, i_ids, j_ids] = 1 - data.update({'conf_matrix_gt': conf_matrix_gt}) + data.update({"conf_matrix_gt": conf_matrix_gt}) # 5. save coarse matches(gt) for training fine level if len(b_ids) == 0: @@ -96,30 +119,26 @@ def spvs_coarse(data, config): i_ids = torch.tensor([0], device=device) j_ids = torch.tensor([0], device=device) - data.update({ - 'spv_b_ids': b_ids, - 'spv_i_ids': i_ids, - 'spv_j_ids': j_ids - }) + data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids}) # 6. save intermediate results (for fast fine-level computation) - data.update({ - 'spv_w_pt0_i': w_pt0_i, - 'spv_pt1_i': grid_pt1_i - }) + data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i}) def compute_supervision_coarse(data, config): - assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" - data_source = data['dataset_name'][0] - if data_source.lower() in ['scannet', 'megadepth']: + assert ( + len(set(data["dataset_name"])) == 1 + ), "Do not support mixed datasets training!" + data_source = data["dataset_name"][0] + if data_source.lower() in ["scannet", "megadepth"]: spvs_coarse(data, config) else: - raise ValueError(f'Unknown data source: {data_source}') + raise ValueError(f"Unknown data source: {data_source}") ############## ↓ Fine-Level supervision ↓ ############## + @torch.no_grad() def spvs_fine(data, config): """ @@ -129,23 +148,25 @@ def spvs_fine(data, config): """ # 1. misc # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') - w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] - scale = config['MODEL']['RESOLUTION'][1] - radius = config['MODEL']['FINE_WINDOW_SIZE'] // 2 + w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"] + scale = config["MODEL"]["RESOLUTION"][1] + radius = config["MODEL"]["FINE_WINDOW_SIZE"] // 2 # 2. get coarse prediction - b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"] # 3. compute gt - scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + scale = scale * data["scale1"][b_ids] if "scale0" in data else scale # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later - expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + expec_f_gt = ( + (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius + ) # [M, 2] data.update({"expec_f_gt": expec_f_gt}) def compute_supervision_fine(data, config): - data_source = data['dataset_name'][0] - if data_source.lower() in ['scannet', 'megadepth']: + data_source = data["dataset_name"][0] + if data_source.lower() in ["scannet", "megadepth"]: spvs_fine(data, config) else: raise NotImplementedError diff --git a/third_party/TopicFM/src/optimizers/__init__.py b/third_party/TopicFM/src/optimizers/__init__.py index e1db2285352586c250912bdd2c4ae5029620ab5f..e4e36c22e00217deccacd589f8924b2f74589456 100644 --- a/third_party/TopicFM/src/optimizers/__init__.py +++ b/third_party/TopicFM/src/optimizers/__init__.py @@ -7,9 +7,13 @@ def build_optimizer(model, config): lr = config.TRAINER.TRUE_LR if name == "adam": - return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + return torch.optim.Adam( + model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY + ) elif name == "adamw": - return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY + ) else: raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") @@ -24,18 +28,27 @@ def build_scheduler(config, optimizer): 'frequency': x, (optional) } """ - scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL} name = config.TRAINER.SCHEDULER - if name == 'MultiStepLR': + if name == "MultiStepLR": scheduler.update( - {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) - elif name == 'CosineAnnealing': + { + "scheduler": MultiStepLR( + optimizer, + config.TRAINER.MSLR_MILESTONES, + gamma=config.TRAINER.MSLR_GAMMA, + ) + } + ) + elif name == "CosineAnnealing": scheduler.update( - {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) - elif name == 'ExponentialLR': + {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)} + ) + elif name == "ExponentialLR": scheduler.update( - {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)} + ) else: raise NotImplementedError() diff --git a/third_party/TopicFM/src/utils/augment.py b/third_party/TopicFM/src/utils/augment.py index d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d..068751c6c07091bbaed76debd43a73155f61b9bd 100644 --- a/third_party/TopicFM/src/utils/augment.py +++ b/third_party/TopicFM/src/utils/augment.py @@ -7,16 +7,21 @@ class DarkAug(object): """ def __init__(self) -> None: - self.augmentor = A.Compose([ - A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), - A.Blur(p=0.1, blur_limit=(3, 9)), - A.MotionBlur(p=0.2, blur_limit=(3, 25)), - A.RandomGamma(p=0.1, gamma_limit=(15, 65)), - A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) - ], p=0.75) + self.augmentor = A.Compose( + [ + A.RandomBrightnessContrast( + p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3) + ), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)), + ], + p=0.75, + ) def __call__(self, x): - return self.augmentor(image=x)['image'] + return self.augmentor(image=x)["image"] class MobileAug(object): @@ -25,31 +30,36 @@ class MobileAug(object): """ def __init__(self): - self.augmentor = A.Compose([ - A.MotionBlur(p=0.25), - A.ColorJitter(p=0.5), - A.RandomRain(p=0.1), # random occlusion - A.RandomSunFlare(p=0.1), - A.JpegCompression(p=0.25), - A.ISONoise(p=0.25) - ], p=1.0) + self.augmentor = A.Compose( + [ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25), + ], + p=1.0, + ) def __call__(self, x): - return self.augmentor(image=x)['image'] + return self.augmentor(image=x)["image"] def build_augmentor(method=None, **kwargs): if method is not None: - raise NotImplementedError('Using of augmentation functions are not supported yet!') - if method == 'dark': + raise NotImplementedError( + "Using of augmentation functions are not supported yet!" + ) + if method == "dark": return DarkAug() - elif method == 'mobile': + elif method == "mobile": return MobileAug() elif method is None: return None else: - raise ValueError(f'Invalid augmentation method: {method}') + raise ValueError(f"Invalid augmentation method: {method}") -if __name__ == '__main__': - augmentor = build_augmentor('FDA') +if __name__ == "__main__": + augmentor = build_augmentor("FDA") diff --git a/third_party/TopicFM/src/utils/comm.py b/third_party/TopicFM/src/utils/comm.py index 26ec9517cc47e224430106d8ae9aa99a3fe49167..9f578cda8933cc358934c645fcf413c63ab4d79d 100644 --- a/third_party/TopicFM/src/utils/comm.py +++ b/third_party/TopicFM/src/utils/comm.py @@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group): device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) - if len(buffer) > 1024 ** 3: + if len(buffer) > 1024**3: logger = logging.getLogger(__name__) logger.warning( "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( - get_rank(), len(buffer) / (1024 ** 3), device + get_rank(), len(buffer) / (1024**3), device ) ) storage = torch.ByteStorage.from_buffer(buffer) @@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group): ), "comm.gather/all_gather must be called from ranks within the given group!" local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) size_list = [ - torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) ] dist.all_gather(size_list, local_size, group=group) @@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group): # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: - padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor @@ -164,7 +167,8 @@ def all_gather(data, group=None): # receiving Tensor from all ranks tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list ] dist.all_gather(tensor_list, tensor, group=group) @@ -205,7 +209,8 @@ def gather(data, dst=0, group=None): if rank == dst: max_size = max(size_list) tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list ] dist.gather(tensor, tensor_list, dst=dst, group=group) @@ -228,7 +233,7 @@ def shared_random_seed(): All workers must call this function, otherwise it will deadlock. """ - ints = np.random.randint(2 ** 31) + ints = np.random.randint(2**31) all_ints = all_gather(ints) return all_ints[0] diff --git a/third_party/TopicFM/src/utils/dataloader.py b/third_party/TopicFM/src/utils/dataloader.py index 6da37b880a290c2bb3ebb028d0c8dab592acc5c1..b980dfd344714870ecdacd9e7a9742f51c3ee14d 100644 --- a/third_party/TopicFM/src/utils/dataloader.py +++ b/third_party/TopicFM/src/utils/dataloader.py @@ -3,21 +3,22 @@ import numpy as np # --- PL-DATAMODULE --- + def get_local_split(items: list, world_size: int, rank: int, seed: int): - """ The local rank only loads a split of the dataset. """ + """The local rank only loads a split of the dataset.""" n_items = len(items) items_permute = np.random.RandomState(seed).permutation(items) if n_items % world_size == 0: padded_items = items_permute else: padding = np.random.RandomState(seed).choice( - items, - world_size - (n_items % world_size), - replace=True) + items, world_size - (n_items % world_size), replace=True + ) padded_items = np.concatenate([items_permute, padding]) - assert len(padded_items) % world_size == 0, \ - f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + assert ( + len(padded_items) % world_size == 0 + ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}" n_per_rank = len(padded_items) // world_size - local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)] return local_items diff --git a/third_party/TopicFM/src/utils/dataset.py b/third_party/TopicFM/src/utils/dataset.py index 647bbadd821b6c90736ed45462270670b1017b0b..f26722dddcc15516b1986182a246b0cdb52c347a 100644 --- a/third_party/TopicFM/src/utils/dataset.py +++ b/third_party/TopicFM/src/utils/dataset.py @@ -12,8 +12,11 @@ MEGADEPTH_CLIENT = SCANNET_CLIENT = None # --- DATA IO --- + def load_array_from_s3( - path, client, cv_type, + path, + client, + cv_type, use_h5py=False, ): byte_str = client.Get(path) @@ -23,7 +26,7 @@ def load_array_from_s3( data = cv2.imdecode(raw_array, cv_type) else: f = io.BytesIO(byte_str) - data = np.array(h5py.File(f, 'r')['/depth']) + data = np.array(h5py.File(f, "r")["/depth"]) except Exception as ex: print(f"==> Data loading failure: {path}") raise ex @@ -33,9 +36,8 @@ def load_array_from_s3( def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): - cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ - else cv2.IMREAD_COLOR - if str(path).startswith('s3://'): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR + if str(path).startswith("s3://"): image = load_array_from_s3(str(path), client, cv_type) else: image = cv2.imread(str(path), cv_type) @@ -49,9 +51,9 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): def get_resized_wh(w, h, resize=None): - if (resize is not None) and (max(h,w) > resize): # resize the longer edge + if (resize is not None) and (max(h, w) > resize): # resize the longer edge scale = resize / max(h, w) - w_new, h_new = int(round(w*scale)), int(round(h*scale)) + w_new, h_new = int(round(w * scale)), int(round(h * scale)) else: w_new, h_new = w, h return w_new, h_new @@ -66,20 +68,22 @@ def get_divisible_wh(w, h, df=None): def pad_bottom_right(inp, pad_size, ret_mask=False): - assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + assert isinstance(pad_size, int) and pad_size >= max( + inp.shape[-2:] + ), f"{pad_size} < {max(inp.shape[-2:])}" mask = None if inp.ndim == 2: padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) - padded[:inp.shape[0], :inp.shape[1]] = inp + padded[: inp.shape[0], : inp.shape[1]] = inp if ret_mask: mask = np.zeros((pad_size, pad_size), dtype=bool) - mask[:inp.shape[0], :inp.shape[1]] = True + mask[: inp.shape[0], : inp.shape[1]] = True elif inp.ndim == 3: padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) - padded[:, :inp.shape[1], :inp.shape[2]] = inp + padded[:, : inp.shape[1], : inp.shape[2]] = inp if ret_mask: mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) - mask[:, :inp.shape[1], :inp.shape[2]] = True + mask[:, : inp.shape[1], : inp.shape[2]] = True else: raise NotImplementedError() return padded, mask @@ -87,6 +91,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False): # --- MEGADEPTH --- + def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): """ Args: @@ -96,7 +101,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No Returns: image (torch.tensor): (1, h, w) mask (torch.tensor): (h, w) - scale (torch.tensor): [w/w_new, h/h_new] + scale (torch.tensor): [w/w_new, h/h_new] """ # read image image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) @@ -107,25 +112,27 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No w_new, h_new = get_divisible_wh(w_new, h_new, df) image = cv2.resize(image, (w_new, h_new)) - scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) if padding: # padding - pad_to = resize #max(h_new, w_new) + pad_to = resize # max(h_new, w_new) image, mask = pad_bottom_right(image, pad_to, ret_mask=True) else: mask = None - image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + image = ( + torch.from_numpy(image).float()[None] / 255 + ) # (h, w) -> (1, h, w) and normalized mask = torch.from_numpy(mask) if mask is not None else None return image, mask, scale def read_megadepth_depth(path, pad_to=None): - if str(path).startswith('s3://'): + if str(path).startswith("s3://"): depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) else: - depth = np.array(h5py.File(path, 'r')['depth']) + depth = np.array(h5py.File(path, "r")["depth"]) if pad_to is not None: depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) depth = torch.from_numpy(depth).float() # (h, w) @@ -134,6 +141,7 @@ def read_megadepth_depth(path, pad_to=None): # --- ScanNet --- + def read_scannet_gray(path, resize=(640, 480), augment_fn=None): """ Args: @@ -142,7 +150,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None): Returns: image (torch.tensor): (1, h, w) mask (torch.tensor): (h, w) - scale (torch.tensor): [w/w_new, h/h_new] + scale (torch.tensor): [w/w_new, h/h_new] """ # read and resize image image = imread_gray(path, augment_fn) @@ -155,6 +163,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None): # ---- evaluation datasets: HLoc, Aachen, InLoc + def read_img_gray(path, resize=None, down_factor=16): # read and resize image image = imread_gray(path, None) @@ -174,7 +183,7 @@ def read_img_gray(path, resize=None, down_factor=16): def read_scannet_depth(path): - if str(path).startswith('s3://'): + if str(path).startswith("s3://"): depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) else: depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) @@ -184,18 +193,17 @@ def read_scannet_depth(path): def read_scannet_pose(path): - """ Read ScanNet's Camera2World pose and transform it to World2Camera. - + """Read ScanNet's Camera2World pose and transform it to World2Camera. + Returns: pose_w2c (np.ndarray): (4, 4) """ - cam2world = np.loadtxt(path, delimiter=' ') + cam2world = np.loadtxt(path, delimiter=" ") world2cam = inv(cam2world) return world2cam def read_scannet_intrinsic(path): - """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. - """ - intrinsic = np.loadtxt(path, delimiter=' ') + """Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" + intrinsic = np.loadtxt(path, delimiter=" ") return intrinsic[:-1, :-1] diff --git a/third_party/TopicFM/src/utils/metrics.py b/third_party/TopicFM/src/utils/metrics.py index a93c31ed1d151cd41e2449a19be2d6abc5f9d419..6190b04f0af85680a0c951f74309c0b66c80e1e5 100644 --- a/third_party/TopicFM/src/utils/metrics.py +++ b/third_party/TopicFM/src/utils/metrics.py @@ -9,6 +9,7 @@ from kornia.geometry.conversions import convert_points_to_homogeneous # --- METRICS --- + def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): # angle error between 2 vectors t_gt = T_0to1[:3, 3] @@ -21,7 +22,7 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): # angle error between 2 rotation matrices R_gt = T_0to1[:3, :3] cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 - cos = np.clip(cos, -1., 1.) # handle numercial errors + cos = np.clip(cos, -1.0, 1.0) # handle numercial errors R_err = np.rad2deg(np.abs(np.arccos(cos))) return t_err, R_err @@ -43,30 +44,36 @@ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] Etp1 = pts1 @ E # [N, 3] - d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + d = p1Ep0**2 * ( + 1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2) + + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2) + ) # N return d def compute_symmetrical_epipolar_errors(data): - """ + """ Update: data (dict):{"epi_errs": [M]} """ - Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) - E_mat = Tx @ data['T_0to1'][:, :3, :3] + Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3]) + E_mat = Tx @ data["T_0to1"][:, :3, :3] - m_bids = data['m_bids'] - pts0 = data['mkpts0_f'] - pts1 = data['mkpts1_f'] + m_bids = data["m_bids"] + pts0 = data["mkpts0_f"] + pts1 = data["mkpts1_f"] epi_errs = [] for bs in range(Tx.size(0)): mask = m_bids == bs epi_errs.append( - symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + symmetric_epipolar_distance( + pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs] + ) + ) epi_errs = torch.cat(epi_errs, dim=0) - data.update({'epi_errs': epi_errs}) + data.update({"epi_errs": epi_errs}) def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): @@ -81,7 +88,8 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): # compute pose with cv2 E, mask = cv2.findEssentialMat( - kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC + ) if E is None: print("\nE is None while trying to recover pose.\n") return None @@ -99,7 +107,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999): - """ + """ Update: data (dict):{ "R_errs" List[float]: [N] @@ -107,35 +115,40 @@ def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999): "inliers" List[np.ndarray]: [N] } """ - pixel_thr = config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr # 0.5 + pixel_thr = ( + config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr + ) # 0.5 conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf # 0.99999 - data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + data.update({"R_errs": [], "t_errs": [], "inliers": []}) - m_bids = data['m_bids'].cpu().numpy() - pts0 = data['mkpts0_f'].cpu().numpy() - pts1 = data['mkpts1_f'].cpu().numpy() - K0 = data['K0'].cpu().numpy() - K1 = data['K1'].cpu().numpy() - T_0to1 = data['T_0to1'].cpu().numpy() + m_bids = data["m_bids"].cpu().numpy() + pts0 = data["mkpts0_f"].cpu().numpy() + pts1 = data["mkpts1_f"].cpu().numpy() + K0 = data["K0"].cpu().numpy() + K1 = data["K1"].cpu().numpy() + T_0to1 = data["T_0to1"].cpu().numpy() for bs in range(K0.shape[0]): mask = m_bids == bs - ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) + ret = estimate_pose( + pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf + ) if ret is None: - data['R_errs'].append(np.inf) - data['t_errs'].append(np.inf) - data['inliers'].append(np.array([]).astype(np.bool)) + data["R_errs"].append(np.inf) + data["t_errs"].append(np.inf) + data["inliers"].append(np.array([]).astype(np.bool)) else: R, t, inliers = ret t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) - data['R_errs'].append(R_err) - data['t_errs'].append(t_err) - data['inliers'].append(inliers) + data["R_errs"].append(R_err) + data["t_errs"].append(t_err) + data["inliers"].append(inliers) # --- METRIC AGGREGATION --- + def error_auc(errors, thresholds): """ Args: @@ -149,11 +162,11 @@ def error_auc(errors, thresholds): thresholds = [5, 10, 20] for thr in thresholds: last_index = np.searchsorted(errors, thr) - y = recall[:last_index] + [recall[last_index-1]] + y = recall[:last_index] + [recall[last_index - 1]] x = errors[:last_index] + [thr] aucs.append(np.trapz(y, x) / thr) - return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)} def epidist_prec(errors, thresholds, ret_dict=False): @@ -165,29 +178,33 @@ def epidist_prec(errors, thresholds, ret_dict=False): prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) precs.append(np.mean(prec_) if len(prec_) > 0 else 0) if ret_dict: - return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + return {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)} else: return precs def aggregate_metrics(metrics, epi_err_thr=5e-4): - """ Aggregate metrics for the whole dataset: + """Aggregate metrics for the whole dataset: (This method should be called once per dataset) 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) """ # filter duplicates - unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"])) unq_ids = list(unq_ids.values()) - logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...") # pose auc angular_thresholds = [5, 10, 20] - pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[ + unq_ids + ] aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) # matching precision dist_thresholds = [epi_err_thr] - precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + precs = epidist_prec( + np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True + ) # (prec@err_thr) return {**aucs, **precs} diff --git a/third_party/TopicFM/src/utils/misc.py b/third_party/TopicFM/src/utils/misc.py index 9c8db04666519753ea2df43903ab6c47ec00a9a1..461077d77f1628c67055d841a5e70c29c7b82ade 100644 --- a/third_party/TopicFM/src/utils/misc.py +++ b/third_party/TopicFM/src/utils/misc.py @@ -24,7 +24,7 @@ def upper_config(dict_cfg): def log_on(condition, message, level): if condition: - assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"] logger.log(level, message) @@ -34,32 +34,35 @@ def get_rank_zero_only_logger(logger: _Logger): else: for _level in logger._core.levels.keys(): level = _level.lower() - setattr(logger, level, - lambda x: None) + setattr(logger, level, lambda x: None) logger._log = lambda x: None return logger def setup_gpus(gpus: Union[str, int]) -> int: - """ A temporary fix for pytorch-lighting 1.3.x """ + """A temporary fix for pytorch-lighting 1.3.x""" gpus = str(gpus) gpu_ids = [] - - if ',' not in gpus: + + if "," not in gpus: n_gpus = int(gpus) return n_gpus if n_gpus != -1 else torch.cuda.device_count() else: - gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] - + gpu_ids = [i.strip() for i in gpus.split(",") if i != ""] + # setup environment variables - visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") if visible_devices is None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) - visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') - logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids) + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + logger.warning( + f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}" + ) else: - logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + logger.warning( + "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process." + ) return len(gpu_ids) @@ -70,11 +73,11 @@ def flattenList(x): @contextlib.contextmanager def tqdm_joblib(tqdm_object): """Context manager to patch joblib to report into tqdm progress bar given as argument - + Usage: with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) - + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) ret_vals = Parallel(n_jobs=args.world_size)( delayed(lambda x: _compute_cov_score(pid, *x))(param) @@ -83,6 +86,7 @@ def tqdm_joblib(tqdm_object): total=len(image_ids)*(len(image_ids)-1)/2)) Src: https://stackoverflow.com/a/58936697 """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -98,4 +102,3 @@ def tqdm_joblib(tqdm_object): finally: joblib.parallel.BatchCompletionCallBack = old_batch_callback tqdm_object.close() - diff --git a/third_party/TopicFM/src/utils/plotting.py b/third_party/TopicFM/src/utils/plotting.py index 89b22ef27e6152225d07ab24bb3e62718d180b59..189045409c822f2e1d79610b29ea7e2825ae4bbd 100644 --- a/third_party/TopicFM/src/utils/plotting.py +++ b/third_party/TopicFM/src/utils/plotting.py @@ -9,37 +9,49 @@ import torch def _compute_conf_thresh(data): - dataset_name = data['dataset_name'][0].lower() - if dataset_name == 'scannet': + dataset_name = data["dataset_name"][0].lower() + if dataset_name == "scannet": thr = 5e-4 - elif dataset_name == 'megadepth': + elif dataset_name == "megadepth": thr = 1e-4 else: - raise ValueError(f'Unknown dataset: {dataset_name}') + raise ValueError(f"Unknown dataset: {dataset_name}") return thr # --- VISUALIZATION --- # + def make_matching_figure( - img0, img1, mkpts0, mkpts1, color, - kpts0=None, kpts1=None, text=[], dpi=75, path=None): + img0, + img1, + mkpts0, + mkpts1, + color, + kpts0=None, + kpts1=None, + text=[], + dpi=75, + path=None, +): # draw image pair - assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + assert ( + mkpts0.shape[0] == mkpts1.shape[0] + ), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}" fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) axes[0].imshow(img0) # , cmap='gray') axes[1].imshow(img1) # , cmap='gray') - for i in range(2): # clear all frames + for i in range(2): # clear all frames axes[i].get_yaxis().set_ticks([]) axes[i].get_xaxis().set_ticks([]) for spine in axes[i].spines.values(): spine.set_visible(False) plt.tight_layout(pad=1) - + if kpts0 is not None: assert kpts1 is not None - axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=5) - axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=5) + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5) # draw matches if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: @@ -47,99 +59,112 @@ def make_matching_figure( transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) - fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), - (fkpts0[i, 1], fkpts1[i, 1]), - transform=fig.transFigure, c=color[i], linewidth=2) - for i in range(len(mkpts0))] - + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=2, + ) + for i in range(len(mkpts0)) + ] + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) # put txts - txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" fig.text( - 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, - fontsize=15, va='top', ha='left', color=txt_color) + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) # save or return figure if path: - plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) plt.close() else: return fig -def _make_evaluation_figure(data, b_id, alpha='dynamic'): - b_mask = data['m_bids'] == b_id +def _make_evaluation_figure(data, b_id, alpha="dynamic"): + b_mask = data["m_bids"] == b_id conf_thr = _compute_conf_thresh(data) - - img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) - kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() - kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() - + + img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data["mkpts0_f"][b_mask].cpu().numpy() + kpts1 = data["mkpts1_f"][b_mask].cpu().numpy() + # for megadepth, we visualize matches on the resized image - if 'scale0' in data: - kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] - kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + if "scale0" in data: + kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]] - epi_errs = data['epi_errs'][b_mask].cpu().numpy() + epi_errs = data["epi_errs"][b_mask].cpu().numpy() correct_mask = epi_errs < conf_thr precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 n_correct = np.sum(correct_mask) - n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu()) recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) # recall might be larger than 1, since the calculation of conf_matrix_gt # uses groundtruth depths and camera poses, but epipolar distance is used here. # matching info - if alpha == 'dynamic': + if alpha == "dynamic": alpha = dynamic_alpha(len(correct_mask)) color = error_colormap(epi_errs, conf_thr, alpha=alpha) - + text = [ - f'#Matches {len(kpts0)}', - f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', - f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + f"#Matches {len(kpts0)}", + f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}", + f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}", ] - + # make the figure - figure = make_matching_figure(img0, img1, kpts0, kpts1, - color, text=text) + figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text) return figure + def _make_confidence_figure(data, b_id): # TODO: Implement confidence figure raise NotImplementedError() -def make_matching_figures(data, config, mode='evaluation'): - """ Make matching figures for a batch. - +def make_matching_figures(data, config, mode="evaluation"): + """Make matching figures for a batch. + Args: data (Dict): a batch updated by PL_LoFTR. config (Dict): matcher config Returns: figures (Dict[str, List[plt.figure]] """ - assert mode in ['evaluation', 'confidence'] # 'confidence' + assert mode in ["evaluation", "confidence"] # 'confidence' figures = {mode: []} - for b_id in range(data['image0'].size(0)): - if mode == 'evaluation': + for b_id in range(data["image0"].size(0)): + if mode == "evaluation": fig = _make_evaluation_figure( - data, b_id, - alpha=config.TRAINER.PLOT_MATCHES_ALPHA) - elif mode == 'confidence': + data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA + ) + elif mode == "confidence": fig = _make_confidence_figure(data, b_id) else: - raise ValueError(f'Unknown plot mode: {mode}') + raise ValueError(f"Unknown plot mode: {mode}") figures[mode].append(fig) return figures -def dynamic_alpha(n_matches, - milestones=[0, 300, 1000, 2000], - alphas=[1.0, 0.8, 0.4, 0.2]): +def dynamic_alpha( + n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2] +): if n_matches == 0: return 1.0 ranges = list(zip(alphas, alphas[1:] + [None])) @@ -148,14 +173,18 @@ def dynamic_alpha(n_matches, if _range[1] is None: return _range[0] return _range[1] + (milestones[loc + 1] - n_matches) / ( - milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + milestones[loc + 1] - milestones[loc] + ) * (_range[0] - _range[1]) def error_colormap(err, thr, alpha=1.0): assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" x = 1 - np.clip(err / (thr * 2), 0, 1) return np.clip( - np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1), + 0, + 1, + ) np.random.seed(1995) @@ -163,7 +192,9 @@ color_map = np.arange(100) np.random.shuffle(color_map) -def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None): +def draw_topics( + data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None +): topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"] hw0_c, hw1_c = data["hw0_c"], data["hw1_c"] @@ -188,27 +219,38 @@ def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, sa theta1 /= theta1.sum().float() # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics] # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics] - top_topics = torch.argsort(theta0*theta1, descending=True)[:show_n_topics] + top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics] # print(sum_topic0, sum_topic1) - topic0 = topic0[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 # + topic0 = topic0[0].argmax( + dim=-1, keepdim=True + ) # .float() / (n_topics - 1) #* 255 + 1 # # topic0[~mask0_nonzero] = -1 - topic1 = topic1[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 + topic1 = topic1[0].argmax( + dim=-1, keepdim=True + ) # .float() / (n_topics - 1) #* 255 + 1 # topic1[~mask1_nonzero] = -1 label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1 for i, k in enumerate(top_topics): label_img0[topic0 == k] = color_map[k] label_img1[topic1 == k] = color_map[k] -# print(hw0_c, scale0) -# print(hw1_c, scale1) + # print(hw0_c, scale0) + # print(hw1_c, scale1) # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0) - map_topic0 = label_img0.float().view(hw0_c).cpu().numpy() #map_topic0.squeeze(0).squeeze(0).cpu().numpy() - map_topic0 = cv2.resize(map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))) + map_topic0 = ( + label_img0.float().view(hw0_c).cpu().numpy() + ) # map_topic0.squeeze(0).squeeze(0).cpu().numpy() + map_topic0 = cv2.resize( + map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1])) + ) # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1) - map_topic1 = label_img1.float().view(hw1_c).cpu().numpy() #map_topic1.squeeze(0).squeeze(0).cpu().numpy() - map_topic1 = cv2.resize(map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))) - + map_topic1 = ( + label_img1.float().view(hw1_c).cpu().numpy() + ) # map_topic1.squeeze(0).squeeze(0).cpu().numpy() + map_topic1 = cv2.resize( + map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1])) + ) # show image0 if saved_name is None: @@ -219,28 +261,57 @@ def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, sa path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name)) plt.imshow(img0) masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0) - plt.imshow(masked_map_topic0, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear') + plt.imshow( + masked_map_topic0, + cmap=plt.cm.jet, + vmin=0, + vmax=n_topics - 1, + alpha=0.3, + interpolation="bilinear", + ) # plt.show() - plt.axis('off') - plt.savefig(path_saved_img0, bbox_inches='tight', pad_inches=0, dpi=250) + plt.axis("off") + plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250) plt.close() path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name)) plt.imshow(img1) masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1) - plt.imshow(masked_map_topic1, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear') - plt.axis('off') - plt.savefig(path_saved_img1, bbox_inches='tight', pad_inches=0, dpi=250) + plt.imshow( + masked_map_topic1, + cmap=plt.cm.jet, + vmin=0, + vmax=n_topics - 1, + alpha=0.3, + interpolation="bilinear", + ) + plt.axis("off") + plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250) plt.close() -def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_topics=8, - topic_alpha=0.3, margin=5, path=None, opencv_display=False, opencv_title=''): +def draw_topicfm_demo( + data, + img0, + img1, + mkpts0, + mkpts1, + mcolor, + text, + show_n_topics=8, + topic_alpha=0.3, + margin=5, + path=None, + opencv_display=False, + opencv_title="", +): topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics) - mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(topic_map1 >= 0, axis=-1) + mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims( + topic_map1 >= 0, axis=-1 + ) - topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.), cm.jet(topic_map1 / 99.) + topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0) topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR) topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR) overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32) @@ -249,7 +320,9 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0) cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1) - overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(np.uint8) + overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype( + np.uint8 + ) h0, w0 = img0.shape[:2] h1, w1 = img1.shape[:2] @@ -258,19 +331,25 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top out_fig[:h0, :w0] = overlay0 if h0 >= h1: start = (h0 - h1) // 2 - out_fig[start:(start+h1), (w0+margin):(w0+margin+w1)] = overlay1 + out_fig[start : (start + h1), (w0 + margin) : (w0 + margin + w1)] = overlay1 else: start = (h1 - h0) // 2 - out_fig[:h0, (w0+margin):(w0+margin+w1)] = overlay1[start:(start+h0)] + out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[ + start : (start + h0) + ] step_h = h0 + margin * 2 - out_fig[step_h:step_h+h0, :w0] = (img0 * 255).astype(np.uint8) + out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8) if h0 >= h1: start = step_h + (h0 - h1) // 2 - out_fig[start:start+h1, (w0+margin):(w0+margin+w1)] = (img1 * 255).astype(np.uint8) + out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = ( + img1 * 255 + ).astype(np.uint8) else: start = (h1 - h0) // 2 - out_fig[step_h:step_h+h0, (w0+margin):(w0+margin+w1)] = (img1[start:start+h0] * 255).astype(np.uint8) + out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = ( + img1[start : start + h0] * 255 + ).astype(np.uint8) # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) @@ -278,24 +357,53 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor): c = c.tolist() - cv2.line(out_fig, (x0, y0+step_h), (x1+margin+w0, y1+step_h+(h0-h1)//2), - color=c, thickness=1, lineType=cv2.LINE_AA) + cv2.line( + out_fig, + (x0, y0 + step_h), + (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2), + color=c, + thickness=1, + lineType=cv2.LINE_AA, + ) # display line end-points as circles - cv2.circle(out_fig, (x0, y0+step_h), 2, c, -1, lineType=cv2.LINE_AA) - cv2.circle(out_fig, (x1+margin+w0, y1+step_h+(h0-h1)//2), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle( + out_fig, + (x1 + margin + w0, y1 + step_h + (h0 - h1) // 2), + 2, + c, + -1, + lineType=cv2.LINE_AA, + ) # Scale factor for consistent visualization across scales. - sc = min(h / 960., 2.0) + sc = min(h / 960.0, 2.0) # Big text. Ht = int(30 * sc) # text height txt_color_fg = (255, 255, 255) txt_color_bg = (0, 0, 0) for i, t in enumerate(text): - cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX, - 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) - cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX, - 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) + cv2.putText( + out_fig, + t, + (int(8 * sc), Ht + step_h * i), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_bg, + 2, + cv2.LINE_AA, + ) + cv2.putText( + out_fig, + t, + (int(8 * sc), Ht + step_h * i), + cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, + txt_color_fg, + 1, + cv2.LINE_AA, + ) if path is not None: cv2.imwrite(str(path), out_fig) @@ -305,9 +413,3 @@ def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_top cv2.waitKey(1) return out_fig - - - - - - diff --git a/third_party/TopicFM/src/utils/profiler.py b/third_party/TopicFM/src/utils/profiler.py index 6d21ed79fb506ef09c75483355402c48a195aaa9..0275ea34e3eb9cceb4ed809bebeda209749f5bc5 100644 --- a/third_party/TopicFM/src/utils/profiler.py +++ b/third_party/TopicFM/src/utils/profiler.py @@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only class InferenceProfiler(SimpleProfiler): """ This profiler records duration of actions with cuda.synchronize() - Use this in test time. + Use this in test time. """ def __init__(self): @@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler): def build_profiler(name): - if name == 'inference': + if name == "inference": return InferenceProfiler() - elif name == 'pytorch': + elif name == "pytorch": from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) elif name is None: return PassThroughProfiler() else: - raise ValueError(f'Invalid profiler: {name}') + raise ValueError(f"Invalid profiler: {name}") diff --git a/third_party/TopicFM/test.py b/third_party/TopicFM/test.py index aeb451cde3674b70b0d2e02f37ff1fd391004d30..7b941ea4f6529c2206d527be85a23523dcf0e148 100644 --- a/third_party/TopicFM/test.py +++ b/third_party/TopicFM/test.py @@ -13,29 +13,43 @@ from src.lightning_trainer.trainer import PL_Trainer def parse_args(): # init a costum parser which will be added into pl.Trainer parser # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("data_cfg_path", type=str, help="data config path") + parser.add_argument("main_cfg_path", type=str, help="main config path") parser.add_argument( - 'data_cfg_path', type=str, help='data config path') + "--ckpt_path", + type=str, + default="weights/indoor_ds.ckpt", + help="path to the checkpoint", + ) parser.add_argument( - 'main_cfg_path', type=str, help='main config path') + "--dump_dir", + type=str, + default=None, + help="if set, the matching results will be dump to dump_dir", + ) parser.add_argument( - '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') + "--profiler_name", + type=str, + default=None, + help="options: [inference, pytorch], or leave it unset", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu") + parser.add_argument("--num_workers", type=int, default=2) parser.add_argument( - '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") - parser.add_argument( - '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') - parser.add_argument( - '--batch_size', type=int, default=1, help='batch_size per gpu') - parser.add_argument( - '--num_workers', type=int, default=2) - parser.add_argument( - '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + "--thr", + type=float, + default=None, + help="modify the coarse-level matching threshold.", + ) parser = pl.Trainer.add_argparse_args(parser) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": # parse arguments args = parse_args() pprint.pprint(vars(args)) @@ -54,7 +68,12 @@ if __name__ == '__main__': # lightning module profiler = build_profiler(args.profiler_name) - model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + model = PL_Trainer( + config, + pretrained_ckpt=args.ckpt_path, + profiler=profiler, + dump_dir=args.dump_dir, + ) loguru_logger.info(f"Model-lightning initialized!") # lightning data @@ -62,7 +81,9 @@ if __name__ == '__main__': loguru_logger.info(f"DataModule initialized!") # lightning trainer - trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + trainer = pl.Trainer.from_argparse_args( + args, replace_sampler_ddp=False, logger=False + ) loguru_logger.info(f"Start testing!") trainer.test(model, datamodule=data_module, verbose=False) diff --git a/third_party/TopicFM/train.py b/third_party/TopicFM/train.py index a552c23718b81ddcb282cedbfe3ceb45e50b3f29..9188b80a3fb407f4871b8147a2c90fa382380e25 100644 --- a/third_party/TopicFM/train.py +++ b/third_party/TopicFM/train.py @@ -23,32 +23,43 @@ loguru_logger = get_rank_zero_only_logger(loguru_logger) def parse_args(): # init a costum parser which will be added into pl.Trainer parser # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("data_cfg_path", type=str, help="data config path") + parser.add_argument("main_cfg_path", type=str, help="main config path") + parser.add_argument("--exp_name", type=str, default="default_exp_name") + parser.add_argument("--batch_size", type=int, default=4, help="batch_size per gpu") + parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( - 'data_cfg_path', type=str, help='data config path') + "--pin_memory", + type=lambda x: bool(strtobool(x)), + nargs="?", + default=True, + help="whether loading data to pinned memory or not", + ) parser.add_argument( - 'main_cfg_path', type=str, help='main config path') + "--ckpt_path", + type=str, + default=None, + help="pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR", + ) parser.add_argument( - '--exp_name', type=str, default='default_exp_name') + "--disable_ckpt", + action="store_true", + help="disable checkpoint saving (useful for debugging).", + ) parser.add_argument( - '--batch_size', type=int, default=4, help='batch_size per gpu') + "--profiler_name", + type=str, + default=None, + help="options: [inference, pytorch], or leave it unset", + ) parser.add_argument( - '--num_workers', type=int, default=4) - parser.add_argument( - '--pin_memory', type=lambda x: bool(strtobool(x)), - nargs='?', default=True, help='whether loading data to pinned memory or not') - parser.add_argument( - '--ckpt_path', type=str, default=None, - help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR') - parser.add_argument( - '--disable_ckpt', action='store_true', - help='disable checkpoint saving (useful for debugging).') - parser.add_argument( - '--profiler_name', type=str, default=None, - help='options: [inference, pytorch], or leave it unset') - parser.add_argument( - '--parallel_load_data', action='store_true', - help='load datasets in with multiple processes.') + "--parallel_load_data", + action="store_true", + help="load datasets in with multiple processes.", + ) parser = pl.Trainer.add_argparse_args(parser) return parser.parse_args() @@ -66,7 +77,7 @@ def main(): pl.seed_everything(config.TRAINER.SEED) # reproducibility # TODO: Use different seeds for each dataloader workers # This is needed for data augmentation - + # scale lr and warmup-step automatically args.gpus = _n_gpus = setup_gpus(args.gpus) config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes @@ -75,49 +86,59 @@ def main(): config.TRAINER.SCALING = _scaling config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) - + # lightning module profiler = build_profiler(args.profiler_name) model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) loguru_logger.info(f"Model LightningModule initialized!") - + # lightning data data_module = MultiSceneDataModule(args, config) loguru_logger.info(f"Model DataModule initialized!") - + # TensorBoard Logger - logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) - ckpt_dir = Path(logger.log_dir) / 'checkpoints' - + logger = TensorBoardLogger( + save_dir="logs/tb_logs", name=args.exp_name, default_hp_metric=False + ) + ckpt_dir = Path(logger.log_dir) / "checkpoints" + # Callbacks # TODO: update ModelCheckpoint to monitor multiple metrics - ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max', - save_last=True, - dirpath=str(ckpt_dir), - filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') - lr_monitor = LearningRateMonitor(logging_interval='step') + ckpt_callback = ModelCheckpoint( + monitor="auc@10", + verbose=True, + save_top_k=5, + mode="max", + save_last=True, + dirpath=str(ckpt_dir), + filename="{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}", + ) + lr_monitor = LearningRateMonitor(logging_interval="step") callbacks = [lr_monitor] if not args.disable_ckpt: callbacks.append(ckpt_callback) - + # Lightning Trainer trainer = pl.Trainer.from_argparse_args( args, - plugins=DDPPlugin(find_unused_parameters=False, - num_nodes=args.num_nodes, - sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + plugins=DDPPlugin( + find_unused_parameters=False, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + ), gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, callbacks=callbacks, logger=logger, sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, replace_sampler_ddp=False, # use custom sampler reload_dataloaders_every_epoch=False, # avoid repeated samples! - weights_summary='full', - profiler=profiler) + weights_summary="full", + profiler=profiler, + ) loguru_logger.info(f"Trainer initialized!") loguru_logger.info(f"Start training!") trainer.fit(model, datamodule=data_module) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/third_party/TopicFM/visualization.py b/third_party/TopicFM/visualization.py index 279b41cd88f61ce3414e2f3077fec642b2c8333a..73ec7dd74e21ac72204484cf8d4f3c6fd56a72a2 100644 --- a/third_party/TopicFM/visualization.py +++ b/third_party/TopicFM/visualization.py @@ -15,9 +15,9 @@ from configs.data.base import cfg as data_cfg import viz -def get_model_config(method_name, dataset_name, root_dir='viz'): - config_file = f'{root_dir}/configs/{method_name}.yml' - with open(config_file, 'r') as f: +def get_model_config(method_name, dataset_name, root_dir="viz"): + config_file = f"{root_dir}/configs/{method_name}.yml" + with open(config_file, "r") as f: model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name] return model_conf @@ -30,7 +30,10 @@ class DemoDataset(Dataset): self.list_img_files.sort() else: with open(img_file) as f: - self.list_img_files = [os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines()] + self.list_img_files = [ + os.path.join(dataset_dir, img_file.strip()) + for img_file in f.readlines() + ] self.resize = resize self.down_factor = down_factor @@ -38,24 +41,31 @@ class DemoDataset(Dataset): return len(self.list_img_files) def __getitem__(self, idx): - img_path = self.list_img_files[idx] #os.path.join(self.dataset_dir, self.list_img_files[idx]) - img, scale = read_img_gray(img_path, resize=self.resize, down_factor=self.down_factor) + img_path = self.list_img_files[ + idx + ] # os.path.join(self.dataset_dir, self.list_img_files[idx]) + img, scale = read_img_gray( + img_path, resize=self.resize, down_factor=self.down_factor + ) return {"img": img, "id": idx, "img_path": img_path} -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Visualize matches') - parser.add_argument('--gpu', '-gpu', type=str, default='0') - parser.add_argument('--method', type=str, default=None) - parser.add_argument('--dataset_dir', type=str, default='data/aachen-day-night') - parser.add_argument('--pair_dir', type=str, default=None) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Visualize matches") + parser.add_argument("--gpu", "-gpu", type=str, default="0") + parser.add_argument("--method", type=str, default=None) + parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night") + parser.add_argument("--pair_dir", type=str, default=None) parser.add_argument( - '--dataset_name', type=str, choices=['megadepth', 'scannet', 'aachen_v1.1', 'inloc'], default='megadepth' + "--dataset_name", + type=str, + choices=["megadepth", "scannet", "aachen_v1.1", "inloc"], + default="megadepth", ) - parser.add_argument('--measure_time', action="store_true") - parser.add_argument('--no_viz', action="store_true") - parser.add_argument('--compute_eval_metrics', action="store_true") - parser.add_argument('--run_demo', action="store_true") + parser.add_argument("--measure_time", action="store_true") + parser.add_argument("--no_viz", action="store_true") + parser.add_argument("--compute_eval_metrics", action="store_true") + parser.add_argument("--run_demo", action="store_true") args = parser.parse_args() @@ -64,26 +74,45 @@ if __name__ == '__main__': model = viz.__dict__[class_name](model_cfg) # all_args = Namespace(**vars(args), **model_cfg) if not args.run_demo: - if args.dataset_name == 'megadepth': + if args.dataset_name == "megadepth": from configs.data.megadepth_test_1500 import cfg data_cfg.merge_from_other_cfg(cfg) - elif args.dataset_name == 'scannet': + elif args.dataset_name == "scannet": from configs.data.scannet_test_1500 import cfg data_cfg.merge_from_other_cfg(cfg) - elif args.dataset_name == 'aachen_v1.1': - data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "aachen_v1.1", - "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"), - "DATASET.TEST_LIST_PATH", args.pair_dir, - "DATASET.TEST_IMGSIZE", model_cfg["imsize"]]) - elif args.dataset_name == 'inloc': - data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "inloc", - "DATASET.TEST_DATA_ROOT", args.dataset_dir, - "DATASET.TEST_LIST_PATH", args.pair_dir, - "DATASET.TEST_IMGSIZE", model_cfg["imsize"]]) - - has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in ["megadepth", "scannet"] + elif args.dataset_name == "aachen_v1.1": + data_cfg.merge_from_list( + [ + "DATASET.TEST_DATA_SOURCE", + "aachen_v1.1", + "DATASET.TEST_DATA_ROOT", + os.path.join(args.dataset_dir, "images/images_upright"), + "DATASET.TEST_LIST_PATH", + args.pair_dir, + "DATASET.TEST_IMGSIZE", + model_cfg["imsize"], + ] + ) + elif args.dataset_name == "inloc": + data_cfg.merge_from_list( + [ + "DATASET.TEST_DATA_SOURCE", + "inloc", + "DATASET.TEST_DATA_ROOT", + args.dataset_dir, + "DATASET.TEST_LIST_PATH", + args.pair_dir, + "DATASET.TEST_IMGSIZE", + model_cfg["imsize"], + ] + ) + + has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [ + "megadepth", + "scannet", + ] dataloader = TestDataLoader(data_cfg) with torch.no_grad(): for data_dict in tqdm(dataloader): @@ -91,11 +120,20 @@ if __name__ == '__main__': if isinstance(v, torch.Tensor): data_dict[k] = v.cuda() if torch.cuda.is_available() else v img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT - model.match_and_draw(data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth, - measure_time=args.measure_time, viz_matches=(not args.no_viz)) + model.match_and_draw( + data_dict, + root_dir=img_root_dir, + ground_truth=has_ground_truth, + measure_time=args.measure_time, + viz_matches=(not args.no_viz), + ) if args.measure_time: - print("Running time for each image is {} miliseconds".format(model.measure_time())) + print( + "Running time for each image is {} miliseconds".format( + model.measure_time() + ) + ) if args.compute_eval_metrics and has_ground_truth: model.compute_eval_metrics() else: @@ -103,6 +141,13 @@ if __name__ == '__main__': sampler = SequentialSampler(demo_dataset) dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler) - writer = cv2.VideoWriter('topicfm_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640 * 2 + 5, 480 * 2 + 10)) + writer = cv2.VideoWriter( + "topicfm_demo.mp4", + cv2.VideoWriter_fourcc(*"mp4v"), + 15, + (640 * 2 + 5, 480 * 2 + 10), + ) - model.run_demo(iter(dataloader), writer) #, output_dir="demo", no_display=True) + model.run_demo( + iter(dataloader), writer + ) # , output_dir="demo", no_display=True) diff --git a/third_party/TopicFM/viz/methods/base.py b/third_party/TopicFM/viz/methods/base.py index 377e95134f339459bff3c5a0d30b3bfbc122d978..1dfc23efb5fb49bbf510364599489c9acf1df263 100644 --- a/third_party/TopicFM/viz/methods/base.py +++ b/third_party/TopicFM/viz/methods/base.py @@ -14,7 +14,9 @@ def flatten_list(x): class Viz(metaclass=ABCMeta): def __init__(self): super().__init__() - self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') + self.device = torch.device( + "cuda:{}".format(0) if torch.cuda.is_available() else "cpu" + ) torch.set_grad_enabled(False) # for evaluation metrics of MegaDepth and ScanNet @@ -33,11 +35,15 @@ class Viz(metaclass=ABCMeta): f"{self.name}", f"#Matches: {len(mkpts0)}", ] - if 'R_errs' in kwargs: - text.append(f"$\\Delta$R:{kwargs['R_errs']:.2f}°, $\\Delta$t:{kwargs['t_errs']:.2f}°",) + if "R_errs" in kwargs: + text.append( + f"$\\Delta$R:{kwargs['R_errs']:.2f}°, $\\Delta$t:{kwargs['t_errs']:.2f}°", + ) if path: - make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150) + make_matching_figure( + img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150 + ) else: return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text) @@ -47,11 +53,11 @@ class Viz(metaclass=ABCMeta): def compute_eval_metrics(self, epi_err_thr=5e-4): # metrics: dict of list, numpy - _metrics = [o['metrics'] for o in self.eval_stats] + _metrics = [o["metrics"] for o in self.eval_stats] metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]} val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr) - print('\n' + pprint.pformat(val_metrics_4tb)) + print("\n" + pprint.pformat(val_metrics_4tb)) def measure_time(self): if len(self.time_stats) == 0: diff --git a/third_party/TopicFM/viz/methods/loftr.py b/third_party/TopicFM/viz/methods/loftr.py index 53d0c00c1a067cee10bf1587197e4780ac8b2eda..29046a2aa95596cbfe9656c3bda6dafcb1a55058 100644 --- a/third_party/TopicFM/viz/methods/loftr.py +++ b/third_party/TopicFM/viz/methods/loftr.py @@ -19,20 +19,27 @@ class VizLoFTR(Viz): # Load model conf = dict(default_cfg) - conf['match_coarse']['thr'] = self.match_threshold + conf["match_coarse"]["thr"] = self.match_threshold print(conf) self.model = LoFTR(config=conf) ckpt_dict = torch.load(args.ckpt) - self.model.load_state_dict(ckpt_dict['state_dict']) + self.model.load_state_dict(ckpt_dict["state_dict"]) self.model = self.model.eval().to(self.device) # Name the method # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0] - self.name = 'LoFTR' + self.name = "LoFTR" - print(f'Initialize {self.name}') + print(f"Initialize {self.name}") - def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): + def match_and_draw( + self, + data_dict, + root_dir=None, + ground_truth=False, + measure_time=False, + viz_matches=True, + ): if measure_time: torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) @@ -45,41 +52,72 @@ class VizLoFTR(Viz): torch.cuda.synchronize() self.time_stats.append(start.elapsed_time(end)) - kpts0 = data_dict['mkpts0_f'].cpu().numpy() - kpts1 = data_dict['mkpts1_f'].cpu().numpy() + kpts0 = data_dict["mkpts0_f"].cpu().numpy() + kpts1 = data_dict["mkpts1_f"].cpu().numpy() - img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0] img0 = cv2.imread(os.path.join(root_dir, img_name0)) img1 = cv2.imread(os.path.join(root_dir, img_name1)) - if str(data_dict["dataset_name"][0]).lower() == 'scannet': + if str(data_dict["dataset_name"][0]).lower() == "scannet": img0 = cv2.resize(img0, (640, 480)) img1 = cv2.resize(img1, (640, 480)) if viz_matches: - saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + saved_name = "_".join( + [ + img_name0.split("/")[-1].split(".")[0], + img_name1.split("/")[-1].split(".")[0], + ] + ) folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) if not os.path.exists(folder_matches): os.makedirs(folder_matches) - path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + path_to_save_matches = os.path.join( + folder_matches, "{}.png".format(saved_name) + ) if ground_truth: - compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match - compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair - epi_errors = data_dict['epi_errs'].cpu().numpy() - R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] + compute_symmetrical_epipolar_errors( + data_dict + ) # compute epi_errs for each match + compute_pose_errors( + data_dict + ) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict["epi_errs"].cpu().numpy() + R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0] - self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, - R_errs=R_errors, t_errs=t_errors) + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + epi_errors, + path=path_to_save_matches, + R_errs=R_errors, + t_errs=t_errors, + ) - rel_pair_names = list(zip(*data_dict['pair_names'])) - bs = data_dict['image0'].size(0) + rel_pair_names = list(zip(*data_dict["pair_names"])) + bs = data_dict["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler - 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], - 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], - 'R_errs': data_dict['R_errs'], - 't_errs': data_dict['t_errs'], - 'inliers': data_dict['inliers']} - self.eval_stats.append({'metrics': metrics}) + "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], + "epi_errs": [ + data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy() + for b in range(bs) + ], + "R_errs": data_dict["R_errs"], + "t_errs": data_dict["t_errs"], + "inliers": data_dict["inliers"], + } + self.eval_stats.append({"metrics": metrics}) else: m_conf = 1 - data_dict["mconf"].cpu().numpy() - self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + m_conf, + path=path_to_save_matches, + conf_thr=0.4, + ) diff --git a/third_party/TopicFM/viz/methods/patch2pix.py b/third_party/TopicFM/viz/methods/patch2pix.py index 14a1d345881e2021be97dc5dde91d8bbe1cd18fa..4d2df36f35c5b06ea8d45980e0b6b91e7482c718 100644 --- a/third_party/TopicFM/viz/methods/patch2pix.py +++ b/third_party/TopicFM/viz/methods/patch2pix.py @@ -7,7 +7,7 @@ from pathlib import Path from .base import Viz from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors -patch2pix_path = Path(__file__).parent / '../../third_party/patch2pix' +patch2pix_path = Path(__file__).parent / "../../third_party/patch2pix" sys.path.append(str(patch2pix_path)) from third_party.patch2pix.utils.eval.model_helper import load_model, estimate_matches @@ -21,25 +21,39 @@ class VizPatch2Pix(Viz): self.imsize = args.imsize self.match_threshold = args.match_threshold self.ksize = args.ksize - self.model = load_model(args.ckpt, method='patch2pix') - self.name = 'Patch2Pix' - print(f'Initialize {self.name} with image size {self.imsize}') + self.model = load_model(args.ckpt, method="patch2pix") + self.name = "Patch2Pix" + print(f"Initialize {self.name} with image size {self.imsize}") - def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): - img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + def match_and_draw( + self, + data_dict, + root_dir=None, + ground_truth=False, + measure_time=False, + viz_matches=True, + ): + img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0] path_img0 = os.path.join(root_dir, img_name0) path_img1 = os.path.join(root_dir, img_name1) img0, img1 = cv2.imread(path_img0), cv2.imread(path_img1) return_m_upscale = True - if str(data_dict["dataset_name"][0]).lower() == 'scannet': + if str(data_dict["dataset_name"][0]).lower() == "scannet": # self.imsize = 640 img0 = cv2.resize(img0, tuple(self.imsize)) # (640, 480)) img1 = cv2.resize(img1, tuple(self.imsize)) # (640, 480)) return_m_upscale = False - outputs = estimate_matches(self.model, path_img0, path_img1, - ksize=self.ksize, io_thres=self.match_threshold, - eval_type='fine', imsize=self.imsize, - return_upscale=return_m_upscale, measure_time=measure_time) + outputs = estimate_matches( + self.model, + path_img0, + path_img1, + ksize=self.ksize, + io_thres=self.match_threshold, + eval_type="fine", + imsize=self.imsize, + return_upscale=return_m_upscale, + measure_time=measure_time, + ) if measure_time: self.time_stats.append(outputs[-1]) matches, mconf = outputs[0], outputs[1] @@ -47,34 +61,71 @@ class VizPatch2Pix(Viz): kpts1 = matches[:, 2:4] if viz_matches: - saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + saved_name = "_".join( + [ + img_name0.split("/")[-1].split(".")[0], + img_name1.split("/")[-1].split(".")[0], + ] + ) folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) if not os.path.exists(folder_matches): os.makedirs(folder_matches) - path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + path_to_save_matches = os.path.join( + folder_matches, "{}.png".format(saved_name) + ) if ground_truth: - data_dict["mkpts0_f"] = torch.from_numpy(matches[:, :2]).float().to(self.device) - data_dict["mkpts1_f"] = torch.from_numpy(matches[:, 2:4]).float().to(self.device) - data_dict["m_bids"] = torch.zeros(matches.shape[0], device=self.device, dtype=torch.float32) - compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match - compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair - epi_errors = data_dict['epi_errs'].cpu().numpy() - R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] + data_dict["mkpts0_f"] = ( + torch.from_numpy(matches[:, :2]).float().to(self.device) + ) + data_dict["mkpts1_f"] = ( + torch.from_numpy(matches[:, 2:4]).float().to(self.device) + ) + data_dict["m_bids"] = torch.zeros( + matches.shape[0], device=self.device, dtype=torch.float32 + ) + compute_symmetrical_epipolar_errors( + data_dict + ) # compute epi_errs for each match + compute_pose_errors( + data_dict + ) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict["epi_errs"].cpu().numpy() + R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0] - self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, - R_errs=R_errors, t_errs=t_errors) + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + epi_errors, + path=path_to_save_matches, + R_errs=R_errors, + t_errs=t_errors, + ) - rel_pair_names = list(zip(*data_dict['pair_names'])) - bs = data_dict['image0'].size(0) + rel_pair_names = list(zip(*data_dict["pair_names"])) + bs = data_dict["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler - 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], - 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], - 'R_errs': data_dict['R_errs'], - 't_errs': data_dict['t_errs'], - 'inliers': data_dict['inliers']} - self.eval_stats.append({'metrics': metrics}) + "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], + "epi_errs": [ + data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy() + for b in range(bs) + ], + "R_errs": data_dict["R_errs"], + "t_errs": data_dict["t_errs"], + "inliers": data_dict["inliers"], + } + self.eval_stats.append({"metrics": metrics}) else: m_conf = 1 - mconf - self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + m_conf, + path=path_to_save_matches, + conf_thr=0.4, + ) diff --git a/third_party/TopicFM/viz/methods/topicfm.py b/third_party/TopicFM/viz/methods/topicfm.py index cd8b1485d5296947a38480cc031c5d7439bf163d..e066dc4e031d47b295c4c14db774643ba0a2f25c 100644 --- a/third_party/TopicFM/viz/methods/topicfm.py +++ b/third_party/TopicFM/viz/methods/topicfm.py @@ -26,21 +26,28 @@ class VizTopicFM(Viz): # Load model conf = dict(get_model_cfg()) - conf['match_coarse']['thr'] = self.match_threshold - conf['coarse']['n_samples'] = self.n_sampling_topics + conf["match_coarse"]["thr"] = self.match_threshold + conf["coarse"]["n_samples"] = self.n_sampling_topics print("model config: ", conf) self.model = TopicFM(config=conf) ckpt_dict = torch.load(args.ckpt) - self.model.load_state_dict(ckpt_dict['state_dict']) + self.model.load_state_dict(ckpt_dict["state_dict"]) self.model = self.model.eval().to(self.device) # Name the method # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0] - self.name = 'TopicFM' - - print(f'Initialize {self.name}') - - def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): + self.name = "TopicFM" + + print(f"Initialize {self.name}") + + def match_and_draw( + self, + data_dict, + root_dir=None, + ground_truth=False, + measure_time=False, + viz_matches=True, + ): if measure_time: torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) @@ -53,86 +60,133 @@ class VizTopicFM(Viz): torch.cuda.synchronize() self.time_stats.append(start.elapsed_time(end)) - kpts0 = data_dict['mkpts0_f'].cpu().numpy() - kpts1 = data_dict['mkpts1_f'].cpu().numpy() + kpts0 = data_dict["mkpts0_f"].cpu().numpy() + kpts1 = data_dict["mkpts1_f"].cpu().numpy() - img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0] img0 = cv2.imread(os.path.join(root_dir, img_name0)) img1 = cv2.imread(os.path.join(root_dir, img_name1)) - if str(data_dict["dataset_name"][0]).lower() == 'scannet': + if str(data_dict["dataset_name"][0]).lower() == "scannet": img0 = cv2.resize(img0, (640, 480)) img1 = cv2.resize(img1, (640, 480)) if viz_matches: - saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + saved_name = "_".join( + [ + img_name0.split("/")[-1].split(".")[0], + img_name1.split("/")[-1].split(".")[0], + ] + ) folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) if not os.path.exists(folder_matches): os.makedirs(folder_matches) - path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + path_to_save_matches = os.path.join( + folder_matches, "{}.png".format(saved_name) + ) if ground_truth: - compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match - compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair - epi_errors = data_dict['epi_errs'].cpu().numpy() - R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] - - self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, - R_errs=R_errors, t_errs=t_errors) + compute_symmetrical_epipolar_errors( + data_dict + ) # compute epi_errs for each match + compute_pose_errors( + data_dict + ) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict["epi_errs"].cpu().numpy() + R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0] + + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + epi_errors, + path=path_to_save_matches, + R_errs=R_errors, + t_errs=t_errors, + ) # compute evaluation metrics - rel_pair_names = list(zip(*data_dict['pair_names'])) - bs = data_dict['image0'].size(0) + rel_pair_names = list(zip(*data_dict["pair_names"])) + bs = data_dict["image0"].size(0) metrics = { # to filter duplicate pairs caused by DistributedSampler - 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], - 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], - 'R_errs': data_dict['R_errs'], - 't_errs': data_dict['t_errs'], - 'inliers': data_dict['inliers']} - self.eval_stats.append({'metrics': metrics}) + "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)], + "epi_errs": [ + data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy() + for b in range(bs) + ], + "R_errs": data_dict["R_errs"], + "t_errs": data_dict["t_errs"], + "inliers": data_dict["inliers"], + } + self.eval_stats.append({"metrics": metrics}) else: m_conf = 1 - data_dict["mconf"].cpu().numpy() - self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) + self.draw_matches( + kpts0, + kpts1, + img0, + img1, + m_conf, + path=path_to_save_matches, + conf_thr=0.4, + ) if self.show_n_topics > 0: - folder_topics = os.path.join(root_dir, "{}_viz_topics".format(self.name)) + folder_topics = os.path.join( + root_dir, "{}_viz_topics".format(self.name) + ) if not os.path.exists(folder_topics): os.makedirs(folder_topics) - draw_topics(data_dict, img0, img1, saved_folder=folder_topics, show_n_topics=self.show_n_topics, - saved_name=saved_name) - - def run_demo(self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1): + draw_topics( + data_dict, + img0, + img1, + saved_folder=folder_topics, + show_n_topics=self.show_n_topics, + saved_name=saved_name, + ) + + def run_demo( + self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1 + ): data_dict = next(dataloader) frame_id = 0 last_image_id = 0 - img0 = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255 + img0 = ( + np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255 + ) frame_tensor = data_dict["img"].to(self.device) - pair_data = {'image0': frame_tensor} - last_frame = cv2.resize(img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR) + pair_data = {"image0": frame_tensor} + last_frame = cv2.resize( + img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR + ) if output_dir is not None: - print('==> Will write outputs to {}'.format(output_dir)) + print("==> Will write outputs to {}".format(output_dir)) Path(output_dir).mkdir(exist_ok=True) # Create a window to display the demo. if not no_display: - window_name = 'Topic-assisted Feature Matching' + window_name = "Topic-assisted Feature Matching" cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) cv2.resizeWindow(window_name, (640 * 2, 480 * 2)) else: - print('Skipping visualization, will not show a GUI.') + print("Skipping visualization, will not show a GUI.") # Print the keyboard help menu. - print('==> Keyboard control:\n' - '\tn: select the current frame as the reference image (left)\n' - '\tq: quit') + print( + "==> Keyboard control:\n" + "\tn: select the current frame as the reference image (left)\n" + "\tq: quit" + ) # vis_range = [kwargs["bottom_k"], kwargs["top_k"]] while True: frame_id += 1 if frame_id == len(dataloader): - print('Finished demo_loftr.py') + print("Finished demo_loftr.py") break data_dict = next(dataloader) if frame_id % skip_frames != 0: @@ -140,17 +194,24 @@ class VizTopicFM(Viz): continue stem0, stem1 = last_image_id, data_dict["id"][0].item() - 1 - frame = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255 + frame = ( + np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) + / 255 + ) frame_tensor = data_dict["img"].to(self.device) - frame = cv2.resize(frame, (frame_tensor.shape[-1], frame_tensor.shape[-2]), interpolation=cv2.INTER_LINEAR) - pair_data = {**pair_data, 'image1': frame_tensor} + frame = cv2.resize( + frame, + (frame_tensor.shape[-1], frame_tensor.shape[-2]), + interpolation=cv2.INTER_LINEAR, + ) + pair_data = {**pair_data, "image1": frame_tensor} self.model(pair_data) - total_n_matches = len(pair_data['mkpts0_f']) - mkpts0 = pair_data['mkpts0_f'].cpu().numpy() # [vis_range[0]:vis_range[1]] - mkpts1 = pair_data['mkpts1_f'].cpu().numpy() # [vis_range[0]:vis_range[1]] - mconf = pair_data['mconf'].cpu().numpy() # [vis_range[0]:vis_range[1]] + total_n_matches = len(pair_data["mkpts0_f"]) + mkpts0 = pair_data["mkpts0_f"].cpu().numpy() # [vis_range[0]:vis_range[1]] + mkpts1 = pair_data["mkpts1_f"].cpu().numpy() # [vis_range[0]:vis_range[1]] + mconf = pair_data["mconf"].cpu().numpy() # [vis_range[0]:vis_range[1]] # Normalize confidence. if len(mconf) > 0: @@ -161,33 +222,42 @@ class VizTopicFM(Viz): color = error_colormap(mconf, thr=0.4, alpha=0.1) text = [ - f'Topics', - '#Matches: {}'.format(total_n_matches), + f"Topics", + "#Matches: {}".format(total_n_matches), ] - out = draw_topicfm_demo(pair_data, last_frame, frame, mkpts0, mkpts1, color, text, - show_n_topics=4, path=None) + out = draw_topicfm_demo( + pair_data, + last_frame, + frame, + mkpts0, + mkpts1, + color, + text, + show_n_topics=4, + path=None, + ) if not no_display: if writer is not None: writer.write(out) - cv2.imshow('TopicFM Matches', out) + cv2.imshow("TopicFM Matches", out) key = chr(cv2.waitKey(10) & 0xFF) - if key == 'q': + if key == "q": if writer is not None: writer.release() - print('Exiting...') + print("Exiting...") break - elif key == 'n': - pair_data['image0'] = frame_tensor + elif key == "n": + pair_data["image0"] = frame_tensor last_frame = frame - last_image_id = (data_dict["id"][0].item() - 1) + last_image_id = data_dict["id"][0].item() - 1 frame_id_left = frame_id elif output_dir is not None: - stem = 'matches_{:06}_{:06}'.format(stem0, stem1) - out_file = str(Path(output_dir, stem + '.png')) - print('\nWriting image to {}'.format(out_file)) + stem = "matches_{:06}_{:06}".format(stem0, stem1) + out_file = str(Path(output_dir, stem + ".png")) + print("\nWriting image to {}".format(out_file)) cv2.imwrite(out_file, out) else: raise ValueError("output_dir is required when no display is given.") @@ -195,4 +265,3 @@ class VizTopicFM(Viz): cv2.destroyAllWindows() if writer is not None: writer.release() - diff --git a/third_party/d2net/extract_features.py b/third_party/d2net/extract_features.py index 628463a7d042a90b5cadea8a317237cde86f5ae4..ebcac0889d084c59d86bb21ed80d1e1ed8f17d8d 100644 --- a/third_party/d2net/extract_features.py +++ b/third_party/d2net/extract_features.py @@ -21,49 +21,55 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") # Argument parsing -parser = argparse.ArgumentParser(description='Feature extraction script') +parser = argparse.ArgumentParser(description="Feature extraction script") parser.add_argument( - '--image_list_file', type=str, required=True, - help='path to a file containing a list of images to process' + "--image_list_file", + type=str, + required=True, + help="path to a file containing a list of images to process", ) parser.add_argument( - '--preprocessing', type=str, default='caffe', - help='image preprocessing (caffe or torch)' + "--preprocessing", + type=str, + default="caffe", + help="image preprocessing (caffe or torch)", ) parser.add_argument( - '--model_file', type=str, default='models/d2_tf.pth', - help='path to the full model' + "--model_file", type=str, default="models/d2_tf.pth", help="path to the full model" ) parser.add_argument( - '--max_edge', type=int, default=1600, - help='maximum image size at network input' + "--max_edge", type=int, default=1600, help="maximum image size at network input" ) parser.add_argument( - '--max_sum_edges', type=int, default=2800, - help='maximum sum of image sizes at network input' + "--max_sum_edges", + type=int, + default=2800, + help="maximum sum of image sizes at network input", ) parser.add_argument( - '--output_extension', type=str, default='.d2-net', - help='extension for the output' + "--output_extension", type=str, default=".d2-net", help="extension for the output" ) parser.add_argument( - '--output_type', type=str, default='npz', - help='output file type (npz or mat)' + "--output_type", type=str, default="npz", help="output file type (npz or mat)" ) parser.add_argument( - '--multiscale', dest='multiscale', action='store_true', - help='extract multiscale features' + "--multiscale", + dest="multiscale", + action="store_true", + help="extract multiscale features", ) parser.set_defaults(multiscale=False) parser.add_argument( - '--no-relu', dest='use_relu', action='store_false', - help='remove ReLU after the dense feature extraction module' + "--no-relu", + dest="use_relu", + action="store_false", + help="remove ReLU after the dense feature extraction module", ) parser.set_defaults(use_relu=True) @@ -72,14 +78,10 @@ args = parser.parse_args() print(args) # Creating CNN model -model = D2Net( - model_file=args.model_file, - use_relu=args.use_relu, - use_cuda=use_cuda -) +model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda) # Process the file -with open(args.image_list_file, 'r') as f: +with open(args.image_list_file, "r") as f: lines = f.readlines() for line in tqdm(lines, total=len(lines)): path = line.strip() @@ -93,39 +95,32 @@ for line in tqdm(lines, total=len(lines)): resized_image = image if max(resized_image.shape) > args.max_edge: resized_image = scipy.misc.imresize( - resized_image, - args.max_edge / max(resized_image.shape) - ).astype('float') - if sum(resized_image.shape[: 2]) > args.max_sum_edges: + resized_image, args.max_edge / max(resized_image.shape) + ).astype("float") + if sum(resized_image.shape[:2]) > args.max_sum_edges: resized_image = scipy.misc.imresize( - resized_image, - args.max_sum_edges / sum(resized_image.shape[: 2]) - ).astype('float') + resized_image, args.max_sum_edges / sum(resized_image.shape[:2]) + ).astype("float") fact_i = image.shape[0] / resized_image.shape[0] fact_j = image.shape[1] / resized_image.shape[1] - input_image = preprocess_image( - resized_image, - preprocessing=args.preprocessing - ) + input_image = preprocess_image(resized_image, preprocessing=args.preprocessing) with torch.no_grad(): if args.multiscale: keypoints, scores, descriptors = process_multiscale( torch.tensor( - input_image[np.newaxis, :, :, :].astype(np.float32), - device=device + input_image[np.newaxis, :, :, :].astype(np.float32), device=device ), - model + model, ) else: keypoints, scores, descriptors = process_multiscale( torch.tensor( - input_image[np.newaxis, :, :, :].astype(np.float32), - device=device + input_image[np.newaxis, :, :, :].astype(np.float32), device=device ), model, - scales=[1] + scales=[1], ) # Input image coordinates @@ -134,23 +129,16 @@ for line in tqdm(lines, total=len(lines)): # i, j -> u, v keypoints = keypoints[:, [1, 0, 2]] - if args.output_type == 'npz': - with open(path + args.output_extension, 'wb') as output_file: + if args.output_type == "npz": + with open(path + args.output_extension, "wb") as output_file: np.savez( - output_file, - keypoints=keypoints, - scores=scores, - descriptors=descriptors + output_file, keypoints=keypoints, scores=scores, descriptors=descriptors ) - elif args.output_type == 'mat': - with open(path + args.output_extension, 'wb') as output_file: + elif args.output_type == "mat": + with open(path + args.output_extension, "wb") as output_file: scipy.io.savemat( output_file, - { - 'keypoints': keypoints, - 'scores': scores, - 'descriptors': descriptors - } + {"keypoints": keypoints, "scores": scores, "descriptors": descriptors}, ) else: - raise ValueError('Unknown output type.') + raise ValueError("Unknown output type.") diff --git a/third_party/d2net/extract_kapture.py b/third_party/d2net/extract_kapture.py index 23198b978229c699dbe24cd3bc0400d62bcab030..bad6ad4254238b9c9425243ff80f830bc4f02198 100644 --- a/third_party/d2net/extract_kapture.py +++ b/third_party/d2net/extract_kapture.py @@ -13,9 +13,21 @@ from os import path import kapture from kapture.io.records import get_image_fullpath from kapture.io.csv import kapture_from_dir, get_all_tar_handlers -from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file -from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file -from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file +from kapture.io.csv import ( + get_feature_csv_fullpath, + keypoints_to_file, + descriptors_to_file, +) +from kapture.io.features import ( + get_keypoints_fullpath, + keypoints_check_dir, + image_keypoints_to_file, +) +from kapture.io.features import ( + get_descriptors_fullpath, + descriptors_check_dir, + image_descriptors_to_file, +) from lib.model_test import D2Net from lib.utils import preprocess_image @@ -28,68 +40,89 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") # Argument parsing -parser = argparse.ArgumentParser(description='Feature extraction script') +parser = argparse.ArgumentParser(description="Feature extraction script") parser.add_argument( - '--kapture-root', type=str, required=True, - help='path to kapture root directory' + "--kapture-root", type=str, required=True, help="path to kapture root directory" ) parser.add_argument( - '--preprocessing', type=str, default='caffe', - help='image preprocessing (caffe or torch)' + "--preprocessing", + type=str, + default="caffe", + help="image preprocessing (caffe or torch)", ) parser.add_argument( - '--model_file', type=str, default='models/d2_tf.pth', - help='path to the full model' + "--model_file", type=str, default="models/d2_tf.pth", help="path to the full model" ) parser.add_argument( - '--keypoints-type', type=str, default=None, - help='keypoint type_name, default is filename of model' + "--keypoints-type", + type=str, + default=None, + help="keypoint type_name, default is filename of model", ) parser.add_argument( - '--descriptors-type', type=str, default=None, - help='descriptors type_name, default is filename of model' + "--descriptors-type", + type=str, + default=None, + help="descriptors type_name, default is filename of model", ) parser.add_argument( - '--max_edge', type=int, default=1600, - help='maximum image size at network input' + "--max_edge", type=int, default=1600, help="maximum image size at network input" ) parser.add_argument( - '--max_sum_edges', type=int, default=2800, - help='maximum sum of image sizes at network input' + "--max_sum_edges", + type=int, + default=2800, + help="maximum sum of image sizes at network input", ) parser.add_argument( - '--multiscale', dest='multiscale', action='store_true', - help='extract multiscale features' + "--multiscale", + dest="multiscale", + action="store_true", + help="extract multiscale features", ) parser.set_defaults(multiscale=False) parser.add_argument( - '--no-relu', dest='use_relu', action='store_false', - help='remove ReLU after the dense feature extraction module' + "--no-relu", + dest="use_relu", + action="store_false", + help="remove ReLU after the dense feature extraction module", ) parser.set_defaults(use_relu=True) -parser.add_argument("--max-keypoints", type=int, default=float("+inf"), - help='max number of keypoints save to disk') +parser.add_argument( + "--max-keypoints", + type=int, + default=float("+inf"), + help="max number of keypoints save to disk", +) args = parser.parse_args() print(args) -with get_all_tar_handlers(args.kapture_root, - mode={kapture.Keypoints: 'a', - kapture.Descriptors: 'a', - kapture.GlobalFeatures: 'r', - kapture.Matches: 'r'}) as tar_handlers: - kdata = kapture_from_dir(args.kapture_root, - skip_list=[kapture.GlobalFeatures, - kapture.Matches, - kapture.Points3d, - kapture.Observations], - tar_handlers=tar_handlers) +with get_all_tar_handlers( + args.kapture_root, + mode={ + kapture.Keypoints: "a", + kapture.Descriptors: "a", + kapture.GlobalFeatures: "r", + kapture.Matches: "r", + }, +) as tar_handlers: + kdata = kapture_from_dir( + args.kapture_root, + skip_list=[ + kapture.GlobalFeatures, + kapture.Matches, + kapture.Points3d, + kapture.Observations, + ], + tar_handlers=tar_handlers, + ) if kdata.keypoints is None: kdata.keypoints = {} if kdata.descriptors is None: @@ -99,28 +132,29 @@ with get_all_tar_handlers(args.kapture_root, image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] if args.keypoints_type is None: args.keypoints_type = path.splitext(path.basename(args.model_file))[0] - print(f'keypoints_type set to {args.keypoints_type}') + print(f"keypoints_type set to {args.keypoints_type}") if args.descriptors_type is None: args.descriptors_type = path.splitext(path.basename(args.model_file))[0] - print(f'descriptors_type set to {args.descriptors_type}') - if args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors: - image_list = [name - for name in image_list - if name not in kdata.keypoints[args.keypoints_type] or - name not in kdata.descriptors[args.descriptors_type]] + print(f"descriptors_type set to {args.descriptors_type}") + if ( + args.keypoints_type in kdata.keypoints + and args.descriptors_type in kdata.descriptors + ): + image_list = [ + name + for name in image_list + if name not in kdata.keypoints[args.keypoints_type] + or name not in kdata.descriptors[args.descriptors_type] + ] if len(image_list) == 0: - print('All features were already extracted') + print("All features were already extracted") exit(0) else: - print(f'Extracting d2net features for {len(image_list)} images') + print(f"Extracting d2net features for {len(image_list)} images") # Creating CNN model - model = D2Net( - model_file=args.model_file, - use_relu=args.use_relu, - use_cuda=use_cuda - ) + model = D2Net(model_file=args.model_file, use_relu=args.use_relu, use_cuda=use_cuda) if args.keypoints_type not in kdata.keypoints: keypoints_dtype = None @@ -138,7 +172,7 @@ with get_all_tar_handlers(args.kapture_root, # Process the files for image_name in tqdm(image_list, total=len(image_list)): img_path = get_image_fullpath(args.kapture_root, image_name) - image = Image.open(img_path).convert('RGB') + image = Image.open(img_path).convert("RGB") width, height = image.size @@ -162,30 +196,27 @@ with get_all_tar_handlers(args.kapture_root, fact_i = width / resized_width fact_j = height / resized_height - resized_image = np.array(resized_image).astype('float') + resized_image = np.array(resized_image).astype("float") - input_image = preprocess_image( - resized_image, - preprocessing=args.preprocessing - ) + input_image = preprocess_image(resized_image, preprocessing=args.preprocessing) with torch.no_grad(): if args.multiscale: keypoints, scores, descriptors = process_multiscale( torch.tensor( input_image[np.newaxis, :, :, :].astype(np.float32), - device=device + device=device, ), - model + model, ) else: keypoints, scores, descriptors = process_multiscale( torch.tensor( input_image[np.newaxis, :, :, :].astype(np.float32), - device=device + device=device, ), model, - scales=[1] + scales=[1], ) # Input image coordinates @@ -196,7 +227,7 @@ with get_all_tar_handlers(args.kapture_root, if args.max_keypoints != float("+inf"): # keep the last (the highest) indexes - idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints):] + idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints) :] keypoints = keypoints[idx_keep] descriptors = descriptors[idx_keep] @@ -207,42 +238,65 @@ with get_all_tar_handlers(args.kapture_root, keypoints_dsize = keypoints.shape[1] descriptors_dsize = descriptors.shape[1] - kdata.keypoints[args.keypoints_type] = kapture.Keypoints('d2net', keypoints_dtype, keypoints_dsize) - kdata.descriptors[args.descriptors_type] = kapture.Descriptors('d2net', descriptors_dtype, - descriptors_dsize, - args.keypoints_type, 'L2') - - keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints, - args.keypoints_type, - args.kapture_root) - descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors, - args.descriptors_type, - args.kapture_root) - - keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]) - descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type]) + kdata.keypoints[args.keypoints_type] = kapture.Keypoints( + "d2net", keypoints_dtype, keypoints_dsize + ) + kdata.descriptors[args.descriptors_type] = kapture.Descriptors( + "d2net", descriptors_dtype, descriptors_dsize, args.keypoints_type, "L2" + ) + + keypoints_config_absolute_path = get_feature_csv_fullpath( + kapture.Keypoints, args.keypoints_type, args.kapture_root + ) + descriptors_config_absolute_path = get_feature_csv_fullpath( + kapture.Descriptors, args.descriptors_type, args.kapture_root + ) + + keypoints_to_file( + keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] + ) + descriptors_to_file( + descriptors_config_absolute_path, + kdata.descriptors[args.descriptors_type], + ) else: assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1] - assert kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1] - assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type - assert kdata.descriptors[args.descriptors_type].metric_type == 'L2' - - keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root, - image_name, tar_handlers) + assert ( + kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1] + ) + assert ( + kdata.descriptors[args.descriptors_type].keypoints_type + == args.keypoints_type + ) + assert kdata.descriptors[args.descriptors_type].metric_type == "L2" + + keypoints_fullpath = get_keypoints_fullpath( + args.keypoints_type, args.kapture_root, image_name, tar_handlers + ) print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}") image_keypoints_to_file(keypoints_fullpath, keypoints) kdata.keypoints[args.keypoints_type].add(image_name) - descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root, - image_name, tar_handlers) + descriptors_fullpath = get_descriptors_fullpath( + args.descriptors_type, args.kapture_root, image_name, tar_handlers + ) print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}") image_descriptors_to_file(descriptors_fullpath, descriptors) kdata.descriptors[args.descriptors_type].add(image_name) - if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type, - args.kapture_root, tar_handlers) or \ - not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type, - args.kapture_root, tar_handlers): - print('local feature extraction ended successfully but not all files were saved') + if not keypoints_check_dir( + kdata.keypoints[args.keypoints_type], + args.keypoints_type, + args.kapture_root, + tar_handlers, + ) or not descriptors_check_dir( + kdata.descriptors[args.descriptors_type], + args.descriptors_type, + args.kapture_root, + tar_handlers, + ): + print( + "local feature extraction ended successfully but not all files were saved" + ) diff --git a/third_party/d2net/megadepth_utils/preprocess_scene.py b/third_party/d2net/megadepth_utils/preprocess_scene.py index fc68a403795e7cddce88dfcb74b38d19ab09e133..5364058829b7e45eabd61a32a591711645fc1ded 100644 --- a/third_party/d2net/megadepth_utils/preprocess_scene.py +++ b/third_party/d2net/megadepth_utils/preprocess_scene.py @@ -6,78 +6,63 @@ import numpy as np import os -parser = argparse.ArgumentParser(description='MegaDepth preprocessing script') +parser = argparse.ArgumentParser(description="MegaDepth preprocessing script") -parser.add_argument( - '--base_path', type=str, required=True, - help='path to MegaDepth' -) -parser.add_argument( - '--scene_id', type=str, required=True, - help='scene ID' -) +parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth") +parser.add_argument("--scene_id", type=str, required=True, help="scene ID") parser.add_argument( - '--output_path', type=str, required=True, - help='path to the output directory' + "--output_path", type=str, required=True, help="path to the output directory" ) args = parser.parse_args() base_path = args.base_path # Remove the trailing / if need be. -if base_path[-1] in ['/', '\\']: - base_path = base_path[: - 1] +if base_path[-1] in ["/", "\\"]: + base_path = base_path[:-1] scene_id = args.scene_id -base_depth_path = os.path.join( - base_path, 'phoenix/S6/zl548/MegaDepth_v1' -) -base_undistorted_sfm_path = os.path.join( - base_path, 'Undistorted_SfM' -) +base_depth_path = os.path.join(base_path, "phoenix/S6/zl548/MegaDepth_v1") +base_undistorted_sfm_path = os.path.join(base_path, "Undistorted_SfM") undistorted_sparse_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'sparse-txt' + base_undistorted_sfm_path, scene_id, "sparse-txt" ) if not os.path.exists(undistorted_sparse_path): exit() -depths_path = os.path.join( - base_depth_path, scene_id, 'dense0', 'depths' -) +depths_path = os.path.join(base_depth_path, scene_id, "dense0", "depths") if not os.path.exists(depths_path): exit() -images_path = os.path.join( - base_undistorted_sfm_path, scene_id, 'images' -) +images_path = os.path.join(base_undistorted_sfm_path, scene_id, "images") if not os.path.exists(images_path): exit() # Process cameras.txt -with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "cameras.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header camera_intrinsics = {} for camera in raw: - camera = camera.split(' ') - camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + camera = camera.split(" ") + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2:]] # Process points3D.txt -with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: - raw = f.readlines()[3 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "points3D.txt"), "r") as f: + raw = f.readlines()[3:] # skip the header points3D = {} for point3D in raw: - point3D = point3D.split(' ') - points3D[int(point3D[0])] = np.array([ - float(point3D[1]), float(point3D[2]), float(point3D[3]) - ]) - + point3D = point3D.split(" ") + points3D[int(point3D[0])] = np.array( + [float(point3D[1]), float(point3D[2]), float(point3D[3])] + ) + # Process images.txt -with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: - raw = f.readlines()[4 :] # skip the header +with open(os.path.join(undistorted_sparse_path, "images.txt"), "r") as f: + raw = f.readlines()[4:] # skip the header image_id_to_idx = {} image_names = [] @@ -85,19 +70,19 @@ raw_pose = [] camera = [] points3D_id_to_2D = [] n_points3D = [] -for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): - image = image.split(' ') - points = points.split(' ') +for idx, (image, points) in enumerate(zip(raw[::2], raw[1::2])): + image = image.split(" ") + points = points.split(" ") image_id_to_idx[int(image[0])] = idx - image_name = image[-1].strip('\n') + image_name = image[-1].strip("\n") image_names.append(image_name) - raw_pose.append([float(elem) for elem in image[1 : -2]]) + raw_pose.append([float(elem) for elem in image[1:-2]]) camera.append(int(image[-2])) current_points3D_id_to_2D = {} - for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]): + for x, y, point3D_id in zip(points[::3], points[1::3], points[2::3]): if int(point3D_id) == -1: continue current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)] @@ -110,12 +95,10 @@ image_paths = [] depth_paths = [] for image_name in image_names: image_path = os.path.join(images_path, image_name) - + # Path to the depth file - depth_path = os.path.join( - depths_path, '%s.h5' % os.path.splitext(image_name)[0] - ) - + depth_path = os.path.join(depths_path, "%s.h5" % os.path.splitext(image_name)[0]) + if os.path.exists(depth_path): # Check if depth map or background / foreground mask file_size = os.stat(depth_path).st_size @@ -152,32 +135,22 @@ for idx, image_name in enumerate(image_names): intrinsics.append(K) image_pose = raw_pose[idx] - qvec = image_pose[: 4] + qvec = image_pose[:4] qvec = qvec / np.linalg.norm(qvec) w, x, y, z = qvec - R = np.array([ - [ - 1 - 2 * y * y - 2 * z * z, - 2 * x * y - 2 * z * w, - 2 * x * z + 2 * y * w - ], + R = np.array( [ - 2 * x * y + 2 * z * w, - 1 - 2 * x * x - 2 * z * z, - 2 * y * z - 2 * x * w - ], - [ - 2 * x * z - 2 * y * w, - 2 * y * z + 2 * x * w, - 1 - 2 * x * x - 2 * y * y + [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w], + [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w], + [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y], ] - ]) + ) principal_axis.append(R[2, :]) - t = image_pose[4 : 7] + t = image_pose[4:7] # World-to-Camera pose current_pose = np.zeros([4, 4]) - current_pose[: 3, : 3] = R - current_pose[: 3, 3] = t + current_pose[:3, :3] = R + current_pose[:3, 3] = t current_pose[3, 3] = 1 # Camera-to-World pose # pose = np.zeros([4, 4]) @@ -185,38 +158,38 @@ for idx, image_name in enumerate(image_names): # pose[: 3, 3] = -np.matmul(np.transpose(R), t) # pose[3, 3] = 1 poses.append(current_pose) - + current_points3D_id_to_ndepth = {} for point3D_id in points3D_id_to_2D[idx].keys(): p3d = points3D[point3D_id] - current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) + current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / ( + 0.5 * (K[0, 0] + K[1, 1]) + ) points3D_id_to_ndepth.append(current_points3D_id_to_ndepth) principal_axis = np.array(principal_axis) -angles = np.rad2deg(np.arccos( - np.clip( - np.dot(principal_axis, np.transpose(principal_axis)), - -1, 1 - ) -)) +angles = np.rad2deg( + np.arccos(np.clip(np.dot(principal_axis, np.transpose(principal_axis)), -1, 1)) +) # Compute overlap score -overlap_matrix = np.full([n_images, n_images], -1.) -scale_ratio_matrix = np.full([n_images, n_images], -1.) +overlap_matrix = np.full([n_images, n_images], -1.0) +scale_ratio_matrix = np.full([n_images, n_images], -1.0) for idx1 in range(n_images): if image_paths[idx1] is None or depth_paths[idx1] is None: continue for idx2 in range(idx1 + 1, n_images): if image_paths[idx2] is None or depth_paths[idx2] is None: continue - matches = ( - points3D_id_to_2D[idx1].keys() & - points3D_id_to_2D[idx2].keys() - ) + matches = points3D_id_to_2D[idx1].keys() & points3D_id_to_2D[idx2].keys() min_num_points3D = min( len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2]) ) - overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D - overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D + overlap_matrix[idx1, idx2] = len(matches) / len( + points3D_id_to_2D[idx1] + ) # min_num_points3D + overlap_matrix[idx2, idx1] = len(matches) / len( + points3D_id_to_2D[idx2] + ) # min_num_points3D if len(matches) == 0: continue points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1] @@ -228,7 +201,7 @@ for idx1 in range(n_images): scale_ratio_matrix[idx2, idx1] = min_scale_ratio np.savez( - os.path.join(args.output_path, '%s.npz' % scene_id), + os.path.join(args.output_path, "%s.npz" % scene_id), image_paths=image_paths, depth_paths=depth_paths, intrinsics=intrinsics, @@ -238,5 +211,5 @@ np.savez( angles=angles, n_points3D=n_points3D, points3D_id_to_2D=points3D_id_to_2D, - points3D_id_to_ndepth=points3D_id_to_ndepth + points3D_id_to_ndepth=points3D_id_to_ndepth, ) diff --git a/third_party/d2net/megadepth_utils/undistort_reconstructions.py b/third_party/d2net/megadepth_utils/undistort_reconstructions.py index a6b99a72f81206e6fbefae9daa9aa683c8754051..822c9abd3fc75fd8fc1e8d9ada75aa76802c6798 100644 --- a/third_party/d2net/megadepth_utils/undistort_reconstructions.py +++ b/third_party/d2net/megadepth_utils/undistort_reconstructions.py @@ -6,28 +6,18 @@ import os import subprocess -parser = argparse.ArgumentParser(description='MegaDepth Undistortion') +parser = argparse.ArgumentParser(description="MegaDepth Undistortion") parser.add_argument( - '--colmap_path', type=str, required=True, - help='path to colmap executable' -) -parser.add_argument( - '--base_path', type=str, required=True, - help='path to MegaDepth' + "--colmap_path", type=str, required=True, help="path to colmap executable" ) +parser.add_argument("--base_path", type=str, required=True, help="path to MegaDepth") args = parser.parse_args() -sfm_path = os.path.join( - args.base_path, 'MegaDepth_v1_SfM' -) -base_depth_path = os.path.join( - args.base_path, 'phoenix/S6/zl548/MegaDepth_v1' -) -output_path = os.path.join( - args.base_path, 'Undistorted_SfM' -) +sfm_path = os.path.join(args.base_path, "MegaDepth_v1_SfM") +base_depth_path = os.path.join(args.base_path, "phoenix/S6/zl548/MegaDepth_v1") +output_path = os.path.join(args.base_path, "Undistorted_SfM") os.mkdir(output_path) @@ -35,35 +25,45 @@ for scene_name in os.listdir(base_depth_path): current_output_path = os.path.join(output_path, scene_name) os.mkdir(current_output_path) - image_path = os.path.join( - base_depth_path, scene_name, 'dense0', 'imgs' - ) + image_path = os.path.join(base_depth_path, scene_name, "dense0", "imgs") if not os.path.exists(image_path): continue - + # Find the maximum image size in scene. max_image_size = 0 for image_name in os.listdir(image_path): max_image_size = max( - max_image_size, - max(imagesize.get(os.path.join(image_path, image_name))) + max_image_size, max(imagesize.get(os.path.join(image_path, image_name))) ) # Undistort the images and update the reconstruction. - subprocess.call([ - os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', - '--image_path', os.path.join(sfm_path, scene_name, 'images'), - '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'), - '--output_path', current_output_path, - '--max_image_size', str(max_image_size) - ]) + subprocess.call( + [ + os.path.join(args.colmap_path, "colmap"), + "image_undistorter", + "--image_path", + os.path.join(sfm_path, scene_name, "images"), + "--input_path", + os.path.join(sfm_path, scene_name, "sparse", "manhattan", "0"), + "--output_path", + current_output_path, + "--max_image_size", + str(max_image_size), + ] + ) # Transform the reconstruction to raw text format. - sparse_txt_path = os.path.join(current_output_path, 'sparse-txt') + sparse_txt_path = os.path.join(current_output_path, "sparse-txt") os.mkdir(sparse_txt_path) - subprocess.call([ - os.path.join(args.colmap_path, 'colmap'), 'model_converter', - '--input_path', os.path.join(current_output_path, 'sparse'), - '--output_path', sparse_txt_path, - '--output_type', 'TXT' - ]) \ No newline at end of file + subprocess.call( + [ + os.path.join(args.colmap_path, "colmap"), + "model_converter", + "--input_path", + os.path.join(current_output_path, "sparse"), + "--output_path", + sparse_txt_path, + "--output_type", + "TXT", + ] + ) diff --git a/third_party/d2net/train.py b/third_party/d2net/train.py index 5817f1712bda0779175fb18437d1f8c263f29f3b..5ca584e131c14930f86c3252f93b89f1aea40713 100644 --- a/third_party/d2net/train.py +++ b/third_party/d2net/train.py @@ -32,72 +32,64 @@ if use_cuda: np.random.seed(1) # Argument parsing -parser = argparse.ArgumentParser(description='Training script') +parser = argparse.ArgumentParser(description="Training script") parser.add_argument( - '--dataset_path', type=str, required=True, - help='path to the dataset' + "--dataset_path", type=str, required=True, help="path to the dataset" ) parser.add_argument( - '--scene_info_path', type=str, required=True, - help='path to the processed scenes' + "--scene_info_path", type=str, required=True, help="path to the processed scenes" ) parser.add_argument( - '--preprocessing', type=str, default='caffe', - help='image preprocessing (caffe or torch)' + "--preprocessing", + type=str, + default="caffe", + help="image preprocessing (caffe or torch)", ) parser.add_argument( - '--model_file', type=str, default='models/d2_ots.pth', - help='path to the full model' + "--model_file", type=str, default="models/d2_ots.pth", help="path to the full model" ) parser.add_argument( - '--num_epochs', type=int, default=10, - help='number of training epochs' + "--num_epochs", type=int, default=10, help="number of training epochs" ) +parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument( - '--lr', type=float, default=1e-3, - help='initial learning rate' -) -parser.add_argument( - '--batch_size', type=int, default=1, - help='batch size' -) -parser.add_argument( - '--num_workers', type=int, default=4, - help='number of workers for data loading' + "--num_workers", type=int, default=4, help="number of workers for data loading" ) parser.add_argument( - '--use_validation', dest='use_validation', action='store_true', - help='use the validation split' + "--use_validation", + dest="use_validation", + action="store_true", + help="use the validation split", ) parser.set_defaults(use_validation=False) parser.add_argument( - '--log_interval', type=int, default=250, - help='loss logging interval' + "--log_interval", type=int, default=250, help="loss logging interval" ) -parser.add_argument( - '--log_file', type=str, default='log.txt', - help='loss logging file' -) +parser.add_argument("--log_file", type=str, default="log.txt", help="loss logging file") parser.add_argument( - '--plot', dest='plot', action='store_true', - help='plot training pairs' + "--plot", dest="plot", action="store_true", help="plot training pairs" ) parser.set_defaults(plot=False) parser.add_argument( - '--checkpoint_directory', type=str, default='checkpoints', - help='directory for training checkpoints' + "--checkpoint_directory", + type=str, + default="checkpoints", + help="directory for training checkpoints", ) parser.add_argument( - '--checkpoint_prefix', type=str, default='d2', - help='prefix for training checkpoints' + "--checkpoint_prefix", + type=str, + default="d2", + help="prefix for training checkpoints", ) args = parser.parse_args() @@ -106,17 +98,14 @@ print(args) # Create the folders for plotting if need be if args.plot: - plot_path = 'train_vis' + plot_path = "train_vis" if os.path.isdir(plot_path): - print('[Warning] Plotting directory already exists.') + print("[Warning] Plotting directory already exists.") else: os.mkdir(plot_path) # Creating CNN model -model = D2Net( - model_file=args.model_file, - use_cuda=use_cuda -) +model = D2Net(model_file=args.model_file, use_cuda=use_cuda) # Optimizer optimizer = optim.Adam( @@ -126,37 +115,39 @@ optimizer = optim.Adam( # Dataset if args.use_validation: validation_dataset = MegaDepthDataset( - scene_list_path='megadepth_utils/valid_scenes.txt', + scene_list_path="megadepth_utils/valid_scenes.txt", scene_info_path=args.scene_info_path, base_path=args.dataset_path, train=False, preprocessing=args.preprocessing, - pairs_per_scene=25 + pairs_per_scene=25, ) validation_dataloader = DataLoader( - validation_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers + validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers ) training_dataset = MegaDepthDataset( - scene_list_path='megadepth_utils/train_scenes.txt', + scene_list_path="megadepth_utils/train_scenes.txt", scene_info_path=args.scene_info_path, base_path=args.dataset_path, - preprocessing=args.preprocessing + preprocessing=args.preprocessing, ) training_dataloader = DataLoader( - training_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers + training_dataset, batch_size=args.batch_size, num_workers=args.num_workers ) # Define epoch function def process_epoch( - epoch_idx, - model, loss_function, optimizer, dataloader, device, - log_file, args, train=True + epoch_idx, + model, + loss_function, + optimizer, + dataloader, + device, + log_file, + args, + train=True, ): epoch_losses = [] @@ -167,12 +158,12 @@ def process_epoch( if train: optimizer.zero_grad() - batch['train'] = train - batch['epoch_idx'] = epoch_idx - batch['batch_idx'] = batch_idx - batch['batch_size'] = args.batch_size - batch['preprocessing'] = args.preprocessing - batch['log_interval'] = args.log_interval + batch["train"] = train + batch["epoch_idx"] = epoch_idx + batch["batch_idx"] = batch_idx + batch["batch_size"] = args.batch_size + batch["preprocessing"] = args.preprocessing + batch["log_interval"] = args.log_interval try: loss = loss_function(model, batch, device, plot=args.plot) @@ -182,23 +173,28 @@ def process_epoch( current_loss = loss.data.cpu().numpy()[0] epoch_losses.append(current_loss) - progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) + progress_bar.set_postfix(loss=("%.4f" % np.mean(epoch_losses))) if batch_idx % args.log_interval == 0: - log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( - 'train' if train else 'valid', - epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) - )) + log_file.write( + "[%s] epoch %d - batch %d / %d - avg_loss: %f\n" + % ( + "train" if train else "valid", + epoch_idx, + batch_idx, + len(dataloader), + np.mean(epoch_losses), + ) + ) if train: loss.backward() optimizer.step() - log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( - 'train' if train else 'valid', - epoch_idx, - np.mean(epoch_losses) - )) + log_file.write( + "[%s] epoch %d - avg_loss: %f\n" + % ("train" if train else "valid", epoch_idx, np.mean(epoch_losses)) + ) log_file.flush() return np.mean(epoch_losses) @@ -206,15 +202,15 @@ def process_epoch( # Create the checkpoint directory if os.path.isdir(args.checkpoint_directory): - print('[Warning] Checkpoint directory already exists.') + print("[Warning] Checkpoint directory already exists.") else: os.mkdir(args.checkpoint_directory) - + # Open the log file for writing if os.path.exists(args.log_file): - print('[Warning] Log file already exists.') -log_file = open(args.log_file, 'a+') + print("[Warning] Log file already exists.") +log_file = open(args.log_file, "a+") # Initialize the history train_loss_history = [] @@ -223,9 +219,14 @@ if args.use_validation: validation_dataset.build_dataset() min_validation_loss = process_epoch( 0, - model, loss_function, optimizer, validation_dataloader, device, - log_file, args, - train=False + model, + loss_function, + optimizer, + validation_dataloader, + device, + log_file, + args, + train=False, ) # Start the training @@ -235,8 +236,13 @@ for epoch_idx in range(1, args.num_epochs + 1): train_loss_history.append( process_epoch( epoch_idx, - model, loss_function, optimizer, training_dataloader, device, - log_file, args + model, + loss_function, + optimizer, + training_dataloader, + device, + log_file, + args, ) ) @@ -244,34 +250,34 @@ for epoch_idx in range(1, args.num_epochs + 1): validation_loss_history.append( process_epoch( epoch_idx, - model, loss_function, optimizer, validation_dataloader, device, - log_file, args, - train=False + model, + loss_function, + optimizer, + validation_dataloader, + device, + log_file, + args, + train=False, ) ) # Save the current checkpoint checkpoint_path = os.path.join( - args.checkpoint_directory, - '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx) + args.checkpoint_directory, "%s.%02d.pth" % (args.checkpoint_prefix, epoch_idx) ) checkpoint = { - 'args': args, - 'epoch_idx': epoch_idx, - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'train_loss_history': train_loss_history, - 'validation_loss_history': validation_loss_history + "args": args, + "epoch_idx": epoch_idx, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "train_loss_history": train_loss_history, + "validation_loss_history": validation_loss_history, } torch.save(checkpoint, checkpoint_path) - if ( - args.use_validation and - validation_loss_history[-1] < min_validation_loss - ): + if args.use_validation and validation_loss_history[-1] < min_validation_loss: min_validation_loss = validation_loss_history[-1] best_checkpoint_path = os.path.join( - args.checkpoint_directory, - '%s.best.pth' % args.checkpoint_prefix + args.checkpoint_directory, "%s.best.pth" % args.checkpoint_prefix ) shutil.copy(checkpoint_path, best_checkpoint_path) diff --git a/third_party/lanet/augmentations.py b/third_party/lanet/augmentations.py index f4e4496c77ce8fc8cdadb230dd0d0750166152a9..c39b7bfee0b42730f81e8f614352a58c25187b59 100644 --- a/third_party/lanet/augmentations.py +++ b/third_party/lanet/augmentations.py @@ -54,110 +54,163 @@ def resize_sample(sample, image_shape, image_interpolation=Image.ANTIALIAS): """ # image image_transform = transforms.Resize(image_shape, interpolation=image_interpolation) - sample['image'] = image_transform(sample['image']) + sample["image"] = image_transform(sample["image"]) return sample + def spatial_augment_sample(sample): - """ Apply spatial augmentation to an image (flipping and random affine transformation).""" - augment_image = transforms.Compose([ - transforms.RandomVerticalFlip(p=0.5), - transforms.RandomHorizontalFlip(p=0.5), - transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)) - - ]) - sample['image'] = augment_image(sample['image']) + """Apply spatial augmentation to an image (flipping and random affine transformation).""" + augment_image = transforms.Compose( + [ + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)), + ] + ) + sample["image"] = augment_image(sample["image"]) return sample + def unnormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): - """ Counterpart method of torchvision.transforms.Normalize.""" + """Counterpart method of torchvision.transforms.Normalize.""" for t, m, s in zip(tensor, mean, std): t.div_(1 / s).sub_(-m) return tensor def sample_homography( - shape, perspective=True, scaling=True, rotation=True, translation=True, - n_scales=100, n_angles=100, scaling_amplitude=0.1, perspective_amplitude=0.4, - patch_ratio=0.8, max_angle=pi/4): - """ Sample a random homography that includes perspective, scale, translation and rotation operations.""" + shape, + perspective=True, + scaling=True, + rotation=True, + translation=True, + n_scales=100, + n_angles=100, + scaling_amplitude=0.1, + perspective_amplitude=0.4, + patch_ratio=0.8, + max_angle=pi / 4, +): + """Sample a random homography that includes perspective, scale, translation and rotation operations.""" width = float(shape[1]) hw_ratio = float(shape[0]) / float(shape[1]) - pts1 = np.stack([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.]], axis=0) + pts1 = np.stack([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]], axis=0) pts2 = pts1.copy() * patch_ratio - pts2[:,1] *= hw_ratio + pts2[:, 1] *= hw_ratio if perspective: - perspective_amplitude_x = np.random.normal(0., perspective_amplitude/2, (2)) - perspective_amplitude_y = np.random.normal(0., hw_ratio * perspective_amplitude/2, (2)) + perspective_amplitude_x = np.random.normal(0.0, perspective_amplitude / 2, (2)) + perspective_amplitude_y = np.random.normal( + 0.0, hw_ratio * perspective_amplitude / 2, (2) + ) - perspective_amplitude_x = np.clip(perspective_amplitude_x, -perspective_amplitude/2, perspective_amplitude/2) - perspective_amplitude_y = np.clip(perspective_amplitude_y, hw_ratio * -perspective_amplitude/2, hw_ratio * perspective_amplitude/2) + perspective_amplitude_x = np.clip( + perspective_amplitude_x, + -perspective_amplitude / 2, + perspective_amplitude / 2, + ) + perspective_amplitude_y = np.clip( + perspective_amplitude_y, + hw_ratio * -perspective_amplitude / 2, + hw_ratio * perspective_amplitude / 2, + ) - pts2[0,0] -= perspective_amplitude_x[1] - pts2[0,1] -= perspective_amplitude_y[1] + pts2[0, 0] -= perspective_amplitude_x[1] + pts2[0, 1] -= perspective_amplitude_y[1] - pts2[1,0] -= perspective_amplitude_x[0] - pts2[1,1] += perspective_amplitude_y[1] + pts2[1, 0] -= perspective_amplitude_x[0] + pts2[1, 1] += perspective_amplitude_y[1] - pts2[2,0] += perspective_amplitude_x[1] - pts2[2,1] -= perspective_amplitude_y[0] + pts2[2, 0] += perspective_amplitude_x[1] + pts2[2, 1] -= perspective_amplitude_y[0] - pts2[3,0] += perspective_amplitude_x[0] - pts2[3,1] += perspective_amplitude_y[0] + pts2[3, 0] += perspective_amplitude_x[0] + pts2[3, 1] += perspective_amplitude_y[0] if scaling: - random_scales = np.random.normal(1, scaling_amplitude/2, (n_scales)) - random_scales = np.clip(random_scales, 1-scaling_amplitude/2, 1+scaling_amplitude/2) + random_scales = np.random.normal(1, scaling_amplitude / 2, (n_scales)) + random_scales = np.clip( + random_scales, 1 - scaling_amplitude / 2, 1 + scaling_amplitude / 2 + ) - scales = np.concatenate([[1.], random_scales], 0) + scales = np.concatenate([[1.0], random_scales], 0) center = np.mean(pts2, axis=0, keepdims=True) - scaled = np.expand_dims(pts2 - center, axis=0) * np.expand_dims( - np.expand_dims(scales, 1), 1) + center + scaled = ( + np.expand_dims(pts2 - center, axis=0) + * np.expand_dims(np.expand_dims(scales, 1), 1) + + center + ) valid = np.arange(n_scales) # all scales are valid except scale=1 idx = valid[np.random.randint(valid.shape[0])] pts2 = scaled[idx] if translation: - t_min, t_max = np.min(pts2 - [-1., -hw_ratio], axis=0), np.min([1., hw_ratio] - pts2, axis=0) - pts2 += np.expand_dims(np.stack([np.random.uniform(-t_min[0], t_max[0]), - np.random.uniform(-t_min[1], t_max[1])]), - axis=0) + t_min, t_max = np.min(pts2 - [-1.0, -hw_ratio], axis=0), np.min( + [1.0, hw_ratio] - pts2, axis=0 + ) + pts2 += np.expand_dims( + np.stack( + [ + np.random.uniform(-t_min[0], t_max[0]), + np.random.uniform(-t_min[1], t_max[1]), + ] + ), + axis=0, + ) if rotation: angles = np.linspace(-max_angle, max_angle, n_angles) - angles = np.concatenate([[0.], angles], axis=0) + angles = np.concatenate([[0.0], angles], axis=0) center = np.mean(pts2, axis=0, keepdims=True) - rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), - np.cos(angles)], axis=1), [-1, 2, 2]) - rotated = np.matmul( - np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles+1, 1, 1]), - rot_mat) + center + rot_mat = np.reshape( + np.stack( + [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)], + axis=1, + ), + [-1, 2, 2], + ) + rotated = ( + np.matmul( + np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles + 1, 1, 1]), + rot_mat, + ) + + center + ) - valid = np.where(np.all((rotated >= [-1.,-hw_ratio]) & (rotated < [1.,hw_ratio]), - axis=(1, 2)))[0] + valid = np.where( + np.all( + (rotated >= [-1.0, -hw_ratio]) & (rotated < [1.0, hw_ratio]), + axis=(1, 2), + ) + )[0] idx = valid[np.random.randint(valid.shape[0])] pts2 = rotated[idx] - pts2[:,1] /= hw_ratio + pts2[:, 1] /= hw_ratio + + def ax(p, q): + return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] - def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] - def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + def ay(p, q): + return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) - p_mat = np.transpose(np.stack( - [[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)) + p_mat = np.transpose( + np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) + ) homography = np.matmul(np.linalg.pinv(a_mat), p_mat).squeeze() - homography = np.concatenate([homography, [1.]]).reshape(3,3) + homography = np.concatenate([homography, [1.0]]).reshape(3, 3) return homography + def warp_homography(sources, homography): """Warp features given a homography @@ -175,12 +228,15 @@ def warp_homography(sources, homography): """ _, H, W, _ = sources.shape warped_sources = sources.clone().squeeze() - warped_sources = warped_sources.view(-1,2) - warped_sources = torch.addmm(homography[:,2], warped_sources, homography[:,:2].t()) - warped_sources.mul_(1/warped_sources[:,2].unsqueeze(1)) - warped_sources = warped_sources[:,:2].contiguous().view(1,H,W,2) + warped_sources = warped_sources.view(-1, 2) + warped_sources = torch.addmm( + homography[:, 2], warped_sources, homography[:, :2].t() + ) + warped_sources.mul_(1 / warped_sources[:, 2].unsqueeze(1)) + warped_sources = warped_sources[:, :2].contiguous().view(1, H, W, 2) return warped_sources + def add_noise(img, mode="gaussian", percent=0.02): """Add image noise @@ -259,36 +315,40 @@ def add_noise(img, mode="gaussian", percent=0.02): return noisy -def non_spatial_augmentation(img_warp_ori, jitter_paramters, color_order=[0,1,2], to_gray=False): - """ Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur).""" +def non_spatial_augmentation( + img_warp_ori, jitter_paramters, color_order=[0, 1, 2], to_gray=False +): + """Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur).""" brightness, contrast, saturation, hue = jitter_paramters color_augmentation = transforms.ColorJitter(brightness, contrast, saturation, hue) - ''' + """ augment_image = color_augmentation.get_params(brightness=[max(0, 1 - brightness), 1 + brightness], contrast=[max(0, 1 - contrast), 1 + contrast], saturation=[max(0, 1 - saturation), 1 + saturation], hue=[-hue, hue]) - ''' + """ B = img_warp_ori.shape[0] img_warp = [] - kernel_sizes = [0,1,3,5] + kernel_sizes = [0, 1, 3, 5] for b in range(B): img_warp_sub = img_warp_ori[b].cpu() img_warp_sub = torchvision.transforms.functional.to_pil_image(img_warp_sub) - img_warp_sub_np = np.array(img_warp_sub) - img_warp_sub_np = img_warp_sub_np[:,:,color_order] - + img_warp_sub_np = np.array(img_warp_sub) + img_warp_sub_np = img_warp_sub_np[:, :, color_order] + if np.random.rand() > 0.5: img_warp_sub_np = add_noise(img_warp_sub_np) rand_index = np.random.randint(4) kernel_size = kernel_sizes[rand_index] - if kernel_size >0: - img_warp_sub_np = cv2.GaussianBlur(img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0) - + if kernel_size > 0: + img_warp_sub_np = cv2.GaussianBlur( + img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0 + ) + if to_gray: img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_RGB2GRAY) img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_GRAY2RGB) @@ -296,35 +356,54 @@ def non_spatial_augmentation(img_warp_ori, jitter_paramters, color_order=[0,1,2] img_warp_sub = Image.fromarray(img_warp_sub_np) img_warp_sub = color_augmentation(img_warp_sub) - img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to(img_warp_ori.device) + img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to( + img_warp_ori.device + ) img_warp.append(img_warp_sub) img_warp = torch.stack(img_warp, dim=0) return img_warp -def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=0.7, scaling_amplitude=0.2, max_angle=pi/4): + +def ha_augment_sample( + data, + jitter_paramters=[0.5, 0.5, 0.2, 0.05], + patch_ratio=0.7, + scaling_amplitude=0.2, + max_angle=pi / 4, +): """Apply Homography Adaptation image augmentation.""" - input_img = data['image'].unsqueeze(0) + input_img = data["image"].unsqueeze(0) _, _, H, W = input_img.shape device = input_img.device - - homography = torch.from_numpy( - sample_homography([H, W], - patch_ratio=patch_ratio, - scaling_amplitude=scaling_amplitude, - max_angle=max_angle)).float().to(device) + + homography = ( + torch.from_numpy( + sample_homography( + [H, W], + patch_ratio=patch_ratio, + scaling_amplitude=scaling_amplitude, + max_angle=max_angle, + ) + ) + .float() + .to(device) + ) homography_inv = torch.inverse(homography) - source = image_grid(1, H, W, - dtype=input_img.dtype, - device=device, - ones=False, normalized=True).clone().permute(0, 2, 3, 1) + source = ( + image_grid( + 1, H, W, dtype=input_img.dtype, device=device, ones=False, normalized=True + ) + .clone() + .permute(0, 2, 3, 1) + ) target_warped = warp_homography(source, homography) img_warp = torch.nn.functional.grid_sample(input_img, target_warped) - color_order = [0,1,2] + color_order = [0, 1, 2] if np.random.rand() > 0.5: random.shuffle(color_order) @@ -332,11 +411,21 @@ def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio= if np.random.rand() > 0.5: to_gray = True - input_img = non_spatial_augmentation(input_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray) - img_warp = non_spatial_augmentation(img_warp, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray) - - data['image'] = input_img.squeeze() - data['image_aug'] = img_warp.squeeze() - data['homography'] = homography - data['homography_inv'] = homography_inv + input_img = non_spatial_augmentation( + input_img, + jitter_paramters=jitter_paramters, + color_order=color_order, + to_gray=to_gray, + ) + img_warp = non_spatial_augmentation( + img_warp, + jitter_paramters=jitter_paramters, + color_order=color_order, + to_gray=to_gray, + ) + + data["image"] = input_img.squeeze() + data["image_aug"] = img_warp.squeeze() + data["homography"] = homography + data["homography_inv"] = homography_inv return data diff --git a/third_party/lanet/config.py b/third_party/lanet/config.py index baa3aedc95410b231c29ab64b31ea5a2bd3266d7..84419d0a1f7199e8bec1afc7b046e674a629d886 100644 --- a/third_party/lanet/config.py +++ b/third_party/lanet/config.py @@ -1,78 +1,94 @@ import argparse arg_lists = [] -parser = argparse.ArgumentParser(description='LANet') +parser = argparse.ArgumentParser(description="LANet") + def str2bool(v): - return v.lower() in ('true', '1') + return v.lower() in ("true", "1") + def add_argument_group(name): arg = parser.add_argument_group(name) arg_lists.append(arg) return arg + # train data params -traindata_arg = add_argument_group('Traindata Params') -traindata_arg.add_argument('--train_txt', type=str, default='', - help='Train set.') -traindata_arg.add_argument('--train_root', type=str, default='', - help='Where the train images are.') -traindata_arg.add_argument('--batch_size', type=int, default=8, - help='# of images in each batch of data') -traindata_arg.add_argument('--num_workers', type=int, default=4, - help='# of subprocesses to use for data loading') -traindata_arg.add_argument('--pin_memory', type=str2bool, default=True, - help='# of subprocesses to use for data loading') -traindata_arg.add_argument('--shuffle', type=str2bool, default=True, - help='Whether to shuffle the train and valid indices') -traindata_arg.add_argument('--image_shape', type=tuple, default=(240, 320), - help='') -traindata_arg.add_argument('--jittering', type=tuple, default=(0.5, 0.5, 0.2, 0.05), - help='') +traindata_arg = add_argument_group("Traindata Params") +traindata_arg.add_argument("--train_txt", type=str, default="", help="Train set.") +traindata_arg.add_argument( + "--train_root", type=str, default="", help="Where the train images are." +) +traindata_arg.add_argument( + "--batch_size", type=int, default=8, help="# of images in each batch of data" +) +traindata_arg.add_argument( + "--num_workers", + type=int, + default=4, + help="# of subprocesses to use for data loading", +) +traindata_arg.add_argument( + "--pin_memory", + type=str2bool, + default=True, + help="# of subprocesses to use for data loading", +) +traindata_arg.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="Whether to shuffle the train and valid indices", +) +traindata_arg.add_argument("--image_shape", type=tuple, default=(240, 320), help="") +traindata_arg.add_argument( + "--jittering", type=tuple, default=(0.5, 0.5, 0.2, 0.05), help="" +) # data storage -storage_arg = add_argument_group('Storage') -storage_arg.add_argument('--ckpt_name', type=str, default='PointModel', - help='') +storage_arg = add_argument_group("Storage") +storage_arg.add_argument("--ckpt_name", type=str, default="PointModel", help="") # training params -train_arg = add_argument_group('Training Params') -train_arg.add_argument('--start_epoch', type=int, default=0, - help='') -train_arg.add_argument('--max_epoch', type=int, default=12, - help='') -train_arg.add_argument('--init_lr', type=float, default=3e-4, - help='Initial learning rate value.') -train_arg.add_argument('--lr_factor', type=float, default=0.5, - help='Reduce learning rate value.') -train_arg.add_argument('--momentum', type=float, default=0.9, - help='Nesterov momentum value.') -train_arg.add_argument('--display', type=int, default=50, - help='') +train_arg = add_argument_group("Training Params") +train_arg.add_argument("--start_epoch", type=int, default=0, help="") +train_arg.add_argument("--max_epoch", type=int, default=12, help="") +train_arg.add_argument( + "--init_lr", type=float, default=3e-4, help="Initial learning rate value." +) +train_arg.add_argument( + "--lr_factor", type=float, default=0.5, help="Reduce learning rate value." +) +train_arg.add_argument( + "--momentum", type=float, default=0.9, help="Nesterov momentum value." +) +train_arg.add_argument("--display", type=int, default=50, help="") # loss function params -loss_arg = add_argument_group('Loss function Params') -loss_arg.add_argument('--score_weight', type=float, default=1., - help='') -loss_arg.add_argument('--loc_weight', type=float, default=1., - help='') -loss_arg.add_argument('--desc_weight', type=float, default=4., - help='') -loss_arg.add_argument('--corres_weight', type=float, default=.5, - help='') -loss_arg.add_argument('--corres_threshold', type=int, default=4., - help='') - +loss_arg = add_argument_group("Loss function Params") +loss_arg.add_argument("--score_weight", type=float, default=1.0, help="") +loss_arg.add_argument("--loc_weight", type=float, default=1.0, help="") +loss_arg.add_argument("--desc_weight", type=float, default=4.0, help="") +loss_arg.add_argument("--corres_weight", type=float, default=0.5, help="") +loss_arg.add_argument("--corres_threshold", type=int, default=4.0, help="") + # other params -misc_arg = add_argument_group('Misc.') -misc_arg.add_argument('--use_gpu', type=str2bool, default=True, - help="Whether to run on the GPU.") -misc_arg.add_argument('--gpu', type=int, default=0, - help="Which GPU to run on.") -misc_arg.add_argument('--seed', type=int, default=1001, - help='Seed to ensure reproducibility.') -misc_arg.add_argument('--ckpt_dir', type=str, default='./checkpoints', - help='Directory in which to save model checkpoints.') +misc_arg = add_argument_group("Misc.") +misc_arg.add_argument( + "--use_gpu", type=str2bool, default=True, help="Whether to run on the GPU." +) +misc_arg.add_argument("--gpu", type=int, default=0, help="Which GPU to run on.") +misc_arg.add_argument( + "--seed", type=int, default=1001, help="Seed to ensure reproducibility." +) +misc_arg.add_argument( + "--ckpt_dir", + type=str, + default="./checkpoints", + help="Directory in which to save model checkpoints.", +) + def get_config(): config, unparsed = parser.parse_known_args() diff --git a/third_party/lanet/data_loader.py b/third_party/lanet/data_loader.py index e694e39bb5f3e7ad6763a5cfcce3ca4804071262..d8e7bcac2274a512127920e1695a8923fd009f8a 100644 --- a/third_party/lanet/data_loader.py +++ b/third_party/lanet/data_loader.py @@ -4,6 +4,7 @@ from torch.utils.data import Dataset, DataLoader from augmentations import ha_augment_sample, resize_sample, spatial_augment_sample from utils import to_tensor_sample + def image_transforms(shape, jittering): def train_transforms(sample): sample = resize_sample(sample, image_shape=shape) @@ -12,14 +13,15 @@ def image_transforms(shape, jittering): sample = ha_augment_sample(sample, jitter_paramters=jittering) return sample - return {'train': train_transforms} + return {"train": train_transforms} + class GetData(Dataset): def __init__(self, config, transforms=None): """ Get the list containing all images and labels. """ - datafile = open(config.train_txt, 'r') + datafile = open(config.train_txt, "r") lines = datafile.readlines() dataset = [] @@ -31,9 +33,9 @@ class GetData(Dataset): self.config = config self.dataset = dataset self.root = config.train_root - + self.transforms = transforms - + def __getitem__(self, index): """ Return image'data and its label. @@ -41,14 +43,14 @@ class GetData(Dataset): img_path = self.dataset[index] img_file = self.root + img_path img = Image.open(img_file) - - # image.mode == 'L' means the image is in gray scale - if img.mode == 'L': + + # image.mode == 'L' means the image is in gray scale + if img.mode == "L": img_new = Image.new("RGB", img.size) img_new.paste(img) - sample = {'image': img_new, 'idx': index} + sample = {"image": img_new, "idx": index} else: - sample = {'image': img, 'idx': index} + sample = {"image": img, "idx": index} if self.transforms: sample = self.transforms(sample) @@ -61,26 +63,27 @@ class GetData(Dataset): """ return len(self.dataset) + def get_data_loader( - config, - transforms=None, - sampler=None, - drop_last=True, - ): + config, + transforms=None, + sampler=None, + drop_last=True, +): """ Return batch data for training. """ transforms = image_transforms(shape=config.image_shape, jittering=config.jittering) - dataset = GetData(config, transforms=transforms['train']) + dataset = GetData(config, transforms=transforms["train"]) train_loader = DataLoader( - dataset, - batch_size=config.batch_size, - shuffle=config.shuffle, - sampler=sampler, - num_workers=config.num_workers, - pin_memory=config.pin_memory, - drop_last=drop_last - ) + dataset, + batch_size=config.batch_size, + shuffle=config.shuffle, + sampler=sampler, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + drop_last=drop_last, + ) return train_loader diff --git a/third_party/lanet/datasets/hp_loader.py b/third_party/lanet/datasets/hp_loader.py index b4c1d8f3c33fd51bfa928c529544a77c06ed73f0..f255c87dac6e06e56b67ad0f04f7da5c131f0189 100644 --- a/third_party/lanet/datasets/hp_loader.py +++ b/third_party/lanet/datasets/hp_loader.py @@ -30,7 +30,15 @@ class PatchesDataset(Dataset): v - viewpoint sequences all - all sequences """ - def __init__(self, root_dir, use_color=True, data_transform=None, output_shape=None, type='all'): + + def __init__( + self, + root_dir, + use_color=True, + data_transform=None, + output_shape=None, + type="all", + ): super().__init__() self.type = type self.root_dir = root_dir @@ -43,33 +51,36 @@ class PatchesDataset(Dataset): warped_image_paths = [] homographies = [] for path in folder_paths: - if self.type == 'i' and path.stem[0] != 'i': + if self.type == "i" and path.stem[0] != "i": continue - if self.type == 'v' and path.stem[0] != 'v': + if self.type == "v" and path.stem[0] != "v": continue num_images = 5 - file_ext = '.ppm' + file_ext = ".ppm" for i in range(2, 2 + num_images): image_paths.append(str(Path(path, "1" + file_ext))) warped_image_paths.append(str(Path(path, str(i) + file_ext))) homographies.append(np.loadtxt(str(Path(path, "H_1_" + str(i))))) - self.files = {'image_paths': image_paths, 'warped_image_paths': warped_image_paths, 'homography': homographies} + self.files = { + "image_paths": image_paths, + "warped_image_paths": warped_image_paths, + "homography": homographies, + } def scale_homography(self, homography, original_scale, new_scale, pre): scales = np.divide(new_scale, original_scale) if pre: - s = np.diag(np.append(scales, 1.)) + s = np.diag(np.append(scales, 1.0)) homography = np.matmul(s, homography) else: - sinv = np.diag(np.append(1. / scales, 1.)) + sinv = np.diag(np.append(1.0 / scales, 1.0)) homography = np.matmul(homography, sinv) return homography def __len__(self): - return len(self.files['image_paths']) + return len(self.files["image_paths"]) def __getitem__(self, idx): - def _read_image(path): img = cv2.imread(path, cv2.IMREAD_COLOR) if self.use_color: @@ -77,30 +88,39 @@ class PatchesDataset(Dataset): gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) return gray - image = _read_image(self.files['image_paths'][idx]) + image = _read_image(self.files["image_paths"][idx]) - warped_image = _read_image(self.files['warped_image_paths'][idx]) - homography = np.array(self.files['homography'][idx]) - sample = {'image': image, 'warped_image': warped_image, 'homography': homography, 'index' : idx} + warped_image = _read_image(self.files["warped_image_paths"][idx]) + homography = np.array(self.files["homography"][idx]) + sample = { + "image": image, + "warped_image": warped_image, + "homography": homography, + "index": idx, + } # Apply transformations if self.output_shape is not None: - sample['homography'] = self.scale_homography(sample['homography'], - sample['image'].shape[:2][::-1], - self.output_shape, - pre=False) - sample['homography'] = self.scale_homography(sample['homography'], - sample['warped_image'].shape[:2][::-1], - self.output_shape, - pre=True) + sample["homography"] = self.scale_homography( + sample["homography"], + sample["image"].shape[:2][::-1], + self.output_shape, + pre=False, + ) + sample["homography"] = self.scale_homography( + sample["homography"], + sample["warped_image"].shape[:2][::-1], + self.output_shape, + pre=True, + ) - for key in ['image', 'warped_image']: + for key in ["image", "warped_image"]: sample[key] = cv2.resize(sample[key], self.output_shape) if self.use_color is False: sample[key] = np.expand_dims(sample[key], axis=2) transform = transforms.ToTensor() - for key in ['image', 'warped_image']: - sample[key] = transform(sample[key]).type('torch.FloatTensor') + for key in ["image", "warped_image"]: + sample[key] = transform(sample[key]).type("torch.FloatTensor") return sample diff --git a/third_party/lanet/datasets/prepare_coco.py b/third_party/lanet/datasets/prepare_coco.py index 0468aba19c6c2c76bda1a1af2b86dc7f20176fdb..612fb400000c66476a3be796d4dcceea8bc331d4 100644 --- a/third_party/lanet/datasets/prepare_coco.py +++ b/third_party/lanet/datasets/prepare_coco.py @@ -1,26 +1,24 @@ import os import argparse + def prepare_coco(args): - train_file = open(os.path.join(args.saved_dir, args.saved_txt), 'w') + train_file = open(os.path.join(args.saved_dir, args.saved_txt), "w") dirs = os.listdir(args.raw_dir) for file in dirs: # Write training files - train_file.write('%s\n' % (file)) + train_file.write("%s\n" % (file)) + + print("Data Preparation Finished.") - print('Data Preparation Finished.') -if __name__ == '__main__': +if __name__ == "__main__": arg_parser = argparse.ArgumentParser(description="coco prepareing.") - arg_parser.add_argument('--dataset', type=str, default='coco', - help='') - arg_parser.add_argument('--raw_dir', type=str, default='', - help='') - arg_parser.add_argument('--saved_dir', type=str, default='', - help='') - arg_parser.add_argument('--saved_txt', type=str, default='train2017.txt', - help='') - args = arg_parser.parse_args() + arg_parser.add_argument("--dataset", type=str, default="coco", help="") + arg_parser.add_argument("--raw_dir", type=str, default="", help="") + arg_parser.add_argument("--saved_dir", type=str, default="", help="") + arg_parser.add_argument("--saved_txt", type=str, default="train2017.txt", help="") + args = arg_parser.parse_args() - prepare_coco(args) \ No newline at end of file + prepare_coco(args) diff --git a/third_party/lanet/evaluation/descriptor_evaluation.py b/third_party/lanet/evaluation/descriptor_evaluation.py index c0e1f84199d353ac5858641c8f68bc298f9d6413..924918a64e769e0b4e661366a0b7d59a2f819ec5 100644 --- a/third_party/lanet/evaluation/descriptor_evaluation.py +++ b/third_party/lanet/evaluation/descriptor_evaluation.py @@ -12,7 +12,7 @@ from utils import warp_keypoints def select_k_best(points, descriptors, k): - """ Select the k most probable points (and strip their probability). + """Select the k most probable points (and strip their probability). points has shape (num_points, 3) where the last coordinate is the probability. Parameters @@ -25,7 +25,7 @@ def select_k_best(points, descriptors, k): Number of keypoints to select, based on probability. Returns ------- - + selected_points: numpy.ndarray (k,2) k most probable keypoints. selected_descriptors: numpy.ndarray (k,256) @@ -44,7 +44,7 @@ def keep_shared_points(keypoints, descriptors, H, shape, keep_k_points=1000): Compute a list of keypoints from the map, filter the list of points by keeping only the points that once mapped by H are still inside the shape of the map and keep at most 'keep_k_points' keypoints in the image. - + Parameters ---------- keypoints: numpy.ndarray (N,3) @@ -53,36 +53,44 @@ def keep_shared_points(keypoints, descriptors, H, shape, keep_k_points=1000): Keypoint descriptors. H: numpy.ndarray (3,3) Homography. - shape: tuple + shape: tuple Image shape. keep_k_points: int Number of keypoints to select, based on probability. Returns - ------- + ------- selected_points: numpy.ndarray (k,2) k most probable keypoints. selected_descriptors: numpy.ndarray (k,256) Descriptors corresponding to the k most probable keypoints. """ - + def keep_true_keypoints(points, descriptors, H, shape): - """ Keep only the points whose warped coordinates by H are still inside shape. """ + """Keep only the points whose warped coordinates by H are still inside shape.""" warped_points = warp_keypoints(points[:, [1, 0]], H) warped_points[:, [0, 1]] = warped_points[:, [1, 0]] - mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\ - (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1]) + mask = ( + (warped_points[:, 0] >= 0) + & (warped_points[:, 0] < shape[0]) + & (warped_points[:, 1] >= 0) + & (warped_points[:, 1] < shape[1]) + ) return points[mask, :], descriptors[mask, :] - selected_keypoints, selected_descriptors = keep_true_keypoints(keypoints, descriptors, H, shape) - selected_keypoints, selected_descriptors = select_k_best(selected_keypoints, selected_descriptors, keep_k_points) + selected_keypoints, selected_descriptors = keep_true_keypoints( + keypoints, descriptors, H, shape + ) + selected_keypoints, selected_descriptors = select_k_best( + selected_keypoints, selected_descriptors, keep_k_points + ) return selected_keypoints, selected_descriptors def compute_matching_score(data, keep_k_points=1000): """ Compute the matching score between two sets of keypoints with associated descriptors. - + Parameters ---------- data: dict @@ -103,31 +111,35 @@ def compute_matching_score(data, keep_k_points=1000): Number of keypoints to select, based on probability. Returns - ------- + ------- ms: float Matching score. """ - shape = data['image_shape'] - real_H = data['homography'] + shape = data["image_shape"] + real_H = data["homography"] # Filter out predictions - keypoints = data['prob'][:, :2].T + keypoints = data["prob"][:, :2].T keypoints = keypoints[::-1] - prob = data['prob'][:, 2] + prob = data["prob"][:, 2] keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) - warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = data["warped_prob"][:, :2].T warped_keypoints = warped_keypoints[::-1] - warped_prob = data['warped_prob'][:, 2] - warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + warped_prob = data["warped_prob"][:, 2] + warped_keypoints = np.stack( + [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1 + ) + + desc = data["desc"] + warped_desc = data["warped_desc"] - desc = data['desc'] - warped_desc = data['warped_desc'] - # Keeps all points for the next frame. The matching for caculating M.Score shouldnt use only in view points. - keypoints, desc = select_k_best(keypoints, desc, keep_k_points) - warped_keypoints, warped_desc = select_k_best(warped_keypoints, warped_desc, keep_k_points) - + keypoints, desc = select_k_best(keypoints, desc, keep_k_points) + warped_keypoints, warped_desc = select_k_best( + warped_keypoints, warped_desc, keep_k_points + ) + # Match the keypoints with the warped_keypoints with nearest neighbor search # This part needs to be done with crossCheck=False. # All the matched pairs need to be evaluated without any selection. @@ -139,11 +151,16 @@ def compute_matching_score(data, keep_k_points=1000): matches_idx = np.array([m.trainIdx for m in matches]) m_warped_keypoints = warped_keypoints[matches_idx, :] - true_warped_keypoints = warp_keypoints(m_warped_keypoints[:, [1, 0]], np.linalg.inv(real_H))[:,::-1] - vis_warped = np.all((true_warped_keypoints >= 0) & (true_warped_keypoints <= (np.array(shape)-1)), axis=-1) + true_warped_keypoints = warp_keypoints( + m_warped_keypoints[:, [1, 0]], np.linalg.inv(real_H) + )[:, ::-1] + vis_warped = np.all( + (true_warped_keypoints >= 0) & (true_warped_keypoints <= (np.array(shape) - 1)), + axis=-1, + ) norm1 = np.linalg.norm(true_warped_keypoints - m_keypoints, axis=-1) - correct1 = (norm1 < 3) + correct1 = norm1 < 3 count1 = np.sum(correct1 * vis_warped) score1 = count1 / np.maximum(np.sum(vis_warped), 1.0) @@ -153,11 +170,13 @@ def compute_matching_score(data, keep_k_points=1000): matches_idx = np.array([m.trainIdx for m in matches]) m_keypoints = keypoints[matches_idx, :] - true_keypoints = warp_keypoints(m_keypoints[:, [1, 0]], real_H)[:,::-1] - vis = np.all((true_keypoints >= 0) & (true_keypoints <= (np.array(shape)-1)), axis=-1) + true_keypoints = warp_keypoints(m_keypoints[:, [1, 0]], real_H)[:, ::-1] + vis = np.all( + (true_keypoints >= 0) & (true_keypoints <= (np.array(shape) - 1)), axis=-1 + ) norm2 = np.linalg.norm(true_keypoints - m_warped_keypoints, axis=-1) - correct2 = (norm2 < 3) + correct2 = norm2 < 3 count2 = np.sum(correct2 * vis) score2 = count2 / np.maximum(np.sum(vis), 1.0) @@ -165,9 +184,10 @@ def compute_matching_score(data, keep_k_points=1000): return ms + def compute_homography(data, keep_k_points=1000): """ - Compute the homography between 2 sets of Keypoints and descriptors inside data. + Compute the homography between 2 sets of Keypoints and descriptors inside data. Use the homography to compute the correctness metrics (1,3,5). Parameters @@ -190,7 +210,7 @@ def compute_homography(data, keep_k_points=1000): Number of keypoints to select, based on probability. Returns - ------- + ------- correctness1: float correctness1 metric. correctness3: float @@ -198,27 +218,30 @@ def compute_homography(data, keep_k_points=1000): correctness5: float correctness5 metric. """ - shape = data['image_shape'] - real_H = data['homography'] + shape = data["image_shape"] + real_H = data["homography"] # Filter out predictions - keypoints = data['prob'][:, :2].T + keypoints = data["prob"][:, :2].T keypoints = keypoints[::-1] - prob = data['prob'][:, 2] + prob = data["prob"][:, 2] keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) - warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = data["warped_prob"][:, :2].T warped_keypoints = warped_keypoints[::-1] - warped_prob = data['warped_prob'][:, 2] - warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + warped_prob = data["warped_prob"][:, 2] + warped_keypoints = np.stack( + [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1 + ) + + desc = data["desc"] + warped_desc = data["warped_desc"] - desc = data['desc'] - warped_desc = data['warped_desc'] - # Keeps only the points shared between the two views keypoints, desc = keep_shared_points(keypoints, desc, real_H, shape, keep_k_points) - warped_keypoints, warped_desc = keep_shared_points(warped_keypoints, warped_desc, np.linalg.inv(real_H), shape, - keep_k_points) + warped_keypoints, warped_desc = keep_shared_points( + warped_keypoints, warped_desc, np.linalg.inv(real_H), shape, keep_k_points + ) bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) matches = bf.match(desc, warped_desc) @@ -228,8 +251,13 @@ def compute_homography(data, keep_k_points=1000): m_warped_keypoints = warped_keypoints[matches_idx, :] # Estimate the homography between the matches using RANSAC - H, _ = cv2.findHomography(m_keypoints[:, [1, 0]], - m_warped_keypoints[:, [1, 0]], cv2.RANSAC, 3, maxIters=5000) + H, _ = cv2.findHomography( + m_keypoints[:, [1, 0]], + m_warped_keypoints[:, [1, 0]], + cv2.RANSAC, + 3, + maxIters=5000, + ) if H is None: return 0, 0, 0 @@ -237,15 +265,19 @@ def compute_homography(data, keep_k_points=1000): shape = shape[::-1] # Compute correctness - corners = np.array([[0, 0, 1], - [0, shape[1] - 1, 1], - [shape[0] - 1, 0, 1], - [shape[0] - 1, shape[1] - 1, 1]]) + corners = np.array( + [ + [0, 0, 1], + [0, shape[1] - 1, 1], + [shape[0] - 1, 0, 1], + [shape[0] - 1, shape[1] - 1, 1], + ] + ) real_warped_corners = np.dot(corners, np.transpose(real_H)) real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:] warped_corners = np.dot(corners, np.transpose(H)) warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] - + mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) correctness1 = float(mean_dist <= 1) correctness3 = float(mean_dist <= 3) diff --git a/third_party/lanet/evaluation/detector_evaluation.py b/third_party/lanet/evaluation/detector_evaluation.py index ccc8792d17a6fbb6b446f0f9f84a2b82e3cdb57c..7198eaec0e6042baf111208885f4040311cc605e 100644 --- a/third_party/lanet/evaluation/detector_evaluation.py +++ b/third_party/lanet/evaluation/detector_evaluation.py @@ -33,7 +33,7 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3): Distance threshold in pixels for a corresponding keypoint to be considered a correct match. Returns - ------- + ------- N1: int Number of true keypoints in the first image. N2: int @@ -43,47 +43,59 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3): loc_err: float Keypoint localization error. """ + def filter_keypoints(points, shape): - """ Keep only the points whose coordinates are inside the dimensions of shape. """ - mask = (points[:, 0] >= 0) & (points[:, 0] < shape[0]) &\ - (points[:, 1] >= 0) & (points[:, 1] < shape[1]) + """Keep only the points whose coordinates are inside the dimensions of shape.""" + mask = ( + (points[:, 0] >= 0) + & (points[:, 0] < shape[0]) + & (points[:, 1] >= 0) + & (points[:, 1] < shape[1]) + ) return points[mask, :] def keep_true_keypoints(points, H, shape): - """ Keep only the points whose warped coordinates by H are still inside shape. """ + """Keep only the points whose warped coordinates by H are still inside shape.""" warped_points = warp_keypoints(points[:, [1, 0]], H) warped_points[:, [0, 1]] = warped_points[:, [1, 0]] - mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\ - (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1]) + mask = ( + (warped_points[:, 0] >= 0) + & (warped_points[:, 0] < shape[0]) + & (warped_points[:, 1] >= 0) + & (warped_points[:, 1] < shape[1]) + ) return points[mask, :] - def select_k_best(points, k): - """ Select the k most probable points (and strip their probability). - points has shape (num_points, 3) where the last coordinate is the probability. """ + """Select the k most probable points (and strip their probability). + points has shape (num_points, 3) where the last coordinate is the probability.""" sorted_prob = points[points[:, 2].argsort(), :2] start = min(k, points.shape[0]) return sorted_prob[-start:, :] - H = data['homography'] - shape = data['image_shape'] + H = data["homography"] + shape = data["image_shape"] # # Filter out predictions - keypoints = data['prob'][:, :2].T + keypoints = data["prob"][:, :2].T keypoints = keypoints[::-1] - prob = data['prob'][:, 2] + prob = data["prob"][:, 2] - warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = data["warped_prob"][:, :2].T warped_keypoints = warped_keypoints[::-1] - warped_prob = data['warped_prob'][:, 2] + warped_prob = data["warped_prob"][:, 2] keypoints = np.stack([keypoints[0], keypoints[1]], axis=-1) - warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + warped_keypoints = np.stack( + [warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1 + ) warped_keypoints = keep_true_keypoints(warped_keypoints, np.linalg.inv(H), shape) # Warp the original keypoints with the true homography true_warped_keypoints = warp_keypoints(keypoints[:, [1, 0]], H) - true_warped_keypoints = np.stack([true_warped_keypoints[:, 1], true_warped_keypoints[:, 0], prob], axis=-1) + true_warped_keypoints = np.stack( + [true_warped_keypoints[:, 1], true_warped_keypoints[:, 0], prob], axis=-1 + ) true_warped_keypoints = filter_keypoints(true_warped_keypoints, shape) # Keep only the keep_k_points best predictions @@ -103,12 +115,12 @@ def compute_repeatability(data, keep_k_points=300, distance_thresh=3): le2 = 0 if N2 != 0: min1 = np.min(norm, axis=1) - correct1 = (min1 <= distance_thresh) + correct1 = min1 <= distance_thresh count1 = np.sum(correct1) le1 = min1[correct1].sum() if N1 != 0: min2 = np.min(norm, axis=0) - correct2 = (min2 <= distance_thresh) + correct2 = min2 <= distance_thresh count2 = np.sum(correct2) le2 = min2[correct2].sum() if N1 + N2 > 0: diff --git a/third_party/lanet/evaluation/evaluate.py b/third_party/lanet/evaluation/evaluate.py index fa9e91ee6d9cc0142ebbe8f2a3f904f6fae8434c..06bec8e5e01b8d285622e6c1eca9000f2a0541cb 100644 --- a/third_party/lanet/evaluation/evaluate.py +++ b/third_party/lanet/evaluation/evaluate.py @@ -5,24 +5,25 @@ import torch import torchvision.transforms as transforms from tqdm import tqdm -from evaluation.descriptor_evaluation import (compute_homography, - compute_matching_score) +from evaluation.descriptor_evaluation import compute_homography, compute_matching_score from evaluation.detector_evaluation import compute_repeatability -def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300): - """Keypoint net evaluation script. +def evaluate_keypoint_net( + data_loader, keypoint_net, output_shape=(320, 240), top_k=300 +): + """Keypoint net evaluation script. Parameters ---------- data_loader: torch.utils.data.DataLoader - Dataset loader. + Dataset loader. keypoint_net: torch.nn.module Keypoint network. output_shape: tuple Original image shape. top_k: int - Number of keypoints to use to compute metrics, selected based on probability. + Number of keypoints to use to compute metrics, selected based on probability. use_color: bool Use color or grayscale images. """ @@ -36,8 +37,8 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to with torch.no_grad(): for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): - image = sample['image'].cuda() - warped_image = sample['warped_image'].cuda() + image = sample["image"].cuda() + warped_image = sample["warped_image"].cuda() score_1, coord_1, desc1 = keypoint_net(image) score_2, coord_2, desc2 = keypoint_net(warped_image) @@ -48,7 +49,7 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() - + # Filter based on confidence threshold desc1 = desc1[score_1[:, 2] > conf_threshold, :] desc2 = desc2[score_2[:, 2] > conf_threshold, :] @@ -56,17 +57,21 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to score_2 = score_2[score_2[:, 2] > conf_threshold, :] # Prepare data for eval - data = {'image': sample['image'].numpy().squeeze(), - 'image_shape' : output_shape[::-1], - 'warped_image': sample['warped_image'].numpy().squeeze(), - 'homography': sample['homography'].squeeze().numpy(), - 'prob': score_1, - 'warped_prob': score_2, - 'desc': desc1, - 'warped_desc': desc2} - + data = { + "image": sample["image"].numpy().squeeze(), + "image_shape": output_shape[::-1], + "warped_image": sample["warped_image"].numpy().squeeze(), + "homography": sample["homography"].squeeze().numpy(), + "prob": score_1, + "warped_prob": score_2, + "desc": desc1, + "warped_desc": desc2, + } + # Compute repeatabilty and localization error - _, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3) + _, _, rep, loc_err = compute_repeatability( + data, keep_k_points=top_k, distance_thresh=3 + ) repeatability.append(rep) localization_err.append(loc_err) @@ -80,5 +85,11 @@ def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), to mscore = compute_matching_score(data, keep_k_points=top_k) MScore.append(mscore) - return np.mean(repeatability), np.mean(localization_err), \ - np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore) + return ( + np.mean(repeatability), + np.mean(localization_err), + np.mean(correctness1), + np.mean(correctness3), + np.mean(correctness5), + np.mean(MScore), + ) diff --git a/third_party/lanet/loss_function.py b/third_party/lanet/loss_function.py index 2e74cf2b53af3c3fc26c34394df4cfe538b3b49c..b5a40c3a969f8e7725e2f30d453762a0eca6b062 100644 --- a/third_party/lanet/loss_function.py +++ b/third_party/lanet/loss_function.py @@ -1,6 +1,9 @@ import torch -def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False): + +def build_descriptor_loss( + source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False +): """ Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. @@ -10,12 +13,12 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re Source image descriptors. target_des: torch.Tensor (B,256,H/8,W/8) Target image descriptors. - source_points: torch.Tensor (B,H/8,W/8,2) - Source image keypoints + source_points: torch.Tensor (B,H/8,W/8,2) + Source image keypoints tar_points: torch.Tensor (B,H/8,W/8,2) - Target image keypoints + Target image keypoints tar_points_un: torch.Tensor (B,2,H/8,W/8) - Target image keypoints unnormalized + Target image keypoints unnormalized eval_only: bool Computes only recall without the loss. Returns @@ -28,11 +31,11 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re device = source_des.device loss = 0 batch_size = source_des.size(0) - recall = 0. + recall = 0.0 relax_field_size = [relax_field] - margins = [1.0] - weights = [1.0] + margins = [1.0] + weights = [1.0] isource_dense = top_kk is None @@ -50,7 +53,7 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re continue ref_desc = source_des[b_id].squeeze()[:, top_k] - tar_desc = target_des[b_id].squeeze()[:, top_k] + tar_desc = target_des[b_id].squeeze()[:, top_k] tar_points_raw = tar_points_un[b_id][:, top_k] # Compute dense descriptor distance matrix and find nearest neighbor @@ -61,7 +64,6 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)) _, idx = torch.sort(dmat, dim=1) - # Compute triplet loss and recall for pyramid in range(len(relax_field_size)): @@ -74,24 +76,41 @@ def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, re tru_y = tar_points_raw[1] if pyramid == 0: - correct2 = (abs(match_k_x[0]-tru_x) == 0) & (abs(match_k_y[0]-tru_y) == 0) + correct2 = (abs(match_k_x[0] - tru_x) == 0) & ( + abs(match_k_y[0] - tru_y) == 0 + ) correct2_cnt = correct2.float().sum() - recall += float(1.0 / batch_size) * (float(correct2_cnt) / float( ref_desc.size(1))) + recall += float(1.0 / batch_size) * ( + float(correct2_cnt) / float(ref_desc.size(1)) + ) if eval_only: continue - correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (abs(match_k_y - tru_y) <= relax_field_size[pyramid]) - - incorrect_index = torch.arange(start=correct_k.shape[0]-1, end=-1, step=-1).unsqueeze(1).repeat(1,correct_k.shape[1]).to(device) - incorrect_first = torch.argmax(incorrect_index * (1 - correct_k.long()), dim=0) - - incorrect_first_index = candidates.gather(0, incorrect_first.unsqueeze(0)).squeeze() + correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & ( + abs(match_k_y - tru_y) <= relax_field_size[pyramid] + ) + + incorrect_index = ( + torch.arange(start=correct_k.shape[0] - 1, end=-1, step=-1) + .unsqueeze(1) + .repeat(1, correct_k.shape[1]) + .to(device) + ) + incorrect_first = torch.argmax( + incorrect_index * (1 - correct_k.long()), dim=0 + ) + + incorrect_first_index = candidates.gather( + 0, incorrect_first.unsqueeze(0) + ).squeeze() anchor_var = ref_desc posource_var = tar_desc neg_var = tar_desc[:, incorrect_first_index] - loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]).mul(weights[pyramid]) + loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss( + anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid] + ).mul(weights[pyramid]) return loss, recall @@ -100,57 +119,108 @@ class KeypointLoss(object): """ Loss function class encapsulating the location loss, the descriptor loss, and the score loss. """ + def __init__(self, config): self.score_weight = config.score_weight self.loc_weight = config.loc_weight self.desc_weight = config.desc_weight self.corres_weight = config.corres_weight self.corres_threshold = config.corres_threshold - + def __call__(self, data): - B, _, hc, wc = data['source_score'].shape - - loc_mat_abs = torch.abs(data['target_coord_warped'].view(B, 2, -1).unsqueeze(3) - data['target_coord'].view(B, 2, -1).unsqueeze(2)) + B, _, hc, wc = data["source_score"].shape + + loc_mat_abs = torch.abs( + data["target_coord_warped"].view(B, 2, -1).unsqueeze(3) + - data["target_coord"].view(B, 2, -1).unsqueeze(2) + ) l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1) l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2) # construct pseudo ground truth matching matrix - loc_min_mat = torch.repeat_interleave(l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1) - pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.) - neg_mask = l2_dist_loc_mat.ge(4.) - - pos_corres = - torch.log(data['confidence_matrix'][pos_mask]) - neg_corres = - torch.log(1.0 - data['confidence_matrix'][neg_mask]) + loc_min_mat = torch.repeat_interleave( + l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1 + ) + pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.0) + neg_mask = l2_dist_loc_mat.ge(4.0) + + pos_corres = -torch.log(data["confidence_matrix"][pos_mask]) + neg_corres = -torch.log(1.0 - data["confidence_matrix"][neg_mask]) corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean() # corresponding distance threshold is 4 - dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data['border_mask'].view(B, hc * wc) - + dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data[ + "border_mask" + ].view(B, hc * wc) + # location loss loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean() - + # desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. - desc_loss, _ = build_descriptor_loss(data['source_desc'], data['target_desc_warped'], data['target_coord_warped'].detach(), top_kk=data['border_mask'], relax_field=8) - + desc_loss, _ = build_descriptor_loss( + data["source_desc"], + data["target_desc_warped"], + data["target_coord_warped"].detach(), + top_kk=data["border_mask"], + relax_field=8, + ) + # score loss - target_score_associated = data['target_score'].view(B, hc * wc).gather(1, l2_dist_loc_min_index).view(B, hc, wc).unsqueeze(1) - dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data['border_mask'].unsqueeze(1) + target_score_associated = ( + data["target_score"] + .view(B, hc * wc) + .gather(1, l2_dist_loc_min_index) + .view(B, hc, wc) + .unsqueeze(1) + ) + dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data[ + "border_mask" + ].unsqueeze(1) l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1) loc_err = l2_dist_loc_min[dist_norm_valid_mask] - + # repeatable_constrain in score loss - repeatable_constrain = ((target_score_associated[dist_norm_valid_mask] + data['source_score'][dist_norm_valid_mask]) * (loc_err - loc_err.mean())).mean() + repeatable_constrain = ( + ( + target_score_associated[dist_norm_valid_mask] + + data["source_score"][dist_norm_valid_mask] + ) + * (loc_err - loc_err.mean()) + ).mean() # consistent_constrain in score_loss - consistent_constrain = torch.nn.functional.mse_loss(data['target_score_warped'][data['border_mask'].unsqueeze(1)], data['source_score'][data['border_mask'].unsqueeze(1)]).mean() * 2 - aware_consistent_loss = torch.nn.functional.mse_loss(data['target_aware_warped'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)], data['source_aware'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)]).mean() * 2 - - score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss - - loss = self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss - - return loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss - - + consistent_constrain = ( + torch.nn.functional.mse_loss( + data["target_score_warped"][data["border_mask"].unsqueeze(1)], + data["source_score"][data["border_mask"].unsqueeze(1)], + ).mean() + * 2 + ) + aware_consistent_loss = ( + torch.nn.functional.mse_loss( + data["target_aware_warped"][ + data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) + ], + data["source_aware"][ + data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) + ], + ).mean() + * 2 + ) + score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss + loss = ( + self.loc_weight * loc_loss + + self.desc_weight * desc_loss + + self.score_weight * score_loss + + self.corres_weight * corres_loss + ) + + return ( + loss, + self.loc_weight * loc_loss, + self.desc_weight * desc_loss, + self.score_weight * score_loss, + self.corres_weight * corres_loss, + ) diff --git a/third_party/lanet/main.py b/third_party/lanet/main.py index 2aa81d8104c19ea1d8c4ce7d1dd547f8b35a4a72..b48dc074a2fd6d4240e126268bcd8e0d8d313d1c 100644 --- a/third_party/lanet/main.py +++ b/third_party/lanet/main.py @@ -5,6 +5,7 @@ from config import get_config from utils import prepare_dirs from data_loader import get_data_loader + def main(config): # ensure directories are setup prepare_dirs(config) @@ -20,6 +21,7 @@ def main(config): trainer = Trainer(config, train_loader=train_loader) trainer.train() -if __name__ == '__main__': + +if __name__ == "__main__": config, unparsed = get_config() - main(config) \ No newline at end of file + main(config) diff --git a/third_party/lanet/network_v0/model.py b/third_party/lanet/network_v0/model.py index 564000330ddd5e9f18821e8606d23cd12dc847bc..6f22e015449dd7bcc8e060a2cd72a794befd2ccb 100644 --- a/third_party/lanet/network_v0/model.py +++ b/third_party/lanet/network_v0/model.py @@ -4,6 +4,7 @@ import torchvision.transforms as tvf from .modules import InterestPointModule, CorrespondenceModule + def warp_homography_batch(sources, homographies): """ Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D. @@ -24,18 +25,19 @@ def warp_homography_batch(sources, homographies): warped_sources = [] for b in range(B): source = sources[b].clone() - source = source.view(-1,2) - ''' + source = source.view(-1, 2) + """ [X, [M11, M12, M13 [x, M11*x + M12*y + M13 [M11, M12 [M13, Y, = M21, M22, M23 * y, = M21*x + M22*y + M23 = [x, y] * M21, M22 + M23, Z] M31, M32, M33] 1] M31*x + M32*y + M33 M31, M32].T M33] - ''' - source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t()) - source.mul_(1/source[:,2].unsqueeze(1)) - source = source[:,:2].contiguous().view(H,W,2) + """ + source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t()) + source.mul_(1 / source[:, 2].unsqueeze(1)) + source = source[:, :2].contiguous().view(H, W, 2) warped_sources.append(source) return torch.stack(warped_sources, dim=0) - + + class PointModel(nn.Module): def __init__(self, is_test=True): super(PointModel, self).__init__() @@ -43,7 +45,7 @@ class PointModel(nn.Module): self.interestpoint_module = InterestPointModule(is_test=self.is_test) self.correspondence_module = CorrespondenceModule() self.norm_rgb = tvf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225]) - + def forward(self, *args): if self.is_test: img = args[0] @@ -51,8 +53,12 @@ class PointModel(nn.Module): score, coord, desc = self.interestpoint_module(img) return score, coord, desc else: - source_score, source_coord, source_desc_block = self.interestpoint_module(args[0]) - target_score, target_coord, target_desc_block = self.interestpoint_module(args[1]) + source_score, source_coord, source_desc_block = self.interestpoint_module( + args[0] + ) + target_score, target_coord, target_desc_block = self.interestpoint_module( + args[1] + ) B, _, H, W = args[0].shape B, _, hc, wc = source_score.shape @@ -60,21 +66,33 @@ class PointModel(nn.Module): # Normalize the coordinates from ([0, h], [0, w]) to ([0, 1], [0, 1]). source_coord_norm = source_coord.clone() - source_coord_norm[:, 0] = (source_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1. - source_coord_norm[:, 1] = (source_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1. + source_coord_norm[:, 0] = ( + source_coord_norm[:, 0] / (float(W - 1) / 2.0) + ) - 1.0 + source_coord_norm[:, 1] = ( + source_coord_norm[:, 1] / (float(H - 1) / 2.0) + ) - 1.0 source_coord_norm = source_coord_norm.permute(0, 2, 3, 1) target_coord_norm = target_coord.clone() - target_coord_norm[:, 0] = (target_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1. - target_coord_norm[:, 1] = (target_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1. + target_coord_norm[:, 0] = ( + target_coord_norm[:, 0] / (float(W - 1) / 2.0) + ) - 1.0 + target_coord_norm[:, 1] = ( + target_coord_norm[:, 1] / (float(H - 1) / 2.0) + ) - 1.0 target_coord_norm = target_coord_norm.permute(0, 2, 3, 1) - + target_coord_warped_norm = warp_homography_batch(source_coord_norm, args[2]) target_coord_warped = target_coord_warped_norm.clone() - + # de-normlize the coordinates - target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (float(W - 1) / 2.) - target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (float(H - 1) / 2.) + target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * ( + float(W - 1) / 2.0 + ) + target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * ( + float(H - 1) / 2.0 + ) target_coord_warped = target_coord_warped.permute(0, 3, 1, 2) # Border mask @@ -85,44 +103,79 @@ class PointModel(nn.Module): border_mask_ori[:, :, wc - 1] = 0 border_mask_ori = border_mask_ori.gt(1e-3).to(device) - oob_mask2 = target_coord_warped_norm[:, :, :, 0].lt(1) & target_coord_warped_norm[:, :, :, 0].gt(-1) & target_coord_warped_norm[:, :, :, 1].lt(1) & target_coord_warped_norm[:, :, :, 1].gt(-1) + oob_mask2 = ( + target_coord_warped_norm[:, :, :, 0].lt(1) + & target_coord_warped_norm[:, :, :, 0].gt(-1) + & target_coord_warped_norm[:, :, :, 1].lt(1) + & target_coord_warped_norm[:, :, :, 1].gt(-1) + ) border_mask = border_mask_ori & oob_mask2 # score - target_score_warped = torch.nn.functional.grid_sample(target_score, target_coord_warped_norm.detach(), align_corners=False) + target_score_warped = torch.nn.functional.grid_sample( + target_score, target_coord_warped_norm.detach(), align_corners=False + ) # descriptor - source_desc2 = torch.nn.functional.grid_sample(source_desc_block[0], source_coord_norm.detach()) - source_desc3 = torch.nn.functional.grid_sample(source_desc_block[1], source_coord_norm.detach()) + source_desc2 = torch.nn.functional.grid_sample( + source_desc_block[0], source_coord_norm.detach() + ) + source_desc3 = torch.nn.functional.grid_sample( + source_desc_block[1], source_coord_norm.detach() + ) source_aware = source_desc_block[2] - source_desc = torch.mul(source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous()) + source_desc = torch.mul( + source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous() + ) + torch.mul( + source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous() + ) - target_desc2 = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_norm.detach()) - target_desc3 = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_norm.detach()) + target_desc2 = torch.nn.functional.grid_sample( + target_desc_block[0], target_coord_norm.detach() + ) + target_desc3 = torch.nn.functional.grid_sample( + target_desc_block[1], target_coord_norm.detach() + ) target_aware = target_desc_block[2] - target_desc = torch.mul(target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous()) - - target_desc2_warped = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_warped_norm.detach()) - target_desc3_warped = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_warped_norm.detach()) - target_aware_warped = torch.nn.functional.grid_sample(target_desc_block[2], target_coord_warped_norm.detach()) - target_desc_warped = torch.mul(target_desc2_warped, target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3_warped, target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous()) - + target_desc = torch.mul( + target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous() + ) + torch.mul( + target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous() + ) + + target_desc2_warped = torch.nn.functional.grid_sample( + target_desc_block[0], target_coord_warped_norm.detach() + ) + target_desc3_warped = torch.nn.functional.grid_sample( + target_desc_block[1], target_coord_warped_norm.detach() + ) + target_aware_warped = torch.nn.functional.grid_sample( + target_desc_block[2], target_coord_warped_norm.detach() + ) + target_desc_warped = torch.mul( + target_desc2_warped, + target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous(), + ) + torch.mul( + target_desc3_warped, + target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous(), + ) + confidence_matrix = self.correspondence_module(source_desc, target_desc) confidence_matrix = torch.clamp(confidence_matrix, 1e-12, 1 - 1e-12) - + output = { - 'source_score': source_score, - 'source_coord': source_coord, - 'source_desc': source_desc, - 'source_aware': source_aware, - 'target_score': target_score, - 'target_coord': target_coord, - 'target_score_warped': target_score_warped, - 'target_coord_warped': target_coord_warped, - 'target_desc_warped': target_desc_warped, - 'target_aware_warped': target_aware_warped, - 'border_mask': border_mask, - 'confidence_matrix': confidence_matrix + "source_score": source_score, + "source_coord": source_coord, + "source_desc": source_desc, + "source_aware": source_aware, + "target_score": target_score, + "target_coord": target_coord, + "target_score_warped": target_score_warped, + "target_coord_warped": target_coord_warped, + "target_desc_warped": target_desc_warped, + "target_aware_warped": target_aware_warped, + "border_mask": border_mask, + "confidence_matrix": confidence_matrix, } - + return output diff --git a/third_party/lanet/network_v0/modules.py b/third_party/lanet/network_v0/modules.py index a38c53133aff8769f267cc054174361296cb3e7d..1e5410d4340369e1d701cfc65cf6e168e776d1f9 100644 --- a/third_party/lanet/network_v0/modules.py +++ b/third_party/lanet/network_v0/modules.py @@ -4,30 +4,53 @@ import torch.nn.functional as F from utils import image_grid + class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() - + self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - + def forward(self, x): return self.conv(x) - + class DilationConv3x3(nn.Module): def __init__(self, in_channels, out_channels): super(DilationConv3x3, self).__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=2, + dilation=2, + bias=False, + ) self.bn = nn.BatchNorm2d(out_channels) - + def forward(self, x): x = self.conv(x) x = self.bn(x) @@ -38,22 +61,26 @@ class InterestPointModule(nn.Module): def __init__(self, is_test=False): super(InterestPointModule, self).__init__() self.is_test = is_test - + self.conv1 = ConvBlock(3, 32) self.conv2 = ConvBlock(32, 64) self.conv3 = ConvBlock(64, 128) self.conv4 = ConvBlock(128, 256) - + self.maxpool2x2 = nn.MaxPool2d(2, 2) - + # score head - self.score_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) + self.score_conv = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=False + ) self.score_norm = nn.BatchNorm2d(256) self.score_out = nn.Conv2d(256, 3, kernel_size=3, stride=1, padding=1) self.softmax = nn.Softmax(dim=1) - + # location head - self.loc_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) + self.loc_conv = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=False + ) self.loc_norm = nn.BatchNorm2d(256) self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) @@ -63,9 +90,9 @@ class InterestPointModule(nn.Module): # cross_head: self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) - + self.relu = nn.ReLU(inplace=True) - + def forward(self, x): B, _, H, W = x.shape @@ -78,12 +105,12 @@ class InterestPointModule(nn.Module): x = self.conv4(x) B, _, Hc, Wc = x.shape - + # score head score_x = self.score_out(self.relu(self.score_norm(self.score_conv(x)))) aware = self.softmax(score_x[:, 0:2, :, :]) score = score_x[:, 2, :, :].unsqueeze(1).sigmoid() - + border_mask = torch.ones(B, Hc, Wc) border_mask[:, 0] = 0 border_mask[:, Hc - 1] = 0 @@ -91,23 +118,31 @@ class InterestPointModule(nn.Module): border_mask[:, :, Wc - 1] = 0 border_mask = border_mask.unsqueeze(1) score = score * border_mask.to(score.device) - - # location head + + # location head coord_x = self.relu(self.loc_norm(self.loc_conv(x))) coord_cell = self.loc_out(coord_x).tanh() - + shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0 - step = ((H/Hc)-1) / 2. - center_base = image_grid(B, Hc, Wc, - dtype=coord_cell.dtype, - device=coord_cell.device, - ones=False, normalized=False).mul(H/Hc) + step + step = ((H / Hc) - 1) / 2.0 + center_base = ( + image_grid( + B, + Hc, + Wc, + dtype=coord_cell.dtype, + device=coord_cell.device, + ones=False, + normalized=False, + ).mul(H / Hc) + + step + ) coord_un = center_base.add(coord_cell.mul(shift_ratio * step)) coord = coord_un.clone() - coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1) - coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1) + coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1) + coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1) # descriptor block desc_block = [] @@ -117,16 +152,20 @@ class InterestPointModule(nn.Module): if self.is_test: coord_norm = coord[:, :2].clone() - coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1. - coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1. + coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.0)) - 1.0 + coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.0)) - 1.0 coord_norm = coord_norm.permute(0, 2, 3, 1) - desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) + desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm) aware = desc_block[2] - - desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :]) - desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)) # Divide by norm to normalize. + + desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul( + desc3, aware[:, 1, :, :] + ) + desc = desc.div( + torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1) + ) # Divide by norm to normalize. return score, coord, desc @@ -134,25 +173,32 @@ class InterestPointModule(nn.Module): class CorrespondenceModule(nn.Module): - def __init__(self, match_type='dual_softmax'): + def __init__(self, match_type="dual_softmax"): super(CorrespondenceModule, self).__init__() self.match_type = match_type - if self.match_type == 'dual_softmax': + if self.match_type == "dual_softmax": self.temperature = 0.1 else: raise NotImplementedError() - - def forward(self, source_desc, target_desc): - b, c, h, w = source_desc.size() - - source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w) - target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w) - if self.match_type == 'dual_softmax': - sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature + def forward(self, source_desc, target_desc): + b, c, h, w = source_desc.size() + + source_desc = source_desc.div( + torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1) + ).view(b, -1, h * w) + target_desc = target_desc.div( + torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1) + ).view(b, -1, h * w) + + if self.match_type == "dual_softmax": + sim_mat = ( + torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) + / self.temperature + ) confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2) else: raise NotImplementedError() - - return confidence_matrix \ No newline at end of file + + return confidence_matrix diff --git a/third_party/lanet/network_v1/model.py b/third_party/lanet/network_v1/model.py index baeb37c563852340fe9278ed5c2dccea4b3b693a..51ca366db1d8afd76722f5c51ccfbf8b081c61e2 100644 --- a/third_party/lanet/network_v1/model.py +++ b/third_party/lanet/network_v1/model.py @@ -4,6 +4,7 @@ import torchvision.transforms as tvf from .modules import InterestPointModule, CorrespondenceModule + def warp_homography_batch(sources, homographies): """ Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D. @@ -24,27 +25,29 @@ def warp_homography_batch(sources, homographies): warped_sources = [] for b in range(B): source = sources[b].clone() - source = source.view(-1,2) - ''' + source = source.view(-1, 2) + """ [X, [M11, M12, M13 [x, M11*x + M12*y + M13 [M11, M12 [M13, Y, = M21, M22, M23 * y, = M21*x + M22*y + M23 = [x, y] * M21, M22 + M23, Z] M31, M32, M33] 1] M31*x + M32*y + M33 M31, M32].T M33] - ''' - source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t()) - source.mul_(1/source[:,2].unsqueeze(1)) - source = source[:,:2].contiguous().view(H,W,2) + """ + source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t()) + source.mul_(1 / source[:, 2].unsqueeze(1)) + source = source[:, :2].contiguous().view(H, W, 2) warped_sources.append(source) return torch.stack(warped_sources, dim=0) - + class PointModel(nn.Module): def __init__(self, is_test=False): super(PointModel, self).__init__() self.is_test = is_test self.interestpoint_module = InterestPointModule(is_test=self.is_test) self.correspondence_module = CorrespondenceModule() - self.norm_rgb = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - + self.norm_rgb = tvf.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + def forward(self, *args): img = args[0] img = self.norm_rgb(img) diff --git a/third_party/lanet/network_v1/modules.py b/third_party/lanet/network_v1/modules.py index 4daed5f12c40e40f6fc8347f701235e141839ada..583076eba72ea6f79f4ca55ffcef82ebbdecd91c 100644 --- a/third_party/lanet/network_v1/modules.py +++ b/third_party/lanet/network_v1/modules.py @@ -6,29 +6,53 @@ import torch.nn.functional as F from torchvision import models from utils import image_grid + class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() - + self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - + def forward(self, x): return self.conv(x) + class DilationConv3x3(nn.Module): def __init__(self, in_channels, out_channels): super(DilationConv3x3, self).__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=2, + dilation=2, + bias=False, + ) self.bn = nn.BatchNorm2d(out_channels) - + def forward(self, x): x = self.conv(x) x = self.bn(x) @@ -43,19 +67,17 @@ class InterestPointModule(nn.Module): model = models.vgg16_bn(pretrained=True) # use the first 23 layers as encoder - self.encoder = nn.Sequential( - *list(model.features.children())[: 33] - ) - + self.encoder = nn.Sequential(*list(model.features.children())[:33]) + # score head self.score_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), - nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1), ) self.softmax = nn.Softmax(dim=1) - + # location head self.loc_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), @@ -65,18 +87,18 @@ class InterestPointModule(nn.Module): # location out self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) - + # descriptor out self.des_out2 = DilationConv3x3(128, 256) self.des_out3 = DilationConv3x3(256, 256) self.des_out4 = DilationConv3x3(512, 256) - + def forward(self, x): B, _, H, W = x.shape x = self.encoder[2](self.encoder[1](self.encoder[0](x))) x = self.encoder[5](self.encoder[4](self.encoder[3](x))) - + x = self.encoder[6](x) x = self.encoder[9](self.encoder[8](self.encoder[7](x))) x2 = self.encoder[12](self.encoder[11](self.encoder[10](x))) @@ -85,20 +107,19 @@ class InterestPointModule(nn.Module): x = self.encoder[16](self.encoder[15](self.encoder[14](x))) x = self.encoder[19](self.encoder[18](self.encoder[17](x))) x3 = self.encoder[22](self.encoder[21](self.encoder[20](x))) - + x = self.encoder[23](x3) x = self.encoder[26](self.encoder[25](self.encoder[24](x))) x = self.encoder[29](self.encoder[28](self.encoder[27](x))) x = self.encoder[32](self.encoder[31](self.encoder[30](x))) - B, _, Hc, Wc = x.shape - + # score head score_x = self.score_head(x) aware = self.softmax(score_x[:, 0:3, :, :]) score = score_x[:, 3, :, :].unsqueeze(1).sigmoid() - + border_mask = torch.ones(B, Hc, Wc) border_mask[:, 0] = 0 border_mask[:, Hc - 1] = 0 @@ -106,23 +127,31 @@ class InterestPointModule(nn.Module): border_mask[:, :, Wc - 1] = 0 border_mask = border_mask.unsqueeze(1) score = score * border_mask.to(score.device) - + # location head - coord_x = self.loc_head(x) + coord_x = self.loc_head(x) coord_cell = self.loc_out(coord_x).tanh() - + shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0 - step = ((H/Hc)-1) / 2. - center_base = image_grid(B, Hc, Wc, - dtype=coord_cell.dtype, - device=coord_cell.device, - ones=False, normalized=False).mul(H/Hc) + step + step = ((H / Hc) - 1) / 2.0 + center_base = ( + image_grid( + B, + Hc, + Wc, + dtype=coord_cell.dtype, + device=coord_cell.device, + ones=False, + normalized=False, + ).mul(H / Hc) + + step + ) coord_un = center_base.add(coord_cell.mul(shift_ratio * step)) coord = coord_un.clone() - coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1) - coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1) + coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1) + coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1) # descriptor block desc_block = [] @@ -133,42 +162,56 @@ class InterestPointModule(nn.Module): if self.is_test: coord_norm = coord[:, :2].clone() - coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1. - coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1. + coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.0)) - 1.0 + coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.0)) - 1.0 coord_norm = coord_norm.permute(0, 2, 3, 1) - desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) + desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm) desc4 = torch.nn.functional.grid_sample(desc_block[2], coord_norm) aware = desc_block[3] - - desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :]) + torch.mul(desc4, aware[:, 2, :, :]) - desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)) # Divide by norm to normalize. + + desc = ( + torch.mul(desc2, aware[:, 0, :, :]) + + torch.mul(desc3, aware[:, 1, :, :]) + + torch.mul(desc4, aware[:, 2, :, :]) + ) + desc = desc.div( + torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1) + ) # Divide by norm to normalize. return score, coord, desc return score, coord, desc_block + class CorrespondenceModule(nn.Module): - def __init__(self, match_type='dual_softmax'): + def __init__(self, match_type="dual_softmax"): super(CorrespondenceModule, self).__init__() self.match_type = match_type - if self.match_type == 'dual_softmax': + if self.match_type == "dual_softmax": self.temperature = 0.1 else: raise NotImplementedError() - - def forward(self, source_desc, target_desc): - b, c, h, w = source_desc.size() - - source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w) - target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w) - if self.match_type == 'dual_softmax': - sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature + def forward(self, source_desc, target_desc): + b, c, h, w = source_desc.size() + + source_desc = source_desc.div( + torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1) + ).view(b, -1, h * w) + target_desc = target_desc.div( + torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1) + ).view(b, -1, h * w) + + if self.match_type == "dual_softmax": + sim_mat = ( + torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) + / self.temperature + ) confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2) else: raise NotImplementedError() - + return confidence_matrix diff --git a/third_party/lanet/test.py b/third_party/lanet/test.py index cc9365f5c92cbd69c3ee9250ff66b07bd1eed1c6..d54b60f6669ac02ca16aacd94bb9145050a99a05 100644 --- a/third_party/lanet/test.py +++ b/third_party/lanet/test.py @@ -14,9 +14,9 @@ from evaluation.evaluate import evaluate_keypoint_net def main(): - parser = argparse.ArgumentParser(description='Testing') - parser.add_argument('--device', default=0, type=int, help='which gpu to run on.') - parser.add_argument('--test_dir', required=True, type=str, help='Test data path.') + parser = argparse.ArgumentParser(description="Testing") + parser.add_argument("--device", default=0, type=int, help="which gpu to run on.") + parser.add_argument("--test_dir", required=True, type=str, help="Test data path.") opt = parser.parse_args() torch.manual_seed(0) @@ -25,63 +25,67 @@ def main(): torch.cuda.set_device(opt.device) # Load data in 320x240 - hp_dataset_320x240 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type='all') - data_loader_320x240 = DataLoader(hp_dataset_320x240, - batch_size=1, - pin_memory=False, - shuffle=False, - num_workers=4, - worker_init_fn=None, - sampler=None) + hp_dataset_320x240 = PatchesDataset( + root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type="all" + ) + data_loader_320x240 = DataLoader( + hp_dataset_320x240, + batch_size=1, + pin_memory=False, + shuffle=False, + num_workers=4, + worker_init_fn=None, + sampler=None, + ) # Load data in 640x480 - hp_dataset_640x480 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type='all') - data_loader_640x480 = DataLoader(hp_dataset_640x480, - batch_size=1, - pin_memory=False, - shuffle=False, - num_workers=4, - worker_init_fn=None, - sampler=None) + hp_dataset_640x480 = PatchesDataset( + root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type="all" + ) + data_loader_640x480 = DataLoader( + hp_dataset_640x480, + batch_size=1, + pin_memory=False, + shuffle=False, + num_workers=4, + worker_init_fn=None, + sampler=None, + ) # Load model model = PointModel(is_test=True) - ckpt = torch.load('./checkpoints/PointModel_v0.pth') - model.load_state_dict(ckpt['model_state']) + ckpt = torch.load("./checkpoints/PointModel_v0.pth") + model.load_state_dict(ckpt["model_state"]) model = model.eval() if use_gpu: model = model.cuda() - - print('Evaluating in 320x240, 300 points') + print("Evaluating in 320x240, 300 points") rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( - data_loader_320x240, - model, - output_shape=(320, 240), - top_k=300) + data_loader_320x240, model, output_shape=(320, 240), top_k=300 + ) - print('Repeatability: {0:.3f}'.format(rep)) - print('Localization Error: {0:.3f}'.format(loc)) - print('H-1 Accuracy: {:.3f}'.format(c1)) - print('H-3 Accuracy: {:.3f}'.format(c3)) - print('H-5 Accuracy: {:.3f}'.format(c5)) - print('Matching Score: {:.3f}'.format(mscore)) - print('\n') + print("Repeatability: {0:.3f}".format(rep)) + print("Localization Error: {0:.3f}".format(loc)) + print("H-1 Accuracy: {:.3f}".format(c1)) + print("H-3 Accuracy: {:.3f}".format(c3)) + print("H-5 Accuracy: {:.3f}".format(c5)) + print("Matching Score: {:.3f}".format(mscore)) + print("\n") - print('Evaluating in 640x480, 1000 points') + print("Evaluating in 640x480, 1000 points") rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( - data_loader_640x480, - model, - output_shape=(640, 480), - top_k=1000) + data_loader_640x480, model, output_shape=(640, 480), top_k=1000 + ) + + print("Repeatability: {0:.3f}".format(rep)) + print("Localization Error: {0:.3f}".format(loc)) + print("H-1 Accuracy: {:.3f}".format(c1)) + print("H-3 Accuracy: {:.3f}".format(c3)) + print("H-5 Accuracy: {:.3f}".format(c5)) + print("Matching Score: {:.3f}".format(mscore)) + print("\n") - print('Repeatability: {0:.3f}'.format(rep)) - print('Localization Error: {0:.3f}'.format(loc)) - print('H-1 Accuracy: {:.3f}'.format(c1)) - print('H-3 Accuracy: {:.3f}'.format(c3)) - print('H-5 Accuracy: {:.3f}'.format(c5)) - print('Matching Score: {:.3f}'.format(mscore)) - print('\n') -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/third_party/lanet/train.py b/third_party/lanet/train.py index 3076a0fdb78a59bfd64367399c0f2b0de1297653..e82900a3b27f8954c65f7bf4127f38a65ac76fff 100644 --- a/third_party/lanet/train.py +++ b/third_party/lanet/train.py @@ -8,6 +8,7 @@ from torch.autograd import Variable from network_v0.model import PointModel from loss_function import KeypointLoss + class Trainer(object): def __init__(self, config, train_loader=None): self.config = config @@ -28,56 +29,76 @@ class Trainer(object): self.random_seed = config.seed self.gpu = config.gpu self.ckpt_dir = config.ckpt_dir - self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed) - + self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) + # build model self.model = PointModel(is_test=False) - + # training on GPU if self.use_gpu: torch.cuda.set_device(self.gpu) self.model.cuda() - print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) - + print( + "Number of model parameters: {:,}".format( + sum([p.data.nelement() for p in self.model.parameters()]) + ) + ) + # build loss functional self.loss_func = KeypointLoss(config) - + # build optimizer and scheduler self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) - self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor) + self.lr_scheduler = optim.lr_scheduler.MultiStepLR( + self.optimizer, milestones=[4, 8], gamma=self.lr_factor + ) # resume if int(self.config.start_epoch) > 0: - self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler) - + ( + self.config.start_epoch, + self.model, + self.optimizer, + self.lr_scheduler, + ) = self.load_checkpoint( + int(self.config.start_epoch), + self.model, + self.optimizer, + self.lr_scheduler, + ) + def train(self): print("\nTrain on {} samples".format(self.num_train)) self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) for epoch in range(self.start_epoch, self.max_epoch): - print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr)) + print( + "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) + ) # train for one epoch self.train_one_epoch(epoch) if self.lr_scheduler: self.lr_scheduler.step() - self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler) - + self.save_checkpoint( + epoch + 1, self.model, self.optimizer, self.lr_scheduler + ) + def train_one_epoch(self, epoch): self.model.train() for (i, data) in enumerate(tqdm(self.train_loader)): if self.use_gpu: - source_img = data['image_aug'].cuda() - target_img = data['image'].cuda() - homography = data['homography'].cuda() - + source_img = data["image_aug"].cuda() + target_img = data["image"].cuda() + homography = data["homography"].cuda() + source_img = Variable(source_img) target_img = Variable(target_img) homography = Variable(homography) - + # forward propogation output = self.model(source_img, target_img, homography) - + # compute loss loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) @@ -87,43 +108,45 @@ class Trainer(object): self.optimizer.step() # print training info - msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\ - "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\ - "loss={:.4f} "\ - .format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data) + msg_batch = ( + "Epoch:{} Iter:{} lr:{:.4f} " + "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " + "loss={:.4f} ".format( + (epoch + 1), + i, + self.lr, + loc_loss.data, + desc_loss.data, + score_loss.data, + corres_loss.data, + loss.data, + ) + ) - if((i % self.display) == 0): + if (i % self.display) == 0: print(msg_batch) return def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): - filename = self.ckpt_name + '_' + str(epoch) + '.pth' + filename = self.ckpt_name + "_" + str(epoch) + ".pth" torch.save( - {'epoch': epoch, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict()}, - os.path.join(self.ckpt_dir, filename)) + { + "epoch": epoch, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + }, + os.path.join(self.ckpt_dir, filename), + ) def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): - filename = self.ckpt_name + '_' + str(epoch) + '.pth' + filename = self.ckpt_name + "_" + str(epoch) + ".pth" ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) - epoch = ckpt['epoch'] - model.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) - lr_scheduler.load_state_dict(ckpt['lr_scheduler']) - - print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch'])) - - return epoch, model, optimizer, lr_scheduler - - - - - - - - - - - \ No newline at end of file + epoch = ckpt["epoch"] + model.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) + lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) + + print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) + + return epoch, model, optimizer, lr_scheduler diff --git a/third_party/lanet/utils.py b/third_party/lanet/utils.py index d5422ebcfc2847be047391791d891a09388ca7d1..6f1ead467c166a95e6782a8112bafe363f948f9b 100644 --- a/third_party/lanet/utils.py +++ b/third_party/lanet/utils.py @@ -4,6 +4,7 @@ import torch import torchvision.transforms as transforms from functools import lru_cache + @lru_cache(maxsize=None) def meshgrid(B, H, W, dtype, device, normalized=False): """ @@ -35,8 +36,8 @@ def meshgrid(B, H, W, dtype, device, normalized=False): xs = torch.linspace(-1, 1, W, device=device, dtype=dtype) ys = torch.linspace(-1, 1, H, device=device, dtype=dtype) else: - xs = torch.linspace(0, W-1, W, device=device, dtype=dtype) - ys = torch.linspace(0, H-1, H, device=device, dtype=dtype) + xs = torch.linspace(0, W - 1, W, device=device, dtype=dtype) + ys = torch.linspace(0, H - 1, H, device=device, dtype=dtype) ys, xs = torch.meshgrid([ys, xs]) return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1]) @@ -75,7 +76,8 @@ def image_grid(B, H, W, dtype, device, ones=True, normalized=False): grid = torch.stack(coords, dim=1) # B3HW return grid -def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): + +def to_tensor_sample(sample, tensor_type="torch.FloatTensor"): """ Casts the keys of sample to tensors. From https://github.com/TRI-ML/KP2D. @@ -92,11 +94,11 @@ def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): Sample with keys cast as tensors """ transform = transforms.ToTensor() - sample['image'] = transform(sample['image']).type(tensor_type) + sample["image"] = transform(sample["image"]).type(tensor_type) return sample + def prepare_dirs(config): for path in [config.ckpt_dir]: if not os.path.exists(path): os.makedirs(path) - diff --git a/third_party/r2d2/datasets/__init__.py b/third_party/r2d2/datasets/__init__.py index 8f11df21be72856ea365f6efd7a389aba267562b..f538fb5372197bcdba9db28c861af39c541539ee 100644 --- a/third_party/r2d2/datasets/__init__.py +++ b/third_party/r2d2/datasets/__init__.py @@ -10,6 +10,7 @@ from .aachen import * # try to instanciate datasets import sys + try: web_images = RandomWebImages(0, 52) except AssertionError as e: @@ -23,11 +24,12 @@ except AssertionError as e: try: aachen_style_transfer_pairs = AachenPairs_StyleTransferDayNight() except AssertionError as e: - print(f"Dataset aachen_style_transfer_pairs not available, reason: {e}", file=sys.stderr) + print( + f"Dataset aachen_style_transfer_pairs not available, reason: {e}", + file=sys.stderr, + ) try: aachen_flow_pairs = AachenPairs_OpticalFlow() except AssertionError as e: print(f"Dataset aachen_flow_pairs not available, reason: {e}", file=sys.stderr) - - diff --git a/third_party/r2d2/datasets/aachen.py b/third_party/r2d2/datasets/aachen.py index 4ddb324cea01da2430ee89b32c7627b34c01a41f..fbe2364a51c648ee48989f1725cf0033cd0c0547 100644 --- a/third_party/r2d2/datasets/aachen.py +++ b/third_party/r2d2/datasets/aachen.py @@ -10,61 +10,61 @@ from .dataset import Dataset from .pair_dataset import PairDataset, StillPairDataset -class AachenImages (Dataset): - """ Loads all images from the Aachen Day-Night dataset - """ - def __init__(self, select='db day night', root='data/aachen'): +class AachenImages(Dataset): + """Loads all images from the Aachen Day-Night dataset""" + + def __init__(self, select="db day night", root="data/aachen"): Dataset.__init__(self) self.root = root - self.img_dir = 'images_upright' + self.img_dir = "images_upright" self.select = set(select.split()) - assert self.select, 'Nothing was selected' - + assert self.select, "Nothing was selected" + self.imgs = [] root = os.path.join(root, self.img_dir) for dirpath, _, filenames in os.walk(root): - r = dirpath[len(root)+1:] - if not(self.select & set(r.split('/'))): continue - self.imgs += [os.path.join(r,f) for f in filenames if f.endswith('.jpg')] - + r = dirpath[len(root) + 1 :] + if not (self.select & set(r.split("/"))): + continue + self.imgs += [os.path.join(r, f) for f in filenames if f.endswith(".jpg")] + self.nimg = len(self.imgs) - assert self.nimg, 'Empty Aachen dataset' + assert self.nimg, "Empty Aachen dataset" def get_key(self, idx): return self.imgs[idx] +class AachenImages_DB(AachenImages): + """Only database (db) images.""" -class AachenImages_DB (AachenImages): - """ Only database (db) images. - """ def __init__(self, **kw): - AachenImages.__init__(self, select='db', **kw) - self.db_image_idxs = {self.get_tag(i) : i for i,f in enumerate(self.imgs)} - - def get_tag(self, idx): - # returns image tag == img number (name) - return os.path.split( self.imgs[idx][:-4] )[1] + AachenImages.__init__(self, select="db", **kw) + self.db_image_idxs = {self.get_tag(i): i for i, f in enumerate(self.imgs)} + def get_tag(self, idx): + # returns image tag == img number (name) + return os.path.split(self.imgs[idx][:-4])[1] -class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset): - """ synthetic day-night pairs of images - (night images obtained using autoamtic style transfer from web night images) +class AachenPairs_StyleTransferDayNight(AachenImages_DB, StillPairDataset): + """synthetic day-night pairs of images + (night images obtained using autoamtic style transfer from web night images) """ - def __init__(self, root='data/aachen/style_transfer', **kw): + + def __init__(self, root="data/aachen/style_transfer", **kw): StillPairDataset.__init__(self) AachenImages_DB.__init__(self, **kw) old_root = os.path.join(self.root, self.img_dir) self.root = os.path.commonprefix((old_root, root)) - self.img_dir = '' + self.img_dir = "" - newpath = lambda folder, f: os.path.join(folder, f)[len(self.root):] + newpath = lambda folder, f: os.path.join(folder, f)[len(self.root) :] self.imgs = [newpath(old_root, f) for f in self.imgs] self.image_pairs = [] for fname in os.listdir(root): - tag = fname.split('.jpg.st_')[0] + tag = fname.split(".jpg.st_")[0] self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs))) self.imgs.append(newpath(root, fname)) @@ -73,42 +73,45 @@ class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset): assert self.nimg and self.npairs +class AachenPairs_OpticalFlow(AachenImages_DB, PairDataset): + """Image pairs from Aachen db with optical flow.""" -class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset): - """ Image pairs from Aachen db with optical flow. - """ - def __init__(self, root='data/aachen/optical_flow', **kw): + def __init__(self, root="data/aachen/optical_flow", **kw): PairDataset.__init__(self) AachenImages_DB.__init__(self, **kw) self.root_flow = root # find out the subsest of valid pairs from the list of flow files - flows = {f for f in os.listdir(os.path.join(root, 'flow')) if f.endswith('.png')} - masks = {f for f in os.listdir(os.path.join(root, 'mask')) if f.endswith('.png')} - assert flows == masks, 'Missing flow or mask pairs' - - make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split('_')) + flows = { + f for f in os.listdir(os.path.join(root, "flow")) if f.endswith(".png") + } + masks = { + f for f in os.listdir(os.path.join(root, "mask")) if f.endswith(".png") + } + assert flows == masks, "Missing flow or mask pairs" + + make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split("_")) self.image_pairs = [make_pair(f) for f in flows] self.npairs = len(self.image_pairs) assert self.nimg and self.npairs def get_mask_filename(self, pair_idx): tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) - return os.path.join(self.root_flow, 'mask', f'{tag_a}_{tag_b}.png') + return os.path.join(self.root_flow, "mask", f"{tag_a}_{tag_b}.png") def get_mask(self, pair_idx): return np.asarray(Image.open(self.get_mask_filename(pair_idx))) def get_flow_filename(self, pair_idx): tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) - return os.path.join(self.root_flow, 'flow', f'{tag_a}_{tag_b}.png') + return os.path.join(self.root_flow, "flow", f"{tag_a}_{tag_b}.png") def get_flow(self, pair_idx): fname = self.get_flow_filename(pair_idx) try: return self._png2flow(fname) except IOError: - flow = open(fname[:-4], 'rb') + flow = open(fname[:-4], "rb") help = np.fromfile(flow, np.float32, 1) assert help == 202021.25 W, H = np.fromfile(flow, np.int32, 2) @@ -116,30 +119,28 @@ class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset): return self._flow2png(flow, fname) def get_pair(self, idx, output=()): - if isinstance(output, str): + if isinstance(output, str): output = output.split() img1, img2 = map(self.get_image, self.image_pairs[idx]) meta = {} - - if 'flow' in output or 'aflow' in output: + + if "flow" in output or "aflow" in output: flow = self.get_flow(idx) assert flow.shape[:2] == img1.size[::-1] - meta['flow'] = flow + meta["flow"] = flow H, W = flow.shape[:2] - meta['aflow'] = flow + np.mgrid[:H,:W][::-1].transpose(1,2,0) - - if 'mask' in output: + meta["aflow"] = flow + np.mgrid[:H, :W][::-1].transpose(1, 2, 0) + + if "mask" in output: mask = self.get_mask(idx) assert mask.shape[:2] == img1.size[::-1] - meta['mask'] = mask - - return img1, img2, meta - + meta["mask"] = mask + return img1, img2, meta -if __name__ == '__main__': +if __name__ == "__main__": print(aachen_db_images) print(aachen_style_transfer_pairs) print(aachen_flow_pairs) diff --git a/third_party/r2d2/datasets/dataset.py b/third_party/r2d2/datasets/dataset.py index 80d893b8ea4ead7845f35c4fe82c9f5a9b849de3..5f4474e7dc8b81f091cac1e13f431c5c9f1840f3 100644 --- a/third_party/r2d2/datasets/dataset.py +++ b/third_party/r2d2/datasets/dataset.py @@ -9,10 +9,10 @@ import numpy as np class Dataset(object): - ''' Base class for a dataset. To be overloaded. - ''' - root = '' - img_dir = '' + """Base class for a dataset. To be overloaded.""" + + root = "" + img_dir = "" nimg = 0 def __len__(self): @@ -26,23 +26,23 @@ class Dataset(object): def get_image(self, img_idx): from PIL import Image + fname = self.get_filename(img_idx) try: - return Image.open(fname).convert('RGB') + return Image.open(fname).convert("RGB") except Exception as e: raise IOError("Could not load image %s (reason: %s)" % (fname, str(e))) def __repr__(self): - res = 'Dataset: %s\n' % self.__class__.__name__ - res += ' %d images' % self.nimg - res += '\n root: %s...\n' % self.root + res = "Dataset: %s\n" % self.__class__.__name__ + res += " %d images" % self.nimg + res += "\n root: %s...\n" % self.root return res +class CatDataset(Dataset): + """Concatenation of several datasets.""" -class CatDataset (Dataset): - ''' Concatenation of several datasets. - ''' def __init__(self, *datasets): assert len(datasets) >= 1 self.datasets = datasets @@ -54,8 +54,8 @@ class CatDataset (Dataset): self.root = None def which(self, i): - pos = np.searchsorted(self.offsets, i, side='right')-1 - assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg) + pos = np.searchsorted(self.offsets, i, side="right") - 1 + assert pos < self.nimg, "Bad image index %d >= %d" % (i, self.nimg) return pos, i - self.offsets[pos] def get_key(self, i): @@ -69,9 +69,5 @@ class CatDataset (Dataset): def __repr__(self): fmt_str = "CatDataset(" for db in self.datasets: - fmt_str += str(db).replace("\n"," ") + ', ' - return fmt_str[:-2] + ')' - - - - + fmt_str += str(db).replace("\n", " ") + ", " + return fmt_str[:-2] + ")" diff --git a/third_party/r2d2/datasets/imgfolder.py b/third_party/r2d2/datasets/imgfolder.py index 45f7bc9ee4c3ba5f04380dbc02ad17b6463cf32f..40168f00e8ad177f3d94f75578dba2e640944c4c 100644 --- a/third_party/r2d2/datasets/imgfolder.py +++ b/third_party/r2d2/datasets/imgfolder.py @@ -8,10 +8,10 @@ from .dataset import Dataset from .pair_dataset import SyntheticPairDataset -class ImgFolder (Dataset): - """ load all images in a folder (no recursion). - """ - def __init__(self, root, imgs=None, exts=('.jpg','.png','.ppm')): +class ImgFolder(Dataset): + """load all images in a folder (no recursion).""" + + def __init__(self, root, imgs=None, exts=(".jpg", ".png", ".ppm")): Dataset.__init__(self) self.root = root self.imgs = imgs or [f for f in os.listdir(root) if f.endswith(exts)] @@ -19,5 +19,3 @@ class ImgFolder (Dataset): def get_key(self, idx): return self.imgs[idx] - - diff --git a/third_party/r2d2/datasets/pair_dataset.py b/third_party/r2d2/datasets/pair_dataset.py index aeed98b6700e0ba108bb44abccc20351d16f3295..ba178c18a0a6fbb1decfe4a797dbcab0636dbeaf 100644 --- a/third_party/r2d2/datasets/pair_dataset.py +++ b/third_party/r2d2/datasets/pair_dataset.py @@ -11,20 +11,24 @@ from tools.transforms import instanciate_transformation from tools.transforms_tools import persp_apply -class PairDataset (Dataset): - """ A dataset that serves image pairs with ground-truth pixel correspondences. - """ +class PairDataset(Dataset): + """A dataset that serves image pairs with ground-truth pixel correspondences.""" + def __init__(self): Dataset.__init__(self) self.npairs = 0 def get_filename(self, img_idx, root=None): - if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of filenames + if is_pair( + img_idx + ): # if img_idx is a pair of indices, we return a pair of filenames return tuple(Dataset.get_filename(self, i, root) for i in img_idx) return Dataset.get_filename(self, img_idx, root) def get_image(self, img_idx): - if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of images + if is_pair( + img_idx + ): # if img_idx is a pair of indices, we return a pair of images return tuple(Dataset.get_image(self, i) for i in img_idx) return Dataset.get_image(self, img_idx) @@ -41,8 +45,8 @@ class PairDataset (Dataset): raise NotImplementedError() def get_pair(self, idx, output=()): - """ returns (img1, img2, `metadata`) - + """returns (img1, img2, `metadata`) + `metadata` is a dict() that can contain: flow: optical flow aflow: absolute flow @@ -55,24 +59,24 @@ class PairDataset (Dataset): def get_paired_images(self): fns = set() for i in range(self.npairs): - a,b = self.image_pairs[i] + a, b = self.image_pairs[i] fns.add(self.get_filename(a)) fns.add(self.get_filename(b)) return fns def __len__(self): - return self.npairs # size should correspond to the number of pairs, not images - + return self.npairs # size should correspond to the number of pairs, not images + def __repr__(self): - res = 'Dataset: %s\n' % self.__class__.__name__ - res += ' %d images,' % self.nimg - res += ' %d image pairs' % self.npairs - res += '\n root: %s...\n' % self.root + res = "Dataset: %s\n" % self.__class__.__name__ + res += " %d images," % self.nimg + res += " %d image pairs" % self.npairs + res += "\n root: %s...\n" % self.root return res @staticmethod def _flow2png(flow, path): - flow = np.clip(np.around(16*flow), -2**15, 2**15-1) + flow = np.clip(np.around(16 * flow), -(2**15), 2**15 - 1) bytes = np.int16(flow).view(np.uint8) Image.fromarray(bytes).save(path) return flow / 16 @@ -86,41 +90,42 @@ class PairDataset (Dataset): raise IOError("Error loading flow for %s" % path) - -class StillPairDataset (PairDataset): - """ A dataset of 'still' image pairs. - By overloading a normal image dataset, it appends the get_pair(i) function - that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i). +class StillPairDataset(PairDataset): + """A dataset of 'still' image pairs. + By overloading a normal image dataset, it appends the get_pair(i) function + that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i). """ + def get_pair(self, pair_idx, output=()): - if isinstance(output, str): output = output.split() + if isinstance(output, str): + output = output.split() img1, img2 = map(self.get_image, self.image_pairs[pair_idx]) - W,H = img1.size + W, H = img1.size sx = img2.size[0] / float(W) sy = img2.size[1] / float(H) meta = {} - if 'aflow' in output or 'flow' in output: - mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32) - meta['aflow'] = mgrid * (sx,sy) - meta['flow'] = meta['aflow'] - mgrid + if "aflow" in output or "flow" in output: + mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32) + meta["aflow"] = mgrid * (sx, sy) + meta["flow"] = meta["aflow"] - mgrid - if 'mask' in output: - meta['mask'] = np.ones((H,W), np.uint8) + if "mask" in output: + meta["mask"] = np.ones((H, W), np.uint8) - if 'homography' in output: - meta['homography'] = np.diag(np.float32([sx, sy, 1])) + if "homography" in output: + meta["homography"] = np.diag(np.float32([sx, sy, 1])) return img1, img2, meta - -class SyntheticPairDataset (PairDataset): - """ A synthetic generator of image pairs. - Given a normal image dataset, it constructs pairs using random homographies & noise. +class SyntheticPairDataset(PairDataset): + """A synthetic generator of image pairs. + Given a normal image dataset, it constructs pairs using random homographies & noise. """ - def __init__(self, dataset, scale='', distort=''): + + def __init__(self, dataset, scale="", distort=""): self.attach_dataset(dataset) self.distort = instanciate_transformation(distort) self.scale = instanciate_transformation(scale) @@ -133,56 +138,57 @@ class SyntheticPairDataset (PairDataset): self.get_key = dataset.get_key self.get_filename = dataset.get_filename self.root = None - + def make_pair(self, img): return img, img - def get_pair(self, i, output=('aflow')): - """ Procedure: - This function applies a series of random transformations to one original image + def get_pair(self, i, output=("aflow")): + """Procedure: + This function applies a series of random transformations to one original image to form a synthetic image pairs with perfect ground-truth. """ - if isinstance(output, str): + if isinstance(output, str): output = output.split() - + original_img = self.dataset.get_image(i) - + scaled_image = self.scale(original_img) scaled_image, scaled_image2 = self.make_pair(scaled_image) scaled_and_distorted_image = self.distort( - dict(img=scaled_image2, persp=(1,0,0,0,1,0,0,0))) + dict(img=scaled_image2, persp=(1, 0, 0, 0, 1, 0, 0, 0)) + ) W, H = scaled_image.size - trf = scaled_and_distorted_image['persp'] + trf = scaled_and_distorted_image["persp"] meta = dict() - if 'aflow' in output or 'flow' in output: + if "aflow" in output or "flow" in output: # compute optical flow - xy = np.mgrid[0:H,0:W][::-1].reshape(2,H*W).T - aflow = np.float32(persp_apply(trf, xy).reshape(H,W,2)) - meta['flow'] = aflow - xy.reshape(H,W,2) - meta['aflow'] = aflow - - if 'homography' in output: - meta['homography'] = np.float32(trf+(1,)).reshape(3,3) - - return scaled_image, scaled_and_distorted_image['img'], meta - - def __repr__(self): - res = 'Dataset: %s\n' % self.__class__.__name__ - res += ' %d images and pairs' % self.npairs - res += '\n root: %s...' % self.dataset.root - res += '\n Scale: %s' % (repr(self.scale).replace('\n','')) - res += '\n Distort: %s' % (repr(self.distort).replace('\n','')) - return res + '\n' + xy = np.mgrid[0:H, 0:W][::-1].reshape(2, H * W).T + aflow = np.float32(persp_apply(trf, xy).reshape(H, W, 2)) + meta["flow"] = aflow - xy.reshape(H, W, 2) + meta["aflow"] = aflow + if "homography" in output: + meta["homography"] = np.float32(trf + (1,)).reshape(3, 3) + return scaled_image, scaled_and_distorted_image["img"], meta -class TransformedPairs (PairDataset): - """ Automatic data augmentation for pre-existing image pairs. - Given an image pair dataset, it generates synthetically jittered pairs - using random transformations (e.g. homographies & noise). + def __repr__(self): + res = "Dataset: %s\n" % self.__class__.__name__ + res += " %d images and pairs" % self.npairs + res += "\n root: %s..." % self.dataset.root + res += "\n Scale: %s" % (repr(self.scale).replace("\n", "")) + res += "\n Distort: %s" % (repr(self.distort).replace("\n", "")) + return res + "\n" + + +class TransformedPairs(PairDataset): + """Automatic data augmentation for pre-existing image pairs. + Given an image pair dataset, it generates synthetically jittered pairs + using random transformations (e.g. homographies & noise). """ - def __init__(self, dataset, trf=''): + + def __init__(self, dataset, trf=""): self.attach_dataset(dataset) self.trf = instanciate_transformation(trf) @@ -195,48 +201,47 @@ class TransformedPairs (PairDataset): self.get_key = dataset.get_key self.get_filename = dataset.get_filename self.root = None - - def get_pair(self, i, output=''): - """ Procedure: - This function applies a series of random transformations to one original image + + def get_pair(self, i, output=""): + """Procedure: + This function applies a series of random transformations to one original image to form a synthetic image pairs with perfect ground-truth. """ img_a, img_b_, metadata = self.dataset.get_pair(i, output) - img_b = self.trf({'img': img_b_, 'persp':(1,0,0,0,1,0,0,0)}) - trf = img_b['persp'] + img_b = self.trf({"img": img_b_, "persp": (1, 0, 0, 0, 1, 0, 0, 0)}) + trf = img_b["persp"] - if 'aflow' in metadata or 'flow' in metadata: - aflow = metadata['aflow'] - aflow[:] = persp_apply(trf, aflow.reshape(-1,2)).reshape(aflow.shape) + if "aflow" in metadata or "flow" in metadata: + aflow = metadata["aflow"] + aflow[:] = persp_apply(trf, aflow.reshape(-1, 2)).reshape(aflow.shape) W, H = img_a.size - flow = metadata['flow'] - mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32) + flow = metadata["flow"] + mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32) flow[:] = aflow - mgrid - if 'corres' in metadata: - corres = metadata['corres'] - corres[:,1] = persp_apply(trf, corres[:,1]) - - if 'homography' in metadata: + if "corres" in metadata: + corres = metadata["corres"] + corres[:, 1] = persp_apply(trf, corres[:, 1]) + + if "homography" in metadata: # p_b = homography * p_a - trf_ = np.float32(trf+(1,)).reshape(3,3) - metadata['homography'] = np.float32(trf_ @ metadata['homography']) + trf_ = np.float32(trf + (1,)).reshape(3, 3) + metadata["homography"] = np.float32(trf_ @ metadata["homography"]) - return img_a, img_b['img'], metadata + return img_a, img_b["img"], metadata def __repr__(self): - res = 'Transformed Pairs from %s\n' % type(self.dataset).__name__ - res += ' %d images and pairs' % self.npairs - res += '\n root: %s...' % self.dataset.root - res += '\n transform: %s' % (repr(self.trf).replace('\n','')) - return res + '\n' + res = "Transformed Pairs from %s\n" % type(self.dataset).__name__ + res += " %d images and pairs" % self.npairs + res += "\n root: %s..." % self.dataset.root + res += "\n transform: %s" % (repr(self.trf).replace("\n", "")) + return res + "\n" +class CatPairDataset(CatDataset): + """Concatenation of several pair datasets.""" -class CatPairDataset (CatDataset): - ''' Concatenation of several pair datasets. - ''' def __init__(self, *datasets): CatDataset.__init__(self, *datasets) pair_offsets = [0] @@ -251,12 +256,12 @@ class CatPairDataset (CatDataset): def __repr__(self): fmt_str = "CatPairDataset(" for db in self.datasets: - fmt_str += str(db).replace("\n"," ") + ', ' - return fmt_str[:-2] + ')' + fmt_str += str(db).replace("\n", " ") + ", " + return fmt_str[:-2] + ")" def pair_which(self, i): - pos = np.searchsorted(self.pair_offsets, i, side='right')-1 - assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs) + pos = np.searchsorted(self.pair_offsets, i, side="right") - 1 + assert pos < self.npairs, "Bad pair index %d >= %d" % (i, self.npairs) return pos, i - self.pair_offsets[pos] def pair_call(self, func, i, *args, **kwargs): @@ -268,20 +273,18 @@ class CatPairDataset (CatDataset): return self.datasets[b].get_pair(i, output) def get_flow_filename(self, pair_idx, *args, **kwargs): - return self.pair_call('get_flow_filename', pair_idx, *args, **kwargs) + return self.pair_call("get_flow_filename", pair_idx, *args, **kwargs) def get_mask_filename(self, pair_idx, *args, **kwargs): - return self.pair_call('get_mask_filename', pair_idx, *args, **kwargs) + return self.pair_call("get_mask_filename", pair_idx, *args, **kwargs) def get_corres_filename(self, pair_idx, *args, **kwargs): - return self.pair_call('get_corres_filename', pair_idx, *args, **kwargs) - + return self.pair_call("get_corres_filename", pair_idx, *args, **kwargs) def is_pair(x): - if isinstance(x, (tuple,list)) and len(x) == 2: + if isinstance(x, (tuple, list)) and len(x) == 2: return True if isinstance(x, np.ndarray) and x.ndim == 1 and x.shape[0] == 2: return True return False - diff --git a/third_party/r2d2/datasets/web_images.py b/third_party/r2d2/datasets/web_images.py index 7c17fbe956f3b4db25d9a4148e8f7c615f122478..f22580f44a9b2488980ab88b656073d8531c3362 100644 --- a/third_party/r2d2/datasets/web_images.py +++ b/third_party/r2d2/datasets/web_images.py @@ -8,42 +8,47 @@ from tqdm import trange from .dataset import Dataset -class RandomWebImages (Dataset): - """ 1 million distractors from Oxford and Paris Revisited - see http://ptak.felk.cvut.cz/revisitop/revisitop1m/ +class RandomWebImages(Dataset): + """1 million distractors from Oxford and Paris Revisited + see http://ptak.felk.cvut.cz/revisitop/revisitop1m/ """ + def __init__(self, start=0, end=1024, root="data/revisitop1m"): Dataset.__init__(self) self.root = root - + bar = None - self.imgs = [] + self.imgs = [] for i in range(start, end): - try: + try: # read cached list - img_list_path = os.path.join(self.root, "image_list_%d.txt"%i) + img_list_path = os.path.join(self.root, "image_list_%d.txt" % i) cached_imgs = [e.strip() for e in open(img_list_path)] assert cached_imgs, f"Cache '{img_list_path}' is empty!" self.imgs += cached_imgs except IOError: - if bar is None: - bar = trange(start, 4*end, desc='Caching') - bar.update(4*i) - + if bar is None: + bar = trange(start, 4 * end, desc="Caching") + bar.update(4 * i) + # create it imgs = [] - for d in range(i*4,(i+1)*4): # 4096 folders in total, on average 256 each + for d in range( + i * 4, (i + 1) * 4 + ): # 4096 folders in total, on average 256 each key = hex(d)[2:].zfill(3) folder = os.path.join(self.root, key) - if not os.path.isdir(folder): continue - imgs += [f for f in os.listdir(folder) if verify_img(folder,f)] + if not os.path.isdir(folder): + continue + imgs += [f for f in os.listdir(folder) if verify_img(folder, f)] bar.update(1) assert imgs, f"No images found in {folder}/" - open(img_list_path,'w').write('\n'.join(imgs)) + open(img_list_path, "w").write("\n".join(imgs)) self.imgs += imgs - if bar: bar.update(bar.total - bar.n) + if bar: + bar.update(bar.total - bar.n) self.nimg = len(self.imgs) def get_key(self, i): @@ -53,12 +58,12 @@ class RandomWebImages (Dataset): def verify_img(folder, f): path = os.path.join(folder, f) - if not f.endswith('.jpg'): return False - try: + if not f.endswith(".jpg"): + return False + try: from PIL import Image - Image.open(path).convert('RGB') # try to open it + + Image.open(path).convert("RGB") # try to open it return True - except: + except: return False - - diff --git a/third_party/r2d2/extract.py b/third_party/r2d2/extract.py index c3fea02f87c0615504e3648bfd590e413ab13898..14f6d5cf4899bb5abccbb91ca324d264d4c27d7f 100644 --- a/third_party/r2d2/extract.py +++ b/third_party/r2d2/extract.py @@ -13,97 +13,105 @@ from tools.dataloader import norm_RGB from nets.patchnet import * -def load_network(model_fn): +def load_network(model_fn): checkpoint = torch.load(model_fn) - print("\n>> Creating net = " + checkpoint['net']) - net = eval(checkpoint['net']) + print("\n>> Creating net = " + checkpoint["net"]) + net = eval(checkpoint["net"]) nb_of_weights = common.model_size(net) print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") # initialization - weights = checkpoint['state_dict'] - net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()}) + weights = checkpoint["state_dict"] + net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()}) return net.eval() -class NonMaxSuppression (torch.nn.Module): +class NonMaxSuppression(torch.nn.Module): def __init__(self, rel_thr=0.7, rep_thr=0.7): nn.Module.__init__(self) self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) self.rel_thr = rel_thr self.rep_thr = rep_thr - + def forward(self, reliability, repeatability, **kw): assert len(reliability) == len(repeatability) == 1 reliability, repeatability = reliability[0], repeatability[0] # local maxima - maxima = (repeatability == self.max_filter(repeatability)) + maxima = repeatability == self.max_filter(repeatability) # remove low peaks - maxima *= (repeatability >= self.rep_thr) - maxima *= (reliability >= self.rel_thr) + maxima *= repeatability >= self.rep_thr + maxima *= reliability >= self.rel_thr return maxima.nonzero().t()[2:4] -def extract_multiscale( net, img, detector, scale_f=2**0.25, - min_scale=0.0, max_scale=1, - min_size=256, max_size=1024, - verbose=False): - old_bm = torch.backends.cudnn.benchmark - torch.backends.cudnn.benchmark = False # speedup - +def extract_multiscale( + net, + img, + detector, + scale_f=2**0.25, + min_scale=0.0, + max_scale=1, + min_size=256, + max_size=1024, + verbose=False, +): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + # extract keypoints at multiple scales B, three, H, W = img.shape assert B == 1 and three == 3, "should be a batch with a single RGB image" - + assert max_scale <= 1 - s = 1.0 # current scale factor - - X,Y,S,C,Q,D = [],[],[],[],[],[] - while s+0.001 >= max(min_scale, min_size / max(H,W)): - if s-0.001 <= min(max_scale, max_size / max(H,W)): + s = 1.0 # current scale factor + + X, Y, S, C, Q, D = [], [], [], [], [], [] + while s + 0.001 >= max(min_scale, min_size / max(H, W)): + if s - 0.001 <= min(max_scale, max_size / max(H, W)): nh, nw = img.shape[2:] - if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") + if verbose: + print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") # extract descriptors with torch.no_grad(): res = net(imgs=[img]) - + # get output and reliability map - descriptors = res['descriptors'][0] - reliability = res['reliability'][0] - repeatability = res['repeatability'][0] + descriptors = res["descriptors"][0] + reliability = res["reliability"][0] + repeatability = res["repeatability"][0] # normalize the reliability for nms # extract maxima and descs - y,x = detector(**res) # nms - c = reliability[0,0,y,x] - q = repeatability[0,0,y,x] - d = descriptors[0,:,y,x].t() + y, x = detector(**res) # nms + c = reliability[0, 0, y, x] + q = repeatability[0, 0, y, x] + d = descriptors[0, :, y, x].t() n = d.shape[0] # accumulate multiple scales - X.append(x.float() * W/nw) - Y.append(y.float() * H/nh) - S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device)) + X.append(x.float() * W / nw) + Y.append(y.float() * H / nh) + S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device)) C.append(c) Q.append(q) D.append(d) s /= scale_f # down-scale the image for next iteration - nh, nw = round(H*s), round(W*s) - img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False) + nh, nw = round(H * s), round(W * s) + img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False) # restore value torch.backends.cudnn.benchmark = old_bm Y = torch.cat(Y) X = torch.cat(X) - S = torch.cat(S) # scale - scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability - XYS = torch.stack([X,Y,S], dim=-1) + S = torch.cat(S) # scale + scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability + XYS = torch.stack([X, Y, S], dim=-1) D = torch.cat(D) return XYS, D, scores @@ -113,71 +121,82 @@ def extract_keypoints(args): # load the network... net = load_network(args.model) - if iscuda: net = net.cuda() + if iscuda: + net = net.cuda() # create the non-maxima detector detector = NonMaxSuppression( - rel_thr = args.reliability_thr, - rep_thr = args.repeatability_thr) + rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr + ) while args.images: img_path = args.images.pop(0) - - if img_path.endswith('.txt'): + + if img_path.endswith(".txt"): args.images = open(img_path).read().splitlines() + args.images continue - + print(f"\nExtracting features for {img_path}") - img = Image.open(img_path).convert('RGB') + img = Image.open(img_path).convert("RGB") W, H = img.size - img = norm_RGB(img)[None] - if iscuda: img = img.cuda() - + img = norm_RGB(img)[None] + if iscuda: + img = img.cuda() + # extract keypoints/descriptors for a single image - xys, desc, scores = extract_multiscale(net, img, detector, - scale_f = args.scale_f, - min_scale = args.min_scale, - max_scale = args.max_scale, - min_size = args.min_size, - max_size = args.max_size, - verbose = True) + xys, desc, scores = extract_multiscale( + net, + img, + detector, + scale_f=args.scale_f, + min_scale=args.min_scale, + max_scale=args.max_scale, + min_size=args.min_size, + max_size=args.max_size, + verbose=True, + ) xys = xys.cpu().numpy() desc = desc.cpu().numpy() scores = scores.cpu().numpy() - idxs = scores.argsort()[-args.top_k or None:] - - outpath = img_path + '.' + args.tag - print(f"Saving {len(idxs)} keypoints to {outpath}") - np.savez(open(outpath,'wb'), - imsize = (W,H), - keypoints = xys[idxs], - descriptors = desc[idxs], - scores = scores[idxs]) + idxs = scores.argsort()[-args.top_k or None :] + outpath = img_path + "." + args.tag + print(f"Saving {len(idxs)} keypoints to {outpath}") + np.savez( + open(outpath, "wb"), + imsize=(W, H), + keypoints=xys[idxs], + descriptors=desc[idxs], + scores=scores[idxs], + ) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser("Extract keypoints for a given image") - parser.add_argument("--model", type=str, required=True, help='model path') - - parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list') - parser.add_argument("--tag", type=str, default='r2d2', help='output file tag') - - parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') + parser.add_argument("--model", type=str, required=True, help="model path") + + parser.add_argument( + "--images", type=str, required=True, nargs="+", help="images / list" + ) + parser.add_argument("--tag", type=str, default="r2d2", help="output file tag") + + parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") parser.add_argument("--scale-f", type=float, default=2**0.25) parser.add_argument("--min-size", type=int, default=256) parser.add_argument("--max-size", type=int, default=1024) parser.add_argument("--min-scale", type=float, default=0) parser.add_argument("--max-scale", type=float, default=1) - + parser.add_argument("--reliability-thr", type=float, default=0.7) parser.add_argument("--repeatability-thr", type=float, default=0.7) - parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') + parser.add_argument( + "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" + ) args = parser.parse_args() extract_keypoints(args) - diff --git a/third_party/r2d2/extract_kapture.py b/third_party/r2d2/extract_kapture.py index 51b2403b8a1730eaee32d099d0b6dd5d091ccdda..8e46bb5306c943ce985a13168934105b1978deb9 100644 --- a/third_party/r2d2/extract_kapture.py +++ b/third_party/r2d2/extract_kapture.py @@ -20,9 +20,21 @@ from extract import load_network, NonMaxSuppression, extract_multiscale import kapture from kapture.io.records import get_image_fullpath from kapture.io.csv import kapture_from_dir -from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file -from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file -from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file +from kapture.io.csv import ( + get_feature_csv_fullpath, + keypoints_to_file, + descriptors_to_file, +) +from kapture.io.features import ( + get_keypoints_fullpath, + keypoints_check_dir, + image_keypoints_to_file, +) +from kapture.io.features import ( + get_descriptors_fullpath, + descriptors_check_dir, + image_descriptors_to_file, +) from kapture.io.csv import get_all_tar_handlers @@ -30,41 +42,60 @@ def extract_kapture_keypoints(args): """ Extract r2d2 keypoints and descritors to the kapture format directly """ - print('extract_kapture_keypoints...') - with get_all_tar_handlers(args.kapture_root, - mode={kapture.Keypoints: 'a', - kapture.Descriptors: 'a', - kapture.GlobalFeatures: 'r', - kapture.Matches: 'r'}) as tar_handlers: - kdata = kapture_from_dir(args.kapture_root, None, - skip_list=[kapture.GlobalFeatures, - kapture.Matches, - kapture.Points3d, - kapture.Observations], - tar_handlers=tar_handlers) + print("extract_kapture_keypoints...") + with get_all_tar_handlers( + args.kapture_root, + mode={ + kapture.Keypoints: "a", + kapture.Descriptors: "a", + kapture.GlobalFeatures: "r", + kapture.Matches: "r", + }, + ) as tar_handlers: + kdata = kapture_from_dir( + args.kapture_root, + None, + skip_list=[ + kapture.GlobalFeatures, + kapture.Matches, + kapture.Points3d, + kapture.Observations, + ], + tar_handlers=tar_handlers, + ) assert kdata.records_camera is not None - image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] + image_list = [ + filename for _, _, filename in kapture.flatten(kdata.records_camera) + ] if args.keypoints_type is None: args.keypoints_type = path.splitext(path.basename(args.model))[0] - print(f'keypoints_type set to {args.keypoints_type}') + print(f"keypoints_type set to {args.keypoints_type}") if args.descriptors_type is None: args.descriptors_type = path.splitext(path.basename(args.model))[0] - print(f'descriptors_type set to {args.descriptors_type}') - - if kdata.keypoints is not None and args.keypoints_type in kdata.keypoints \ - and kdata.descriptors is not None and args.descriptors_type in kdata.descriptors: - print('detected already computed features of same keypoints_type/descriptors_type, resuming extraction...') - image_list = [name - for name in image_list - if name not in kdata.keypoints[args.keypoints_type] or - name not in kdata.descriptors[args.descriptors_type]] + print(f"descriptors_type set to {args.descriptors_type}") + + if ( + kdata.keypoints is not None + and args.keypoints_type in kdata.keypoints + and kdata.descriptors is not None + and args.descriptors_type in kdata.descriptors + ): + print( + "detected already computed features of same keypoints_type/descriptors_type, resuming extraction..." + ) + image_list = [ + name + for name in image_list + if name not in kdata.keypoints[args.keypoints_type] + or name not in kdata.descriptors[args.descriptors_type] + ] if len(image_list) == 0: - print('All features were already extracted') + print("All features were already extracted") return else: - print(f'Extracting r2d2 features for {len(image_list)} images') + print(f"Extracting r2d2 features for {len(image_list)} images") iscuda = common.torch_set_gpu(args.gpu) @@ -75,8 +106,8 @@ def extract_kapture_keypoints(args): # create the non-maxima detector detector = NonMaxSuppression( - rel_thr=args.reliability_thr, - rep_thr=args.repeatability_thr) + rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr + ) if kdata.keypoints is None: kdata.keypoints = {} @@ -99,25 +130,29 @@ def extract_kapture_keypoints(args): for image_name in image_list: img_path = get_image_fullpath(args.kapture_root, image_name) print(f"\nExtracting features for {img_path}") - img = Image.open(img_path).convert('RGB') + img = Image.open(img_path).convert("RGB") W, H = img.size img = norm_RGB(img)[None] if iscuda: img = img.cuda() # extract keypoints/descriptors for a single image - xys, desc, scores = extract_multiscale(net, img, detector, - scale_f=args.scale_f, - min_scale=args.min_scale, - max_scale=args.max_scale, - min_size=args.min_size, - max_size=args.max_size, - verbose=True) + xys, desc, scores = extract_multiscale( + net, + img, + detector, + scale_f=args.scale_f, + min_scale=args.min_scale, + max_scale=args.max_scale, + min_size=args.min_size, + max_size=args.max_size, + verbose=True, + ) xys = xys.cpu().numpy() desc = desc.cpu().numpy() scores = scores.cpu().numpy() - idxs = scores.argsort()[-args.top_k or None:] + idxs = scores.argsort()[-args.top_k or None :] xys = xys[idxs] desc = desc[idxs] @@ -128,56 +163,93 @@ def extract_kapture_keypoints(args): keypoints_dsize = xys.shape[1] descriptors_dsize = desc.shape[1] - kdata.keypoints[args.keypoints_type] = kapture.Keypoints('r2d2', keypoints_dtype, keypoints_dsize) - kdata.descriptors[args.descriptors_type] = kapture.Descriptors('r2d2', descriptors_dtype, - descriptors_dsize, - args.keypoints_type, 'L2') - keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints, - args.keypoints_type, - args.kapture_root) - descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors, - args.descriptors_type, - args.kapture_root) - keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]) - descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type]) + kdata.keypoints[args.keypoints_type] = kapture.Keypoints( + "r2d2", keypoints_dtype, keypoints_dsize + ) + kdata.descriptors[args.descriptors_type] = kapture.Descriptors( + "r2d2", + descriptors_dtype, + descriptors_dsize, + args.keypoints_type, + "L2", + ) + keypoints_config_absolute_path = get_feature_csv_fullpath( + kapture.Keypoints, args.keypoints_type, args.kapture_root + ) + descriptors_config_absolute_path = get_feature_csv_fullpath( + kapture.Descriptors, args.descriptors_type, args.kapture_root + ) + keypoints_to_file( + keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] + ) + descriptors_to_file( + descriptors_config_absolute_path, + kdata.descriptors[args.descriptors_type], + ) else: assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1] assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1] - assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type - assert kdata.descriptors[args.descriptors_type].metric_type == 'L2' - - keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root, - image_name, tar_handlers) + assert ( + kdata.descriptors[args.descriptors_type].keypoints_type + == args.keypoints_type + ) + assert kdata.descriptors[args.descriptors_type].metric_type == "L2" + + keypoints_fullpath = get_keypoints_fullpath( + args.keypoints_type, args.kapture_root, image_name, tar_handlers + ) print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}") image_keypoints_to_file(keypoints_fullpath, xys) kdata.keypoints[args.keypoints_type].add(image_name) - descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root, - image_name, tar_handlers) + descriptors_fullpath = get_descriptors_fullpath( + args.descriptors_type, args.kapture_root, image_name, tar_handlers + ) print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}") image_descriptors_to_file(descriptors_fullpath, desc) kdata.descriptors[args.descriptors_type].add(image_name) - if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type, - args.kapture_root, tar_handlers) or \ - not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type, - args.kapture_root, tar_handlers): - print('local feature extraction ended successfully but not all files were saved') - - -if __name__ == '__main__': + if not keypoints_check_dir( + kdata.keypoints[args.keypoints_type], + args.keypoints_type, + args.kapture_root, + tar_handlers, + ) or not descriptors_check_dir( + kdata.descriptors[args.descriptors_type], + args.descriptors_type, + args.kapture_root, + tar_handlers, + ): + print( + "local feature extraction ended successfully but not all files were saved" + ) + + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser( - "Extract r2d2 local features for all images in a dataset stored in the kapture format") - parser.add_argument("--model", type=str, required=True, help='model path') - parser.add_argument('--keypoints-type', default=None, help='keypoint type_name, default is filename of model') - parser.add_argument('--descriptors-type', default=None, help='descriptors type_name, default is filename of model') - - parser.add_argument("--kapture-root", type=str, required=True, help='path to kapture root directory') - parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') + parser = argparse.ArgumentParser( + "Extract r2d2 local features for all images in a dataset stored in the kapture format" + ) + parser.add_argument("--model", type=str, required=True, help="model path") + parser.add_argument( + "--keypoints-type", + default=None, + help="keypoint type_name, default is filename of model", + ) + parser.add_argument( + "--descriptors-type", + default=None, + help="descriptors type_name, default is filename of model", + ) + + parser.add_argument( + "--kapture-root", type=str, required=True, help="path to kapture root directory" + ) + + parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") parser.add_argument("--scale-f", type=float, default=2**0.25) parser.add_argument("--min-size", type=int, default=256) @@ -188,7 +260,9 @@ if __name__ == '__main__': parser.add_argument("--reliability-thr", type=float, default=0.7) parser.add_argument("--repeatability-thr", type=float, default=0.7) - parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') + parser.add_argument( + "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" + ) args = parser.parse_args() extract_kapture_keypoints(args) diff --git a/third_party/r2d2/nets/ap_loss.py b/third_party/r2d2/nets/ap_loss.py index 251815cd97009a5feb6a815c20caca0c40daaccd..deb59e4c067aa25c834caf4d0a3c06f9d470ecd4 100644 --- a/third_party/r2d2/nets/ap_loss.py +++ b/third_party/r2d2/nets/ap_loss.py @@ -8,15 +8,16 @@ import torch import torch.nn as nn -class APLoss (nn.Module): - """ differentiable AP loss, through quantization. - - Input: (N, M) values in [min, max] - label: (N, M) values in {0, 1} - - Returns: list of query AP (for each n in {1..N}) - Note: typically, you want to minimize 1 - mean(AP) +class APLoss(nn.Module): + """differentiable AP loss, through quantization. + + Input: (N, M) values in [min, max] + label: (N, M) values in {0, 1} + + Returns: list of query AP (for each n in {1..N}) + Note: typically, you want to minimize 1 - mean(AP) """ + def __init__(self, nq=25, min=0, max=1, euc=False): nn.Module.__init__(self) assert isinstance(nq, int) and 2 <= nq <= 100 @@ -26,16 +27,20 @@ class APLoss (nn.Module): self.euc = euc gap = max - min assert gap > 0 - + # init quantizer = non-learnable (fixed) convolution - self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True) - a = (nq-1) / gap - #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) + self.quantizer = q = nn.Conv1d(1, 2 * nq, kernel_size=1, bias=True) + a = (nq - 1) / gap + # 1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) q.weight.data[:nq] = -a - q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x) - #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.bias.data[:nq] = torch.from_numpy( + a * min + np.arange(nq, 0, -1) + ) # b = 1 + a*(min+x) + # 2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) q.weight.data[nq:] = a - q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x) + q.bias.data[nq:] = torch.from_numpy( + np.arange(2 - nq, 2, 1) - a * min + ) # b = 1 - a*(min+x) # first and last one are special: just horizontal straight line q.weight.data[0] = q.weight.data[-1] = 0 q.bias.data[0] = q.bias.data[-1] = 1 @@ -43,25 +48,22 @@ class APLoss (nn.Module): def compute_AP(self, x, label): N, M = x.shape if self.euc: # euclidean distance in same range than similarities - x = 1 - torch.sqrt(2.001 - 2*x) + x = 1 - torch.sqrt(2.001 - 2 * x) # quantize all predictions q = self.quantizer(x.unsqueeze(1)) - q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M + q = torch.min(q[:, : self.nq], q[:, self.nq :]).clamp(min=0) # N x Q x M - nbs = q.sum(dim=-1) # number of samples N x Q = c - rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q - prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision - rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] + nbs = q.sum(dim=-1) # number of samples N x Q = c + rec = (q * label.view(N, 1, M).float()).sum( + dim=-1 + ) # nb of correct samples = c+ N x Q + prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision + rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] - ap = (prec * rec).sum(dim=-1) # per-image AP + ap = (prec * rec).sum(dim=-1) # per-image AP return ap def forward(self, x, label): - assert x.shape == label.shape # N x M + assert x.shape == label.shape # N x M return self.compute_AP(x, label) - - - - - diff --git a/third_party/r2d2/nets/losses.py b/third_party/r2d2/nets/losses.py index f8eea8f6e82835e22d2bb445125f7dc722db85b2..973c592aab3f8f1c69b4001d1d324f1ad46ebe2d 100644 --- a/third_party/r2d2/nets/losses.py +++ b/third_party/r2d2/nets/losses.py @@ -13,44 +13,40 @@ from nets.repeatability_loss import * from nets.reliability_loss import * -class MultiLoss (nn.Module): - """ Combines several loss functions for convenience. +class MultiLoss(nn.Module): + """Combines several loss functions for convenience. *args: [loss weight (float), loss creator, ... ] - + Example: loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() ) """ + def __init__(self, *args, dbg=()): nn.Module.__init__(self) - assert len(args) % 2 == 0, 'args must be a list of (float, loss)' + assert len(args) % 2 == 0, "args must be a list of (float, loss)" self.weights = [] self.losses = nn.ModuleList() - for i in range(len(args)//2): - weight = float(args[2*i+0]) - loss = args[2*i+1] + for i in range(len(args) // 2): + weight = float(args[2 * i + 0]) + loss = args[2 * i + 1] assert isinstance(loss, nn.Module), "%s is not a loss!" % loss self.weights.append(weight) self.losses.append(loss) def forward(self, select=None, **variables): - assert not select or all(1<=n<=len(self.losses) for n in select) + assert not select or all(1 <= n <= len(self.losses) for n in select) d = dict() cum_loss = 0 - for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses),1): - if select is not None and num not in select: continue - l = loss_func(**{k:v for k,v in variables.items()}) + for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses), 1): + if select is not None and num not in select: + continue + l = loss_func(**{k: v for k, v in variables.items()}) if isinstance(l, tuple): assert len(l) == 2 and isinstance(l[1], dict) else: - l = l, {loss_func.name:l} + l = l, {loss_func.name: l} cum_loss = cum_loss + weight * l[0] - for key,val in l[1].items(): - d['loss_'+key] = float(val) - d['loss'] = float(cum_loss) + for key, val in l[1].items(): + d["loss_" + key] = float(val) + d["loss"] = float(cum_loss) return cum_loss, d - - - - - - diff --git a/third_party/r2d2/nets/patchnet.py b/third_party/r2d2/nets/patchnet.py index 854c61ecf9b879fa7f420255296c4fbbfd665181..8ed3fdbd55ccbbd58f0cea3dad9384a402ec5e9d 100644 --- a/third_party/r2d2/nets/patchnet.py +++ b/third_party/r2d2/nets/patchnet.py @@ -8,22 +8,25 @@ import torch.nn as nn import torch.nn.functional as F -class BaseNet (nn.Module): - """ Takes a list of images as input, and returns for each image: - - a pixelwise descriptor - - a pixelwise confidence +class BaseNet(nn.Module): + """Takes a list of images as input, and returns for each image: + - a pixelwise descriptor + - a pixelwise confidence """ + def softmax(self, ux): if ux.shape[1] == 1: x = F.softplus(ux) return x / (1 + x) # for sure in [0,1], much less plateaus than softmax elif ux.shape[1] == 2: - return F.softmax(ux, dim=1)[:,1:2] + return F.softmax(ux, dim=1)[:, 1:2] def normalize(self, x, ureliability, urepeatability): - return dict(descriptors = F.normalize(x, p=2, dim=1), - repeatability = self.softmax( urepeatability ), - reliability = self.softmax( ureliability )) + return dict( + descriptors=F.normalize(x, p=2, dim=1), + repeatability=self.softmax(urepeatability), + reliability=self.softmax(ureliability), + ) def forward_one(self, x): raise NotImplementedError() @@ -31,15 +34,15 @@ class BaseNet (nn.Module): def forward(self, imgs, **kw): res = [self.forward_one(img) for img in imgs] # merge all dictionaries into one - res = {k:[r[k] for r in res if k in r] for k in {k for r in res for k in r}} + res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}} return dict(res, imgs=imgs, **kw) - -class PatchNet (BaseNet): - """ Helper class to construct a fully-convolutional network that - extract a l2-normalized patch descriptor. +class PatchNet(BaseNet): + """Helper class to construct a fully-convolutional network that + extract a l2-normalized patch descriptor. """ + def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): BaseNet.__init__(self) self.inchan = inchan @@ -53,41 +56,54 @@ class PatchNet (BaseNet): def _make_bn(self, outd): return nn.BatchNorm2d(outd, affine=self.bn_affine) - def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max'): + def _add_conv( + self, + outd, + k=3, + stride=1, + dilation=1, + bn=True, + relu=True, + k_pool=1, + pool_type="max", + ): # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer d = self.dilation * dilation - if self.dilated: - conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1) + if self.dilated: + conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1) self.dilation *= stride else: - conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride) - self.ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) - if bn and self.bn: self.ops.append( self._make_bn(outd) ) - if relu: self.ops.append( nn.ReLU(inplace=True) ) + conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride) + self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) + if bn and self.bn: + self.ops.append(self._make_bn(outd)) + if relu: + self.ops.append(nn.ReLU(inplace=True)) self.curchan = outd - + if k_pool > 1: - if pool_type == 'avg': + if pool_type == "avg": self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) - elif pool_type == 'max': + elif pool_type == "max": self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) else: print(f"Error, unknown pooling type {pool_type}...") - + def forward_one(self, x): assert self.ops, "You need to add convolutions first" - for n,op in enumerate(self.ops): + for n, op in enumerate(self.ops): x = op(x) return self.normalize(x) -class L2_Net (PatchNet): - """ Compute a 128D descriptor for all overlapping 32x32 patches. - From the L2Net paper (CVPR'17). +class L2_Net(PatchNet): + """Compute a 128D descriptor for all overlapping 32x32 patches. + From the L2Net paper (CVPR'17). """ - def __init__(self, dim=128, **kw ): + + def __init__(self, dim=128, **kw): PatchNet.__init__(self, **kw) - add_conv = lambda n,**kw: self._add_conv((n*dim)//128,**kw) + add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw) add_conv(32) add_conv(32) add_conv(64, stride=2) @@ -98,35 +114,34 @@ class L2_Net (PatchNet): self.out_dim = dim -class Quad_L2Net (PatchNet): - """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs. - """ - def __init__(self, dim=128, mchan=4, relu22=False, **kw ): +class Quad_L2Net(PatchNet): + """Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.""" + + def __init__(self, dim=128, mchan=4, relu22=False, **kw): PatchNet.__init__(self, **kw) - self._add_conv( 8*mchan) - self._add_conv( 8*mchan) - self._add_conv( 16*mchan, stride=2) - self._add_conv( 16*mchan) - self._add_conv( 32*mchan, stride=2) - self._add_conv( 32*mchan) + self._add_conv(8 * mchan) + self._add_conv(8 * mchan) + self._add_conv(16 * mchan, stride=2) + self._add_conv(16 * mchan) + self._add_conv(32 * mchan, stride=2) + self._add_conv(32 * mchan) # replace last 8x8 convolution with 3 2x2 convolutions - self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) - self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) + self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) self._add_conv(dim, k=2, stride=2, bn=False, relu=False) self.out_dim = dim +class Quad_L2Net_ConfCFS(Quad_L2Net): + """Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.""" -class Quad_L2Net_ConfCFS (Quad_L2Net): - """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability. - """ - def __init__(self, **kw ): + def __init__(self, **kw): Quad_L2Net.__init__(self, **kw) # reliability classifier self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) # repeatability classifier: for some reasons it's a softplus, not a softmax! # Why? I guess it's a mistake that was left unnoticed in the code for a long time... - self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) + self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) def forward_one(self, x): assert self.ops, "You need to add convolutions first" @@ -138,44 +153,51 @@ class Quad_L2Net_ConfCFS (Quad_L2Net): return self.normalize(x, ureliability, urepeatability) -class Fast_Quad_L2Net (PatchNet): - """ Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time +class Fast_Quad_L2Net(PatchNet): + """Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time Dilation factors and pooling: 1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2 """ - def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw ): + + def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw): PatchNet.__init__(self, **kw) - self._add_conv( 8*mchan) - self._add_conv( 8*mchan) - self._add_conv( 16*mchan, k_pool = downsample_factor) # added avg pooling to decrease img resolution - self._add_conv( 16*mchan) - self._add_conv( 32*mchan, stride=2) - self._add_conv( 32*mchan) - + self._add_conv(8 * mchan) + self._add_conv(8 * mchan) + self._add_conv( + 16 * mchan, k_pool=downsample_factor + ) # added avg pooling to decrease img resolution + self._add_conv(16 * mchan) + self._add_conv(32 * mchan, stride=2) + self._add_conv(32 * mchan) + # replace last 8x8 convolution with 3 2x2 convolutions - self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) - self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) + self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) self._add_conv(dim, k=2, stride=2, bn=False, relu=False) - + # Go back to initial image resolution with upsampling - self.ops.append(torch.nn.Upsample(scale_factor=downsample_factor, mode='bilinear', align_corners=False)) - + self.ops.append( + torch.nn.Upsample( + scale_factor=downsample_factor, mode="bilinear", align_corners=False + ) + ) + self.out_dim = dim - - -class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net): - """ Fast r2d2 architecture - """ - def __init__(self, **kw ): + + +class Fast_Quad_L2Net_ConfCFS(Fast_Quad_L2Net): + """Fast r2d2 architecture""" + + def __init__(self, **kw): Fast_Quad_L2Net.__init__(self, **kw) # reliability classifier self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) - + # repeatability classifier: for some reasons it's a softplus, not a softmax! # Why? I guess it's a mistake that was left unnoticed in the code for a long time... - self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) - + self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) + def forward_one(self, x): assert self.ops, "You need to add convolutions first" for op in self.ops: @@ -183,4 +205,4 @@ class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net): # compute the confidence maps ureliability = self.clf(x**2) urepeatability = self.sal(x**2) - return self.normalize(x, ureliability, urepeatability) \ No newline at end of file + return self.normalize(x, ureliability, urepeatability) diff --git a/third_party/r2d2/nets/reliability_loss.py b/third_party/r2d2/nets/reliability_loss.py index 52d5383b0eaa52bcf2111eabb4b45e39b63b976f..e560d1ea1b4dc27d81031c62cc4c0aed9161cc67 100644 --- a/third_party/r2d2/nets/reliability_loss.py +++ b/third_party/r2d2/nets/reliability_loss.py @@ -9,18 +9,19 @@ import torch.nn.functional as F from nets.ap_loss import APLoss -class PixelAPLoss (nn.Module): - """ Computes the pixel-wise AP loss: - Given two images and ground-truth optical flow, computes the AP per pixel. - - feat1: (B, C, H, W) pixel-wise features extracted from img1 - feat2: (B, C, H, W) pixel-wise features extracted from img2 - aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 +class PixelAPLoss(nn.Module): + """Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 """ + def __init__(self, sampler, nq=20): nn.Module.__init__(self) self.aploss = APLoss(nq, min=0, max=1, euc=False) - self.name = 'pixAP' + self.name = "pixAP" self.sampler = sampler def loss_from_ap(self, ap, rel): @@ -28,32 +29,31 @@ class PixelAPLoss (nn.Module): def forward(self, descriptors, aflow, **kw): # subsample things - scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow) - + scores, gt, msk, qconf = self.sampler(descriptors, kw.get("reliability"), aflow) + # compute pixel-wise AP n = qconf.numel() - if n == 0: return 0 - scores, gt = scores.view(n,-1), gt.view(n,-1) + if n == 0: + return 0 + scores, gt = scores.view(n, -1), gt.view(n, -1) ap = self.aploss(scores, gt).view(msk.shape) pixel_loss = self.loss_from_ap(ap, qconf) - + loss = pixel_loss[msk].mean() return loss -class ReliabilityLoss (PixelAPLoss): - """ same than PixelAPLoss, but also train a pixel-wise confidence - that this pixel is going to have a good AP. +class ReliabilityLoss(PixelAPLoss): + """same than PixelAPLoss, but also train a pixel-wise confidence + that this pixel is going to have a good AP. """ + def __init__(self, sampler, base=0.5, **kw): PixelAPLoss.__init__(self, sampler, **kw) assert 0 <= base < 1 self.base = base - self.name = 'reliability' + self.name = "reliability" def loss_from_ap(self, ap, rel): - return 1 - ap*rel - (1-rel)*self.base - - - + return 1 - ap * rel - (1 - rel) * self.base diff --git a/third_party/r2d2/nets/repeatability_loss.py b/third_party/r2d2/nets/repeatability_loss.py index 5cda0b6d036f98af88a88780fe39da0c5c0b610e..af49e77f444c5b4b035cd43d0c065096e8dd7c1b 100644 --- a/third_party/r2d2/nets/repeatability_loss.py +++ b/third_party/r2d2/nets/repeatability_loss.py @@ -10,27 +10,28 @@ import torch.nn.functional as F from nets.sampler import FullSampler -class CosimLoss (nn.Module): - """ Try to make the repeatability repeatable from one image to the other. - """ + +class CosimLoss(nn.Module): + """Try to make the repeatability repeatable from one image to the other.""" + def __init__(self, N=16): nn.Module.__init__(self) - self.name = f'cosim{N}' - self.patches = nn.Unfold(N, padding=0, stride=N//2) + self.name = f"cosim{N}" + self.patches = nn.Unfold(N, padding=0, stride=N // 2) def extract_patches(self, sal): - patches = self.patches(sal).transpose(1,2) # flatten - patches = F.normalize(patches, p=2, dim=2) # norm + patches = self.patches(sal).transpose(1, 2) # flatten + patches = F.normalize(patches, p=2, dim=2) # norm return patches - + def forward(self, repeatability, aflow, **kw): - B,two,H,W = aflow.shape + B, two, H, W = aflow.shape assert two == 2 # normalize sali1, sali2 = repeatability grid = FullSampler._aflow_to_grid(aflow) - sali2 = F.grid_sample(sali2, grid, mode='bilinear', padding_mode='border') + sali2 = F.grid_sample(sali2, grid, mode="bilinear", padding_mode="border") patches1 = self.extract_patches(sali1) patches2 = self.extract_patches(sali2) @@ -38,29 +39,25 @@ class CosimLoss (nn.Module): return 1 - cosim.mean() -class PeakyLoss (nn.Module): - """ Try to make the repeatability locally peaky. +class PeakyLoss(nn.Module): + """Try to make the repeatability locally peaky. Mechanism: we maximize, for each pixel, the difference between the local mean and the local max. """ + def __init__(self, N=16): nn.Module.__init__(self) - self.name = f'peaky{N}' - assert N % 2 == 0, 'N must be pair' + self.name = f"peaky{N}" + assert N % 2 == 0, "N must be pair" self.preproc = nn.AvgPool2d(3, stride=1, padding=1) - self.maxpool = nn.MaxPool2d(N+1, stride=1, padding=N//2) - self.avgpool = nn.AvgPool2d(N+1, stride=1, padding=N//2) + self.maxpool = nn.MaxPool2d(N + 1, stride=1, padding=N // 2) + self.avgpool = nn.AvgPool2d(N + 1, stride=1, padding=N // 2) def forward_one(self, sali): - sali = self.preproc(sali) # remove super high frequency + sali = self.preproc(sali) # remove super high frequency return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean() def forward(self, repeatability, **kw): sali1, sali2 = repeatability - return (self.forward_one(sali1) + self.forward_one(sali2)) /2 - - - - - + return (self.forward_one(sali1) + self.forward_one(sali2)) / 2 diff --git a/third_party/r2d2/nets/sampler.py b/third_party/r2d2/nets/sampler.py index 9fede70d3a04d7f31a1d414eace0aaf3729e8235..3f2e5a276a80b997561549ed3e8466da3876e382 100644 --- a/third_party/r2d2/nets/sampler.py +++ b/third_party/r2d2/nets/sampler.py @@ -15,65 +15,69 @@ import torch.nn.functional as F class FullSampler(nn.Module): - """ all pixels are selected - - feats: keypoint descriptors - - confs: reliability values + """all pixels are selected + - feats: keypoint descriptors + - confs: reliability values """ + def __init__(self): nn.Module.__init__(self) - self.mode = 'bilinear' - self.padding = 'zeros' + self.mode = "bilinear" + self.padding = "zeros" @staticmethod def _aflow_to_grid(aflow): H, W = aflow.shape[2:] - grid = aflow.permute(0,2,3,1).clone() - grid[:,:,:,0] *= 2/(W-1) - grid[:,:,:,1] *= 2/(H-1) + grid = aflow.permute(0, 2, 3, 1).clone() + grid[:, :, :, 0] *= 2 / (W - 1) + grid[:, :, :, 1] *= 2 / (H - 1) grid -= 1 - grid[torch.isnan(grid)] = 9e9 # invalids + grid[torch.isnan(grid)] = 9e9 # invalids return grid - + def _warp(self, feats, confs, aflow): - if isinstance(aflow, tuple): return aflow # result was precomputed + if isinstance(aflow, tuple): + return aflow # result was precomputed feat1, feat2 = feats - conf1, conf2 = confs if confs else (None,None) - + conf1, conf2 = confs if confs else (None, None) + B, two, H, W = aflow.shape D = feat1.shape[1] - assert feat1.shape == feat2.shape == (B, D, H, W) # D = 128, B = batch + assert feat1.shape == feat2.shape == (B, D, H, W) # D = 128, B = batch assert conf1.shape == conf2.shape == (B, 1, H, W) if confs else True # warp img2 to img1 grid = self._aflow_to_grid(aflow) - ones2 = feat2.new_ones(feat2[:,0:1].shape) + ones2 = feat2.new_ones(feat2[:, 0:1].shape) feat2to1 = F.grid_sample(feat2, grid, mode=self.mode, padding_mode=self.padding) - mask2to1 = F.grid_sample(ones2, grid, mode='nearest', padding_mode='zeros') - conf2to1 = F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding) \ - if confs else None + mask2to1 = F.grid_sample(ones2, grid, mode="nearest", padding_mode="zeros") + conf2to1 = ( + F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding) + if confs + else None + ) return feat2to1, mask2to1.byte(), conf2to1 def _warp_positions(self, aflow): B, two, H, W = aflow.shape assert two == 2 - + Y = torch.arange(H, device=aflow.device) X = torch.arange(W, device=aflow.device) - XY = torch.stack(torch.meshgrid(Y,X)[::-1], dim=0) + XY = torch.stack(torch.meshgrid(Y, X)[::-1], dim=0) XY = XY[None].expand(B, 2, H, W).float() - + grid = self._aflow_to_grid(aflow) - XY2 = F.grid_sample(XY, grid, mode='bilinear', padding_mode='zeros') + XY2 = F.grid_sample(XY, grid, mode="bilinear", padding_mode="zeros") return XY, XY2 +class SubSampler(FullSampler): + """pixels are selected in an uniformly spaced grid""" -class SubSampler (FullSampler): - """ pixels are selected in an uniformly spaced grid - """ def __init__(self, border, subq, subd, perimage=False): FullSampler.__init__(self) - assert subq % subd == 0, 'subq must be multiple of subd' + assert subq % subd == 0, "subq must be multiple of subd" self.sub_q = subq self.sub_d = subd self.border = border @@ -81,13 +85,17 @@ class SubSampler (FullSampler): def __repr__(self): return "SubSampler(border=%d, subq=%d, subd=%d, perimage=%d)" % ( - self.border, self.sub_q, self.sub_d, self.perimage) + self.border, + self.sub_q, + self.sub_d, + self.perimage, + ) def __call__(self, feats, confs, aflow): feat1, conf1 = feats[0], (confs[0] if confs else None) # warp with optical flow in img1 coords feat2, mask2, conf2 = self._warp(feats, confs, aflow) - + # subsample img1 slq = slice(self.border, -self.border or None, self.sub_q) feat1 = feat1[:, :, slq, slq] @@ -97,47 +105,50 @@ class SubSampler (FullSampler): feat2 = feat2[:, :, sld, sld] mask2 = mask2[:, :, sld, sld] conf2 = conf2[:, :, sld, sld] if confs else None - + B, D, Hq, Wq = feat1.shape B, D, Hd, Wd = feat2.shape - + # compute gt if self.perimage or self.sub_q != self.sub_d: # compute ground-truth by comparing pixel indices - f = feats[0][0:1,0] if self.perimage else feats[0][:,0] - idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view(f.shape) - idxs1 = idxs[:, slq, slq].reshape(-1,Hq*Wq) - idxs2 = idxs[:, sld, sld].reshape(-1,Hd*Wd) + f = feats[0][0:1, 0] if self.perimage else feats[0][:, 0] + idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view( + f.shape + ) + idxs1 = idxs[:, slq, slq].reshape(-1, Hq * Wq) + idxs2 = idxs[:, sld, sld].reshape(-1, Hd * Wd) if self.perimage: - gt = (idxs1[0].view(-1,1) == idxs2[0].view(1,-1)) - gt = gt[None,:,:].expand(B, Hq*Wq, Hd*Wd) - else : - gt = (idxs1.view(-1,1) == idxs2.view(1,-1)) + gt = idxs1[0].view(-1, 1) == idxs2[0].view(1, -1) + gt = gt[None, :, :].expand(B, Hq * Wq, Hd * Wd) + else: + gt = idxs1.view(-1, 1) == idxs2.view(1, -1) else: - gt = torch.eye(feat1[:,0].numel(), dtype=torch.uint8, device=feat1.device) # always binary for AP loss - + gt = torch.eye( + feat1[:, 0].numel(), dtype=torch.uint8, device=feat1.device + ) # always binary for AP loss + # compute all images together - queries = feat1.reshape(B,D,-1) # B x D x (Hq x Wq) - database = feat2.reshape(B,D,-1) # B x D x (Hd x Wd) + queries = feat1.reshape(B, D, -1) # B x D x (Hq x Wq) + database = feat2.reshape(B, D, -1) # B x D x (Hd x Wd) if self.perimage: - queries = queries.transpose(1,2) # B x (Hd x Wd) x D - scores = torch.bmm(queries, database) # B x (Hq x Wq) x (Hd x Wd) + queries = queries.transpose(1, 2) # B x (Hd x Wd) x D + scores = torch.bmm(queries, database) # B x (Hq x Wq) x (Hd x Wd) else: - queries = queries .transpose(1,2).reshape(-1,D) # (B x Hq x Wq) x D - database = database.transpose(1,0).reshape(D,-1) # D x (B x Hd x Wd) - scores = torch.matmul(queries, database) # (B x Hq x Wq) x (B x Hd x Wd) + queries = queries.transpose(1, 2).reshape(-1, D) # (B x Hq x Wq) x D + database = database.transpose(1, 0).reshape(D, -1) # D x (B x Hd x Wd) + scores = torch.matmul(queries, database) # (B x Hq x Wq) x (B x Hd x Wd) # compute reliability - qconf = (conf1 + conf2)/2 if confs else None + qconf = (conf1 + conf2) / 2 if confs else None assert gt.shape == scores.shape return scores, gt, mask2, qconf +class NghSampler(FullSampler): + """all pixels in a small neighborhood""" -class NghSampler (FullSampler): - """ all pixels in a small neighborhood - """ def __init__(self, ngh, subq=1, subd=1, ignore=1, border=None): FullSampler.__init__(self) assert 0 <= ignore < ngh @@ -146,86 +157,96 @@ class NghSampler (FullSampler): assert subd <= ngh self.sub_q = subq self.sub_d = subd - if border is None: border = ngh - assert border >= ngh, 'border has to be larger than ngh' + if border is None: + border = ngh + assert border >= ngh, "border has to be larger than ngh" self.border = border def __repr__(self): return "NghSampler(ngh=%d, subq=%d, subd=%d, ignore=%d, border=%d)" % ( - self.ngh, self.sub_q, self.sub_d, self.ignore, self.border) + self.ngh, + self.sub_q, + self.sub_d, + self.ignore, + self.border, + ) def trans(self, arr, i, j): - s = lambda i: slice(self.border+i, i-self.border or None, self.sub_q) - return arr[:,:,s(j),s(i)] + s = lambda i: slice(self.border + i, i - self.border or None, self.sub_q) + return arr[:, :, s(j), s(i)] def __call__(self, feats, confs, aflow): feat1, conf1 = feats[0], (confs[0] if confs else None) # warp with optical flow in img1 coords feat2, mask2, conf2 = self._warp(feats, confs, aflow) - - qfeat = self.trans(feat1,0,0) - qconf = (self.trans(conf1,0,0) + self.trans(conf2,0,0)) / 2 if confs else None - mask2 = self.trans(mask2,0,0) - scores_at = lambda i,j: (qfeat * self.trans(feat2,i,j)).sum(dim=1) - + + qfeat = self.trans(feat1, 0, 0) + qconf = ( + (self.trans(conf1, 0, 0) + self.trans(conf2, 0, 0)) / 2 if confs else None + ) + mask2 = self.trans(mask2, 0, 0) + scores_at = lambda i, j: (qfeat * self.trans(feat2, i, j)).sum(dim=1) + # compute scores for all neighbors B, D = feat1.shape[:2] min_d = self.ignore**2 max_d = self.ngh**2 - rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple negs = [] offsets = [] - for j in range(-rad, rad+1, self.sub_d): - for i in range(-rad, rad+1, self.sub_d): - if not(min_d < i*i + j*j <= max_d): - continue # out of scope - offsets.append((i,j)) # Note: this list is just for debug - negs.append( scores_at(i,j) ) - - scores = torch.stack([scores_at(0,0)] + negs, dim=-1) + for j in range(-rad, rad + 1, self.sub_d): + for i in range(-rad, rad + 1, self.sub_d): + if not (min_d < i * i + j * j <= max_d): + continue # out of scope + offsets.append((i, j)) # Note: this list is just for debug + negs.append(scores_at(i, j)) + + scores = torch.stack([scores_at(0, 0)] + negs, dim=-1) gt = scores.new_zeros(scores.shape, dtype=torch.uint8) - gt[..., 0] = 1 # only the center point is positive + gt[..., 0] = 1 # only the center point is positive return scores, gt, mask2, qconf +class FarNearSampler(FullSampler): + """Sample pixels from *both* a small neighborhood *and* far-away pixels. -class FarNearSampler (FullSampler): - """ Sample pixels from *both* a small neighborhood *and* far-away pixels. - How it works? 1) Queries are sampled from img1, - - at least `border` pixels from borders and + - at least `border` pixels from borders and - on a grid with step = `subq` - - 2) Close database pixels + + 2) Close database pixels - from the corresponding image (img2), - - within a `ngh` distance radius + - within a `ngh` distance radius - on a grid with step = `subd_ngh` - ignored if distance to query is >0 and <=`ignore` - + 3) Far-away database pixels from , - from all batch images in `img2` - at least `border` pixels from borders - on a grid with step = `subd_far` """ - def __init__(self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, - maxpool_ngh=False ): + + def __init__( + self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, maxpool_ngh=False + ): FullSampler.__init__(self) border = border or ngh - assert ignore < ngh < subd_far, 'neighborhood needs to be smaller than far step' - self.close_sampler = NghSampler(ngh=ngh, subq=subq, subd=subd_ngh, - ignore=not(maxpool_ngh), border=border) + assert ignore < ngh < subd_far, "neighborhood needs to be smaller than far step" + self.close_sampler = NghSampler( + ngh=ngh, subq=subq, subd=subd_ngh, ignore=not (maxpool_ngh), border=border + ) self.faraway_sampler = SubSampler(border=border, subq=subq, subd=subd_far) self.maxpool_ngh = maxpool_ngh def __repr__(self): - c,f = self.close_sampler, self.faraway_sampler + c, f = self.close_sampler, self.faraway_sampler res = "FarNearSampler(subq=%d, ngh=%d" % (c.sub_q, c.ngh) res += ", subd_ngh=%d, subd_far=%d" % (c.sub_d, f.sub_d) res += ", border=%d, ign=%d" % (f.border, c.ignore) res += ", maxpool_ngh=%d" % self.maxpool_ngh - return res+')' + return res + ")" def __call__(self, feats, confs, aflow): # warp with optical flow in img1 coords @@ -233,10 +254,10 @@ class FarNearSampler (FullSampler): # sample ngh pixels scores1, gt1, msk1, conf1 = self.close_sampler(feats, confs, aflow) - scores1, gt1 = scores1.view(-1,scores1.shape[-1]), gt1.view(-1,gt1.shape[-1]) + scores1, gt1 = scores1.view(-1, scores1.shape[-1]), gt1.view(-1, gt1.shape[-1]) if self.maxpool_ngh: # we consider all scores from ngh as potential positives - scores1, self._cached_maxpool_ngh = scores1.max(dim=1,keepdim=True) + scores1, self._cached_maxpool_ngh = scores1.max(dim=1, keepdim=True) gt1 = gt1[:, 0:1] # sample far pixels @@ -244,22 +265,35 @@ class FarNearSampler (FullSampler): # assert (msk1 == msk2).all() # assert (conf1 == conf2).all() - return (torch.cat((scores1,scores2),dim=1), - torch.cat((gt1, gt2), dim=1), - msk1, conf1 if confs else None) + return ( + torch.cat((scores1, scores2), dim=1), + torch.cat((gt1, gt2), dim=1), + msk1, + conf1 if confs else None, + ) -class NghSampler2 (nn.Module): - """ Similar to NghSampler, but doesnt warp the 2nd image. +class NghSampler2(nn.Module): + """Similar to NghSampler, but doesnt warp the 2nd image. Distance to GT => 0 ... pos_d ... neg_d ... ngh Pixel label => + + + + + + 0 0 - - - - - - - - + Subsample on query side: if > 0, regular grid - < 0, random points + < 0, random points In both cases, the number of query points is = W*H/subq**2 """ - def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, - maxpool_pos=True, subd_neg=0): + + def __init__( + self, + ngh, + subq=1, + subd=1, + pos_d=0, + neg_d=2, + border=None, + maxpool_pos=True, + subd_neg=0, + ): nn.Module.__init__(self) assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) self.ngh = ngh @@ -270,8 +304,9 @@ class NghSampler2 (nn.Module): self.sub_q = subq self.sub_d = subd self.sub_d_neg = subd_neg - if border is None: border = ngh - assert border >= ngh, 'border has to be larger than ngh' + if border is None: + border = ngh + assert border >= ngh, "border has to be larger than ngh" self.border = border self.maxpool_pos = maxpool_pos self.precompute_offsets() @@ -280,19 +315,19 @@ class NghSampler2 (nn.Module): pos_d2 = self.pos_d**2 neg_d2 = self.neg_d**2 rad2 = self.ngh**2 - rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + rad = (self.ngh // self.sub_d) * self.ngh # make an integer multiple pos = [] neg = [] - for j in range(-rad, rad+1, self.sub_d): - for i in range(-rad, rad+1, self.sub_d): - d2 = i*i + j*j - if d2 <= pos_d2: - pos.append( (i,j) ) - elif neg_d2 <= d2 <= rad2: - neg.append( (i,j) ) + for j in range(-rad, rad + 1, self.sub_d): + for i in range(-rad, rad + 1, self.sub_d): + d2 = i * i + j * j + if d2 <= pos_d2: + pos.append((i, j)) + elif neg_d2 <= d2 <= rad2: + neg.append((i, j)) - self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) - self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) + self.register_buffer("pos_offsets", torch.LongTensor(pos).view(-1, 2).t()) + self.register_buffer("neg_offsets", torch.LongTensor(neg).view(-1, 2).t()) def gen_grid(self, step, aflow): B, two, H, W = aflow.shape @@ -300,21 +335,21 @@ class NghSampler2 (nn.Module): b1 = torch.arange(B, device=dev) if step > 0: # regular grid - x1 = torch.arange(self.border, W-self.border, step, device=dev) - y1 = torch.arange(self.border, H-self.border, step, device=dev) + x1 = torch.arange(self.border, W - self.border, step, device=dev) + y1 = torch.arange(self.border, H - self.border, step, device=dev) H1, W1 = len(y1), len(x1) - x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1) - y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1) - b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1) + x1 = x1[None, None, :].expand(B, H1, W1).reshape(-1) + y1 = y1[None, :, None].expand(B, H1, W1).reshape(-1) + b1 = b1[:, None, None].expand(B, H1, W1).reshape(-1) shape = (B, H1, W1) else: # randomly spread - n = (H - 2*self.border) * (W - 2*self.border) // step**2 - x1 = torch.randint(self.border, W-self.border, (n,), device=dev) - y1 = torch.randint(self.border, H-self.border, (n,), device=dev) - x1 = x1[None,:].expand(B,n).reshape(-1) - y1 = y1[None,:].expand(B,n).reshape(-1) - b1 = b1[:,None].expand(B,n).reshape(-1) + n = (H - 2 * self.border) * (W - 2 * self.border) // step**2 + x1 = torch.randint(self.border, W - self.border, (n,), device=dev) + y1 = torch.randint(self.border, H - self.border, (n,), device=dev) + x1 = x1[None, :].expand(B, n).reshape(-1) + y1 = y1[None, :].expand(B, n).reshape(-1) + b1 = b1[:, None].expand(B, n).reshape(-1) shape = (B, n) return b1, y1, x1, shape @@ -323,41 +358,41 @@ class NghSampler2 (nn.Module): assert two == 2 feat1, conf1 = feats[0], (confs[0] if confs else None) feat2, conf2 = feats[1], (confs[1] if confs else None) - + # positions in the first image b1, y1, x1, shape = self.gen_grid(self.sub_q, aflow) # sample features from first image feat1 = feat1[b1, :, y1, x1] qconf = conf1[b1, :, y1, x1].view(shape) if confs else None - - #sample GT from second image + + # sample GT from second image b2 = b1 xy2 = (aflow[b1, :, y1, x1] + 0.5).long().t() mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H) mask = mask.view(shape) - + def clamp(xy): - torch.clamp(xy[0], 0, W-1, out=xy[0]) - torch.clamp(xy[1], 0, H-1, out=xy[1]) + torch.clamp(xy[0], 0, W - 1, out=xy[0]) + torch.clamp(xy[1], 0, H - 1, out=xy[1]) return xy - + # compute positive scores - xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None]) - pscores = (feat1[None,:,:] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t() -# xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None]) -# grid = FullSampler._aflow_to_grid(aflow) -# feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border') -# pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t() + xy2p = clamp(xy2[:, None, :] + self.pos_offsets[:, :, None]) + pscores = (feat1[None, :, :] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t() + # xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None]) + # grid = FullSampler._aflow_to_grid(aflow) + # feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border') + # pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t() if self.maxpool_pos: pscores, pos = pscores.max(dim=1, keepdim=True) - if confs: - sel = clamp(xy2 + self.pos_offsets[:,pos.view(-1)]) - qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape))/2 - + if confs: + sel = clamp(xy2 + self.pos_offsets[:, pos.view(-1)]) + qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape)) / 2 + # compute negative scores - xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None]) - nscores = (feat1[None,:,:] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t() + xy2n = clamp(xy2[:, None, :] + self.neg_offsets[:, :, None]) + nscores = (feat1[None, :, :] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t() if self.sub_d_neg: # add distractors from a grid @@ -365,26 +400,18 @@ class NghSampler2 (nn.Module): distractors = feat2[b3, :, y3, x3] dscores = torch.matmul(feat1, distractors.t()) del distractors - + # remove scores that corresponds to positives or nulls - dis2 = (x3 - xy2[0][:,None])**2 + (y3 - xy2[1][:,None])**2 - dis2 += (b3 != b2[:,None]).long() * self.neg_d**2 + dis2 = (x3 - xy2[0][:, None]) ** 2 + (y3 - xy2[1][:, None]) ** 2 + dis2 += (b3 != b2[:, None]).long() * self.neg_d**2 dscores[dis2 < self.neg_d**2] = 0 - + scores = torch.cat((pscores, nscores, dscores), dim=1) else: # concat everything scores = torch.cat((pscores, nscores), dim=1) gt = scores.new_zeros(scores.shape, dtype=torch.uint8) - gt[:, :pscores.shape[1]] = 1 + gt[:, : pscores.shape[1]] = 1 return scores, gt, mask, qconf - - - - - - - - diff --git a/third_party/r2d2/tools/common.py b/third_party/r2d2/tools/common.py index a7875ddd714b1d08efb0d1369c3a856490796288..be5137c60e3fb71cbbf180d0058de20a508ff140 100644 --- a/third_party/r2d2/tools/common.py +++ b/third_party/r2d2/tools/common.py @@ -2,7 +2,7 @@ # CC BY-NC-SA 3.0 # Available only for non-commercial use -import os, pdb#, shutil +import os, pdb # , shutil import numpy as np import torch @@ -12,8 +12,7 @@ def mkdir_for(file_path): def model_size(model): - ''' Computes the number of parameters of the model - ''' + """Computes the number of parameters of the model""" size = 0 for weights in model.state_dict().values(): size += np.prod(weights.shape) @@ -24,18 +23,19 @@ def torch_set_gpu(gpus): if type(gpus) is int: gpus = [gpus] - cuda = all(gpu>=0 for gpu in gpus) + cuda = all(gpu >= 0 for gpu in gpus) if cuda: - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in gpus]) assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( - os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']) - torch.backends.cudnn.benchmark = True # speed-up cudnn - torch.backends.cudnn.fastest = True # even more speed-up? - print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] ) + os.environ["HOSTNAME"], + os.environ["CUDA_VISIBLE_DEVICES"], + ) + torch.backends.cudnn.benchmark = True # speed-up cudnn + torch.backends.cudnn.fastest = True # even more speed-up? + print("Launching on GPUs " + os.environ["CUDA_VISIBLE_DEVICES"]) else: - print( 'Launching on CPU' ) + print("Launching on CPU") return cuda - diff --git a/third_party/r2d2/tools/dataloader.py b/third_party/r2d2/tools/dataloader.py index f6d9fff5f8dfb8d9d3b243a57555779de33d0818..a0fc97d8085c1e7c5fc5c14cc4e0818bd343595f 100644 --- a/third_party/r2d2/tools/dataloader.py +++ b/third_party/r2d2/tools/dataloader.py @@ -14,99 +14,113 @@ from tools.transforms_tools import persp_apply RGB_mean = [0.485, 0.456, 0.406] -RGB_std = [0.229, 0.224, 0.225] +RGB_std = [0.229, 0.224, 0.225] norm_RGB = tvf.Compose([tvf.ToTensor(), tvf.Normalize(mean=RGB_mean, std=RGB_std)]) class PairLoader: - """ On-the-fly jittering of pairs of image with dense pixel ground-truth correspondences. - + """On-the-fly jittering of pairs of image with dense pixel ground-truth correspondences. + crop: random crop applied to both images scale: random scaling applied to img2 distort: random ditorsion applied to img2 - + self[idx] returns a dictionary with keys: img1, img2, aflow, mask - img1: cropped original - img2: distorted cropped original - aflow: 'absolute' optical flow = (x,y) position of each pixel from img1 in img2 - mask: (binary image) valid pixels of img1 """ - def __init__(self, dataset, crop='', scale='', distort='', norm = norm_RGB, - what = 'aflow mask', idx_as_rng_seed = False): - assert hasattr(dataset, 'npairs') - assert hasattr(dataset, 'get_pair') + + def __init__( + self, + dataset, + crop="", + scale="", + distort="", + norm=norm_RGB, + what="aflow mask", + idx_as_rng_seed=False, + ): + assert hasattr(dataset, "npairs") + assert hasattr(dataset, "get_pair") self.dataset = dataset self.distort = instanciate_transformation(distort) self.crop = instanciate_transformation(crop) self.norm = instanciate_transformation(norm) self.scale = instanciate_transformation(scale) - self.idx_as_rng_seed = idx_as_rng_seed # to remove randomness + self.idx_as_rng_seed = idx_as_rng_seed # to remove randomness self.what = what.split() if isinstance(what, str) else what - self.n_samples = 5 # number of random trials per image + self.n_samples = 5 # number of random trials per image def __len__(self): - assert len(self.dataset) == self.dataset.npairs, pdb.set_trace() # and not nimg + assert len(self.dataset) == self.dataset.npairs, pdb.set_trace() # and not nimg return len(self.dataset) def __repr__(self): - fmt_str = 'PairLoader\n' + fmt_str = "PairLoader\n" fmt_str += repr(self.dataset) - fmt_str += ' npairs: %d\n' % self.dataset.npairs - short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ') - fmt_str += ' Distort: %s\n' % short_repr(self.distort) - fmt_str += ' Crop: %s\n' % short_repr(self.crop) - fmt_str += ' Norm: %s\n' % short_repr(self.norm) + fmt_str += " npairs: %d\n" % self.dataset.npairs + short_repr = ( + lambda s: repr(s).strip().replace("\n", ", ")[14:-1].replace(" ", " ") + ) + fmt_str += " Distort: %s\n" % short_repr(self.distort) + fmt_str += " Crop: %s\n" % short_repr(self.crop) + fmt_str += " Norm: %s\n" % short_repr(self.norm) return fmt_str def __getitem__(self, i): - #from time import time as now; t0 = now() + # from time import time as now; t0 = now() if self.idx_as_rng_seed: import random + random.seed(i) np.random.seed(i) # Retrieve an image pair and their absolute flow img_a, img_b, metadata = self.dataset.get_pair(i, self.what) - - # aflow contains pixel coordinates indicating where each + + # aflow contains pixel coordinates indicating where each # pixel from the left image ended up in the right image # as (x,y) pairs, but its shape is (H,W,2) - aflow = np.float32(metadata['aflow']) - mask = metadata.get('mask', np.ones(aflow.shape[:2],np.uint8)) + aflow = np.float32(metadata["aflow"]) + mask = metadata.get("mask", np.ones(aflow.shape[:2], np.uint8)) # apply transformations to the second image - img_b = {'img': img_b, 'persp':(1,0,0,0,1,0,0,0)} + img_b = {"img": img_b, "persp": (1, 0, 0, 0, 1, 0, 0, 0)} if self.scale: img_b = self.scale(img_b) if self.distort: img_b = self.distort(img_b) - + # apply the same transformation to the flow - aflow[:] = persp_apply(img_b['persp'], aflow.reshape(-1,2)).reshape(aflow.shape) + aflow[:] = persp_apply(img_b["persp"], aflow.reshape(-1, 2)).reshape( + aflow.shape + ) corres = None - if 'corres' in metadata: - corres = np.float32(metadata['corres']) - corres[:,1] = persp_apply(img_b['persp'], corres[:,1]) - + if "corres" in metadata: + corres = np.float32(metadata["corres"]) + corres[:, 1] = persp_apply(img_b["persp"], corres[:, 1]) + # apply the same transformation to the homography homography = None - if 'homography' in metadata: - homography = np.float32(metadata['homography']) + if "homography" in metadata: + homography = np.float32(metadata["homography"]) # p_b = homography * p_a - persp = np.float32(img_b['persp']+(1,)).reshape(3,3) + persp = np.float32(img_b["persp"] + (1,)).reshape(3, 3) homography = persp @ homography # determine crop size - img_b = img_b['img'] - crop_size = self.crop({'imsize':(10000,10000)})['imsize'] + img_b = img_b["img"] + crop_size = self.crop({"imsize": (10000, 10000)})["imsize"] output_size_a = min(img_a.size, crop_size) output_size_b = min(img_b.size, crop_size) img_a = np.array(img_a) img_b = np.array(img_b) - ah,aw,p1 = img_a.shape - bh,bw,p2 = img_b.shape + ah, aw, p1 = img_a.shape + bh, bw, p2 = img_b.shape assert p1 == 3 assert p2 == 3 assert aflow.shape == (ah, aw, 2) @@ -114,68 +128,82 @@ class PairLoader: # Let's start by computing the scale of the # optical flow and applying a median filter: - dx = np.gradient(aflow[:,:,0]) - dy = np.gradient(aflow[:,:,1]) - scale = np.sqrt(np.clip(np.abs(dx[1]*dy[0] - dx[0]*dy[1]), 1e-16, 1e16)) + dx = np.gradient(aflow[:, :, 0]) + dy = np.gradient(aflow[:, :, 1]) + scale = np.sqrt(np.clip(np.abs(dx[1] * dy[0] - dx[0] * dy[1]), 1e-16, 1e16)) - accu2 = np.zeros((16,16), bool) + accu2 = np.zeros((16, 16), bool) Q = lambda x, w: np.int32(16 * (x - w.start) / (w.stop - w.start)) - + def window1(x, size, w): l = x - int(0.5 + size / 2) r = l + int(0.5 + size) - if l < 0: l,r = (0, r - l) - if r > w: l,r = (l + w - r, w) - if l < 0: l,r = 0,w # larger than width - return slice(l,r) + if l < 0: + l, r = (0, r - l) + if r > w: + l, r = (l + w - r, w) + if l < 0: + l, r = 0, w # larger than width + return slice(l, r) + def window(cx, cy, win_size, scale, img_shape): - return (window1(cy, win_size[1]*scale, img_shape[0]), - window1(cx, win_size[0]*scale, img_shape[1])) + return ( + window1(cy, win_size[1] * scale, img_shape[0]), + window1(cx, win_size[0] * scale, img_shape[1]), + ) n_valid_pixel = mask.sum() sample_w = mask / (1e-16 + n_valid_pixel) + def sample_valid_pixel(): n = np.random.choice(sample_w.size, p=sample_w.ravel()) y, x = np.unravel_index(n, sample_w.shape) return x, y - + # Find suitable left and right windows - trials = 0 # take the best out of few trials + trials = 0 # take the best out of few trials best = -np.inf, None - for _ in range(50*self.n_samples): - if trials >= self.n_samples: break # finished! + for _ in range(50 * self.n_samples): + if trials >= self.n_samples: + break # finished! # pick a random valid point from the first image - if n_valid_pixel == 0: break + if n_valid_pixel == 0: + break c1x, c1y = sample_valid_pixel() - + # Find in which position the center of the left # window ended up being placed in the right image c2x, c2y = (aflow[c1y, c1x] + 0.5).astype(np.int32) - if not(0 <= c2x < bw and 0 <= c2y < bh): continue + if not (0 <= c2x < bw and 0 <= c2y < bh): + continue # Get the flow scale sigma = scale[c1y, c1x] # Determine sampling windows - if 0.2 < sigma < 1: - win1 = window(c1x, c1y, output_size_a, 1/sigma, img_a.shape) + if 0.2 < sigma < 1: + win1 = window(c1x, c1y, output_size_a, 1 / sigma, img_a.shape) win2 = window(c2x, c2y, output_size_b, 1, img_b.shape) elif 1 <= sigma < 5: win1 = window(c1x, c1y, output_size_a, 1, img_a.shape) win2 = window(c2x, c2y, output_size_b, sigma, img_b.shape) else: - continue # bad scale + continue # bad scale # compute a score based on the flow - x2,y2 = aflow[win1].reshape(-1, 2).T.astype(np.int32) + x2, y2 = aflow[win1].reshape(-1, 2).T.astype(np.int32) # Check the proportion of valid flow vectors - valid = (win2[1].start <= x2) & (x2 < win2[1].stop) \ - & (win2[0].start <= y2) & (y2 < win2[0].stop) + valid = ( + (win2[1].start <= x2) + & (x2 < win2[1].stop) + & (win2[0].start <= y2) + & (y2 < win2[0].stop) + ) score1 = (valid * mask[win1].ravel()).mean() # check the coverage of the second window accu2[:] = False - accu2[Q(y2[valid],win2[0]), Q(x2[valid],win2[1])] = True + accu2[Q(y2[valid], win2[0]), Q(x2[valid], win2[1])] = True score2 = accu2.mean() # Check how many hits we got score = min(score1, score2) @@ -183,12 +211,12 @@ class PairLoader: trials += 1 if score > best[0]: best = score, win1, win2 - - if None in best: # counldn't find a good window - img_a = np.zeros(output_size_a[::-1]+(3,), dtype=np.uint8) - img_b = np.zeros(output_size_b[::-1]+(3,), dtype=np.uint8) - aflow = np.nan * np.ones((2,)+output_size_a[::-1], dtype=np.float32) - homography = np.nan * np.ones((3,3), dtype=np.float32) + + if None in best: # counldn't find a good window + img_a = np.zeros(output_size_a[::-1] + (3,), dtype=np.uint8) + img_b = np.zeros(output_size_b[::-1] + (3,), dtype=np.uint8) + aflow = np.nan * np.ones((2,) + output_size_a[::-1], dtype=np.float32) + homography = np.nan * np.ones((3, 3), dtype=np.float32) else: win1, win2 = best[1:] @@ -196,92 +224,103 @@ class PairLoader: img_b = img_b[win2] aflow = aflow[win1] - np.float32([[[win2[1].start, win2[0].start]]]) mask = mask[win1] - aflow[~mask.view(bool)] = np.nan # mask bad pixels! - aflow = aflow.transpose(2,0,1) # --> (2,H,W) - + aflow[~mask.view(bool)] = np.nan # mask bad pixels! + aflow = aflow.transpose(2, 0, 1) # --> (2,H,W) + if corres is not None: - corres[:,0] -= (win1[1].start, win1[0].start) - corres[:,1] -= (win2[1].start, win2[0].start) - + corres[:, 0] -= (win1[1].start, win1[0].start) + corres[:, 1] -= (win2[1].start, win2[0].start) + if homography is not None: trans1 = np.eye(3, dtype=np.float32) - trans1[:2,2] = (win1[1].start, win1[0].start) + trans1[:2, 2] = (win1[1].start, win1[0].start) trans2 = np.eye(3, dtype=np.float32) - trans2[:2,2] = (-win2[1].start, -win2[0].start) + trans2[:2, 2] = (-win2[1].start, -win2[0].start) homography = trans2 @ homography @ trans1 - homography /= homography[2,2] - + homography /= homography[2, 2] + # rescale if necessary if img_a.shape[:2][::-1] != output_size_a: - sx, sy = (np.float32(output_size_a)-1)/(np.float32(img_a.shape[:2][::-1])-1) - img_a = np.asarray(Image.fromarray(img_a).resize(output_size_a, Image.ANTIALIAS)) - mask = np.asarray(Image.fromarray(mask).resize(output_size_a, Image.NEAREST)) + sx, sy = (np.float32(output_size_a) - 1) / ( + np.float32(img_a.shape[:2][::-1]) - 1 + ) + img_a = np.asarray( + Image.fromarray(img_a).resize(output_size_a, Image.ANTIALIAS) + ) + mask = np.asarray( + Image.fromarray(mask).resize(output_size_a, Image.NEAREST) + ) afx = Image.fromarray(aflow[0]).resize(output_size_a, Image.NEAREST) afy = Image.fromarray(aflow[1]).resize(output_size_a, Image.NEAREST) aflow = np.stack((np.float32(afx), np.float32(afy))) - + if corres is not None: - corres[:,0] *= (sx, sy) - + corres[:, 0] *= (sx, sy) + if homography is not None: - homography = homography @ np.diag(np.float32([1/sx,1/sy,1])) - homography /= homography[2,2] + homography = homography @ np.diag(np.float32([1 / sx, 1 / sy, 1])) + homography /= homography[2, 2] if img_b.shape[:2][::-1] != output_size_b: - sx, sy = (np.float32(output_size_b)-1)/(np.float32(img_b.shape[:2][::-1])-1) - img_b = np.asarray(Image.fromarray(img_b).resize(output_size_b, Image.ANTIALIAS)) + sx, sy = (np.float32(output_size_b) - 1) / ( + np.float32(img_b.shape[:2][::-1]) - 1 + ) + img_b = np.asarray( + Image.fromarray(img_b).resize(output_size_b, Image.ANTIALIAS) + ) aflow *= [[[sx]], [[sy]]] - + if corres is not None: - corres[:,1] *= (sx, sy) - + corres[:, 1] *= (sx, sy) + if homography is not None: - homography = np.diag(np.float32([sx,sy,1])) @ homography - homography /= homography[2,2] - + homography = np.diag(np.float32([sx, sy, 1])) @ homography + homography /= homography[2, 2] + assert aflow.dtype == np.float32, pdb.set_trace() assert homography is None or homography.dtype == np.float32, pdb.set_trace() - if 'flow' in self.what: + if "flow" in self.what: H, W = img_a.shape[:2] mgrid = np.mgrid[0:H, 0:W][::-1].astype(np.float32) flow = aflow - mgrid - + result = dict(img1=self.norm(img_a), img2=self.norm(img_b)) for what in self.what: - try: result[what] = eval(what) - except NameError: pass + try: + result[what] = eval(what) + except NameError: + pass return result +def threaded_loader(loader, iscuda, threads, batch_size=1, shuffle=True): + """Get a data loader, given the dataset and some parameters. -def threaded_loader( loader, iscuda, threads, batch_size=1, shuffle=True): - """ Get a data loader, given the dataset and some parameters. - Parameters ---------- loader : object[i] returns the i-th training example. - + iscuda : bool - + batch_size : int - + threads : int - + shuffle : int - + Returns ------- a multi-threaded pytorch loader. """ return torch.utils.data.DataLoader( loader, - batch_size = batch_size, - shuffle = shuffle, - sampler = None, - num_workers = threads, - pin_memory = iscuda, - collate_fn=collate) - + batch_size=batch_size, + shuffle=shuffle, + sampler=None, + num_workers=threads, + pin_memory=iscuda, + collate_fn=collate, + ) def collate(batch, _use_shared_memory=True): @@ -289,6 +328,7 @@ def collate(batch, _use_shared_memory=True): Copied from https://github.com/pytorch in torch/utils/data/_utils/collate.py """ import re + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if isinstance(batch[0], torch.Tensor): @@ -300,12 +340,15 @@ def collate(batch, _use_shared_memory=True): storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): elem = batch[0] - assert elem_type.__name__ == 'ndarray' + assert elem_type.__name__ == "ndarray" # array of string classes and object - if re.search('[SaUO]', elem.dtype.str) is not None: + if re.search("[SaUO]", elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) batch = [torch.from_numpy(b) for b in batch] try: @@ -322,46 +365,52 @@ def collate(batch, _use_shared_memory=True): return batch elif isinstance(batch[0], dict): return {key: collate([d[key] for d in batch]) for key in batch[0]} - elif isinstance(batch[0], (tuple,list)): + elif isinstance(batch[0], (tuple, list)): transposed = zip(*batch) return [collate(samples) for samples in transposed] raise TypeError((error_msg.format(type(batch[0])))) - def tensor2img(tensor, model=None): - """ convert back a torch/numpy tensor to a PIL Image - by undoing the ToTensor() and Normalize() transforms. + """convert back a torch/numpy tensor to a PIL Image + by undoing the ToTensor() and Normalize() transforms. """ mean = norm_RGB.transforms[1].mean - std = norm_RGB.transforms[1].std + std = norm_RGB.transforms[1].std if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() - - res = np.uint8(np.clip(255*((tensor.transpose(1,2,0) * std) + mean), 0, 255)) + + res = np.uint8(np.clip(255 * ((tensor.transpose(1, 2, 0) * std) + mean), 0, 255)) from PIL import Image + return Image.fromarray(res) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser("Tool to debug/visualize the data loader") - parser.add_argument("dataloader", type=str, help="command to create the data loader") + parser.add_argument( + "dataloader", type=str, help="command to create the data loader" + ) args = parser.parse_args() from datasets import * - auto_pairs = lambda db: SyntheticPairDataset(db, - 'RandomScale(256,1024,can_upscale=True)', - 'RandomTilting(0.5), PixelNoise(25)') - + + auto_pairs = lambda db: SyntheticPairDataset( + db, + "RandomScale(256,1024,can_upscale=True)", + "RandomTilting(0.5), PixelNoise(25)", + ) + loader = eval(args.dataloader) print("Data loader =", loader) from tools.viz import show_flow + for data in loader: - aflow = data['aflow'] + aflow = data["aflow"] H, W = aflow.shape[-2:] - flow = (aflow - np.mgrid[:H, :W][::-1]).transpose(1,2,0) - show_flow(tensor2img(data['img1']), tensor2img(data['img2']), flow) - + flow = (aflow - np.mgrid[:H, :W][::-1]).transpose(1, 2, 0) + show_flow(tensor2img(data["img1"]), tensor2img(data["img2"]), flow) diff --git a/third_party/r2d2/tools/trainer.py b/third_party/r2d2/tools/trainer.py index 9f893395efdeb8e13cc00539325572553168c5ce..d71ef137f556b7709ebed37a6ea4c865e5ab6c37 100644 --- a/third_party/r2d2/tools/trainer.py +++ b/third_party/r2d2/tools/trainer.py @@ -10,15 +10,16 @@ import torch import torch.nn as nn -class Trainer (nn.Module): - """ Helper class to train a deep network. +class Trainer(nn.Module): + """Helper class to train a deep network. Overload this class `forward_backward` for your actual needs. - - Usage: + + Usage: train = Trainer(net, loader, loss, optimizer) for epoch in range(n_epochs): train() """ + def __init__(self, net, loader, loss, optimizer): nn.Module.__init__(self) self.net = net @@ -27,50 +28,48 @@ class Trainer (nn.Module): self.optimizer = optimizer def iscuda(self): - return next(self.net.parameters()).device != torch.device('cpu') + return next(self.net.parameters()).device != torch.device("cpu") def todevice(self, x): if isinstance(x, dict): - return {k:self.todevice(v) for k,v in x.items()} - if isinstance(x, (tuple,list)): - return [self.todevice(v) for v in x] - - if self.iscuda(): + return {k: self.todevice(v) for k, v in x.items()} + if isinstance(x, (tuple, list)): + return [self.todevice(v) for v in x] + + if self.iscuda(): return x.contiguous().cuda(non_blocking=True) else: return x.cpu() def __call__(self): self.net.train() - + stats = defaultdict(list) - - for iter,inputs in enumerate(tqdm(self.loader)): + + for iter, inputs in enumerate(tqdm(self.loader)): inputs = self.todevice(inputs) - + # compute gradient and do model update self.optimizer.zero_grad() - + loss, details = self.forward_backward(inputs) if torch.isnan(loss): - raise RuntimeError('Loss is NaN') - + raise RuntimeError("Loss is NaN") + self.optimizer.step() - + for key, val in details.items(): - stats[key].append( val ) - + stats[key].append(val) + print(" Summary of losses during this epoch:") mean = lambda lis: sum(lis) / len(lis) for loss_name, vals in stats.items(): - N = 1 + len(vals)//10 - print(f" - {loss_name:20}:", end='') - print(f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})") - return mean(stats['loss']) # return average loss + N = 1 + len(vals) // 10 + print(f" - {loss_name:20}:", end="") + print( + f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})" + ) + return mean(stats["loss"]) # return average loss def forward_backward(self, inputs): raise NotImplementedError() - - - - diff --git a/third_party/r2d2/tools/transforms.py b/third_party/r2d2/tools/transforms.py index 87275276310191a7da3fc14f606345d9616208e0..604a7c2a3ec6da955c1e85b7505103c694232458 100644 --- a/third_party/r2d2/tools/transforms.py +++ b/third_party/r2d2/tools/transforms.py @@ -11,23 +11,23 @@ from math import ceil from . import transforms_tools as F -''' +""" Example command to try out some transformation chain: python -m tools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)" -''' +""" def instanciate_transformation(cmd_line): - ''' Create a sequence of transformations. - + """Create a sequence of transformations. + cmd_line: (str) Comma-separated list of transformations. Ex: "Rotate(10), Scale(256)" - ''' + """ if not isinstance(cmd_line, str): - return cmd_line # already instanciated - + return cmd_line # already instanciated + cmd_line = "tvf.Compose([%s])" % cmd_line try: return eval(cmd_line) @@ -35,19 +35,26 @@ def instanciate_transformation(cmd_line): print("Cannot interpret this transform list: %s\nReason: %s" % (cmd_line, e)) -class Scale (object): - """ Rescale the input PIL.Image to a given size. +class Scale(object): + """Rescale the input PIL.Image to a given size. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py - + The smallest dimension of the resulting image will be = size. - + if largest == True: same behaviour for the largest dimension. - + if not can_upscale: don't upscale if not can_downscale: don't downscale """ - def __init__(self, size, interpolation=Image.BILINEAR, largest=False, - can_upscale=True, can_downscale=True): + + def __init__( + self, + size, + interpolation=Image.BILINEAR, + largest=False, + can_upscale=True, + can_downscale=True, + ): assert isinstance(size, int) or (len(size) == 2) self.size = size self.interpolation = interpolation @@ -57,15 +64,18 @@ class Scale (object): def __repr__(self): fmt_str = "RandomScale(%s" % str(self.size) - if self.largest: fmt_str += ', largest=True' - if not self.can_upscale: fmt_str += ', can_upscale=False' - if not self.can_downscale: fmt_str += ', can_downscale=False' - return fmt_str+')' + if self.largest: + fmt_str += ", largest=True" + if not self.can_upscale: + fmt_str += ", can_upscale=False" + if not self.can_downscale: + fmt_str += ", can_downscale=False" + return fmt_str + ")" def get_params(self, imsize): - w,h = imsize + w, h = imsize if isinstance(self.size, int): - cmp = lambda a,b: (a>=b) if self.largest else (a<=b) + cmp = lambda a, b: (a >= b) if self.largest else (a <= b) if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size): ow, oh = w, h elif cmp(w, h): @@ -81,19 +91,22 @@ class Scale (object): def __call__(self, inp): img = F.grab_img(inp) w, h = img.size - + size2 = ow, oh = self.get_params(img.size) - + if size2 != img.size: a1, a2 = img.size, size2 - if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)): + if (self.can_upscale and min(a1) < min(a2)) or ( + self.can_downscale and min(a1) > min(a2) + ): img = img.resize(size2, self.interpolation) - return F.update_img_and_labels(inp, img, persp=(ow/w,0,0,0,oh/h,0,0,0)) - + return F.update_img_and_labels( + inp, img, persp=(ow / w, 0, 0, 0, oh / h, 0, 0, 0) + ) -class RandomScale (Scale): +class RandomScale(Scale): """Rescale the input PIL.Image to a random size. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -108,53 +121,79 @@ class RandomScale (Scale): ``PIL.Image.BILINEAR`` """ - def __init__(self, min_size, max_size, ar=1, - can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR): - Scale.__init__(self, 0, can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation) - assert type(min_size) == type(max_size), 'min_size and max_size can only be 2 ints or 2 floats' - assert isinstance(min_size, int) and min_size >= 1 or isinstance(min_size, float) and min_size>0 - assert isinstance(max_size, (int,float)) and min_size <= max_size + def __init__( + self, + min_size, + max_size, + ar=1, + can_upscale=False, + can_downscale=True, + interpolation=Image.BILINEAR, + ): + Scale.__init__( + self, + 0, + can_upscale=can_upscale, + can_downscale=can_downscale, + interpolation=interpolation, + ) + assert type(min_size) == type( + max_size + ), "min_size and max_size can only be 2 ints or 2 floats" + assert ( + isinstance(min_size, int) + and min_size >= 1 + or isinstance(min_size, float) + and min_size > 0 + ) + assert isinstance(max_size, (int, float)) and min_size <= max_size self.min_size = min_size self.max_size = max_size - if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar)) + if type(ar) in (float, int): + ar = (min(1 / ar, ar), max(1 / ar, ar)) assert 0.2 < ar[0] <= ar[1] < 5 self.ar = ar def get_params(self, imsize): - w,h = imsize + w, h = imsize if isinstance(self.min_size, float): - min_size = int(self.min_size*min(w,h) + 0.5) + min_size = int(self.min_size * min(w, h) + 0.5) if isinstance(self.max_size, float): - max_size = int(self.max_size*min(w,h) + 0.5) + max_size = int(self.max_size * min(w, h) + 0.5) if isinstance(self.min_size, int): min_size = self.min_size if isinstance(self.max_size, int): max_size = self.max_size - + if not self.can_upscale: - max_size = min(max_size,min(w,h)) - - size = int(0.5 + F.rand_log_uniform(min_size,max_size)) - ar = F.rand_log_uniform(*self.ar) # change of aspect ratio + max_size = min(max_size, min(w, h)) + + size = int(0.5 + F.rand_log_uniform(min_size, max_size)) + ar = F.rand_log_uniform(*self.ar) # change of aspect ratio - if w < h: # image is taller + if w < h: # image is taller ow = size oh = int(0.5 + size * h / w / ar) if oh < min_size: - ow,oh = int(0.5 + ow*float(min_size)/oh),min_size - else: # image is wider + ow, oh = int(0.5 + ow * float(min_size) / oh), min_size + else: # image is wider oh = size ow = int(0.5 + size * w / h * ar) if ow < min_size: - ow,oh = min_size,int(0.5 + oh*float(min_size)/ow) - - assert ow >= min_size, 'image too small (width=%d < min_size=%d)' % (ow, min_size) - assert oh >= min_size, 'image too small (height=%d < min_size=%d)' % (oh, min_size) + ow, oh = min_size, int(0.5 + oh * float(min_size) / ow) + + assert ow >= min_size, "image too small (width=%d < min_size=%d)" % ( + ow, + min_size, + ) + assert oh >= min_size, "image too small (height=%d < min_size=%d)" % ( + oh, + min_size, + ) return ow, oh - -class RandomCrop (object): +class RandomCrop(object): """Crop the given PIL Image at a random location. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -182,7 +221,12 @@ class RandomCrop (object): def get_params(img, output_size): w, h = img.size th, tw = output_size - assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th) + assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % ( + w, + h, + tw, + th, + ) y = np.random.randint(0, h - th) if h > th else 0 x = np.random.randint(0, w - tw) if w > tw else 0 @@ -204,12 +248,14 @@ class RandomCrop (object): padl, padt = self.padding[0:2] i, j, tw, th = self.get_params(img, self.size) - img = img.crop((i, j, i+tw, j+th)) - - return F.update_img_and_labels(inp, img, persp=(1,0,padl-i,0,1,padt-j,0,0)) + img = img.crop((i, j, i + tw, j + th)) + return F.update_img_and_labels( + inp, img, persp=(1, 0, padl - i, 0, 1, padt - j, 0, 0) + ) -class CenterCrop (RandomCrop): + +class CenterCrop(RandomCrop): """Crops the given PIL Image at the center. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -218,16 +264,16 @@ class CenterCrop (RandomCrop): int instead of sequence like (h, w), a square crop (size, size) is made. """ + @staticmethod def get_params(img, output_size): w, h = img.size th, tw = output_size - y = int(0.5 +((h - th) / 2.)) - x = int(0.5 +((w - tw) / 2.)) + y = int(0.5 + ((h - th) / 2.0)) + x = int(0.5 + ((w - tw) / 2.0)) return x, y, tw, th - class RandomRotation(object): """Rescale the input PIL.Image to a random size. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -247,19 +293,18 @@ class RandomRotation(object): def __call__(self, inp): img = F.grab_img(inp) w, h = img.size - + angle = np.random.uniform(-self.degrees, self.degrees) - + img = img.rotate(angle, resample=self.interpolation) w2, h2 = img.size - trf = F.translate(-w/2,-h/2) - trf = F.persp_mul(trf, F.rotate(-angle * np.pi/180)) - trf = F.persp_mul(trf, F.translate(w2/2,h2/2)) + trf = F.translate(-w / 2, -h / 2) + trf = F.persp_mul(trf, F.rotate(-angle * np.pi / 180)) + trf = F.persp_mul(trf, F.translate(w2 / 2, h2 / 2)) return F.update_img_and_labels(inp, img, persp=trf) - class RandomTilting(object): """Apply a random tilting (left, right, up, down) to the input PIL.Image Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -272,34 +317,34 @@ class RandomTilting(object): examples: "all", "left,right", "up-down-right" """ - def __init__(self, magnitude, directions='all'): + def __init__(self, magnitude, directions="all"): self.magnitude = magnitude - self.directions = directions.lower().replace(',',' ').replace('-',' ') + self.directions = directions.lower().replace(",", " ").replace("-", " ") def __repr__(self): - return "RandomTilt(%g, '%s')" % (self.magnitude,self.directions) + return "RandomTilt(%g, '%s')" % (self.magnitude, self.directions) def __call__(self, inp): img = F.grab_img(inp) w, h = img.size - x1,y1,x2,y2 = 0,0,h,w + x1, y1, x2, y2 = 0, 0, h, w original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] max_skew_amount = max(w, h) max_skew_amount = int(ceil(max_skew_amount * self.magnitude)) skew_amount = random.randint(1, max_skew_amount) - if self.directions == 'all': - choices = [0,1,2,3] + if self.directions == "all": + choices = [0, 1, 2, 3] else: - dirs = ['left', 'right', 'up', 'down'] + dirs = ["left", "right", "up", "down"] choices = [] for d in self.directions.split(): try: choices.append(dirs.index(d)) except: - raise ValueError('Tilting direction %s not recognized' % d) + raise ValueError("Tilting direction %s not recognized" % d) skew_direction = random.choice(choices) @@ -307,28 +352,36 @@ class RandomTilting(object): if skew_direction == 0: # Left Tilt - new_plane = [(y1, x1 - skew_amount), # Top Left - (y2, x1), # Top Right - (y2, x2), # Bottom Right - (y1, x2 + skew_amount)] # Bottom Left + new_plane = [ + (y1, x1 - skew_amount), # Top Left + (y2, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2 + skew_amount), + ] # Bottom Left elif skew_direction == 1: # Right Tilt - new_plane = [(y1, x1), # Top Left - (y2, x1 - skew_amount), # Top Right - (y2, x2 + skew_amount), # Bottom Right - (y1, x2)] # Bottom Left + new_plane = [ + (y1, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2, x2 + skew_amount), # Bottom Right + (y1, x2), + ] # Bottom Left elif skew_direction == 2: # Forward Tilt - new_plane = [(y1 - skew_amount, x1), # Top Left - (y2 + skew_amount, x1), # Top Right - (y2, x2), # Bottom Right - (y1, x2)] # Bottom Left + new_plane = [ + (y1 - skew_amount, x1), # Top Left + (y2 + skew_amount, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2), + ] # Bottom Left elif skew_direction == 3: # Backward Tilt - new_plane = [(y1, x1), # Top Left - (y2, x1), # Top Right - (y2 + skew_amount, x2), # Bottom Right - (y1 - skew_amount, x2)] # Bottom Left + new_plane = [ + (y1, x1), # Top Left + (y2, x1), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1 - skew_amount, x2), + ] # Bottom Left # To calculate the coefficients required by PIL for the perspective skew, # see the following Stack Overflow discussion: https://goo.gl/sSgJdj @@ -343,42 +396,49 @@ class RandomTilting(object): homography = np.dot(np.linalg.pinv(A), B) homography = tuple(np.array(homography).reshape(8)) - #print(homography) + # print(homography) - img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC) + img = img.transform( + img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC + ) - homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8] + homography = np.linalg.pinv( + np.float32(homography + (1,)).reshape(3, 3) + ).ravel()[:8] return F.update_img_and_labels(inp, img, persp=tuple(homography)) -RandomTilt = RandomTilting # redefinition +RandomTilt = RandomTilting # redefinition class Tilt(object): - """Apply a known tilting to an image - """ + """Apply a known tilting to an image""" + def __init__(self, *homography): assert len(homography) == 8 self.homography = homography - + def __call__(self, inp): img = F.grab_img(inp) homography = self.homography - #print(homography) - - img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC) - - homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8] + # print(homography) + + img = img.transform( + img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC + ) + + homography = np.linalg.pinv( + np.float32(homography + (1,)).reshape(3, 3) + ).ravel()[:8] return F.update_img_and_labels(inp, img, persp=tuple(homography)) +class StillTransform(object): + """Takes and return an image, without changing its shape or geometry.""" -class StillTransform (object): - """ Takes and return an image, without changing its shape or geometry. - """ def _transform(self, img): raise NotImplementedError() - + def __call__(self, inp): img = F.grab_img(inp) @@ -388,13 +448,12 @@ class StillTransform (object): except TypeError: pass - return F.update_img_and_labels(inp, img, persp=(1,0,0,0,1,0,0,0)) + return F.update_img_and_labels(inp, img, persp=(1, 0, 0, 0, 1, 0, 0, 0)) +class PixelNoise(StillTransform): + """Takes an image, and add random white noise.""" -class PixelNoise (StillTransform): - """ Takes an image, and add random white noise. - """ def __init__(self, ampl=20): StillTransform.__init__(self) assert 0 <= ampl < 255 @@ -405,12 +464,13 @@ class PixelNoise (StillTransform): def _transform(self, img): img = np.float32(img) - img += np.random.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape) - return Image.fromarray(np.uint8(img.clip(0,255))) - + img += np.random.uniform( + 0.5 - self.ampl / 2, 0.5 + self.ampl / 2, size=img.shape + ) + return Image.fromarray(np.uint8(img.clip(0, 255))) -class ColorJitter (StillTransform): +class ColorJitter(StillTransform): """Randomly change the brightness, contrast and saturation of an image. Copied from https://github.com/pytorch in torchvision/transforms/transforms.py @@ -424,6 +484,7 @@ class ColorJitter (StillTransform): hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast @@ -432,8 +493,12 @@ class ColorJitter (StillTransform): def __repr__(self): return "ColorJitter(%g,%g,%g,%g)" % ( - self.brightness, self.contrast, self.saturation, self.hue) - + self.brightness, + self.contrast, + self.saturation, + self.hue, + ) + @staticmethod def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. @@ -444,16 +509,26 @@ class ColorJitter (StillTransform): """ transforms = [] if brightness > 0: - brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) - transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + brightness_factor = np.random.uniform( + max(0, 1 - brightness), 1 + brightness + ) + transforms.append( + tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)) + ) if contrast > 0: contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) - transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + transforms.append( + tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)) + ) if saturation > 0: - saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) - transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + saturation_factor = np.random.uniform( + max(0, 1 - saturation), 1 + saturation + ) + transforms.append( + tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)) + ) if hue > 0: hue_factor = np.random.uniform(-hue, hue) @@ -467,47 +542,52 @@ class ColorJitter (StillTransform): return transform def _transform(self, img): - transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + transform = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) return transform(img) - -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser("Script to try out and visualize transformations") - parser.add_argument('--img', type=str, default='imgs/test.png', help='input image') - parser.add_argument('--trfs', type=str, required=True, help='list of transformations') - parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='nb of rows,cols') + parser.add_argument("--img", type=str, default="imgs/test.png", help="input image") + parser.add_argument( + "--trfs", type=str, required=True, help="list of transformations" + ) + parser.add_argument( + "--layout", type=int, nargs=2, default=(3, 3), help="nb of rows,cols" + ) args = parser.parse_args() - + import os - args.img = args.img.replace('$HERE',os.path.dirname(__file__)) + + args.img = args.img.replace("$HERE", os.path.dirname(__file__)) img = Image.open(args.img) img = dict(img=img) - + trfs = instanciate_transformation(args.trfs) - + from matplotlib import pyplot as pl + pl.ion() - pl.subplots_adjust(0,0,1,1) - - nr,nc = args.layout - + pl.subplots_adjust(0, 0, 1, 1) + + nr, nc = args.layout + while True: for j in range(nr): for i in range(nc): - pl.subplot(nr,nc,i+j*nc+1) - if i==j==0: + pl.subplot(nr, nc, i + j * nc + 1) + if i == j == 0: img2 = img else: img2 = trfs(img.copy()) if isinstance(img2, dict): - img2 = img2['img'] + img2 = img2["img"] pl.imshow(img2) pl.xlabel("%d x %d" % img2.size) pl.xticks(()) pl.yticks(()) pdb.set_trace() - - - diff --git a/third_party/r2d2/tools/transforms_tools.py b/third_party/r2d2/tools/transforms_tools.py index 294c22228a88f70480af52f79a77d73f9e5b3e1a..77eb1da2306116d789cdcf6b957a6c144a746a4f 100644 --- a/third_party/r2d2/tools/transforms_tools.py +++ b/third_party/r2d2/tools/transforms_tools.py @@ -8,31 +8,31 @@ from PIL import Image, ImageOps, ImageEnhance class DummyImg: - ''' This class is a dummy image only defined by its size. - ''' + """This class is a dummy image only defined by its size.""" + def __init__(self, size): self.size = size - + def resize(self, size, *args, **kwargs): return DummyImg(size) - + def expand(self, border): w, h = self.size if isinstance(border, int): - size = (w+2*border, h+2*border) + size = (w + 2 * border, h + 2 * border) else: - l,t,r,b = border - size = (w+l+r, h+t+b) + l, t, r, b = border + size = (w + l + r, h + t + b) return DummyImg(size) def crop(self, border): w, h = self.size - l,t,r,b = border + l, t, r, b = border assert 0 <= l <= r <= w assert 0 <= t <= b <= h - size = (r-l, b-t) + size = (r - l, b - t) return DummyImg(size) - + def rotate(self, angle): raise NotImplementedError @@ -40,89 +40,85 @@ class DummyImg: return DummyImg(size) -def grab_img( img_and_label ): - ''' Called to extract the image from an img_and_label input +def grab_img(img_and_label): + """Called to extract the image from an img_and_label input (a dictionary). Also compatible with old-style PIL images. - ''' + """ if isinstance(img_and_label, dict): # if input is a dictionary, then # it must contains the img or its size. try: - return img_and_label['img'] + return img_and_label["img"] except KeyError: - return DummyImg(img_and_label['imsize']) - + return DummyImg(img_and_label["imsize"]) + else: # or it must be the img directly return img_and_label def update_img_and_labels(img_and_label, img, persp=None): - ''' Called to update the img_and_label - ''' + """Called to update the img_and_label""" if isinstance(img_and_label, dict): - img_and_label['img'] = img - img_and_label['imsize'] = img.size + img_and_label["img"] = img + img_and_label["imsize"] = img.size if persp: - if 'persp' not in img_and_label: - img_and_label['persp'] = (1,0,0,0,1,0,0,0) - img_and_label['persp'] = persp_mul(persp, img_and_label['persp']) - + if "persp" not in img_and_label: + img_and_label["persp"] = (1, 0, 0, 0, 1, 0, 0, 0) + img_and_label["persp"] = persp_mul(persp, img_and_label["persp"]) + return img_and_label - + else: # or it must be the img directly return img def rand_log_uniform(a, b): - return np.exp(np.random.uniform(np.log(a),np.log(b))) + return np.exp(np.random.uniform(np.log(a), np.log(b))) def translate(tx, ty): - return (1,0,tx, - 0,1,ty, - 0,0) + return (1, 0, tx, 0, 1, ty, 0, 0) + def rotate(angle): - return (np.cos(angle),-np.sin(angle), 0, - np.sin(angle), np.cos(angle), 0, - 0, 0) + return (np.cos(angle), -np.sin(angle), 0, np.sin(angle), np.cos(angle), 0, 0, 0) def persp_mul(mat, mat2): - ''' homography (perspective) multiplication. + """homography (perspective) multiplication. mat: 8-tuple (homography transform) mat2: 8-tuple (homography transform) or 2-tuple (point) - ''' + """ assert isinstance(mat, tuple) assert isinstance(mat2, tuple) - mat = np.float32(mat+(1,)).reshape(3,3) - mat2 = np.array(mat2+(1,)).reshape(3,3) + mat = np.float32(mat + (1,)).reshape(3, 3) + mat2 = np.array(mat2 + (1,)).reshape(3, 3) res = np.dot(mat, mat2) - return tuple((res/res[2,2]).ravel()[:8]) + return tuple((res / res[2, 2]).ravel()[:8]) def persp_apply(mat, pts): - ''' homography (perspective) transformation. + """homography (perspective) transformation. mat: 8-tuple (homography transform) pts: numpy array - ''' + """ assert isinstance(mat, tuple) assert isinstance(pts, np.ndarray) assert pts.shape[-1] == 2 - mat = np.float32(mat+(1,)).reshape(3,3) + mat = np.float32(mat + (1,)).reshape(3, 3) if pts.ndim == 1: - pt = np.dot(pts, mat[:,:2].T).ravel() + mat[:,2] - pt /= pt[2] # homogeneous coordinates + pt = np.dot(pts, mat[:, :2].T).ravel() + mat[:, 2] + pt /= pt[2] # homogeneous coordinates return tuple(pt[:2]) else: - pt = np.dot(pts, mat[:,:2].T) + mat[:,2] - pt[:,:2] /= pt[:,2:3] # homogeneous coordinates - return pt[:,:2] + pt = np.dot(pts, mat[:, :2].T) + mat[:, 2] + pt[:, :2] /= pt[:, 2:3] # homogeneous coordinates + return pt[:, :2] def is_pil_image(img): @@ -141,7 +137,7 @@ def adjust_brightness(img, brightness_factor): Copied from https://github.com/pytorch in torchvision/transforms/functional.py """ if not is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(brightness_factor) @@ -160,7 +156,7 @@ def adjust_contrast(img, contrast_factor): Copied from https://github.com/pytorch in torchvision/transforms/functional.py """ if not is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Contrast(img) img = enhancer.enhance(contrast_factor) @@ -179,7 +175,7 @@ def adjust_saturation(img, saturation_factor): Copied from https://github.com/pytorch in torchvision/transforms/functional.py """ if not is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Color(img) img = enhancer.enhance(saturation_factor) @@ -205,26 +201,23 @@ def adjust_hue(img, hue_factor): PIL Image: Hue adjusted image. Copied from https://github.com/pytorch in torchvision/transforms/functional.py """ - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError("hue_factor is not in [-0.5, 0.5].".format(hue_factor)) if not is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: + if input_mode in {"L", "1", "I", "F"}: return img - h, s, v = img.convert('HSV').split() + h, s, v = img.convert("HSV").split() np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): + with np.errstate(over="ignore"): np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') + h = Image.fromarray(np_h, "L") - img = Image.merge('HSV', (h, s, v)).convert(input_mode) + img = Image.merge("HSV", (h, s, v)).convert(input_mode) return img - - - diff --git a/third_party/r2d2/tools/viz.py b/third_party/r2d2/tools/viz.py index c86103f3aeb468fca8b0ac9a412f22b85239361b..4cf4b90a670ee448d9d6d1ba4137abae32def005 100644 --- a/third_party/r2d2/tools/viz.py +++ b/third_party/r2d2/tools/viz.py @@ -8,16 +8,16 @@ import matplotlib.pyplot as pl def make_colorwheel(): - ''' + """ Generates a color wheel for optical flow visualization as presented in: Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun - + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py Copyright (c) 2018 Tom Runia - ''' + """ RY = 15 YG = 6 @@ -32,32 +32,32 @@ def make_colorwheel(): # RY colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) - col = col+RY + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY # YG - colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) - colorwheel[col:col+YG, 1] = 255 - col = col+YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) - col = col+GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC # CB - colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) - colorwheel[col:col+CB, 2] = 255 - col = col+CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) - col = col+BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM # MR - colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) - colorwheel[col:col+MR, 0] = 255 + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 return colorwheel def flow_compute_color(u, v, convert_to_bgr=False): - ''' + """ Applies the flow color wheel to (possibly clipped) flow components u and v. According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun @@ -65,10 +65,10 @@ def flow_compute_color(u, v, convert_to_bgr=False): :param v: np.ndarray, input vertical flow :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB :return: - + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py Copyright (c) 2018 Tom Runia - ''' + """ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) @@ -76,9 +76,9 @@ def flow_compute_color(u, v, convert_to_bgr=False): ncols = colorwheel.shape[0] rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u)/np.pi + a = np.arctan2(-v, -u) / np.pi - fk = (a+1) / 2*(ncols-1) + fk = (a + 1) / 2 * (ncols - 1) k0 = np.floor(fk).astype(np.int32) k1 = k0 + 1 k1[k1 == ncols] = 0 @@ -86,43 +86,43 @@ def flow_compute_color(u, v, convert_to_bgr=False): for i in range(colorwheel.shape[1]): - tmp = colorwheel[:,i] + tmp = colorwheel[:, i] col0 = tmp[k0] / 255.0 col1 = tmp[k1] / 255.0 - col = (1-f)*col0 + f*col1 + col = (1 - f) * col0 + f * col1 - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1-col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range? + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? # Note the 2-i => BGR instead of RGB - ch_idx = 2-i if convert_to_bgr else i - flow_image[:,:,ch_idx] = np.floor(255 * col) + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) return flow_image def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): - ''' + """ Expects a two dimensional flow image of shape [H,W,2] According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun :param flow_uv: np.ndarray of shape [H,W,2] :param clip_flow: float, maximum clipping value for flow :return: - + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py Copyright (c) 2018 Tom Runia - ''' + """ - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" if clip_flow is not None: flow_uv = np.clip(flow_uv, 0, clip_flow) - u = flow_uv[:,:,0] - v = flow_uv[:,:,1] + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] rad = np.sqrt(np.square(u) + np.square(v)) rad_max = np.max(rad) @@ -134,58 +134,59 @@ def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): return flow_compute_color(u, v, convert_to_bgr) - -def show_flow( img0, img1, flow, mask=None ): +def show_flow(img0, img1, flow, mask=None): img0 = np.asarray(img0) img1 = np.asarray(img1) - if mask is None: mask = 1 + if mask is None: + mask = 1 mask = np.asarray(mask) - if mask.ndim == 2: mask = mask[:,:,None] + if mask.ndim == 2: + mask = mask[:, :, None] assert flow.ndim == 3 assert flow.shape[:2] == img0.shape[:2] and flow.shape[2] == 2 - + def noticks(): - pl.xticks([]) - pl.yticks([]) + pl.xticks([]) + pl.yticks([]) + fig = pl.figure("showing correspondences") ax1 = pl.subplot(221) ax1.numaxis = 0 - pl.imshow(img0*mask) + pl.imshow(img0 * mask) noticks() ax2 = pl.subplot(222) ax2.numaxis = 1 pl.imshow(img1) noticks() - + ax = pl.subplot(212) ax.numaxis = 0 flow_img = flow_to_color(np.where(np.isnan(flow), 0, flow)) pl.imshow(flow_img * mask) noticks() - + pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, wspace=0.02, hspace=0.02) - + def motion_notify_callback(event): - if event.inaxes is None: return - x,y = event.xdata, event.ydata - ax1.lines = [] - ax2.lines = [] - try: - x,y = int(x+0.5), int(y+0.5) - ax1.plot(x,y,'+',ms=10,mew=2,color='blue',scalex=False,scaley=False) - x,y = flow[y,x] + (x,y) - ax2.plot(x,y,'+',ms=10,mew=2,color='red',scalex=False,scaley=False) - # we redraw only the concerned axes - renderer = fig.canvas.get_renderer() - ax1.draw(renderer) - ax2.draw(renderer) - fig.canvas.blit(ax1.bbox) - fig.canvas.blit(ax2.bbox) - except IndexError: - return - - cid_move = fig.canvas.mpl_connect('motion_notify_event',motion_notify_callback) + if event.inaxes is None: + return + x, y = event.xdata, event.ydata + ax1.lines = [] + ax2.lines = [] + try: + x, y = int(x + 0.5), int(y + 0.5) + ax1.plot(x, y, "+", ms=10, mew=2, color="blue", scalex=False, scaley=False) + x, y = flow[y, x] + (x, y) + ax2.plot(x, y, "+", ms=10, mew=2, color="red", scalex=False, scaley=False) + # we redraw only the concerned axes + renderer = fig.canvas.get_renderer() + ax1.draw(renderer) + ax2.draw(renderer) + fig.canvas.blit(ax1.bbox) + fig.canvas.blit(ax2.bbox) + except IndexError: + return + + cid_move = fig.canvas.mpl_connect("motion_notify_event", motion_notify_callback) print("Move your mouse over the images to show matches (ctrl-C to quit)") pl.show() - - diff --git a/third_party/r2d2/train.py b/third_party/r2d2/train.py index 10d23d9e40ebe8cb10c4d548b7fcb5c1c0fd7739..232d61d0eb830454b4f785cfb82536b6cfba7071 100644 --- a/third_party/r2d2/train.py +++ b/third_party/r2d2/train.py @@ -35,12 +35,12 @@ db_aachen_style_transfer = """TransformedPairs( db_aachen_flow = "aachen_flow_pairs" data_sources = dict( - D = toy_db_debug, - W = db_web_images, - A = db_aachen_images, - F = db_aachen_flow, - S = db_aachen_style_transfer, - ) + D=toy_db_debug, + W=db_web_images, + A=db_aachen_images, + F=db_aachen_flow, + S=db_aachen_style_transfer, +) default_dataloader = """PairLoader(CatPairDataset(`data`), scale = 'RandomScale(256,1024,can_upscale=True)', @@ -57,75 +57,101 @@ default_loss = """MultiLoss( class MyTrainer(trainer.Trainer): - """ This class implements the network training. - Below is the function I need to overload to explain how to do the backprop. + """This class implements the network training. + Below is the function I need to overload to explain how to do the backprop. """ + def forward_backward(self, inputs): - output = self.net(imgs=[inputs.pop('img1'),inputs.pop('img2')]) + output = self.net(imgs=[inputs.pop("img1"), inputs.pop("img2")]) allvars = dict(inputs, **output) loss, details = self.loss_func(**allvars) - if torch.is_grad_enabled(): loss.backward() + if torch.is_grad_enabled(): + loss.backward() return loss, details - -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser("Train R2D2") parser.add_argument("--data-loader", type=str, default=default_dataloader) - parser.add_argument("--train-data", type=str, default=list('WASF'), nargs='+', - choices = set(data_sources.keys())) - parser.add_argument("--net", type=str, default=default_net, help='network architecture') + parser.add_argument( + "--train-data", + type=str, + default=list("WASF"), + nargs="+", + choices=set(data_sources.keys()), + ) + parser.add_argument( + "--net", type=str, default=default_net, help="network architecture" + ) + + parser.add_argument( + "--pretrained", type=str, default="", help="pretrained model path" + ) + parser.add_argument( + "--save-path", type=str, required=True, help="model save_path path" + ) - parser.add_argument("--pretrained", type=str, default="", help='pretrained model path') - parser.add_argument("--save-path", type=str, required=True, help='model save_path path') - parser.add_argument("--loss", type=str, default=default_loss, help="loss function") - parser.add_argument("--sampler", type=str, default=default_sampler, help="AP sampler") - parser.add_argument("--N", type=int, default=16, help="patch size for repeatability") + parser.add_argument( + "--sampler", type=str, default=default_sampler, help="AP sampler" + ) + parser.add_argument( + "--N", type=int, default=16, help="patch size for repeatability" + ) - parser.add_argument("--epochs", type=int, default=25, help='number of training epochs') + parser.add_argument( + "--epochs", type=int, default=25, help="number of training epochs" + ) parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size") parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4) parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4) - - parser.add_argument("--threads", type=int, default=8, help='number of worker threads') - parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU') - + + parser.add_argument( + "--threads", type=int, default=8, help="number of worker threads" + ) + parser.add_argument("--gpu", type=int, nargs="+", default=[0], help="-1 for CPU") + args = parser.parse_args() - + iscuda = common.torch_set_gpu(args.gpu) common.mkdir_for(args.save_path) # Create data loader from datasets import * + db = [data_sources[key] for key in args.train_data] - db = eval(args.data_loader.replace('`data`',','.join(db)).replace('\n','')) + db = eval(args.data_loader.replace("`data`", ",".join(db)).replace("\n", "")) print("Training image database =", db) loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True) # create network - print("\n>> Creating net = " + args.net) + print("\n>> Creating net = " + args.net) net = eval(args.net) print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") # initialization if args.pretrained: - checkpoint = torch.load(args.pretrained, lambda a,b:a) - net.load_pretrained(checkpoint['state_dict']) - + checkpoint = torch.load(args.pretrained, lambda a, b: a) + net.load_pretrained(checkpoint["state_dict"]) + # create losses - loss = args.loss.replace('`sampler`',args.sampler).replace('`N`',str(args.N)) + loss = args.loss.replace("`sampler`", args.sampler).replace("`N`", str(args.N)) print("\n>> Creating loss = " + loss) - loss = eval(loss.replace('\n','')) - + loss = eval(loss.replace("\n", "")) + # create optimizer - optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad], - lr=args.learning_rate, weight_decay=args.weight_decay) + optimizer = optim.Adam( + [p for p in net.parameters() if p.requires_grad], + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) train = MyTrainer(net, loader, loss, optimizer) - if iscuda: train = train.cuda() + if iscuda: + train = train.cuda() # Training loop # for epoch in range(args.epochs): @@ -133,6 +159,4 @@ if __name__ == '__main__': train() print(f"\n>> Saving model to {args.save_path}") - torch.save({'net': args.net, 'state_dict': net.state_dict()}, args.save_path) - - + torch.save({"net": args.net, "state_dict": net.state_dict()}, args.save_path) diff --git a/third_party/r2d2/viz_heatmaps.py b/third_party/r2d2/viz_heatmaps.py index 42705e70ecea82696a0d784b274f7f387fdf6595..e5cb8b3bb502ce4d9e5169c55be3f479f8f8fce4 100644 --- a/third_party/r2d2/viz_heatmaps.py +++ b/third_party/r2d2/viz_heatmaps.py @@ -7,116 +7,134 @@ import numpy as np import torch from PIL import Image -from matplotlib import pyplot as pl; pl.ion() +from matplotlib import pyplot as pl + +pl.ion() from scipy.ndimage import uniform_filter + smooth = lambda arr: uniform_filter(arr, 3) + def transparent(img, alpha, cmap, **kw): from matplotlib.colors import Normalize - colored_img = cmap(Normalize(clip=True,**kw)(img)) - colored_img[:,:,-1] = alpha + + colored_img = cmap(Normalize(clip=True, **kw)(img)) + colored_img[:, :, -1] = alpha return colored_img + from tools import common from tools.dataloader import norm_RGB from nets.patchnet import * from extract import NonMaxSuppression -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") - + parser.add_argument("--img", type=str, default="imgs/brooklyn.png") parser.add_argument("--resize", type=int, default=512) parser.add_argument("--out", type=str, default="viz.png") - parser.add_argument("--checkpoint", type=str, required=True, help='network path') - parser.add_argument("--net", type=str, default="", help='network command') + parser.add_argument("--checkpoint", type=str, required=True, help="network path") + parser.add_argument("--net", type=str, default="", help="network command") parser.add_argument("--max-kpts", type=int, default=200) parser.add_argument("--reliability-thr", type=float, default=0.8) parser.add_argument("--repeatability-thr", type=float, default=0.7) - parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border') + parser.add_argument( + "--border", type=int, default=20, help="rm keypoints close to border" + ) + + parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU") + parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options") - parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU') - parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options') - args = parser.parse_args() args.dbg = set(args.dbg) - + iscuda = common.torch_set_gpu(args.gpu) - device = torch.device('cuda' if iscuda else 'cpu') + device = torch.device("cuda" if iscuda else "cpu") # create network - checkpoint = torch.load(args.checkpoint, lambda a,b:a) - args.net = args.net or checkpoint['net'] - print("\n>> Creating net = " + args.net) + checkpoint = torch.load(args.checkpoint, lambda a, b: a) + args.net = args.net or checkpoint["net"] + print("\n>> Creating net = " + args.net) net = eval(args.net) - net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()}) - if iscuda: net = net.cuda() + net.load_state_dict( + {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} + ) + if iscuda: + net = net.cuda() print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") - img = Image.open(args.img).convert('RGB') - if args.resize: img.thumbnail((args.resize,args.resize)) + img = Image.open(args.img).convert("RGB") + if args.resize: + img.thumbnail((args.resize, args.resize)) img = np.asarray(img) - + detector = NonMaxSuppression( - rel_thr = args.reliability_thr, - rep_thr = args.repeatability_thr) + rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr + ) with torch.no_grad(): print(">> computing features...") res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) - rela = res.get('reliability') - repe = res.get('repeatability') - kpts = detector(**res).T[:,[1,0]] - kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]] + rela = res.get("reliability") + repe = res.get("repeatability") + kpts = detector(**res).T[:, [1, 0]] + kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]] fig = pl.figure("viz") kw = dict(cmap=pl.cm.RdYlGn, vmax=1) - crop = (slice(args.border,-args.border or 1),)*2 - - if 'reliability' in args.dbg: - + crop = (slice(args.border, -args.border or 1),) * 2 + + if "reliability" in args.dbg: + ax1 = pl.subplot(131) pl.imshow(img[crop], cmap=pl.cm.gray) - pl.xticks(()); pl.yticks(()) + pl.xticks(()) + pl.yticks(()) pl.subplot(132) pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) - pl.xticks(()); pl.yticks(()) + pl.xticks(()) + pl.yticks(()) - x,y = kpts[:,0:2].cpu().numpy().T - args.border - pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) + x, y = kpts[:, 0:2].cpu().numpy().T - args.border + pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) ax1 = pl.subplot(133) - rela = rela[0][0,0].cpu().numpy() + rela = rela[0][0, 0].cpu().numpy() pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) - pl.xticks(()); pl.yticks(()) + pl.xticks(()) + pl.yticks(()) else: ax1 = pl.subplot(131) pl.imshow(img[crop], cmap=pl.cm.gray) - pl.xticks(()); pl.yticks(()) + pl.xticks(()) + pl.yticks(()) - x,y = kpts[:,0:2].cpu().numpy().T - args.border - pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) + x, y = kpts[:, 0:2].cpu().numpy().T - args.border + pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) pl.subplot(132) pl.imshow(img[crop], cmap=pl.cm.gray) - pl.xticks(()); pl.yticks(()) - c = repe[0][0,0].cpu().numpy() + pl.xticks(()) + pl.yticks(()) + c = repe[0][0, 0].cpu().numpy() pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) ax1 = pl.subplot(133) pl.imshow(img[crop], cmap=pl.cm.gray) - pl.xticks(()); pl.yticks(()) - rela = rela[0][0,0].cpu().numpy() + pl.xticks(()) + pl.yticks(()) + rela = rela[0][0, 0].cpu().numpy() pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) pl.gcf().set_size_inches(9, 2.73) - pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1) + pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1) pl.savefig(args.out) pdb.set_trace() -