Spaces:
Running
Running
unimatch demo
Browse files- app.py +160 -0
- dataloader/__init__.py +0 -0
- dataloader/stereo/transforms.py +434 -0
- demo/flow_davis_skate-jump_00059.jpg +0 -0
- demo/flow_davis_skate-jump_00060.jpg +0 -0
- demo/flow_kitti_test_000197_10.png +0 -0
- demo/flow_kitti_test_000197_11.png +0 -0
- demo/flow_sintel_cave_3_frame_0049.png +0 -0
- demo/flow_sintel_cave_3_frame_0050.png +0 -0
- demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg +0 -0
- demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg +0 -0
- pretrained/tmp.txt +0 -0
- requirements.txt +5 -0
- unimatch/__init__.py +0 -0
- unimatch/attention.py +253 -0
- unimatch/backbone.py +117 -0
- unimatch/geometry.py +195 -0
- unimatch/matching.py +279 -0
- unimatch/position.py +46 -0
- unimatch/reg_refine.py +119 -0
- unimatch/transformer.py +294 -0
- unimatch/trident_conv.py +90 -0
- unimatch/unimatch.py +367 -0
- unimatch/utils.py +216 -0
- utils/flow_viz.py +290 -0
- utils/visualization.py +110 -0
app.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from unimatch.unimatch import UniMatch
|
8 |
+
from utils.flow_viz import flow_to_image
|
9 |
+
from dataloader.stereo import transforms
|
10 |
+
from utils.visualization import vis_disparity
|
11 |
+
|
12 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
13 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def inference(image1, image2, task='flow'):
|
18 |
+
"""Inference on an image pair for optical flow or stereo disparity prediction"""
|
19 |
+
|
20 |
+
model = UniMatch(feature_channels=128,
|
21 |
+
num_scales=2,
|
22 |
+
upsample_factor=4,
|
23 |
+
ffn_dim_expansion=4,
|
24 |
+
num_transformer_layers=6,
|
25 |
+
reg_refine=True,
|
26 |
+
task=task)
|
27 |
+
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
if task == 'flow':
|
31 |
+
checkpoint_path = 'pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth'
|
32 |
+
else:
|
33 |
+
checkpoint_path = 'pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth'
|
34 |
+
|
35 |
+
checkpoint_flow = torch.load(checkpoint_path)
|
36 |
+
model.load_state_dict(checkpoint_flow['model'], strict=True)
|
37 |
+
|
38 |
+
padding_factor = 32
|
39 |
+
attn_type = 'swin' if task == 'flow' else 'self_swin2d_cross_swin1d'
|
40 |
+
attn_splits_list = [2, 8]
|
41 |
+
corr_radius_list = [-1, 4]
|
42 |
+
prop_radius_list = [-1, 1]
|
43 |
+
num_reg_refine = 6 if task == 'flow' else 3
|
44 |
+
|
45 |
+
# smaller inference size for faster speed
|
46 |
+
max_inference_size = [384, 768] if task == 'flow' else [640, 960]
|
47 |
+
|
48 |
+
transpose_img = False
|
49 |
+
|
50 |
+
image1 = np.array(image1).astype(np.float32)
|
51 |
+
image2 = np.array(image2).astype(np.float32)
|
52 |
+
|
53 |
+
if len(image1.shape) == 2: # gray image
|
54 |
+
image1 = np.tile(image1[..., None], (1, 1, 3))
|
55 |
+
image2 = np.tile(image2[..., None], (1, 1, 3))
|
56 |
+
else:
|
57 |
+
image1 = image1[..., :3]
|
58 |
+
image2 = image2[..., :3]
|
59 |
+
|
60 |
+
if task == 'flow':
|
61 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0)
|
62 |
+
image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0)
|
63 |
+
else:
|
64 |
+
val_transform_list = [transforms.ToTensor(),
|
65 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
66 |
+
]
|
67 |
+
|
68 |
+
val_transform = transforms.Compose(val_transform_list)
|
69 |
+
|
70 |
+
sample = {'left': image1, 'right': image2}
|
71 |
+
sample = val_transform(sample)
|
72 |
+
|
73 |
+
image1 = sample['left'].unsqueeze(0) # [1, 3, H, W]
|
74 |
+
image2 = sample['right'].unsqueeze(0) # [1, 3, H, W]
|
75 |
+
|
76 |
+
# the model is trained with size: width > height
|
77 |
+
if task == 'flow' and image1.size(-2) > image1.size(-1):
|
78 |
+
image1 = torch.transpose(image1, -2, -1)
|
79 |
+
image2 = torch.transpose(image2, -2, -1)
|
80 |
+
transpose_img = True
|
81 |
+
|
82 |
+
nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor,
|
83 |
+
int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor]
|
84 |
+
|
85 |
+
inference_size = [min(max_inference_size[0], nearest_size[0]), min(max_inference_size[1], nearest_size[1])]
|
86 |
+
|
87 |
+
assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
|
88 |
+
ori_size = image1.shape[-2:]
|
89 |
+
|
90 |
+
# resize before inference
|
91 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
92 |
+
image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
|
93 |
+
align_corners=True)
|
94 |
+
image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
|
95 |
+
align_corners=True)
|
96 |
+
|
97 |
+
results_dict = model(image1, image2,
|
98 |
+
attn_type=attn_type,
|
99 |
+
attn_splits_list=attn_splits_list,
|
100 |
+
corr_radius_list=corr_radius_list,
|
101 |
+
prop_radius_list=prop_radius_list,
|
102 |
+
num_reg_refine=num_reg_refine,
|
103 |
+
task=task,
|
104 |
+
)
|
105 |
+
|
106 |
+
flow_pr = results_dict['flow_preds'][-1] # [1, 2, H, W] or [1, H, W]
|
107 |
+
|
108 |
+
# resize back
|
109 |
+
if task == 'flow':
|
110 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
111 |
+
flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
|
112 |
+
align_corners=True)
|
113 |
+
flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
|
114 |
+
flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
|
115 |
+
else:
|
116 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
117 |
+
pred_disp = F.interpolate(flow_pr.unsqueeze(1), size=ori_size,
|
118 |
+
mode='bilinear',
|
119 |
+
align_corners=True).squeeze(1) # [1, H, W]
|
120 |
+
pred_disp = pred_disp * ori_size[-1] / float(inference_size[-1])
|
121 |
+
|
122 |
+
if task == 'flow':
|
123 |
+
if transpose_img:
|
124 |
+
flow_pr = torch.transpose(flow_pr, -2, -1)
|
125 |
+
|
126 |
+
flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
|
127 |
+
|
128 |
+
output = flow_to_image(flow) # [H, W, 3]
|
129 |
+
else:
|
130 |
+
disp = pred_disp[0].cpu().numpy()
|
131 |
+
|
132 |
+
output = vis_disparity(disp, return_rgb=True)
|
133 |
+
|
134 |
+
return Image.fromarray(output)
|
135 |
+
|
136 |
+
|
137 |
+
title = "UniMatch"
|
138 |
+
|
139 |
+
description = "<p style='text-align: center'>Optical flow and stereo matching demo for <a href='https://haofeixu.github.io/unimatch/' target='_blank'>Unifying Flow, Stereo and Depth Estimation</a> | <a href='https://arxiv.org/abs/2211.05783' target='_blank'>Paper</a> | <a href='https://github.com/autonomousvision/unimatch' target='_blank'>Code</a> | <a href='https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing' target='_blank'>Colab</a><br>Simply upload your images or click one of the provided examples.<br>The <strong>first three</strong> examples are video frames for <strong>flow</strong> task, and the <strong>last three</strong> are stereo pairs for <strong>stereo</strong> task.<br><strong>Select the task type according to your input images</strong>.</p>"
|
140 |
+
|
141 |
+
examples = [
|
142 |
+
['demo/flow_kitti_test_000197_10.png', 'demo/flow_kitti_test_000197_11.png'],
|
143 |
+
['demo/flow_sintel_cave_3_frame_0049.png', 'demo/flow_sintel_cave_3_frame_0050.png'],
|
144 |
+
['demo/flow_davis_skate-jump_00059.jpg', 'demo/flow_davis_skate-jump_00060.jpg'],
|
145 |
+
['demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg',
|
146 |
+
'demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg'],
|
147 |
+
['demo/stereo_middlebury_plants_im0.png', 'demo/stereo_middlebury_plants_im1.png'],
|
148 |
+
['demo/stereo_holopix_left.png', 'demo/stereo_holopix_right.png']
|
149 |
+
]
|
150 |
+
|
151 |
+
gr.Interface(
|
152 |
+
inference,
|
153 |
+
[gr.Image(type="pil", label="Image1"), gr.Image(type="pil", label="Image2"), gr.Radio(choices=['flow', 'stereo'], value='flow', label='Task')],
|
154 |
+
gr.Image(type="pil", label="Flow/Disparity"),
|
155 |
+
title=title,
|
156 |
+
description=description,
|
157 |
+
examples=examples,
|
158 |
+
thumbnail="https://haofeixu.github.io/unimatch/resources/teaser.svg",
|
159 |
+
allow_flagging="auto",
|
160 |
+
).launch(debug=True)
|
dataloader/__init__.py
ADDED
File without changes
|
dataloader/stereo/transforms.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms.functional as F
|
6 |
+
import random
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
|
10 |
+
class Compose(object):
|
11 |
+
def __init__(self, transforms):
|
12 |
+
self.transforms = transforms
|
13 |
+
|
14 |
+
def __call__(self, sample):
|
15 |
+
for t in self.transforms:
|
16 |
+
sample = t(sample)
|
17 |
+
return sample
|
18 |
+
|
19 |
+
|
20 |
+
class ToTensor(object):
|
21 |
+
"""Convert numpy array to torch tensor"""
|
22 |
+
|
23 |
+
def __init__(self, no_normalize=False):
|
24 |
+
self.no_normalize = no_normalize
|
25 |
+
|
26 |
+
def __call__(self, sample):
|
27 |
+
left = np.transpose(sample['left'], (2, 0, 1)) # [3, H, W]
|
28 |
+
if self.no_normalize:
|
29 |
+
sample['left'] = torch.from_numpy(left)
|
30 |
+
else:
|
31 |
+
sample['left'] = torch.from_numpy(left) / 255.
|
32 |
+
right = np.transpose(sample['right'], (2, 0, 1))
|
33 |
+
|
34 |
+
if self.no_normalize:
|
35 |
+
sample['right'] = torch.from_numpy(right)
|
36 |
+
else:
|
37 |
+
sample['right'] = torch.from_numpy(right) / 255.
|
38 |
+
|
39 |
+
# disp = np.expand_dims(sample['disp'], axis=0) # [1, H, W]
|
40 |
+
if 'disp' in sample.keys():
|
41 |
+
disp = sample['disp'] # [H, W]
|
42 |
+
sample['disp'] = torch.from_numpy(disp)
|
43 |
+
|
44 |
+
return sample
|
45 |
+
|
46 |
+
|
47 |
+
class Normalize(object):
|
48 |
+
"""Normalize image, with type tensor"""
|
49 |
+
|
50 |
+
def __init__(self, mean, std):
|
51 |
+
self.mean = mean
|
52 |
+
self.std = std
|
53 |
+
|
54 |
+
def __call__(self, sample):
|
55 |
+
|
56 |
+
norm_keys = ['left', 'right']
|
57 |
+
|
58 |
+
for key in norm_keys:
|
59 |
+
# Images have converted to tensor, with shape [C, H, W]
|
60 |
+
for t, m, s in zip(sample[key], self.mean, self.std):
|
61 |
+
t.sub_(m).div_(s)
|
62 |
+
|
63 |
+
return sample
|
64 |
+
|
65 |
+
|
66 |
+
class RandomCrop(object):
|
67 |
+
def __init__(self, img_height, img_width):
|
68 |
+
self.img_height = img_height
|
69 |
+
self.img_width = img_width
|
70 |
+
|
71 |
+
def __call__(self, sample):
|
72 |
+
ori_height, ori_width = sample['left'].shape[:2]
|
73 |
+
|
74 |
+
# pad zero when crop size is larger than original image size
|
75 |
+
if self.img_height > ori_height or self.img_width > ori_width:
|
76 |
+
|
77 |
+
# can be used for only pad one side
|
78 |
+
top_pad = max(self.img_height - ori_height, 0)
|
79 |
+
right_pad = max(self.img_width - ori_width, 0)
|
80 |
+
|
81 |
+
# try edge padding
|
82 |
+
sample['left'] = np.lib.pad(sample['left'],
|
83 |
+
((top_pad, 0), (0, right_pad), (0, 0)),
|
84 |
+
mode='edge')
|
85 |
+
sample['right'] = np.lib.pad(sample['right'],
|
86 |
+
((top_pad, 0), (0, right_pad), (0, 0)),
|
87 |
+
mode='edge')
|
88 |
+
|
89 |
+
if 'disp' in sample.keys():
|
90 |
+
sample['disp'] = np.lib.pad(sample['disp'],
|
91 |
+
((top_pad, 0), (0, right_pad)),
|
92 |
+
mode='constant',
|
93 |
+
constant_values=0)
|
94 |
+
|
95 |
+
# update image resolution
|
96 |
+
ori_height, ori_width = sample['left'].shape[:2]
|
97 |
+
|
98 |
+
assert self.img_height <= ori_height and self.img_width <= ori_width
|
99 |
+
|
100 |
+
# Training: random crop
|
101 |
+
self.offset_x = np.random.randint(ori_width - self.img_width + 1)
|
102 |
+
|
103 |
+
start_height = 0
|
104 |
+
assert ori_height - start_height >= self.img_height
|
105 |
+
|
106 |
+
self.offset_y = np.random.randint(start_height, ori_height - self.img_height + 1)
|
107 |
+
|
108 |
+
sample['left'] = self.crop_img(sample['left'])
|
109 |
+
sample['right'] = self.crop_img(sample['right'])
|
110 |
+
if 'disp' in sample.keys():
|
111 |
+
sample['disp'] = self.crop_img(sample['disp'])
|
112 |
+
|
113 |
+
return sample
|
114 |
+
|
115 |
+
def crop_img(self, img):
|
116 |
+
return img[self.offset_y:self.offset_y + self.img_height,
|
117 |
+
self.offset_x:self.offset_x + self.img_width]
|
118 |
+
|
119 |
+
|
120 |
+
class RandomVerticalFlip(object):
|
121 |
+
"""Randomly vertically filps"""
|
122 |
+
|
123 |
+
def __call__(self, sample):
|
124 |
+
if np.random.random() < 0.5:
|
125 |
+
sample['left'] = np.copy(np.flipud(sample['left']))
|
126 |
+
sample['right'] = np.copy(np.flipud(sample['right']))
|
127 |
+
|
128 |
+
sample['disp'] = np.copy(np.flipud(sample['disp']))
|
129 |
+
|
130 |
+
return sample
|
131 |
+
|
132 |
+
|
133 |
+
class ToPILImage(object):
|
134 |
+
|
135 |
+
def __call__(self, sample):
|
136 |
+
sample['left'] = Image.fromarray(sample['left'].astype('uint8'))
|
137 |
+
sample['right'] = Image.fromarray(sample['right'].astype('uint8'))
|
138 |
+
|
139 |
+
return sample
|
140 |
+
|
141 |
+
|
142 |
+
class ToNumpyArray(object):
|
143 |
+
|
144 |
+
def __call__(self, sample):
|
145 |
+
sample['left'] = np.array(sample['left']).astype(np.float32)
|
146 |
+
sample['right'] = np.array(sample['right']).astype(np.float32)
|
147 |
+
|
148 |
+
return sample
|
149 |
+
|
150 |
+
|
151 |
+
# Random coloring
|
152 |
+
class RandomContrast(object):
|
153 |
+
"""Random contrast"""
|
154 |
+
|
155 |
+
def __init__(self,
|
156 |
+
asymmetric_color_aug=True,
|
157 |
+
):
|
158 |
+
|
159 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
160 |
+
|
161 |
+
def __call__(self, sample):
|
162 |
+
if np.random.random() < 0.5:
|
163 |
+
contrast_factor = np.random.uniform(0.8, 1.2)
|
164 |
+
|
165 |
+
sample['left'] = F.adjust_contrast(sample['left'], contrast_factor)
|
166 |
+
|
167 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
168 |
+
contrast_factor = np.random.uniform(0.8, 1.2)
|
169 |
+
|
170 |
+
sample['right'] = F.adjust_contrast(sample['right'], contrast_factor)
|
171 |
+
|
172 |
+
return sample
|
173 |
+
|
174 |
+
|
175 |
+
class RandomGamma(object):
|
176 |
+
|
177 |
+
def __init__(self,
|
178 |
+
asymmetric_color_aug=True,
|
179 |
+
):
|
180 |
+
|
181 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
182 |
+
|
183 |
+
def __call__(self, sample):
|
184 |
+
if np.random.random() < 0.5:
|
185 |
+
gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
|
186 |
+
|
187 |
+
sample['left'] = F.adjust_gamma(sample['left'], gamma)
|
188 |
+
|
189 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
190 |
+
gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
|
191 |
+
|
192 |
+
sample['right'] = F.adjust_gamma(sample['right'], gamma)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class RandomBrightness(object):
|
198 |
+
|
199 |
+
def __init__(self,
|
200 |
+
asymmetric_color_aug=True,
|
201 |
+
):
|
202 |
+
|
203 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
if np.random.random() < 0.5:
|
207 |
+
brightness = np.random.uniform(0.5, 2.0)
|
208 |
+
|
209 |
+
sample['left'] = F.adjust_brightness(sample['left'], brightness)
|
210 |
+
|
211 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
212 |
+
brightness = np.random.uniform(0.5, 2.0)
|
213 |
+
|
214 |
+
sample['right'] = F.adjust_brightness(sample['right'], brightness)
|
215 |
+
|
216 |
+
return sample
|
217 |
+
|
218 |
+
|
219 |
+
class RandomHue(object):
|
220 |
+
|
221 |
+
def __init__(self,
|
222 |
+
asymmetric_color_aug=True,
|
223 |
+
):
|
224 |
+
|
225 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
226 |
+
|
227 |
+
def __call__(self, sample):
|
228 |
+
if np.random.random() < 0.5:
|
229 |
+
hue = np.random.uniform(-0.1, 0.1)
|
230 |
+
|
231 |
+
sample['left'] = F.adjust_hue(sample['left'], hue)
|
232 |
+
|
233 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
234 |
+
hue = np.random.uniform(-0.1, 0.1)
|
235 |
+
|
236 |
+
sample['right'] = F.adjust_hue(sample['right'], hue)
|
237 |
+
|
238 |
+
return sample
|
239 |
+
|
240 |
+
|
241 |
+
class RandomSaturation(object):
|
242 |
+
|
243 |
+
def __init__(self,
|
244 |
+
asymmetric_color_aug=True,
|
245 |
+
):
|
246 |
+
|
247 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
248 |
+
|
249 |
+
def __call__(self, sample):
|
250 |
+
if np.random.random() < 0.5:
|
251 |
+
saturation = np.random.uniform(0.8, 1.2)
|
252 |
+
|
253 |
+
sample['left'] = F.adjust_saturation(sample['left'], saturation)
|
254 |
+
|
255 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
256 |
+
saturation = np.random.uniform(0.8, 1.2)
|
257 |
+
|
258 |
+
sample['right'] = F.adjust_saturation(sample['right'], saturation)
|
259 |
+
|
260 |
+
return sample
|
261 |
+
|
262 |
+
|
263 |
+
class RandomColor(object):
|
264 |
+
|
265 |
+
def __init__(self,
|
266 |
+
asymmetric_color_aug=True,
|
267 |
+
):
|
268 |
+
|
269 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
270 |
+
|
271 |
+
def __call__(self, sample):
|
272 |
+
transforms = [RandomContrast(asymmetric_color_aug=self.asymmetric_color_aug),
|
273 |
+
RandomGamma(asymmetric_color_aug=self.asymmetric_color_aug),
|
274 |
+
RandomBrightness(asymmetric_color_aug=self.asymmetric_color_aug),
|
275 |
+
RandomHue(asymmetric_color_aug=self.asymmetric_color_aug),
|
276 |
+
RandomSaturation(asymmetric_color_aug=self.asymmetric_color_aug)]
|
277 |
+
|
278 |
+
sample = ToPILImage()(sample)
|
279 |
+
|
280 |
+
if np.random.random() < 0.5:
|
281 |
+
# A single transform
|
282 |
+
t = random.choice(transforms)
|
283 |
+
sample = t(sample)
|
284 |
+
else:
|
285 |
+
# Combination of transforms
|
286 |
+
# Random order
|
287 |
+
random.shuffle(transforms)
|
288 |
+
for t in transforms:
|
289 |
+
sample = t(sample)
|
290 |
+
|
291 |
+
sample = ToNumpyArray()(sample)
|
292 |
+
|
293 |
+
return sample
|
294 |
+
|
295 |
+
|
296 |
+
class RandomScale(object):
|
297 |
+
def __init__(self,
|
298 |
+
min_scale=-0.4,
|
299 |
+
max_scale=0.4,
|
300 |
+
crop_width=512,
|
301 |
+
nearest_interp=False, # for sparse gt
|
302 |
+
):
|
303 |
+
self.min_scale = min_scale
|
304 |
+
self.max_scale = max_scale
|
305 |
+
self.crop_width = crop_width
|
306 |
+
self.nearest_interp = nearest_interp
|
307 |
+
|
308 |
+
def __call__(self, sample):
|
309 |
+
if np.random.rand() < 0.5:
|
310 |
+
h, w = sample['disp'].shape
|
311 |
+
|
312 |
+
scale_x = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
313 |
+
|
314 |
+
scale_x = np.clip(scale_x, self.crop_width / float(w), None)
|
315 |
+
|
316 |
+
# only random scale x axis
|
317 |
+
sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
|
318 |
+
sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
|
319 |
+
|
320 |
+
sample['disp'] = cv2.resize(
|
321 |
+
sample['disp'], None, fx=scale_x, fy=1.,
|
322 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
323 |
+
) * scale_x
|
324 |
+
|
325 |
+
if 'pseudo_disp' in sample and sample['pseudo_disp'] is not None:
|
326 |
+
sample['pseudo_disp'] = cv2.resize(sample['pseudo_disp'], None, fx=scale_x, fy=1.,
|
327 |
+
interpolation=cv2.INTER_LINEAR) * scale_x
|
328 |
+
|
329 |
+
return sample
|
330 |
+
|
331 |
+
|
332 |
+
class Resize(object):
|
333 |
+
def __init__(self,
|
334 |
+
scale_x=1,
|
335 |
+
scale_y=1,
|
336 |
+
nearest_interp=True, # for sparse gt
|
337 |
+
):
|
338 |
+
"""
|
339 |
+
Resize low-resolution data to high-res for mixed dataset training
|
340 |
+
"""
|
341 |
+
self.scale_x = scale_x
|
342 |
+
self.scale_y = scale_y
|
343 |
+
self.nearest_interp = nearest_interp
|
344 |
+
|
345 |
+
def __call__(self, sample):
|
346 |
+
scale_x = self.scale_x
|
347 |
+
scale_y = self.scale_y
|
348 |
+
|
349 |
+
sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
350 |
+
sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
351 |
+
|
352 |
+
sample['disp'] = cv2.resize(
|
353 |
+
sample['disp'], None, fx=scale_x, fy=scale_y,
|
354 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
355 |
+
) * scale_x
|
356 |
+
|
357 |
+
return sample
|
358 |
+
|
359 |
+
|
360 |
+
class RandomGrayscale(object):
|
361 |
+
def __init__(self, p=0.2):
|
362 |
+
self.p = p
|
363 |
+
|
364 |
+
def __call__(self, sample):
|
365 |
+
if np.random.random() < self.p:
|
366 |
+
sample = ToPILImage()(sample)
|
367 |
+
|
368 |
+
# only supported in higher version pytorch
|
369 |
+
# default output channels is 1
|
370 |
+
sample['left'] = F.rgb_to_grayscale(sample['left'], num_output_channels=3)
|
371 |
+
sample['right'] = F.rgb_to_grayscale(sample['right'], num_output_channels=3)
|
372 |
+
|
373 |
+
sample = ToNumpyArray()(sample)
|
374 |
+
|
375 |
+
return sample
|
376 |
+
|
377 |
+
|
378 |
+
class RandomRotateShiftRight(object):
|
379 |
+
def __init__(self, p=0.5):
|
380 |
+
self.p = p
|
381 |
+
|
382 |
+
def __call__(self, sample):
|
383 |
+
if np.random.random() < self.p:
|
384 |
+
angle, pixel = 0.1, 2
|
385 |
+
px = np.random.uniform(-pixel, pixel)
|
386 |
+
ag = np.random.uniform(-angle, angle)
|
387 |
+
|
388 |
+
right_img = sample['right']
|
389 |
+
|
390 |
+
image_center = (
|
391 |
+
np.random.uniform(0, right_img.shape[0]),
|
392 |
+
np.random.uniform(0, right_img.shape[1])
|
393 |
+
)
|
394 |
+
|
395 |
+
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
|
396 |
+
right_img = cv2.warpAffine(
|
397 |
+
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
398 |
+
)
|
399 |
+
trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
|
400 |
+
right_img = cv2.warpAffine(
|
401 |
+
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
402 |
+
)
|
403 |
+
|
404 |
+
sample['right'] = right_img
|
405 |
+
|
406 |
+
return sample
|
407 |
+
|
408 |
+
|
409 |
+
class RandomOcclusion(object):
|
410 |
+
def __init__(self, p=0.5,
|
411 |
+
occlusion_mask_zero=False):
|
412 |
+
self.p = p
|
413 |
+
self.occlusion_mask_zero = occlusion_mask_zero
|
414 |
+
|
415 |
+
def __call__(self, sample):
|
416 |
+
bounds = [50, 100]
|
417 |
+
if np.random.random() < self.p:
|
418 |
+
img2 = sample['right']
|
419 |
+
ht, wd = img2.shape[:2]
|
420 |
+
|
421 |
+
if self.occlusion_mask_zero:
|
422 |
+
mean_color = 0
|
423 |
+
else:
|
424 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
425 |
+
|
426 |
+
x0 = np.random.randint(0, wd)
|
427 |
+
y0 = np.random.randint(0, ht)
|
428 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
429 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
430 |
+
img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
|
431 |
+
|
432 |
+
sample['right'] = img2
|
433 |
+
|
434 |
+
return sample
|
demo/flow_davis_skate-jump_00059.jpg
ADDED
demo/flow_davis_skate-jump_00060.jpg
ADDED
demo/flow_kitti_test_000197_10.png
ADDED
demo/flow_kitti_test_000197_11.png
ADDED
demo/flow_sintel_cave_3_frame_0049.png
ADDED
demo/flow_sintel_cave_3_frame_0050.png
ADDED
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg
ADDED
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg
ADDED
pretrained/tmp.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
matplotlib
|
4 |
+
opencv-python
|
5 |
+
pillow
|
unimatch/__init__.py
ADDED
File without changes
|
unimatch/attention.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
|
6 |
+
|
7 |
+
|
8 |
+
def single_head_full_attention(q, k, v):
|
9 |
+
# q, k, v: [B, L, C]
|
10 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
11 |
+
|
12 |
+
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
|
13 |
+
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
14 |
+
out = torch.matmul(attn, v) # [B, L, C]
|
15 |
+
|
16 |
+
return out
|
17 |
+
|
18 |
+
|
19 |
+
def single_head_full_attention_1d(q, k, v,
|
20 |
+
h=None,
|
21 |
+
w=None,
|
22 |
+
):
|
23 |
+
# q, k, v: [B, L, C]
|
24 |
+
|
25 |
+
assert h is not None and w is not None
|
26 |
+
assert q.size(1) == h * w
|
27 |
+
|
28 |
+
b, _, c = q.size()
|
29 |
+
|
30 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
31 |
+
k = k.view(b, h, w, c)
|
32 |
+
v = v.view(b, h, w, c)
|
33 |
+
|
34 |
+
scale_factor = c ** 0.5
|
35 |
+
|
36 |
+
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
|
37 |
+
|
38 |
+
attn = torch.softmax(scores, dim=-1)
|
39 |
+
|
40 |
+
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
|
41 |
+
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
def single_head_split_window_attention(q, k, v,
|
46 |
+
num_splits=1,
|
47 |
+
with_shift=False,
|
48 |
+
h=None,
|
49 |
+
w=None,
|
50 |
+
attn_mask=None,
|
51 |
+
):
|
52 |
+
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
53 |
+
# q, k, v: [B, L, C]
|
54 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
55 |
+
|
56 |
+
assert h is not None and w is not None
|
57 |
+
assert q.size(1) == h * w
|
58 |
+
|
59 |
+
b, _, c = q.size()
|
60 |
+
|
61 |
+
b_new = b * num_splits * num_splits
|
62 |
+
|
63 |
+
window_size_h = h // num_splits
|
64 |
+
window_size_w = w // num_splits
|
65 |
+
|
66 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
67 |
+
k = k.view(b, h, w, c)
|
68 |
+
v = v.view(b, h, w, c)
|
69 |
+
|
70 |
+
scale_factor = c ** 0.5
|
71 |
+
|
72 |
+
if with_shift:
|
73 |
+
assert attn_mask is not None # compute once
|
74 |
+
shift_size_h = window_size_h // 2
|
75 |
+
shift_size_w = window_size_w // 2
|
76 |
+
|
77 |
+
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
78 |
+
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
79 |
+
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
80 |
+
|
81 |
+
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
82 |
+
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
83 |
+
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
84 |
+
|
85 |
+
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
86 |
+
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
|
87 |
+
|
88 |
+
if with_shift:
|
89 |
+
scores += attn_mask.repeat(b, 1, 1)
|
90 |
+
|
91 |
+
attn = torch.softmax(scores, dim=-1)
|
92 |
+
|
93 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
94 |
+
|
95 |
+
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
|
96 |
+
num_splits=num_splits, channel_last=True) # [B, H, W, C]
|
97 |
+
|
98 |
+
# shift back
|
99 |
+
if with_shift:
|
100 |
+
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
101 |
+
|
102 |
+
out = out.view(b, -1, c)
|
103 |
+
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
def single_head_split_window_attention_1d(q, k, v,
|
108 |
+
relative_position_bias=None,
|
109 |
+
num_splits=1,
|
110 |
+
with_shift=False,
|
111 |
+
h=None,
|
112 |
+
w=None,
|
113 |
+
attn_mask=None,
|
114 |
+
):
|
115 |
+
# q, k, v: [B, L, C]
|
116 |
+
|
117 |
+
assert h is not None and w is not None
|
118 |
+
assert q.size(1) == h * w
|
119 |
+
|
120 |
+
b, _, c = q.size()
|
121 |
+
|
122 |
+
b_new = b * num_splits * h
|
123 |
+
|
124 |
+
window_size_w = w // num_splits
|
125 |
+
|
126 |
+
q = q.view(b * h, w, c) # [B*H, W, C]
|
127 |
+
k = k.view(b * h, w, c)
|
128 |
+
v = v.view(b * h, w, c)
|
129 |
+
|
130 |
+
scale_factor = c ** 0.5
|
131 |
+
|
132 |
+
if with_shift:
|
133 |
+
assert attn_mask is not None # compute once
|
134 |
+
shift_size_w = window_size_w // 2
|
135 |
+
|
136 |
+
q = torch.roll(q, shifts=-shift_size_w, dims=1)
|
137 |
+
k = torch.roll(k, shifts=-shift_size_w, dims=1)
|
138 |
+
v = torch.roll(v, shifts=-shift_size_w, dims=1)
|
139 |
+
|
140 |
+
q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
|
141 |
+
k = split_feature_1d(k, num_splits=num_splits)
|
142 |
+
v = split_feature_1d(v, num_splits=num_splits)
|
143 |
+
|
144 |
+
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
145 |
+
) / scale_factor # [B*H*K, W/K, W/K]
|
146 |
+
|
147 |
+
if with_shift:
|
148 |
+
# attn_mask: [K, W/K, W/K]
|
149 |
+
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
|
150 |
+
|
151 |
+
attn = torch.softmax(scores, dim=-1)
|
152 |
+
|
153 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
|
154 |
+
|
155 |
+
out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
|
156 |
+
|
157 |
+
# shift back
|
158 |
+
if with_shift:
|
159 |
+
out = torch.roll(out, shifts=shift_size_w, dims=2)
|
160 |
+
|
161 |
+
out = out.view(b, -1, c)
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
class SelfAttnPropagation(nn.Module):
|
167 |
+
"""
|
168 |
+
flow propagation with self-attention on feature
|
169 |
+
query: feature0, key: feature0, value: flow
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, in_channels,
|
173 |
+
**kwargs,
|
174 |
+
):
|
175 |
+
super(SelfAttnPropagation, self).__init__()
|
176 |
+
|
177 |
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
178 |
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
179 |
+
|
180 |
+
for p in self.parameters():
|
181 |
+
if p.dim() > 1:
|
182 |
+
nn.init.xavier_uniform_(p)
|
183 |
+
|
184 |
+
def forward(self, feature0, flow,
|
185 |
+
local_window_attn=False,
|
186 |
+
local_window_radius=1,
|
187 |
+
**kwargs,
|
188 |
+
):
|
189 |
+
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
190 |
+
if local_window_attn:
|
191 |
+
return self.forward_local_window_attn(feature0, flow,
|
192 |
+
local_window_radius=local_window_radius)
|
193 |
+
|
194 |
+
b, c, h, w = feature0.size()
|
195 |
+
|
196 |
+
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
197 |
+
|
198 |
+
# a note: the ``correct'' implementation should be:
|
199 |
+
# ``query = self.q_proj(query), key = self.k_proj(query)''
|
200 |
+
# this problem is observed while cleaning up the code
|
201 |
+
# however, this doesn't affect the performance since the projection is a linear operation,
|
202 |
+
# thus the two projection matrices for key can be merged
|
203 |
+
# so I just leave it as is in order to not re-train all models :)
|
204 |
+
query = self.q_proj(query) # [B, H*W, C]
|
205 |
+
key = self.k_proj(query) # [B, H*W, C]
|
206 |
+
|
207 |
+
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
208 |
+
|
209 |
+
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
|
210 |
+
prob = torch.softmax(scores, dim=-1)
|
211 |
+
|
212 |
+
out = torch.matmul(prob, value) # [B, H*W, 2]
|
213 |
+
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
214 |
+
|
215 |
+
return out
|
216 |
+
|
217 |
+
def forward_local_window_attn(self, feature0, flow,
|
218 |
+
local_window_radius=1,
|
219 |
+
):
|
220 |
+
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
|
221 |
+
assert local_window_radius > 0
|
222 |
+
|
223 |
+
b, c, h, w = feature0.size()
|
224 |
+
|
225 |
+
value_channel = flow.size(1)
|
226 |
+
|
227 |
+
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
|
228 |
+
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
|
229 |
+
|
230 |
+
kernel_size = 2 * local_window_radius + 1
|
231 |
+
|
232 |
+
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
|
233 |
+
|
234 |
+
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
|
235 |
+
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
|
236 |
+
|
237 |
+
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
|
238 |
+
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
|
239 |
+
|
240 |
+
flow_window = F.unfold(flow, kernel_size=kernel_size,
|
241 |
+
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
|
242 |
+
|
243 |
+
flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
|
244 |
+
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
|
245 |
+
|
246 |
+
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
|
247 |
+
|
248 |
+
prob = torch.softmax(scores, dim=-1)
|
249 |
+
|
250 |
+
out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
|
251 |
+
).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
252 |
+
|
253 |
+
return out
|
unimatch/backbone.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .trident_conv import MultiScaleTridentConv
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
|
8 |
+
):
|
9 |
+
super(ResidualBlock, self).__init__()
|
10 |
+
|
11 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
|
12 |
+
dilation=dilation, padding=dilation, stride=stride, bias=False)
|
13 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
14 |
+
dilation=dilation, padding=dilation, bias=False)
|
15 |
+
self.relu = nn.ReLU(inplace=True)
|
16 |
+
|
17 |
+
self.norm1 = norm_layer(planes)
|
18 |
+
self.norm2 = norm_layer(planes)
|
19 |
+
if not stride == 1 or in_planes != planes:
|
20 |
+
self.norm3 = norm_layer(planes)
|
21 |
+
|
22 |
+
if stride == 1 and in_planes == planes:
|
23 |
+
self.downsample = None
|
24 |
+
else:
|
25 |
+
self.downsample = nn.Sequential(
|
26 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
y = x
|
30 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
31 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
32 |
+
|
33 |
+
if self.downsample is not None:
|
34 |
+
x = self.downsample(x)
|
35 |
+
|
36 |
+
return self.relu(x + y)
|
37 |
+
|
38 |
+
|
39 |
+
class CNNEncoder(nn.Module):
|
40 |
+
def __init__(self, output_dim=128,
|
41 |
+
norm_layer=nn.InstanceNorm2d,
|
42 |
+
num_output_scales=1,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
super(CNNEncoder, self).__init__()
|
46 |
+
self.num_branch = num_output_scales
|
47 |
+
|
48 |
+
feature_dims = [64, 96, 128]
|
49 |
+
|
50 |
+
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
|
51 |
+
self.norm1 = norm_layer(feature_dims[0])
|
52 |
+
self.relu1 = nn.ReLU(inplace=True)
|
53 |
+
|
54 |
+
self.in_planes = feature_dims[0]
|
55 |
+
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
|
56 |
+
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
|
57 |
+
|
58 |
+
# highest resolution 1/4 or 1/8
|
59 |
+
stride = 2 if num_output_scales == 1 else 1
|
60 |
+
self.layer3 = self._make_layer(feature_dims[2], stride=stride,
|
61 |
+
norm_layer=norm_layer,
|
62 |
+
) # 1/4 or 1/8
|
63 |
+
|
64 |
+
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
65 |
+
|
66 |
+
if self.num_branch > 1:
|
67 |
+
if self.num_branch == 4:
|
68 |
+
strides = (1, 2, 4, 8)
|
69 |
+
elif self.num_branch == 3:
|
70 |
+
strides = (1, 2, 4)
|
71 |
+
elif self.num_branch == 2:
|
72 |
+
strides = (1, 2)
|
73 |
+
else:
|
74 |
+
raise ValueError
|
75 |
+
|
76 |
+
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
|
77 |
+
kernel_size=3,
|
78 |
+
strides=strides,
|
79 |
+
paddings=1,
|
80 |
+
num_branch=self.num_branch,
|
81 |
+
)
|
82 |
+
|
83 |
+
for m in self.modules():
|
84 |
+
if isinstance(m, nn.Conv2d):
|
85 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
86 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
87 |
+
if m.weight is not None:
|
88 |
+
nn.init.constant_(m.weight, 1)
|
89 |
+
if m.bias is not None:
|
90 |
+
nn.init.constant_(m.bias, 0)
|
91 |
+
|
92 |
+
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
93 |
+
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
|
94 |
+
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
|
95 |
+
|
96 |
+
layers = (layer1, layer2)
|
97 |
+
|
98 |
+
self.in_planes = dim
|
99 |
+
return nn.Sequential(*layers)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x = self.conv1(x)
|
103 |
+
x = self.norm1(x)
|
104 |
+
x = self.relu1(x)
|
105 |
+
|
106 |
+
x = self.layer1(x) # 1/2
|
107 |
+
x = self.layer2(x) # 1/4
|
108 |
+
x = self.layer3(x) # 1/8 or 1/4
|
109 |
+
|
110 |
+
x = self.conv2(x)
|
111 |
+
|
112 |
+
if self.num_branch > 1:
|
113 |
+
out = self.trident_conv([x] * self.num_branch) # high to low res
|
114 |
+
else:
|
115 |
+
out = [x]
|
116 |
+
|
117 |
+
return out
|
unimatch/geometry.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
6 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
7 |
+
|
8 |
+
stacks = [x, y]
|
9 |
+
|
10 |
+
if homogeneous:
|
11 |
+
ones = torch.ones_like(x) # [H, W]
|
12 |
+
stacks.append(ones)
|
13 |
+
|
14 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
15 |
+
|
16 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
17 |
+
|
18 |
+
if device is not None:
|
19 |
+
grid = grid.to(device)
|
20 |
+
|
21 |
+
return grid
|
22 |
+
|
23 |
+
|
24 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
25 |
+
assert device is not None
|
26 |
+
|
27 |
+
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
28 |
+
torch.linspace(h_min, h_max, len_h, device=device)],
|
29 |
+
)
|
30 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
31 |
+
|
32 |
+
return grid
|
33 |
+
|
34 |
+
|
35 |
+
def normalize_coords(coords, h, w):
|
36 |
+
# coords: [B, H, W, 2]
|
37 |
+
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
38 |
+
return (coords - c) / c # [-1, 1]
|
39 |
+
|
40 |
+
|
41 |
+
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
42 |
+
# img: [B, C, H, W]
|
43 |
+
# sample_coords: [B, 2, H, W] in image scale
|
44 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
45 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
46 |
+
|
47 |
+
b, _, h, w = sample_coords.shape
|
48 |
+
|
49 |
+
# Normalize to [-1, 1]
|
50 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
51 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
52 |
+
|
53 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
54 |
+
|
55 |
+
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
56 |
+
|
57 |
+
if return_mask:
|
58 |
+
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
59 |
+
|
60 |
+
return img, mask
|
61 |
+
|
62 |
+
return img
|
63 |
+
|
64 |
+
|
65 |
+
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
66 |
+
b, c, h, w = feature.size()
|
67 |
+
assert flow.size(1) == 2
|
68 |
+
|
69 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
70 |
+
|
71 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
72 |
+
return_mask=mask)
|
73 |
+
|
74 |
+
|
75 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
76 |
+
alpha=0.01,
|
77 |
+
beta=0.5
|
78 |
+
):
|
79 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
80 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
81 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
82 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
83 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
84 |
+
|
85 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
86 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
87 |
+
|
88 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
89 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
90 |
+
|
91 |
+
threshold = alpha * flow_mag + beta
|
92 |
+
|
93 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
94 |
+
bwd_occ = (diff_bwd > threshold).float()
|
95 |
+
|
96 |
+
return fwd_occ, bwd_occ
|
97 |
+
|
98 |
+
|
99 |
+
def back_project(depth, intrinsics):
|
100 |
+
# Back project 2D pixel coords to 3D points
|
101 |
+
# depth: [B, H, W]
|
102 |
+
# intrinsics: [B, 3, 3]
|
103 |
+
b, h, w = depth.shape
|
104 |
+
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
|
105 |
+
|
106 |
+
intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
|
107 |
+
|
108 |
+
points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W]
|
109 |
+
|
110 |
+
return points
|
111 |
+
|
112 |
+
|
113 |
+
def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
|
114 |
+
# Transform 3D points from reference camera to target camera
|
115 |
+
# points_ref: [B, 3, H, W]
|
116 |
+
# extrinsics_ref: [B, 4, 4]
|
117 |
+
# extrinsics_tgt: [B, 4, 4]
|
118 |
+
# extrinsics_rel: [B, 4, 4], relative pose transform
|
119 |
+
b, _, h, w = points_ref.shape
|
120 |
+
|
121 |
+
if extrinsics_rel is None:
|
122 |
+
extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
|
123 |
+
|
124 |
+
points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
|
125 |
+
points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W]
|
126 |
+
|
127 |
+
points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
|
128 |
+
|
129 |
+
return points_tgt
|
130 |
+
|
131 |
+
|
132 |
+
def reproject(points_tgt, intrinsics, return_mask=False):
|
133 |
+
# reproject to target view
|
134 |
+
# points_tgt: [B, 3, H, W]
|
135 |
+
# intrinsics: [B, 3, 3]
|
136 |
+
|
137 |
+
b, _, h, w = points_tgt.shape
|
138 |
+
|
139 |
+
proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
|
140 |
+
|
141 |
+
X = proj_points[:, 0]
|
142 |
+
Y = proj_points[:, 1]
|
143 |
+
Z = proj_points[:, 2].clamp(min=1e-3)
|
144 |
+
|
145 |
+
pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale
|
146 |
+
|
147 |
+
if return_mask:
|
148 |
+
# valid mask in pixel space
|
149 |
+
mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & (
|
150 |
+
pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W]
|
151 |
+
|
152 |
+
return pixel_coords, mask
|
153 |
+
|
154 |
+
return pixel_coords
|
155 |
+
|
156 |
+
|
157 |
+
def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
|
158 |
+
return_mask=False):
|
159 |
+
# Compute reprojection sample coords
|
160 |
+
points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
|
161 |
+
points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
|
162 |
+
|
163 |
+
if return_mask:
|
164 |
+
reproj_coords, mask = reproject(points_tgt, intrinsics,
|
165 |
+
return_mask=return_mask) # [B, 2, H, W] in image scale
|
166 |
+
|
167 |
+
return reproj_coords, mask
|
168 |
+
|
169 |
+
reproj_coords = reproject(points_tgt, intrinsics,
|
170 |
+
return_mask=return_mask) # [B, 2, H, W] in image scale
|
171 |
+
|
172 |
+
return reproj_coords
|
173 |
+
|
174 |
+
|
175 |
+
def compute_flow_with_depth_pose(depth_ref, intrinsics,
|
176 |
+
extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
|
177 |
+
return_mask=False):
|
178 |
+
b, h, w = depth_ref.shape
|
179 |
+
coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
|
180 |
+
|
181 |
+
if return_mask:
|
182 |
+
reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
|
183 |
+
extrinsics_rel=extrinsics_rel,
|
184 |
+
return_mask=return_mask) # [B, 2, H, W]
|
185 |
+
rigid_flow = reproj_coords - coords_init
|
186 |
+
|
187 |
+
return rigid_flow, mask
|
188 |
+
|
189 |
+
reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
|
190 |
+
extrinsics_rel=extrinsics_rel,
|
191 |
+
return_mask=return_mask) # [B, 2, H, W]
|
192 |
+
|
193 |
+
rigid_flow = reproj_coords - coords_init
|
194 |
+
|
195 |
+
return rigid_flow
|
unimatch/matching.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .geometry import coords_grid, generate_window_grid, normalize_coords
|
5 |
+
|
6 |
+
|
7 |
+
def global_correlation_softmax(feature0, feature1,
|
8 |
+
pred_bidir_flow=False,
|
9 |
+
):
|
10 |
+
# global correlation
|
11 |
+
b, c, h, w = feature0.shape
|
12 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
13 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
14 |
+
|
15 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
|
16 |
+
|
17 |
+
# flow from softmax
|
18 |
+
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
|
19 |
+
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
20 |
+
|
21 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
22 |
+
|
23 |
+
if pred_bidir_flow:
|
24 |
+
correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
|
25 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
|
26 |
+
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
|
27 |
+
b = b * 2
|
28 |
+
|
29 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
30 |
+
|
31 |
+
correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
32 |
+
|
33 |
+
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
|
34 |
+
flow = correspondence - init_grid
|
35 |
+
|
36 |
+
return flow, prob
|
37 |
+
|
38 |
+
|
39 |
+
def local_correlation_softmax(feature0, feature1, local_radius,
|
40 |
+
padding_mode='zeros',
|
41 |
+
):
|
42 |
+
b, c, h, w = feature0.size()
|
43 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
44 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
45 |
+
|
46 |
+
local_h = 2 * local_radius + 1
|
47 |
+
local_w = 2 * local_radius + 1
|
48 |
+
|
49 |
+
window_grid = generate_window_grid(-local_radius, local_radius,
|
50 |
+
-local_radius, local_radius,
|
51 |
+
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
|
52 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
53 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
|
54 |
+
|
55 |
+
sample_coords_softmax = sample_coords
|
56 |
+
|
57 |
+
# exclude coords that are out of image space
|
58 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
|
59 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
|
60 |
+
|
61 |
+
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
62 |
+
|
63 |
+
# normalize coordinates to [-1, 1]
|
64 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
65 |
+
window_feature = F.grid_sample(feature1, sample_coords_norm,
|
66 |
+
padding_mode=padding_mode, align_corners=True
|
67 |
+
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
|
68 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
69 |
+
|
70 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
|
71 |
+
|
72 |
+
# mask invalid locations
|
73 |
+
corr[~valid] = -1e9
|
74 |
+
|
75 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
|
76 |
+
|
77 |
+
correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
|
78 |
+
b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
79 |
+
|
80 |
+
flow = correspondence - coords_init
|
81 |
+
match_prob = prob
|
82 |
+
|
83 |
+
return flow, match_prob
|
84 |
+
|
85 |
+
|
86 |
+
def local_correlation_with_flow(feature0, feature1,
|
87 |
+
flow,
|
88 |
+
local_radius,
|
89 |
+
padding_mode='zeros',
|
90 |
+
dilation=1,
|
91 |
+
):
|
92 |
+
b, c, h, w = feature0.size()
|
93 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
94 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
95 |
+
|
96 |
+
local_h = 2 * local_radius + 1
|
97 |
+
local_w = 2 * local_radius + 1
|
98 |
+
|
99 |
+
window_grid = generate_window_grid(-local_radius, local_radius,
|
100 |
+
-local_radius, local_radius,
|
101 |
+
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
|
102 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
103 |
+
sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
|
104 |
+
|
105 |
+
# flow can be zero when using features after transformer
|
106 |
+
if not isinstance(flow, float):
|
107 |
+
sample_coords = sample_coords + flow.view(
|
108 |
+
b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2]
|
109 |
+
else:
|
110 |
+
assert flow == 0.
|
111 |
+
|
112 |
+
# normalize coordinates to [-1, 1]
|
113 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
114 |
+
window_feature = F.grid_sample(feature1, sample_coords_norm,
|
115 |
+
padding_mode=padding_mode, align_corners=True
|
116 |
+
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
|
117 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
118 |
+
|