Spaces:
Runtime error
Runtime error
Yiting1009
commited on
Commit
·
5d87992
1
Parent(s):
8b1d8da
Upload 26 files
Browse files- .gitattributes +2 -0
- app.py +27 -0
- canyons_intrinsics.json +5 -0
- flatiron_1.tiff +3 -0
- flatiron_2.tiff +3 -0
- horse_canyon_1.tiff +0 -0
- horse_canyon_2.tiff +0 -0
- src/.DS_Store +0 -0
- src/networks/__init__.py +7 -0
- src/networks/__pycache__/__init__.cpython-39.pyc +0 -0
- src/networks/__pycache__/depth_decoder.cpython-39.pyc +0 -0
- src/networks/__pycache__/pose_cnn.cpython-39.pyc +0 -0
- src/networks/__pycache__/pose_decoder.cpython-39.pyc +0 -0
- src/networks/__pycache__/resnet_encoder.cpython-39.pyc +0 -0
- src/networks/__pycache__/restoration_model.cpython-39.pyc +0 -0
- src/networks/depth_decoder.py +101 -0
- src/networks/pose_cnn.py +47 -0
- src/networks/pose_decoder.py +52 -0
- src/networks/resnet_encoder.py +431 -0
- src/networks/restoration_model.py +273 -0
- src/weights/depth.pth +3 -0
- src/weights/encoder.pth +3 -0
- src/weights/pose.pth +3 -0
- src/weights/pose_encoder.pth +3 -0
- src/weights/uie_model.pth +3 -0
- test_simple.py +169 -0
- utils.py +118 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
flatiron_1.tiff filter=lfs diff=lfs merge=lfs -text
|
36 |
+
flatiron_2.tiff filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
3 |
+
import warnings
|
4 |
+
warnings.filterwarnings('ignore')
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
from test_simple import test_simple
|
8 |
+
|
9 |
+
def predict(image: Image):
|
10 |
+
return test_simple(image)
|
11 |
+
|
12 |
+
title = "Underwater Image Restoration"
|
13 |
+
|
14 |
+
iface = gr.Interface(
|
15 |
+
predict,
|
16 |
+
inputs=gr.Image(type="pil"),
|
17 |
+
outputs="image",
|
18 |
+
title=title,
|
19 |
+
allow_flagging="never",
|
20 |
+
examples=[
|
21 |
+
["flatiron_1.tiff"],
|
22 |
+
["flatiron_2.tiff"],
|
23 |
+
["horse_canyon_1.tiff"],
|
24 |
+
["horse_canyon_2.tiff"],
|
25 |
+
],
|
26 |
+
)
|
27 |
+
iface.launch(share=True)
|
canyons_intrinsics.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
[1.21, 0, 0.5],
|
3 |
+
[0, 1.93, 0.5],
|
4 |
+
[0, 0, 1.0]
|
5 |
+
]
|
flatiron_1.tiff
ADDED
Git LFS Details
|
flatiron_2.tiff
ADDED
Git LFS Details
|
horse_canyon_1.tiff
ADDED
horse_canyon_2.tiff
ADDED
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/networks/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: F401
|
2 |
+
from .resnet_encoder import ResnetEncoder, ResnetEncoderMatching
|
3 |
+
from .depth_decoder import DepthDecoder
|
4 |
+
from .pose_decoder import PoseDecoder
|
5 |
+
from .pose_cnn import PoseCNN
|
6 |
+
from .restoration_model import MainModel
|
7 |
+
# from .layers import BackprojectDepth, Project3D, ConvBlock, Conv3x3, upsample
|
src/networks/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (422 Bytes). View file
|
|
src/networks/__pycache__/depth_decoder.cpython-39.pyc
ADDED
Binary file (3.17 kB). View file
|
|
src/networks/__pycache__/pose_cnn.cpython-39.pyc
ADDED
Binary file (1.35 kB). View file
|
|
src/networks/__pycache__/pose_decoder.cpython-39.pyc
ADDED
Binary file (1.77 kB). View file
|
|
src/networks/__pycache__/resnet_encoder.cpython-39.pyc
ADDED
Binary file (12.7 kB). View file
|
|
src/networks/__pycache__/restoration_model.cpython-39.pyc
ADDED
Binary file (7.88 kB). View file
|
|
src/networks/depth_decoder.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
class ConvBlock(nn.Module):
|
15 |
+
"""Layer to perform a convolution followed by ELU
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, in_channels, out_channels):
|
19 |
+
super(ConvBlock, self).__init__()
|
20 |
+
|
21 |
+
self.conv = Conv3x3(in_channels, out_channels)
|
22 |
+
self.nonlin = nn.ELU(inplace=True)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
out = self.conv(x)
|
26 |
+
out = self.nonlin(out)
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
class Conv3x3(nn.Module):
|
31 |
+
"""Layer to pad and convolve input
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, in_channels, out_channels, use_refl=True):
|
35 |
+
super(Conv3x3, self).__init__()
|
36 |
+
|
37 |
+
if use_refl:
|
38 |
+
self.pad = nn.ReflectionPad2d(1)
|
39 |
+
else:
|
40 |
+
self.pad = nn.ZeroPad2d(1)
|
41 |
+
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
out = self.pad(x)
|
45 |
+
out = self.conv(out)
|
46 |
+
return out
|
47 |
+
|
48 |
+
def upsample(x):
|
49 |
+
"""Upsample input tensor by a factor of 2
|
50 |
+
"""
|
51 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
52 |
+
|
53 |
+
class DepthDecoder(nn.Module):
|
54 |
+
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
|
55 |
+
super(DepthDecoder, self).__init__()
|
56 |
+
|
57 |
+
self.num_output_channels = num_output_channels
|
58 |
+
self.use_skips = use_skips
|
59 |
+
self.upsample_mode = 'nearest'
|
60 |
+
self.scales = scales
|
61 |
+
|
62 |
+
self.num_ch_enc = num_ch_enc
|
63 |
+
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
|
64 |
+
|
65 |
+
# decoder
|
66 |
+
self.convs = OrderedDict()
|
67 |
+
for i in range(4, -1, -1):
|
68 |
+
# upconv_0
|
69 |
+
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
|
70 |
+
num_ch_out = self.num_ch_dec[i]
|
71 |
+
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
|
72 |
+
|
73 |
+
# upconv_1
|
74 |
+
num_ch_in = self.num_ch_dec[i]
|
75 |
+
if self.use_skips and i > 0:
|
76 |
+
num_ch_in += self.num_ch_enc[i - 1]
|
77 |
+
num_ch_out = self.num_ch_dec[i]
|
78 |
+
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
|
79 |
+
|
80 |
+
for s in self.scales:
|
81 |
+
self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
|
82 |
+
|
83 |
+
self.decoder = nn.ModuleList(list(self.convs.values()))
|
84 |
+
self.sigmoid = nn.Sigmoid()
|
85 |
+
|
86 |
+
def forward(self, input_features):
|
87 |
+
self.outputs = {}
|
88 |
+
|
89 |
+
# decoder
|
90 |
+
x = input_features[-1]
|
91 |
+
for i in range(4, -1, -1):
|
92 |
+
x = self.convs[("upconv", i, 0)](x)
|
93 |
+
x = [upsample(x)]
|
94 |
+
if self.use_skips and i > 0:
|
95 |
+
x += [input_features[i - 1]]
|
96 |
+
x = torch.cat(x, 1)
|
97 |
+
x = self.convs[("upconv", i, 1)](x)
|
98 |
+
if i in self.scales:
|
99 |
+
self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
|
100 |
+
|
101 |
+
return self.outputs
|
src/networks/pose_cnn.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class PoseCNN(nn.Module):
|
11 |
+
def __init__(self, num_input_frames):
|
12 |
+
super(PoseCNN, self).__init__()
|
13 |
+
|
14 |
+
self.num_input_frames = num_input_frames
|
15 |
+
|
16 |
+
self.convs = {}
|
17 |
+
self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
|
18 |
+
self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
|
19 |
+
self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
|
20 |
+
self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
|
21 |
+
self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
|
22 |
+
self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
|
23 |
+
self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
|
24 |
+
|
25 |
+
self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
|
26 |
+
|
27 |
+
self.num_convs = len(self.convs)
|
28 |
+
|
29 |
+
self.relu = nn.ReLU(True)
|
30 |
+
|
31 |
+
self.net = nn.ModuleList(list(self.convs.values()))
|
32 |
+
|
33 |
+
def forward(self, out):
|
34 |
+
|
35 |
+
for i in range(self.num_convs):
|
36 |
+
out = self.convs[i](out)
|
37 |
+
out = self.relu(out)
|
38 |
+
|
39 |
+
out = self.pose_conv(out)
|
40 |
+
out = out.mean(3).mean(2)
|
41 |
+
|
42 |
+
out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
|
43 |
+
|
44 |
+
axisangle = out[..., :3]
|
45 |
+
translation = out[..., 3:]
|
46 |
+
|
47 |
+
return axisangle, translation
|
src/networks/pose_decoder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
|
12 |
+
class PoseDecoder(nn.Module):
|
13 |
+
def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
|
14 |
+
super(PoseDecoder, self).__init__()
|
15 |
+
|
16 |
+
self.num_ch_enc = num_ch_enc
|
17 |
+
self.num_input_features = num_input_features
|
18 |
+
|
19 |
+
if num_frames_to_predict_for is None:
|
20 |
+
num_frames_to_predict_for = num_input_features - 1
|
21 |
+
self.num_frames_to_predict_for = num_frames_to_predict_for
|
22 |
+
|
23 |
+
self.convs = OrderedDict()
|
24 |
+
self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
|
25 |
+
self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
|
26 |
+
self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
|
27 |
+
self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
|
28 |
+
|
29 |
+
self.relu = nn.ReLU()
|
30 |
+
|
31 |
+
self.net = nn.ModuleList(list(self.convs.values()))
|
32 |
+
|
33 |
+
def forward(self, input_features):
|
34 |
+
last_features = [f[-1] for f in input_features]
|
35 |
+
|
36 |
+
cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
|
37 |
+
cat_features = torch.cat(cat_features, 1)
|
38 |
+
|
39 |
+
out = cat_features
|
40 |
+
for i in range(3):
|
41 |
+
out = self.convs[("pose", i)](out)
|
42 |
+
if i != 2:
|
43 |
+
out = self.relu(out)
|
44 |
+
|
45 |
+
out = out.mean(3).mean(2)
|
46 |
+
|
47 |
+
out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
|
48 |
+
|
49 |
+
axisangle = out[..., :3]
|
50 |
+
translation = out[..., 3:]
|
51 |
+
|
52 |
+
return axisangle, translation
|
src/networks/resnet_encoder.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import os
|
8 |
+
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
|
9 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
|
10 |
+
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchvision.models as models
|
18 |
+
import torch.utils.model_zoo as model_zoo
|
19 |
+
|
20 |
+
|
21 |
+
class BackprojectDepth(nn.Module):
|
22 |
+
"""Layer to transform a depth image into a point cloud
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, batch_size, height, width):
|
26 |
+
super(BackprojectDepth, self).__init__()
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.height = height
|
30 |
+
self.width = width
|
31 |
+
|
32 |
+
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
|
33 |
+
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
|
34 |
+
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
|
35 |
+
requires_grad=False)
|
36 |
+
|
37 |
+
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
|
38 |
+
requires_grad=False)
|
39 |
+
|
40 |
+
self.pix_coords = torch.unsqueeze(torch.stack(
|
41 |
+
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
|
42 |
+
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
|
43 |
+
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
|
44 |
+
requires_grad=False)
|
45 |
+
|
46 |
+
def forward(self, depth, inv_K):
|
47 |
+
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
|
48 |
+
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
|
49 |
+
cam_points = torch.cat([cam_points, self.ones], 1)
|
50 |
+
|
51 |
+
return cam_points
|
52 |
+
|
53 |
+
|
54 |
+
class Project3D(nn.Module):
|
55 |
+
"""Layer which projects 3D points into a camera with intrinsics K and at position T
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, batch_size, height, width, eps=1e-7):
|
59 |
+
super(Project3D, self).__init__()
|
60 |
+
|
61 |
+
self.batch_size = batch_size
|
62 |
+
self.height = height
|
63 |
+
self.width = width
|
64 |
+
self.eps = eps
|
65 |
+
|
66 |
+
def forward(self, points, K, T):
|
67 |
+
P = torch.matmul(K, T)[:, :3, :]
|
68 |
+
|
69 |
+
cam_points = torch.matmul(P, points)
|
70 |
+
|
71 |
+
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
|
72 |
+
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
|
73 |
+
pix_coords = pix_coords.permute(0, 2, 3, 1)
|
74 |
+
pix_coords[..., 0] /= self.width - 1
|
75 |
+
pix_coords[..., 1] /= self.height - 1
|
76 |
+
pix_coords = (pix_coords - 0.5) * 2
|
77 |
+
return pix_coords
|
78 |
+
|
79 |
+
class ResNetMultiImageInput(models.ResNet):
|
80 |
+
"""Constructs a resnet model with varying number of input images.
|
81 |
+
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
|
85 |
+
super(ResNetMultiImageInput, self).__init__(block, layers)
|
86 |
+
self.inplanes = 64
|
87 |
+
self.conv1 = nn.Conv2d(
|
88 |
+
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
89 |
+
self.bn1 = nn.BatchNorm2d(64)
|
90 |
+
self.relu = nn.ReLU(inplace=True)
|
91 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
92 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
93 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
94 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
95 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
96 |
+
|
97 |
+
for m in self.modules():
|
98 |
+
if isinstance(m, nn.Conv2d):
|
99 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
100 |
+
elif isinstance(m, nn.BatchNorm2d):
|
101 |
+
nn.init.constant_(m.weight, 1)
|
102 |
+
nn.init.constant_(m.bias, 0)
|
103 |
+
|
104 |
+
|
105 |
+
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
|
106 |
+
"""Constructs a ResNet model.
|
107 |
+
Args:
|
108 |
+
num_layers (int): Number of resnet layers. Must be 18 or 50
|
109 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
110 |
+
num_input_images (int): Number of frames stacked as input
|
111 |
+
"""
|
112 |
+
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
|
113 |
+
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
|
114 |
+
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
|
115 |
+
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
|
116 |
+
|
117 |
+
if pretrained:
|
118 |
+
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
|
119 |
+
loaded['conv1.weight'] = torch.cat(
|
120 |
+
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images
|
121 |
+
model.load_state_dict(loaded)
|
122 |
+
return model
|
123 |
+
|
124 |
+
|
125 |
+
class ResnetEncoderMatching(nn.Module):
|
126 |
+
"""Resnet encoder adapted to include a cost volume after the 2nd block.
|
127 |
+
|
128 |
+
Setting adaptive_bins=True will recompute the depth bins used for matching upon each
|
129 |
+
forward pass - this is required for training from monocular video as there is an unknown scale.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, num_layers, pretrained, input_height, input_width,
|
133 |
+
min_depth_bin=0.1, max_depth_bin=20.0, num_depth_bins=96,
|
134 |
+
adaptive_bins=False, depth_binning='linear'):
|
135 |
+
|
136 |
+
super(ResnetEncoderMatching, self).__init__()
|
137 |
+
|
138 |
+
self.adaptive_bins = adaptive_bins
|
139 |
+
self.depth_binning = depth_binning
|
140 |
+
self.set_missing_to_max = True
|
141 |
+
|
142 |
+
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
|
143 |
+
self.num_depth_bins = num_depth_bins
|
144 |
+
# we build the cost volume at 1/4 resolution
|
145 |
+
self.matching_height, self.matching_width = input_height // 4, input_width // 4
|
146 |
+
|
147 |
+
self.is_cuda = False
|
148 |
+
self.warp_depths = None
|
149 |
+
self.depth_bins = None
|
150 |
+
|
151 |
+
resnets = {18: models.resnet18,
|
152 |
+
34: models.resnet34,
|
153 |
+
50: models.resnet50,
|
154 |
+
101: models.resnet101,
|
155 |
+
152: models.resnet152}
|
156 |
+
|
157 |
+
if num_layers not in resnets:
|
158 |
+
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
|
159 |
+
|
160 |
+
encoder = resnets[num_layers](pretrained)
|
161 |
+
self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
|
162 |
+
self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1)
|
163 |
+
self.layer2 = encoder.layer2
|
164 |
+
self.layer3 = encoder.layer3
|
165 |
+
self.layer4 = encoder.layer4
|
166 |
+
|
167 |
+
if num_layers > 34:
|
168 |
+
self.num_ch_enc[1:] *= 4
|
169 |
+
|
170 |
+
self.backprojector = BackprojectDepth(batch_size=self.num_depth_bins,
|
171 |
+
height=self.matching_height,
|
172 |
+
width=self.matching_width)
|
173 |
+
self.projector = Project3D(batch_size=self.num_depth_bins,
|
174 |
+
height=self.matching_height,
|
175 |
+
width=self.matching_width)
|
176 |
+
|
177 |
+
self.compute_depth_bins(min_depth_bin, max_depth_bin)
|
178 |
+
|
179 |
+
self.prematching_conv = nn.Sequential(nn.Conv2d(64, out_channels=16,
|
180 |
+
kernel_size=1, stride=1, padding=0),
|
181 |
+
nn.ReLU(inplace=True)
|
182 |
+
)
|
183 |
+
|
184 |
+
self.reduce_conv = nn.Sequential(nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins,
|
185 |
+
out_channels=self.num_ch_enc[1],
|
186 |
+
kernel_size=3, stride=1, padding=1),
|
187 |
+
nn.ReLU(inplace=True)
|
188 |
+
)
|
189 |
+
|
190 |
+
def compute_depth_bins(self, min_depth_bin, max_depth_bin):
|
191 |
+
"""Compute the depths bins used to build the cost volume. Bins will depend upon
|
192 |
+
self.depth_binning, to either be linear in depth (linear) or linear in inverse depth
|
193 |
+
(inverse)"""
|
194 |
+
|
195 |
+
if self.depth_binning == 'inverse':
|
196 |
+
self.depth_bins = 1 / np.linspace(1 / max_depth_bin,
|
197 |
+
1 / min_depth_bin,
|
198 |
+
self.num_depth_bins)[::-1] # maintain depth order
|
199 |
+
|
200 |
+
elif self.depth_binning == 'linear':
|
201 |
+
self.depth_bins = np.linspace(min_depth_bin, max_depth_bin, self.num_depth_bins)
|
202 |
+
else:
|
203 |
+
raise NotImplementedError
|
204 |
+
self.depth_bins = torch.from_numpy(self.depth_bins).float()
|
205 |
+
|
206 |
+
self.warp_depths = []
|
207 |
+
for depth in self.depth_bins:
|
208 |
+
depth = torch.ones((1, self.matching_height, self.matching_width)) * depth
|
209 |
+
self.warp_depths.append(depth)
|
210 |
+
self.warp_depths = torch.stack(self.warp_depths, 0).float()
|
211 |
+
if self.is_cuda:
|
212 |
+
self.warp_depths = self.warp_depths.cuda()
|
213 |
+
|
214 |
+
def match_features(self, current_feats, lookup_feats, relative_poses, K, invK):
|
215 |
+
"""Compute a cost volume based on L1 difference between current_feats and lookup_feats.
|
216 |
+
|
217 |
+
We backwards warp the lookup_feats into the current frame using the estimated relative
|
218 |
+
pose, known intrinsics and using hypothesised depths self.warp_depths (which are either
|
219 |
+
linear in depth or linear in inverse depth).
|
220 |
+
|
221 |
+
If relative_pose == 0 then this indicates that the lookup frame is missing (i.e. we are
|
222 |
+
at the start of a sequence), and so we skip it"""
|
223 |
+
|
224 |
+
batch_cost_volume = [] # store all cost volumes of the batch
|
225 |
+
cost_volume_masks = [] # store locations of '0's in cost volume for confidence
|
226 |
+
|
227 |
+
for batch_idx in range(len(current_feats)):
|
228 |
+
|
229 |
+
volume_shape = (self.num_depth_bins, self.matching_height, self.matching_width)
|
230 |
+
cost_volume = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
|
231 |
+
counts = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
|
232 |
+
|
233 |
+
# select an item from batch of ref feats
|
234 |
+
_lookup_feats = lookup_feats[batch_idx:batch_idx + 1]
|
235 |
+
_lookup_poses = relative_poses[batch_idx:batch_idx + 1]
|
236 |
+
|
237 |
+
_K = K[batch_idx:batch_idx + 1]
|
238 |
+
_invK = invK[batch_idx:batch_idx + 1]
|
239 |
+
world_points = self.backprojector(self.warp_depths, _invK)
|
240 |
+
|
241 |
+
# loop through ref images adding to the current cost volume
|
242 |
+
for lookup_idx in range(_lookup_feats.shape[1]):
|
243 |
+
lookup_feat = _lookup_feats[:, lookup_idx] # 1 x C x H x W
|
244 |
+
lookup_pose = _lookup_poses[:, lookup_idx]
|
245 |
+
|
246 |
+
# ignore missing images
|
247 |
+
if lookup_pose.sum() == 0:
|
248 |
+
continue
|
249 |
+
|
250 |
+
lookup_feat = lookup_feat.repeat([self.num_depth_bins, 1, 1, 1])
|
251 |
+
pix_locs = self.projector(world_points, _K, lookup_pose)
|
252 |
+
warped = F.grid_sample(lookup_feat, pix_locs, padding_mode='zeros', mode='bilinear',
|
253 |
+
align_corners=True)
|
254 |
+
|
255 |
+
# mask values landing outside the image (and near the border)
|
256 |
+
# we want to ignore edge pixels of the lookup images and the current image
|
257 |
+
# because of zero padding in ResNet
|
258 |
+
# Masking of ref image border
|
259 |
+
x_vals = (pix_locs[..., 0].detach() / 2 + 0.5) * (
|
260 |
+
self.matching_width - 1) # convert from (-1, 1) to pixel values
|
261 |
+
y_vals = (pix_locs[..., 1].detach() / 2 + 0.5) * (self.matching_height - 1)
|
262 |
+
|
263 |
+
edge_mask = (x_vals >= 2.0) * (x_vals <= self.matching_width - 2) * \
|
264 |
+
(y_vals >= 2.0) * (y_vals <= self.matching_height - 2)
|
265 |
+
edge_mask = edge_mask.float()
|
266 |
+
|
267 |
+
# masking of current image
|
268 |
+
current_mask = torch.zeros_like(edge_mask)
|
269 |
+
current_mask[:, 2:-2, 2:-2] = 1.0
|
270 |
+
edge_mask = edge_mask * current_mask
|
271 |
+
|
272 |
+
diffs = torch.abs(warped - current_feats[batch_idx:batch_idx + 1]).mean(
|
273 |
+
1) * edge_mask
|
274 |
+
|
275 |
+
# integrate into cost volume
|
276 |
+
cost_volume = cost_volume + diffs
|
277 |
+
counts = counts + (diffs > 0).float()
|
278 |
+
# average over lookup images
|
279 |
+
cost_volume = cost_volume / (counts + 1e-7)
|
280 |
+
|
281 |
+
# if some missing values for a pixel location (i.e. some depths landed outside) then
|
282 |
+
# set to max of existing values
|
283 |
+
missing_val_mask = (cost_volume == 0).float()
|
284 |
+
if self.set_missing_to_max:
|
285 |
+
cost_volume = cost_volume * (1 - missing_val_mask) + \
|
286 |
+
cost_volume.max(0)[0].unsqueeze(0) * missing_val_mask
|
287 |
+
batch_cost_volume.append(cost_volume)
|
288 |
+
cost_volume_masks.append(missing_val_mask)
|
289 |
+
|
290 |
+
batch_cost_volume = torch.stack(batch_cost_volume, 0)
|
291 |
+
cost_volume_masks = torch.stack(cost_volume_masks, 0)
|
292 |
+
|
293 |
+
return batch_cost_volume, cost_volume_masks
|
294 |
+
|
295 |
+
def feature_extraction(self, image, return_all_feats=False):
|
296 |
+
""" Run feature extraction on an image - first 2 blocks of ResNet"""
|
297 |
+
|
298 |
+
image = (image - 0.45) / 0.225 # imagenet normalisation
|
299 |
+
feats_0 = self.layer0(image)
|
300 |
+
feats_1 = self.layer1(feats_0)
|
301 |
+
|
302 |
+
if return_all_feats:
|
303 |
+
return [feats_0, feats_1]
|
304 |
+
else:
|
305 |
+
return feats_1
|
306 |
+
|
307 |
+
def indices_to_disparity(self, indices):
|
308 |
+
"""Convert cost volume indices to 1/depth for visualisation"""
|
309 |
+
|
310 |
+
batch, height, width = indices.shape
|
311 |
+
depth = self.depth_bins[indices.reshape(-1).cpu()]
|
312 |
+
disp = 1 / depth.reshape((batch, height, width))
|
313 |
+
return disp
|
314 |
+
|
315 |
+
def compute_confidence_mask(self, cost_volume, num_bins_threshold=None):
|
316 |
+
""" Returns a 'confidence' mask based on how many times a depth bin was observed"""
|
317 |
+
|
318 |
+
if num_bins_threshold is None:
|
319 |
+
num_bins_threshold = self.num_depth_bins
|
320 |
+
confidence_mask = ((cost_volume > 0).sum(1) == num_bins_threshold).float()
|
321 |
+
|
322 |
+
return confidence_mask
|
323 |
+
|
324 |
+
def forward(self, current_image, lookup_images, poses, K, invK,
|
325 |
+
min_depth_bin=None, max_depth_bin=None
|
326 |
+
):
|
327 |
+
|
328 |
+
# feature extraction
|
329 |
+
self.features = self.feature_extraction(current_image, return_all_feats=True)
|
330 |
+
current_feats = self.features[-1]
|
331 |
+
# print('current_feats:', current_feats.shape)
|
332 |
+
|
333 |
+
# feature extraction on lookup images - disable gradients to save memory
|
334 |
+
with torch.no_grad():
|
335 |
+
if self.adaptive_bins:
|
336 |
+
self.compute_depth_bins(min_depth_bin, max_depth_bin)
|
337 |
+
|
338 |
+
batch_size, num_frames, chns, height, width = lookup_images.shape
|
339 |
+
lookup_images = lookup_images.reshape(batch_size * num_frames, chns, height, width)
|
340 |
+
lookup_feats = self.feature_extraction(lookup_images,
|
341 |
+
return_all_feats=False)
|
342 |
+
_, chns, height, width = lookup_feats.shape
|
343 |
+
lookup_feats = lookup_feats.reshape(batch_size, num_frames, chns, height, width)
|
344 |
+
# print('lookup_feats:', lookup_feats.shape)
|
345 |
+
|
346 |
+
# warp features to find cost volume
|
347 |
+
cost_volume, missing_mask = \
|
348 |
+
self.match_features(current_feats, lookup_feats, poses, K, invK)
|
349 |
+
confidence_mask = self.compute_confidence_mask(cost_volume.detach() *
|
350 |
+
(1 - missing_mask.detach()))
|
351 |
+
|
352 |
+
# for visualisation - ignore 0s in cost volume for minimum
|
353 |
+
viz_cost_vol = cost_volume.clone().detach()
|
354 |
+
viz_cost_vol[viz_cost_vol == 0] = 100
|
355 |
+
mins, argmin = torch.min(viz_cost_vol, 1)
|
356 |
+
lowest_cost = self.indices_to_disparity(argmin)
|
357 |
+
|
358 |
+
# mask the cost volume based on the confidence
|
359 |
+
cost_volume *= confidence_mask.unsqueeze(1)
|
360 |
+
post_matching_feats = self.reduce_conv(torch.cat([self.features[-1], cost_volume], 1))
|
361 |
+
# print('post_matching_feats:', post_matching_feats.shape)
|
362 |
+
|
363 |
+
self.features.append(self.layer2(post_matching_feats))
|
364 |
+
self.features.append(self.layer3(self.features[-1]))
|
365 |
+
self.features.append(self.layer4(self.features[-1]))
|
366 |
+
|
367 |
+
return self.features, lowest_cost, confidence_mask
|
368 |
+
|
369 |
+
def cuda(self):
|
370 |
+
super().cuda()
|
371 |
+
self.backprojector.cuda()
|
372 |
+
self.projector.cuda()
|
373 |
+
self.is_cuda = True
|
374 |
+
if self.warp_depths is not None:
|
375 |
+
self.warp_depths = self.warp_depths.cuda()
|
376 |
+
|
377 |
+
def cpu(self):
|
378 |
+
super().cpu()
|
379 |
+
self.backprojector.cpu()
|
380 |
+
self.projector.cpu()
|
381 |
+
self.is_cuda = False
|
382 |
+
if self.warp_depths is not None:
|
383 |
+
self.warp_depths = self.warp_depths.cpu()
|
384 |
+
|
385 |
+
def to(self, device):
|
386 |
+
if str(device) == 'cpu':
|
387 |
+
self.cpu()
|
388 |
+
elif str(device) == 'cuda':
|
389 |
+
self.cuda()
|
390 |
+
else:
|
391 |
+
raise NotImplementedError
|
392 |
+
|
393 |
+
|
394 |
+
class ResnetEncoder(nn.Module):
|
395 |
+
"""Pytorch module for a resnet encoder
|
396 |
+
"""
|
397 |
+
|
398 |
+
def __init__(self, num_layers, pretrained, num_input_images=1, **kwargs):
|
399 |
+
super(ResnetEncoder, self).__init__()
|
400 |
+
|
401 |
+
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
|
402 |
+
|
403 |
+
resnets = {18: models.resnet18,
|
404 |
+
34: models.resnet34,
|
405 |
+
50: models.resnet50,
|
406 |
+
101: models.resnet101,
|
407 |
+
152: models.resnet152}
|
408 |
+
|
409 |
+
if num_layers not in resnets:
|
410 |
+
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
|
411 |
+
|
412 |
+
if num_input_images > 1:
|
413 |
+
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
|
414 |
+
else:
|
415 |
+
self.encoder = resnets[num_layers](pretrained)
|
416 |
+
|
417 |
+
if num_layers > 34:
|
418 |
+
self.num_ch_enc[1:] *= 4
|
419 |
+
|
420 |
+
def forward(self, input_image):
|
421 |
+
self.features = []
|
422 |
+
x = (input_image - 0.45) / 0.225
|
423 |
+
x = self.encoder.conv1(x)
|
424 |
+
x = self.encoder.bn1(x)
|
425 |
+
self.features.append(self.encoder.relu(x))
|
426 |
+
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
|
427 |
+
self.features.append(self.encoder.layer2(self.features[-1]))
|
428 |
+
self.features.append(self.encoder.layer3(self.features[-1]))
|
429 |
+
self.features.append(self.encoder.layer4(self.features[-1]))
|
430 |
+
|
431 |
+
return self.features
|
src/networks/restoration_model.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
|
3 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
|
4 |
+
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
|
12 |
+
def make_model():
|
13 |
+
return MainModel()
|
14 |
+
|
15 |
+
|
16 |
+
class DoubleConv(nn.Module):
|
17 |
+
def __init__(self, in_ch, out_ch):
|
18 |
+
super(DoubleConv, self).__init__()
|
19 |
+
self.conv = nn.Sequential(
|
20 |
+
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
|
21 |
+
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
|
24 |
+
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
|
25 |
+
nn.ReLU(inplace=True)
|
26 |
+
)
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.conv(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class InDoubleConv(nn.Module):
|
33 |
+
def __init__(self, in_ch, out_ch):
|
34 |
+
super(InDoubleConv, self).__init__()
|
35 |
+
self.conv = nn.Sequential(
|
36 |
+
nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'),
|
37 |
+
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
|
40 |
+
nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
|
41 |
+
nn.ReLU(inplace=True)
|
42 |
+
)
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.conv(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class InConv(nn.Module):
|
49 |
+
def __init__(self, in_ch, out_ch):
|
50 |
+
super(InConv, self).__init__()
|
51 |
+
self.conv = nn.Sequential(
|
52 |
+
nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'),
|
53 |
+
nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
|
54 |
+
nn.ReLU(inplace=True)
|
55 |
+
)
|
56 |
+
self.convf = nn.Sequential(
|
57 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'),
|
58 |
+
nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
|
59 |
+
nn.ReLU(inplace=False)
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
R = x[:, 0:1, :, :]
|
64 |
+
G = x[:, 1:2, :, :]
|
65 |
+
B = x[:, 2:3, :, :]
|
66 |
+
xR = torch.unsqueeze(self.conv(R), 1)
|
67 |
+
xG = torch.unsqueeze(self.conv(G), 1)
|
68 |
+
xB = torch.unsqueeze(self.conv(B), 1)
|
69 |
+
x = torch.cat([xR, xG, xB], 1)
|
70 |
+
x, _ = torch.min(x, dim=1)
|
71 |
+
return self.convf(x)
|
72 |
+
|
73 |
+
|
74 |
+
class SKConv(nn.Module):
|
75 |
+
def __init__(self, outfeatures=64, infeatures=1, M=4, L=32):
|
76 |
+
|
77 |
+
super(SKConv, self).__init__()
|
78 |
+
self.M = M
|
79 |
+
self.convs = nn.ModuleList([])
|
80 |
+
in_conv = InConv(in_ch=infeatures, out_ch=outfeatures)
|
81 |
+
for i in range(M):
|
82 |
+
if i==0:
|
83 |
+
self.convs.append(in_conv)
|
84 |
+
else:
|
85 |
+
self.convs.append(nn.Sequential(
|
86 |
+
nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True),
|
87 |
+
in_conv,
|
88 |
+
nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True)
|
89 |
+
))
|
90 |
+
self.fc = nn.Linear(outfeatures, L)
|
91 |
+
self.fcs = nn.ModuleList([])
|
92 |
+
for i in range(M):
|
93 |
+
self.fcs.append(
|
94 |
+
nn.Linear(L, outfeatures)
|
95 |
+
)
|
96 |
+
self.softmax = nn.Softmax(dim=1)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
for i, conv in enumerate(self.convs):
|
100 |
+
fea = conv(x).unsqueeze(dim=1)
|
101 |
+
if i == 0:
|
102 |
+
feas = fea
|
103 |
+
else:
|
104 |
+
feas = torch.cat([feas, fea], dim=1)
|
105 |
+
fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W)
|
106 |
+
fea_s = fea_U.mean(-1).mean(-1) # (1, 64)
|
107 |
+
fea_z = self.fc(fea_s) # (1, 32)
|
108 |
+
for i, fc in enumerate(self.fcs):
|
109 |
+
vector = fc(fea_z).unsqueeze(dim=1)
|
110 |
+
if i == 0:
|
111 |
+
attention_vectors = vector
|
112 |
+
else:
|
113 |
+
attention_vectors = torch.cat([attention_vectors, vector], dim=1)
|
114 |
+
attention_vectors = self.softmax(attention_vectors) # (1, 3, 64)
|
115 |
+
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1)
|
116 |
+
fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W)
|
117 |
+
return fea_v
|
118 |
+
|
119 |
+
|
120 |
+
class estimation(nn.Module):
|
121 |
+
def __init__(self):
|
122 |
+
super(estimation, self).__init__()
|
123 |
+
|
124 |
+
self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32)
|
125 |
+
|
126 |
+
self.convt_1 = DoubleConv(64, 64)
|
127 |
+
self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
128 |
+
self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
|
129 |
+
|
130 |
+
self.convt_2 = DoubleConv(64, 64)
|
131 |
+
self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
132 |
+
self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
|
133 |
+
|
134 |
+
self.inconv_1 = InDoubleConv(3, 64)
|
135 |
+
self.maxpool_1 = nn.MaxPool2d(15, 7)
|
136 |
+
self.doubleconv_1 = DoubleConv(64, 64)
|
137 |
+
self.pool_1 = nn.AdaptiveAvgPool2d(1)
|
138 |
+
self.dense_1 = nn.Linear(64, 3, bias=False)
|
139 |
+
|
140 |
+
self.inconv_2 = InDoubleConv(3, 64)
|
141 |
+
self.maxpool_2 = nn.MaxPool2d(15, 7)
|
142 |
+
self.doubleconv_2 = DoubleConv(64, 64)
|
143 |
+
self.pool_2 = nn.AdaptiveAvgPool2d(1)
|
144 |
+
self.dense_2 = nn.Linear(64, 3, bias=False)
|
145 |
+
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
|
149 |
+
xmin = self.InConv(x)
|
150 |
+
|
151 |
+
beta = self.OutConv_1(self.up_1(self.convt_1(xmin)))
|
152 |
+
beta = torch.sigmoid(beta) + 1e-12
|
153 |
+
|
154 |
+
atm = self.inconv_2(x)
|
155 |
+
atm = torch.mul(atm, xmin)
|
156 |
+
atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm)))
|
157 |
+
atm = atm.view(-1, 64)
|
158 |
+
atm = torch.sigmoid(self.dense_2(atm))
|
159 |
+
|
160 |
+
return beta, atm
|
161 |
+
|
162 |
+
|
163 |
+
class JNet(torch.nn.Module):
|
164 |
+
def __init__(self, num=64):
|
165 |
+
super().__init__()
|
166 |
+
self.conv1 = torch.nn.Sequential(
|
167 |
+
torch.nn.ReflectionPad2d(1),
|
168 |
+
torch.nn.Conv2d(3, num, 3, 1, 0),
|
169 |
+
torch.nn.InstanceNorm2d(num),
|
170 |
+
torch.nn.ReLU()
|
171 |
+
)
|
172 |
+
self.conv2 = torch.nn.Sequential(
|
173 |
+
torch.nn.ReflectionPad2d(1),
|
174 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
175 |
+
torch.nn.InstanceNorm2d(num),
|
176 |
+
torch.nn.ReLU()
|
177 |
+
)
|
178 |
+
self.conv3 = torch.nn.Sequential(
|
179 |
+
torch.nn.ReflectionPad2d(1),
|
180 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
181 |
+
torch.nn.InstanceNorm2d(num),
|
182 |
+
torch.nn.ReLU()
|
183 |
+
)
|
184 |
+
self.conv4 = torch.nn.Sequential(
|
185 |
+
torch.nn.ReflectionPad2d(1),
|
186 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
187 |
+
torch.nn.InstanceNorm2d(num),
|
188 |
+
torch.nn.ReLU()
|
189 |
+
)
|
190 |
+
self.final = torch.nn.Sequential(
|
191 |
+
torch.nn.Conv2d(num, 3, 1, 1, 0),
|
192 |
+
torch.nn.Sigmoid()
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(self, data):
|
196 |
+
data = self.conv1(data)
|
197 |
+
data = self.conv2(data)
|
198 |
+
data = self.conv3(data)
|
199 |
+
data = self.conv4(data)
|
200 |
+
data1 = self.final(data)
|
201 |
+
|
202 |
+
return data1
|
203 |
+
|
204 |
+
class TNet(torch.nn.Module):
|
205 |
+
def __init__(self, num=64):
|
206 |
+
super().__init__()
|
207 |
+
self.conv1 = torch.nn.Sequential(
|
208 |
+
torch.nn.ReflectionPad2d(1),
|
209 |
+
torch.nn.Conv2d(3, num, 3, 1, 0),
|
210 |
+
torch.nn.InstanceNorm2d(num),
|
211 |
+
torch.nn.ReLU()
|
212 |
+
)
|
213 |
+
self.conv2 = torch.nn.Sequential(
|
214 |
+
torch.nn.ReflectionPad2d(1),
|
215 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
216 |
+
torch.nn.InstanceNorm2d(num),
|
217 |
+
torch.nn.ReLU()
|
218 |
+
)
|
219 |
+
self.conv3 = torch.nn.Sequential(
|
220 |
+
torch.nn.ReflectionPad2d(1),
|
221 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
222 |
+
torch.nn.InstanceNorm2d(num),
|
223 |
+
torch.nn.ReLU()
|
224 |
+
)
|
225 |
+
self.conv4 = torch.nn.Sequential(
|
226 |
+
torch.nn.ReflectionPad2d(1),
|
227 |
+
torch.nn.Conv2d(num, num, 3, 1, 0),
|
228 |
+
torch.nn.InstanceNorm2d(num),
|
229 |
+
torch.nn.ReLU()
|
230 |
+
)
|
231 |
+
self.final = torch.nn.Sequential(
|
232 |
+
torch.nn.Conv2d(num, 6, 1, 1, 0),
|
233 |
+
torch.nn.Sigmoid()
|
234 |
+
)
|
235 |
+
|
236 |
+
def forward(self, data):
|
237 |
+
data = self.conv1(data)
|
238 |
+
data = self.conv2(data)
|
239 |
+
data = self.conv3(data)
|
240 |
+
data = self.conv4(data)
|
241 |
+
data1 = self.final(data)
|
242 |
+
|
243 |
+
return data1
|
244 |
+
|
245 |
+
class MainModel(nn.Module):
|
246 |
+
def __init__(self):
|
247 |
+
super().__init__()
|
248 |
+
|
249 |
+
self.estimation = estimation()
|
250 |
+
self.Jnet = JNet()
|
251 |
+
# self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True)
|
252 |
+
# self.Tnet = TNet()
|
253 |
+
|
254 |
+
def forward(self, img):
|
255 |
+
|
256 |
+
beta, A = self.estimation(img)
|
257 |
+
beta_d = beta[:, :3, :, :]
|
258 |
+
beta_b = beta[:, 3:, :, :]
|
259 |
+
J = self.Jnet(img)
|
260 |
+
A = torch.unsqueeze(torch.unsqueeze(A, 2), 2)
|
261 |
+
A = A.expand_as(J)
|
262 |
+
|
263 |
+
return [beta_d, beta_b], J, A
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
def weights_init(m):
|
268 |
+
classname = m.__class__.__name__
|
269 |
+
if classname.find('Conv2d') != -1:
|
270 |
+
m.weight.data.normal_(0.0, 0.001)
|
271 |
+
if classname.find('Linear') != -1:
|
272 |
+
m.weight.data.normal_(0.0, 0.001)
|
273 |
+
|
src/weights/depth.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:185fdc2788a039352584f942fbd7c47e70eb32472ee92770e4eb90c9ee8f3cd7
|
3 |
+
size 12621521
|
src/weights/encoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1c1df619da0aed60d1bb68b4a0012b0d6f541836f64ffe6a10aa098ef4c0732
|
3 |
+
size 76780611
|
src/weights/pose.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31360cc629502594dd329c756ad71b2ce6e2c42ae580b52bf07f399b3d9a2322
|
3 |
+
size 5260687
|
src/weights/pose_encoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c6b0a44f764c42ea39d272abf6a072ba13430da7efea43dfdbd6a2e73a0562e
|
3 |
+
size 46875213
|
src/weights/uie_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b13897691126aa5aa806aa95dc1599186546e5c585c26fad9c30e60d6b2d7a5f
|
3 |
+
size 2300389
|
test_simple.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import json
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import matplotlib as mpl
|
14 |
+
import matplotlib.cm as cm
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import Tensor
|
18 |
+
import torchvision
|
19 |
+
from torchvision import transforms
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from src.networks import *
|
23 |
+
from utils import transformation_from_parameters, disp_to_depth, line
|
24 |
+
|
25 |
+
|
26 |
+
def load_and_preprocess_image(image, resize_width, resize_height):
|
27 |
+
image_ori = image.convert('RGB')
|
28 |
+
W, H = image_ori.size
|
29 |
+
W_resized = W - W % 32
|
30 |
+
H_resized = H - H % 32
|
31 |
+
img_ori_npy = np.array(image_ori)[0:H_resized, 0:W_resized]
|
32 |
+
|
33 |
+
image = image_ori.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
|
34 |
+
image = transforms.ToTensor()(image)
|
35 |
+
image_ori = transforms.ToTensor()(img_ori_npy).unsqueeze(0)
|
36 |
+
image = line(image).unsqueeze(0)
|
37 |
+
if torch.cuda.is_available():
|
38 |
+
return image_ori.cuda(), image.cuda(), (H, W)
|
39 |
+
return image_ori, image, (H, W)
|
40 |
+
|
41 |
+
def load_and_preprocess_intrinsics(intrinsics_path, resize_width, resize_height):
|
42 |
+
K = np.eye(4)
|
43 |
+
with open(intrinsics_path, 'r') as f:
|
44 |
+
K[:3, :3] = np.array(json.load(f))
|
45 |
+
|
46 |
+
# Convert normalised intrinsics to 1/4 size unnormalised intrinsics.
|
47 |
+
# (The cost volume construction expects the intrinsics corresponding to 1/4 size images)
|
48 |
+
K[0, :] *= resize_width // 4
|
49 |
+
K[1, :] *= resize_height // 4
|
50 |
+
|
51 |
+
invK = torch.Tensor(np.linalg.pinv(K)).unsqueeze(0)
|
52 |
+
K = torch.Tensor(K).unsqueeze(0)
|
53 |
+
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
return K.cuda(), invK.cuda()
|
56 |
+
return K, invK
|
57 |
+
|
58 |
+
def tensor2img(img: Tensor) -> np.ndarray:
|
59 |
+
return (255.0 * img.permute(1, 2, 0).cpu().detach().numpy()).astype(np.uint8)
|
60 |
+
|
61 |
+
|
62 |
+
def test_simple(image: Image):
|
63 |
+
"""Function to predict for a single image or folder of images
|
64 |
+
"""
|
65 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
66 |
+
|
67 |
+
# Loading pretrained model
|
68 |
+
encoder_dict = torch.load("src/weights/encoder.pth", map_location=device)
|
69 |
+
encoder = ResnetEncoderMatching(18, False,
|
70 |
+
input_width=encoder_dict['width'],
|
71 |
+
input_height=encoder_dict['height'],
|
72 |
+
adaptive_bins=True,
|
73 |
+
min_depth_bin=encoder_dict['min_depth_bin'],
|
74 |
+
max_depth_bin=encoder_dict['max_depth_bin'],
|
75 |
+
depth_binning='linear',
|
76 |
+
num_depth_bins=96)
|
77 |
+
|
78 |
+
filtered_dict_enc = {k: v for k, v in encoder_dict.items() if k in encoder.state_dict()}
|
79 |
+
encoder.load_state_dict(filtered_dict_enc)
|
80 |
+
|
81 |
+
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
|
82 |
+
|
83 |
+
loaded_dict = torch.load("src/weights/depth.pth", map_location=device)
|
84 |
+
depth_decoder.load_state_dict(loaded_dict)
|
85 |
+
|
86 |
+
pose_enc_dict = torch.load("src/weights/pose_encoder.pth", map_location=device)
|
87 |
+
pose_dec_dict = torch.load("src/weights/pose.pth", map_location=device)
|
88 |
+
|
89 |
+
pose_enc = ResnetEncoder(18, False, num_input_images=2)
|
90 |
+
pose_dec = PoseDecoder(pose_enc.num_ch_enc,
|
91 |
+
num_input_features=1,
|
92 |
+
num_frames_to_predict_for=2)
|
93 |
+
|
94 |
+
pose_enc.load_state_dict(pose_enc_dict, strict=True)
|
95 |
+
pose_dec.load_state_dict(pose_dec_dict, strict=True)
|
96 |
+
|
97 |
+
restoration_dict = torch.load("src/weights/uie_model.pth", map_location=device)
|
98 |
+
uie_model = MainModel()
|
99 |
+
uie_model.load_state_dict(restoration_dict, strict=False)
|
100 |
+
|
101 |
+
# Setting states of networks
|
102 |
+
encoder.eval()
|
103 |
+
depth_decoder.eval()
|
104 |
+
pose_enc.eval()
|
105 |
+
pose_dec.eval()
|
106 |
+
uie_model.eval()
|
107 |
+
if torch.cuda.is_available():
|
108 |
+
encoder.cuda()
|
109 |
+
depth_decoder.cuda()
|
110 |
+
pose_enc.cuda()
|
111 |
+
pose_dec.cuda()
|
112 |
+
uie_model.cuda()
|
113 |
+
|
114 |
+
# Load input data
|
115 |
+
input_image_ori, input_image, original_size = load_and_preprocess_image(image,
|
116 |
+
resize_width=encoder_dict['width'],
|
117 |
+
resize_height=encoder_dict['height'])
|
118 |
+
source_image_ori, source_image, _ = load_and_preprocess_image(image,
|
119 |
+
resize_width=encoder_dict['width'],
|
120 |
+
resize_height=encoder_dict['height'])
|
121 |
+
|
122 |
+
K, invK = load_and_preprocess_intrinsics('canyons_intrinsics.json',
|
123 |
+
resize_width=encoder_dict['width'],
|
124 |
+
resize_height=encoder_dict['height'])
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
|
128 |
+
# Estimate poses
|
129 |
+
pose_inputs = [source_image, input_image]
|
130 |
+
pose_inputs = [pose_enc(torch.cat(pose_inputs, 1))]
|
131 |
+
axisangle, translation = pose_dec(pose_inputs)
|
132 |
+
pose = transformation_from_parameters(axisangle[:, 0], translation[:, 0], invert=True)
|
133 |
+
|
134 |
+
pose *= 0 # zero poses are a signal to the encoder not to construct a cost volume
|
135 |
+
source_image *= 0
|
136 |
+
|
137 |
+
# Estimate depth
|
138 |
+
output, lowest_cost, _ = encoder(current_image=input_image,
|
139 |
+
lookup_images=source_image.unsqueeze(1),
|
140 |
+
poses=pose.unsqueeze(1),
|
141 |
+
K=K,
|
142 |
+
invK=invK,
|
143 |
+
min_depth_bin=encoder_dict['min_depth_bin'],
|
144 |
+
max_depth_bin=encoder_dict['max_depth_bin'])
|
145 |
+
|
146 |
+
output = depth_decoder(output)
|
147 |
+
|
148 |
+
sigmoid_output = output[("disp", 0)]
|
149 |
+
_, depth_output = disp_to_depth(sigmoid_output, min_depth=0.1, max_depth=20)
|
150 |
+
sigmoid_output_resized = F.interpolate(
|
151 |
+
sigmoid_output, original_size, mode="bilinear", align_corners=False)
|
152 |
+
sigmoid_output_resized = sigmoid_output_resized.cpu().numpy()[:, 0]
|
153 |
+
depth = F.interpolate(
|
154 |
+
depth_output, input_image_ori.shape[2:], mode="bilinear", align_corners=False)
|
155 |
+
|
156 |
+
beta, J, A = uie_model(input_image_ori)
|
157 |
+
|
158 |
+
beta[0] = 5.0 * beta[0]
|
159 |
+
beta[1] = 5.0 * beta[1]
|
160 |
+
|
161 |
+
t1 = torch.exp(-beta[0] * depth)
|
162 |
+
D1 = J * t1
|
163 |
+
B1 = (1 - torch.exp(-beta[1] * depth)) * A
|
164 |
+
I_rec = D1 + B1
|
165 |
+
|
166 |
+
J_out = Image.open(tensor2img(J[0]))
|
167 |
+
|
168 |
+
return J_out
|
169 |
+
|
utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2021. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the ManyDepth licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def disp_to_depth(disp, min_depth=0.1, max_depth=100):
|
15 |
+
"""Convert network's sigmoid output into depth prediction
|
16 |
+
The formula for this conversion is given in the 'additional considerations'
|
17 |
+
section of the paper.
|
18 |
+
"""
|
19 |
+
min_disp = 1 / max_depth # 0.05
|
20 |
+
max_disp = 1 / min_depth # 10
|
21 |
+
scaled_disp = min_disp + (max_disp - min_disp) * disp
|
22 |
+
depth = 1 / scaled_disp
|
23 |
+
return scaled_disp, depth
|
24 |
+
|
25 |
+
|
26 |
+
def transformation_from_parameters(axisangle, translation, invert=False):
|
27 |
+
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
|
28 |
+
"""
|
29 |
+
R = rot_from_axisangle(axisangle)
|
30 |
+
t = translation.clone()
|
31 |
+
|
32 |
+
if invert:
|
33 |
+
R = R.transpose(1, 2)
|
34 |
+
t *= -1
|
35 |
+
|
36 |
+
T = get_translation_matrix(t)
|
37 |
+
|
38 |
+
if invert:
|
39 |
+
M = torch.matmul(R, T)
|
40 |
+
else:
|
41 |
+
M = torch.matmul(T, R)
|
42 |
+
|
43 |
+
return M
|
44 |
+
|
45 |
+
|
46 |
+
def get_translation_matrix(translation_vector):
|
47 |
+
"""Convert a translation vector into a 4x4 transformation matrix
|
48 |
+
"""
|
49 |
+
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
|
50 |
+
|
51 |
+
t = translation_vector.contiguous().view(-1, 3, 1)
|
52 |
+
|
53 |
+
T[:, 0, 0] = 1
|
54 |
+
T[:, 1, 1] = 1
|
55 |
+
T[:, 2, 2] = 1
|
56 |
+
T[:, 3, 3] = 1
|
57 |
+
T[:, :3, 3, None] = t
|
58 |
+
|
59 |
+
return T
|
60 |
+
|
61 |
+
|
62 |
+
def rot_from_axisangle(vec):
|
63 |
+
"""Convert an axisangle rotation into a 4x4 transformation matrix
|
64 |
+
(adapted from https://github.com/Wallacoloo/printipi)
|
65 |
+
Input 'vec' has to be Bx1x3
|
66 |
+
"""
|
67 |
+
angle = torch.norm(vec, 2, 2, True)
|
68 |
+
axis = vec / (angle + 1e-7)
|
69 |
+
|
70 |
+
ca = torch.cos(angle)
|
71 |
+
sa = torch.sin(angle)
|
72 |
+
C = 1 - ca
|
73 |
+
|
74 |
+
x = axis[..., 0].unsqueeze(1)
|
75 |
+
y = axis[..., 1].unsqueeze(1)
|
76 |
+
z = axis[..., 2].unsqueeze(1)
|
77 |
+
|
78 |
+
xs = x * sa
|
79 |
+
ys = y * sa
|
80 |
+
zs = z * sa
|
81 |
+
xC = x * C
|
82 |
+
yC = y * C
|
83 |
+
zC = z * C
|
84 |
+
xyC = x * yC
|
85 |
+
yzC = y * zC
|
86 |
+
zxC = z * xC
|
87 |
+
|
88 |
+
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
|
89 |
+
|
90 |
+
rot[:, 0, 0] = torch.squeeze(x * xC + ca)
|
91 |
+
rot[:, 0, 1] = torch.squeeze(xyC - zs)
|
92 |
+
rot[:, 0, 2] = torch.squeeze(zxC + ys)
|
93 |
+
rot[:, 1, 0] = torch.squeeze(xyC + zs)
|
94 |
+
rot[:, 1, 1] = torch.squeeze(y * yC + ca)
|
95 |
+
rot[:, 1, 2] = torch.squeeze(yzC - xs)
|
96 |
+
rot[:, 2, 0] = torch.squeeze(zxC - ys)
|
97 |
+
rot[:, 2, 1] = torch.squeeze(yzC + xs)
|
98 |
+
rot[:, 2, 2] = torch.squeeze(z * zC + ca)
|
99 |
+
rot[:, 3, 3] = 1
|
100 |
+
|
101 |
+
return rot
|
102 |
+
|
103 |
+
def normalize(img):
|
104 |
+
return (img - img.min()) / (img.max() - img.min())
|
105 |
+
|
106 |
+
def line(img):
|
107 |
+
img = img.unsqueeze(0)
|
108 |
+
if img.shape[1] == 1:
|
109 |
+
q5, q95 = torch.quantile(img.flatten(), q=torch.tensor((0.05, 0.95), device=img.device))
|
110 |
+
|
111 |
+
img[img < q5] = q5
|
112 |
+
img[img > q95] = q95
|
113 |
+
|
114 |
+
return normalize(img)
|
115 |
+
elif img.shape[1] == 3:
|
116 |
+
for c in range(3):
|
117 |
+
img[:, c:c+1] = line(img[:, c:c+1])
|
118 |
+
return img.squeeze()
|