Yiting1009 commited on
Commit
5d87992
·
1 Parent(s): 8b1d8da

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ flatiron_1.tiff filter=lfs diff=lfs merge=lfs -text
36
+ flatiron_2.tiff filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
3
+ import warnings
4
+ warnings.filterwarnings('ignore')
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from test_simple import test_simple
8
+
9
+ def predict(image: Image):
10
+ return test_simple(image)
11
+
12
+ title = "Underwater Image Restoration"
13
+
14
+ iface = gr.Interface(
15
+ predict,
16
+ inputs=gr.Image(type="pil"),
17
+ outputs="image",
18
+ title=title,
19
+ allow_flagging="never",
20
+ examples=[
21
+ ["flatiron_1.tiff"],
22
+ ["flatiron_2.tiff"],
23
+ ["horse_canyon_1.tiff"],
24
+ ["horse_canyon_2.tiff"],
25
+ ],
26
+ )
27
+ iface.launch(share=True)
canyons_intrinsics.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [
2
+ [1.21, 0, 0.5],
3
+ [0, 1.93, 0.5],
4
+ [0, 0, 1.0]
5
+ ]
flatiron_1.tiff ADDED

Git LFS Details

  • SHA256: d631bcde7c7e35e9545cca059f0d3db88516417c8a1e7f6307bf04bcc0655a1d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
flatiron_2.tiff ADDED

Git LFS Details

  • SHA256: 0459cd1c4e592a2b2d0b35ba94a18ff7e810b90f4f99c08c22339975152ceafe
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
horse_canyon_1.tiff ADDED
horse_canyon_2.tiff ADDED
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/networks/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ from .resnet_encoder import ResnetEncoder, ResnetEncoderMatching
3
+ from .depth_decoder import DepthDecoder
4
+ from .pose_decoder import PoseDecoder
5
+ from .pose_cnn import PoseCNN
6
+ from .restoration_model import MainModel
7
+ # from .layers import BackprojectDepth, Project3D, ConvBlock, Conv3x3, upsample
src/networks/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (422 Bytes). View file
 
src/networks/__pycache__/depth_decoder.cpython-39.pyc ADDED
Binary file (3.17 kB). View file
 
src/networks/__pycache__/pose_cnn.cpython-39.pyc ADDED
Binary file (1.35 kB). View file
 
src/networks/__pycache__/pose_decoder.cpython-39.pyc ADDED
Binary file (1.77 kB). View file
 
src/networks/__pycache__/resnet_encoder.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
src/networks/__pycache__/restoration_model.cpython-39.pyc ADDED
Binary file (7.88 kB). View file
 
src/networks/depth_decoder.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from collections import OrderedDict
13
+
14
+ class ConvBlock(nn.Module):
15
+ """Layer to perform a convolution followed by ELU
16
+ """
17
+
18
+ def __init__(self, in_channels, out_channels):
19
+ super(ConvBlock, self).__init__()
20
+
21
+ self.conv = Conv3x3(in_channels, out_channels)
22
+ self.nonlin = nn.ELU(inplace=True)
23
+
24
+ def forward(self, x):
25
+ out = self.conv(x)
26
+ out = self.nonlin(out)
27
+ return out
28
+
29
+
30
+ class Conv3x3(nn.Module):
31
+ """Layer to pad and convolve input
32
+ """
33
+
34
+ def __init__(self, in_channels, out_channels, use_refl=True):
35
+ super(Conv3x3, self).__init__()
36
+
37
+ if use_refl:
38
+ self.pad = nn.ReflectionPad2d(1)
39
+ else:
40
+ self.pad = nn.ZeroPad2d(1)
41
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
42
+
43
+ def forward(self, x):
44
+ out = self.pad(x)
45
+ out = self.conv(out)
46
+ return out
47
+
48
+ def upsample(x):
49
+ """Upsample input tensor by a factor of 2
50
+ """
51
+ return F.interpolate(x, scale_factor=2, mode="nearest")
52
+
53
+ class DepthDecoder(nn.Module):
54
+ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
55
+ super(DepthDecoder, self).__init__()
56
+
57
+ self.num_output_channels = num_output_channels
58
+ self.use_skips = use_skips
59
+ self.upsample_mode = 'nearest'
60
+ self.scales = scales
61
+
62
+ self.num_ch_enc = num_ch_enc
63
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
64
+
65
+ # decoder
66
+ self.convs = OrderedDict()
67
+ for i in range(4, -1, -1):
68
+ # upconv_0
69
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
70
+ num_ch_out = self.num_ch_dec[i]
71
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
72
+
73
+ # upconv_1
74
+ num_ch_in = self.num_ch_dec[i]
75
+ if self.use_skips and i > 0:
76
+ num_ch_in += self.num_ch_enc[i - 1]
77
+ num_ch_out = self.num_ch_dec[i]
78
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
79
+
80
+ for s in self.scales:
81
+ self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
82
+
83
+ self.decoder = nn.ModuleList(list(self.convs.values()))
84
+ self.sigmoid = nn.Sigmoid()
85
+
86
+ def forward(self, input_features):
87
+ self.outputs = {}
88
+
89
+ # decoder
90
+ x = input_features[-1]
91
+ for i in range(4, -1, -1):
92
+ x = self.convs[("upconv", i, 0)](x)
93
+ x = [upsample(x)]
94
+ if self.use_skips and i > 0:
95
+ x += [input_features[i - 1]]
96
+ x = torch.cat(x, 1)
97
+ x = self.convs[("upconv", i, 1)](x)
98
+ if i in self.scales:
99
+ self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
100
+
101
+ return self.outputs
src/networks/pose_cnn.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ class PoseCNN(nn.Module):
11
+ def __init__(self, num_input_frames):
12
+ super(PoseCNN, self).__init__()
13
+
14
+ self.num_input_frames = num_input_frames
15
+
16
+ self.convs = {}
17
+ self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
18
+ self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
19
+ self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
20
+ self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
21
+ self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
22
+ self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
23
+ self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
24
+
25
+ self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
26
+
27
+ self.num_convs = len(self.convs)
28
+
29
+ self.relu = nn.ReLU(True)
30
+
31
+ self.net = nn.ModuleList(list(self.convs.values()))
32
+
33
+ def forward(self, out):
34
+
35
+ for i in range(self.num_convs):
36
+ out = self.convs[i](out)
37
+ out = self.relu(out)
38
+
39
+ out = self.pose_conv(out)
40
+ out = out.mean(3).mean(2)
41
+
42
+ out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
43
+
44
+ axisangle = out[..., :3]
45
+ translation = out[..., 3:]
46
+
47
+ return axisangle, translation
src/networks/pose_decoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from collections import OrderedDict
10
+
11
+
12
+ class PoseDecoder(nn.Module):
13
+ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
14
+ super(PoseDecoder, self).__init__()
15
+
16
+ self.num_ch_enc = num_ch_enc
17
+ self.num_input_features = num_input_features
18
+
19
+ if num_frames_to_predict_for is None:
20
+ num_frames_to_predict_for = num_input_features - 1
21
+ self.num_frames_to_predict_for = num_frames_to_predict_for
22
+
23
+ self.convs = OrderedDict()
24
+ self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
25
+ self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
26
+ self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
27
+ self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
28
+
29
+ self.relu = nn.ReLU()
30
+
31
+ self.net = nn.ModuleList(list(self.convs.values()))
32
+
33
+ def forward(self, input_features):
34
+ last_features = [f[-1] for f in input_features]
35
+
36
+ cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
37
+ cat_features = torch.cat(cat_features, 1)
38
+
39
+ out = cat_features
40
+ for i in range(3):
41
+ out = self.convs[("pose", i)](out)
42
+ if i != 2:
43
+ out = self.relu(out)
44
+
45
+ out = out.mean(3).mean(2)
46
+
47
+ out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
48
+
49
+ axisangle = out[..., :3]
50
+ translation = out[..., 3:]
51
+
52
+ return axisangle, translation
src/networks/resnet_encoder.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import os
8
+ os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
9
+ os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
10
+ os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torchvision.models as models
18
+ import torch.utils.model_zoo as model_zoo
19
+
20
+
21
+ class BackprojectDepth(nn.Module):
22
+ """Layer to transform a depth image into a point cloud
23
+ """
24
+
25
+ def __init__(self, batch_size, height, width):
26
+ super(BackprojectDepth, self).__init__()
27
+
28
+ self.batch_size = batch_size
29
+ self.height = height
30
+ self.width = width
31
+
32
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
33
+ self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
34
+ self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
35
+ requires_grad=False)
36
+
37
+ self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
38
+ requires_grad=False)
39
+
40
+ self.pix_coords = torch.unsqueeze(torch.stack(
41
+ [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
42
+ self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
43
+ self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
44
+ requires_grad=False)
45
+
46
+ def forward(self, depth, inv_K):
47
+ cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
48
+ cam_points = depth.view(self.batch_size, 1, -1) * cam_points
49
+ cam_points = torch.cat([cam_points, self.ones], 1)
50
+
51
+ return cam_points
52
+
53
+
54
+ class Project3D(nn.Module):
55
+ """Layer which projects 3D points into a camera with intrinsics K and at position T
56
+ """
57
+
58
+ def __init__(self, batch_size, height, width, eps=1e-7):
59
+ super(Project3D, self).__init__()
60
+
61
+ self.batch_size = batch_size
62
+ self.height = height
63
+ self.width = width
64
+ self.eps = eps
65
+
66
+ def forward(self, points, K, T):
67
+ P = torch.matmul(K, T)[:, :3, :]
68
+
69
+ cam_points = torch.matmul(P, points)
70
+
71
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
72
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
73
+ pix_coords = pix_coords.permute(0, 2, 3, 1)
74
+ pix_coords[..., 0] /= self.width - 1
75
+ pix_coords[..., 1] /= self.height - 1
76
+ pix_coords = (pix_coords - 0.5) * 2
77
+ return pix_coords
78
+
79
+ class ResNetMultiImageInput(models.ResNet):
80
+ """Constructs a resnet model with varying number of input images.
81
+ Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
82
+ """
83
+
84
+ def __init__(self, block, layers, num_classes=1000, num_input_images=1):
85
+ super(ResNetMultiImageInput, self).__init__(block, layers)
86
+ self.inplanes = 64
87
+ self.conv1 = nn.Conv2d(
88
+ num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
89
+ self.bn1 = nn.BatchNorm2d(64)
90
+ self.relu = nn.ReLU(inplace=True)
91
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
92
+ self.layer1 = self._make_layer(block, 64, layers[0])
93
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
94
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
95
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
96
+
97
+ for m in self.modules():
98
+ if isinstance(m, nn.Conv2d):
99
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
100
+ elif isinstance(m, nn.BatchNorm2d):
101
+ nn.init.constant_(m.weight, 1)
102
+ nn.init.constant_(m.bias, 0)
103
+
104
+
105
+ def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
106
+ """Constructs a ResNet model.
107
+ Args:
108
+ num_layers (int): Number of resnet layers. Must be 18 or 50
109
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
110
+ num_input_images (int): Number of frames stacked as input
111
+ """
112
+ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
113
+ blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
114
+ block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
115
+ model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
116
+
117
+ if pretrained:
118
+ loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
119
+ loaded['conv1.weight'] = torch.cat(
120
+ [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
121
+ model.load_state_dict(loaded)
122
+ return model
123
+
124
+
125
+ class ResnetEncoderMatching(nn.Module):
126
+ """Resnet encoder adapted to include a cost volume after the 2nd block.
127
+
128
+ Setting adaptive_bins=True will recompute the depth bins used for matching upon each
129
+ forward pass - this is required for training from monocular video as there is an unknown scale.
130
+ """
131
+
132
+ def __init__(self, num_layers, pretrained, input_height, input_width,
133
+ min_depth_bin=0.1, max_depth_bin=20.0, num_depth_bins=96,
134
+ adaptive_bins=False, depth_binning='linear'):
135
+
136
+ super(ResnetEncoderMatching, self).__init__()
137
+
138
+ self.adaptive_bins = adaptive_bins
139
+ self.depth_binning = depth_binning
140
+ self.set_missing_to_max = True
141
+
142
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
143
+ self.num_depth_bins = num_depth_bins
144
+ # we build the cost volume at 1/4 resolution
145
+ self.matching_height, self.matching_width = input_height // 4, input_width // 4
146
+
147
+ self.is_cuda = False
148
+ self.warp_depths = None
149
+ self.depth_bins = None
150
+
151
+ resnets = {18: models.resnet18,
152
+ 34: models.resnet34,
153
+ 50: models.resnet50,
154
+ 101: models.resnet101,
155
+ 152: models.resnet152}
156
+
157
+ if num_layers not in resnets:
158
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
159
+
160
+ encoder = resnets[num_layers](pretrained)
161
+ self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
162
+ self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1)
163
+ self.layer2 = encoder.layer2
164
+ self.layer3 = encoder.layer3
165
+ self.layer4 = encoder.layer4
166
+
167
+ if num_layers > 34:
168
+ self.num_ch_enc[1:] *= 4
169
+
170
+ self.backprojector = BackprojectDepth(batch_size=self.num_depth_bins,
171
+ height=self.matching_height,
172
+ width=self.matching_width)
173
+ self.projector = Project3D(batch_size=self.num_depth_bins,
174
+ height=self.matching_height,
175
+ width=self.matching_width)
176
+
177
+ self.compute_depth_bins(min_depth_bin, max_depth_bin)
178
+
179
+ self.prematching_conv = nn.Sequential(nn.Conv2d(64, out_channels=16,
180
+ kernel_size=1, stride=1, padding=0),
181
+ nn.ReLU(inplace=True)
182
+ )
183
+
184
+ self.reduce_conv = nn.Sequential(nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins,
185
+ out_channels=self.num_ch_enc[1],
186
+ kernel_size=3, stride=1, padding=1),
187
+ nn.ReLU(inplace=True)
188
+ )
189
+
190
+ def compute_depth_bins(self, min_depth_bin, max_depth_bin):
191
+ """Compute the depths bins used to build the cost volume. Bins will depend upon
192
+ self.depth_binning, to either be linear in depth (linear) or linear in inverse depth
193
+ (inverse)"""
194
+
195
+ if self.depth_binning == 'inverse':
196
+ self.depth_bins = 1 / np.linspace(1 / max_depth_bin,
197
+ 1 / min_depth_bin,
198
+ self.num_depth_bins)[::-1] # maintain depth order
199
+
200
+ elif self.depth_binning == 'linear':
201
+ self.depth_bins = np.linspace(min_depth_bin, max_depth_bin, self.num_depth_bins)
202
+ else:
203
+ raise NotImplementedError
204
+ self.depth_bins = torch.from_numpy(self.depth_bins).float()
205
+
206
+ self.warp_depths = []
207
+ for depth in self.depth_bins:
208
+ depth = torch.ones((1, self.matching_height, self.matching_width)) * depth
209
+ self.warp_depths.append(depth)
210
+ self.warp_depths = torch.stack(self.warp_depths, 0).float()
211
+ if self.is_cuda:
212
+ self.warp_depths = self.warp_depths.cuda()
213
+
214
+ def match_features(self, current_feats, lookup_feats, relative_poses, K, invK):
215
+ """Compute a cost volume based on L1 difference between current_feats and lookup_feats.
216
+
217
+ We backwards warp the lookup_feats into the current frame using the estimated relative
218
+ pose, known intrinsics and using hypothesised depths self.warp_depths (which are either
219
+ linear in depth or linear in inverse depth).
220
+
221
+ If relative_pose == 0 then this indicates that the lookup frame is missing (i.e. we are
222
+ at the start of a sequence), and so we skip it"""
223
+
224
+ batch_cost_volume = [] # store all cost volumes of the batch
225
+ cost_volume_masks = [] # store locations of '0's in cost volume for confidence
226
+
227
+ for batch_idx in range(len(current_feats)):
228
+
229
+ volume_shape = (self.num_depth_bins, self.matching_height, self.matching_width)
230
+ cost_volume = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
231
+ counts = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
232
+
233
+ # select an item from batch of ref feats
234
+ _lookup_feats = lookup_feats[batch_idx:batch_idx + 1]
235
+ _lookup_poses = relative_poses[batch_idx:batch_idx + 1]
236
+
237
+ _K = K[batch_idx:batch_idx + 1]
238
+ _invK = invK[batch_idx:batch_idx + 1]
239
+ world_points = self.backprojector(self.warp_depths, _invK)
240
+
241
+ # loop through ref images adding to the current cost volume
242
+ for lookup_idx in range(_lookup_feats.shape[1]):
243
+ lookup_feat = _lookup_feats[:, lookup_idx] # 1 x C x H x W
244
+ lookup_pose = _lookup_poses[:, lookup_idx]
245
+
246
+ # ignore missing images
247
+ if lookup_pose.sum() == 0:
248
+ continue
249
+
250
+ lookup_feat = lookup_feat.repeat([self.num_depth_bins, 1, 1, 1])
251
+ pix_locs = self.projector(world_points, _K, lookup_pose)
252
+ warped = F.grid_sample(lookup_feat, pix_locs, padding_mode='zeros', mode='bilinear',
253
+ align_corners=True)
254
+
255
+ # mask values landing outside the image (and near the border)
256
+ # we want to ignore edge pixels of the lookup images and the current image
257
+ # because of zero padding in ResNet
258
+ # Masking of ref image border
259
+ x_vals = (pix_locs[..., 0].detach() / 2 + 0.5) * (
260
+ self.matching_width - 1) # convert from (-1, 1) to pixel values
261
+ y_vals = (pix_locs[..., 1].detach() / 2 + 0.5) * (self.matching_height - 1)
262
+
263
+ edge_mask = (x_vals >= 2.0) * (x_vals <= self.matching_width - 2) * \
264
+ (y_vals >= 2.0) * (y_vals <= self.matching_height - 2)
265
+ edge_mask = edge_mask.float()
266
+
267
+ # masking of current image
268
+ current_mask = torch.zeros_like(edge_mask)
269
+ current_mask[:, 2:-2, 2:-2] = 1.0
270
+ edge_mask = edge_mask * current_mask
271
+
272
+ diffs = torch.abs(warped - current_feats[batch_idx:batch_idx + 1]).mean(
273
+ 1) * edge_mask
274
+
275
+ # integrate into cost volume
276
+ cost_volume = cost_volume + diffs
277
+ counts = counts + (diffs > 0).float()
278
+ # average over lookup images
279
+ cost_volume = cost_volume / (counts + 1e-7)
280
+
281
+ # if some missing values for a pixel location (i.e. some depths landed outside) then
282
+ # set to max of existing values
283
+ missing_val_mask = (cost_volume == 0).float()
284
+ if self.set_missing_to_max:
285
+ cost_volume = cost_volume * (1 - missing_val_mask) + \
286
+ cost_volume.max(0)[0].unsqueeze(0) * missing_val_mask
287
+ batch_cost_volume.append(cost_volume)
288
+ cost_volume_masks.append(missing_val_mask)
289
+
290
+ batch_cost_volume = torch.stack(batch_cost_volume, 0)
291
+ cost_volume_masks = torch.stack(cost_volume_masks, 0)
292
+
293
+ return batch_cost_volume, cost_volume_masks
294
+
295
+ def feature_extraction(self, image, return_all_feats=False):
296
+ """ Run feature extraction on an image - first 2 blocks of ResNet"""
297
+
298
+ image = (image - 0.45) / 0.225 # imagenet normalisation
299
+ feats_0 = self.layer0(image)
300
+ feats_1 = self.layer1(feats_0)
301
+
302
+ if return_all_feats:
303
+ return [feats_0, feats_1]
304
+ else:
305
+ return feats_1
306
+
307
+ def indices_to_disparity(self, indices):
308
+ """Convert cost volume indices to 1/depth for visualisation"""
309
+
310
+ batch, height, width = indices.shape
311
+ depth = self.depth_bins[indices.reshape(-1).cpu()]
312
+ disp = 1 / depth.reshape((batch, height, width))
313
+ return disp
314
+
315
+ def compute_confidence_mask(self, cost_volume, num_bins_threshold=None):
316
+ """ Returns a 'confidence' mask based on how many times a depth bin was observed"""
317
+
318
+ if num_bins_threshold is None:
319
+ num_bins_threshold = self.num_depth_bins
320
+ confidence_mask = ((cost_volume > 0).sum(1) == num_bins_threshold).float()
321
+
322
+ return confidence_mask
323
+
324
+ def forward(self, current_image, lookup_images, poses, K, invK,
325
+ min_depth_bin=None, max_depth_bin=None
326
+ ):
327
+
328
+ # feature extraction
329
+ self.features = self.feature_extraction(current_image, return_all_feats=True)
330
+ current_feats = self.features[-1]
331
+ # print('current_feats:', current_feats.shape)
332
+
333
+ # feature extraction on lookup images - disable gradients to save memory
334
+ with torch.no_grad():
335
+ if self.adaptive_bins:
336
+ self.compute_depth_bins(min_depth_bin, max_depth_bin)
337
+
338
+ batch_size, num_frames, chns, height, width = lookup_images.shape
339
+ lookup_images = lookup_images.reshape(batch_size * num_frames, chns, height, width)
340
+ lookup_feats = self.feature_extraction(lookup_images,
341
+ return_all_feats=False)
342
+ _, chns, height, width = lookup_feats.shape
343
+ lookup_feats = lookup_feats.reshape(batch_size, num_frames, chns, height, width)
344
+ # print('lookup_feats:', lookup_feats.shape)
345
+
346
+ # warp features to find cost volume
347
+ cost_volume, missing_mask = \
348
+ self.match_features(current_feats, lookup_feats, poses, K, invK)
349
+ confidence_mask = self.compute_confidence_mask(cost_volume.detach() *
350
+ (1 - missing_mask.detach()))
351
+
352
+ # for visualisation - ignore 0s in cost volume for minimum
353
+ viz_cost_vol = cost_volume.clone().detach()
354
+ viz_cost_vol[viz_cost_vol == 0] = 100
355
+ mins, argmin = torch.min(viz_cost_vol, 1)
356
+ lowest_cost = self.indices_to_disparity(argmin)
357
+
358
+ # mask the cost volume based on the confidence
359
+ cost_volume *= confidence_mask.unsqueeze(1)
360
+ post_matching_feats = self.reduce_conv(torch.cat([self.features[-1], cost_volume], 1))
361
+ # print('post_matching_feats:', post_matching_feats.shape)
362
+
363
+ self.features.append(self.layer2(post_matching_feats))
364
+ self.features.append(self.layer3(self.features[-1]))
365
+ self.features.append(self.layer4(self.features[-1]))
366
+
367
+ return self.features, lowest_cost, confidence_mask
368
+
369
+ def cuda(self):
370
+ super().cuda()
371
+ self.backprojector.cuda()
372
+ self.projector.cuda()
373
+ self.is_cuda = True
374
+ if self.warp_depths is not None:
375
+ self.warp_depths = self.warp_depths.cuda()
376
+
377
+ def cpu(self):
378
+ super().cpu()
379
+ self.backprojector.cpu()
380
+ self.projector.cpu()
381
+ self.is_cuda = False
382
+ if self.warp_depths is not None:
383
+ self.warp_depths = self.warp_depths.cpu()
384
+
385
+ def to(self, device):
386
+ if str(device) == 'cpu':
387
+ self.cpu()
388
+ elif str(device) == 'cuda':
389
+ self.cuda()
390
+ else:
391
+ raise NotImplementedError
392
+
393
+
394
+ class ResnetEncoder(nn.Module):
395
+ """Pytorch module for a resnet encoder
396
+ """
397
+
398
+ def __init__(self, num_layers, pretrained, num_input_images=1, **kwargs):
399
+ super(ResnetEncoder, self).__init__()
400
+
401
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
402
+
403
+ resnets = {18: models.resnet18,
404
+ 34: models.resnet34,
405
+ 50: models.resnet50,
406
+ 101: models.resnet101,
407
+ 152: models.resnet152}
408
+
409
+ if num_layers not in resnets:
410
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
411
+
412
+ if num_input_images > 1:
413
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
414
+ else:
415
+ self.encoder = resnets[num_layers](pretrained)
416
+
417
+ if num_layers > 34:
418
+ self.num_ch_enc[1:] *= 4
419
+
420
+ def forward(self, input_image):
421
+ self.features = []
422
+ x = (input_image - 0.45) / 0.225
423
+ x = self.encoder.conv1(x)
424
+ x = self.encoder.bn1(x)
425
+ self.features.append(self.encoder.relu(x))
426
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
427
+ self.features.append(self.encoder.layer2(self.features[-1]))
428
+ self.features.append(self.encoder.layer3(self.features[-1]))
429
+ self.features.append(self.encoder.layer4(self.features[-1]))
430
+
431
+ return self.features
src/networks/restoration_model.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
3
+ os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
4
+ os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+
12
+ def make_model():
13
+ return MainModel()
14
+
15
+
16
+ class DoubleConv(nn.Module):
17
+ def __init__(self, in_ch, out_ch):
18
+ super(DoubleConv, self).__init__()
19
+ self.conv = nn.Sequential(
20
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
21
+ nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
24
+ nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
25
+ nn.ReLU(inplace=True)
26
+ )
27
+ def forward(self, x):
28
+ x = self.conv(x)
29
+ return x
30
+
31
+
32
+ class InDoubleConv(nn.Module):
33
+ def __init__(self, in_ch, out_ch):
34
+ super(InDoubleConv, self).__init__()
35
+ self.conv = nn.Sequential(
36
+ nn.Conv2d(in_ch, out_ch, 9, stride=4, padding=4, bias=False, padding_mode='reflect'),
37
+ nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
38
+ nn.ReLU(inplace=True),
39
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False, padding_mode='reflect'),
40
+ nn.GroupNorm(num_channels=out_ch, num_groups=8, affine=True),
41
+ nn.ReLU(inplace=True)
42
+ )
43
+ def forward(self, x):
44
+ x = self.conv(x)
45
+ return x
46
+
47
+
48
+ class InConv(nn.Module):
49
+ def __init__(self, in_ch, out_ch):
50
+ super(InConv, self).__init__()
51
+ self.conv = nn.Sequential(
52
+ nn.Conv2d(1, 64, 7, stride = 4, padding=3, bias=False, padding_mode='reflect'),
53
+ nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
54
+ nn.ReLU(inplace=True)
55
+ )
56
+ self.convf = nn.Sequential(
57
+ nn.Conv2d(64, 64, 3, padding=1, bias=False, padding_mode='reflect'),
58
+ nn.GroupNorm(num_channels=64, num_groups=8, affine=True),
59
+ nn.ReLU(inplace=False)
60
+ )
61
+
62
+ def forward(self, x):
63
+ R = x[:, 0:1, :, :]
64
+ G = x[:, 1:2, :, :]
65
+ B = x[:, 2:3, :, :]
66
+ xR = torch.unsqueeze(self.conv(R), 1)
67
+ xG = torch.unsqueeze(self.conv(G), 1)
68
+ xB = torch.unsqueeze(self.conv(B), 1)
69
+ x = torch.cat([xR, xG, xB], 1)
70
+ x, _ = torch.min(x, dim=1)
71
+ return self.convf(x)
72
+
73
+
74
+ class SKConv(nn.Module):
75
+ def __init__(self, outfeatures=64, infeatures=1, M=4, L=32):
76
+
77
+ super(SKConv, self).__init__()
78
+ self.M = M
79
+ self.convs = nn.ModuleList([])
80
+ in_conv = InConv(in_ch=infeatures, out_ch=outfeatures)
81
+ for i in range(M):
82
+ if i==0:
83
+ self.convs.append(in_conv)
84
+ else:
85
+ self.convs.append(nn.Sequential(
86
+ nn.Upsample(scale_factor=1/(2**i), mode='bilinear', align_corners=True),
87
+ in_conv,
88
+ nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=True)
89
+ ))
90
+ self.fc = nn.Linear(outfeatures, L)
91
+ self.fcs = nn.ModuleList([])
92
+ for i in range(M):
93
+ self.fcs.append(
94
+ nn.Linear(L, outfeatures)
95
+ )
96
+ self.softmax = nn.Softmax(dim=1)
97
+
98
+ def forward(self, x):
99
+ for i, conv in enumerate(self.convs):
100
+ fea = conv(x).unsqueeze(dim=1)
101
+ if i == 0:
102
+ feas = fea
103
+ else:
104
+ feas = torch.cat([feas, fea], dim=1)
105
+ fea_U = torch.sum(feas, dim=1) # fea_U:(1, 64, H, W)
106
+ fea_s = fea_U.mean(-1).mean(-1) # (1, 64)
107
+ fea_z = self.fc(fea_s) # (1, 32)
108
+ for i, fc in enumerate(self.fcs):
109
+ vector = fc(fea_z).unsqueeze(dim=1)
110
+ if i == 0:
111
+ attention_vectors = vector
112
+ else:
113
+ attention_vectors = torch.cat([attention_vectors, vector], dim=1)
114
+ attention_vectors = self.softmax(attention_vectors) # (1, 3, 64)
115
+ attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) # (1, 3, 64, 1, 1)
116
+ fea_v = (feas * attention_vectors).sum(dim=1) # (1, 64, H, W)
117
+ return fea_v
118
+
119
+
120
+ class estimation(nn.Module):
121
+ def __init__(self):
122
+ super(estimation, self).__init__()
123
+
124
+ self.InConv = SKConv(outfeatures=64, infeatures=1, M=3 ,L=32)
125
+
126
+ self.convt_1 = DoubleConv(64, 64)
127
+ self.up_1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
128
+ self.OutConv_1 = nn.Conv2d(64, 6, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
129
+
130
+ self.convt_2 = DoubleConv(64, 64)
131
+ self.up_2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
132
+ self.OutConv_2 = nn.Conv2d(64, 3, 3, padding = 1, stride=1, bias=False, padding_mode='reflect')
133
+
134
+ self.inconv_1 = InDoubleConv(3, 64)
135
+ self.maxpool_1 = nn.MaxPool2d(15, 7)
136
+ self.doubleconv_1 = DoubleConv(64, 64)
137
+ self.pool_1 = nn.AdaptiveAvgPool2d(1)
138
+ self.dense_1 = nn.Linear(64, 3, bias=False)
139
+
140
+ self.inconv_2 = InDoubleConv(3, 64)
141
+ self.maxpool_2 = nn.MaxPool2d(15, 7)
142
+ self.doubleconv_2 = DoubleConv(64, 64)
143
+ self.pool_2 = nn.AdaptiveAvgPool2d(1)
144
+ self.dense_2 = nn.Linear(64, 3, bias=False)
145
+
146
+
147
+ def forward(self, x):
148
+
149
+ xmin = self.InConv(x)
150
+
151
+ beta = self.OutConv_1(self.up_1(self.convt_1(xmin)))
152
+ beta = torch.sigmoid(beta) + 1e-12
153
+
154
+ atm = self.inconv_2(x)
155
+ atm = torch.mul(atm, xmin)
156
+ atm = self.pool_2(self.doubleconv_2(self.maxpool_2(atm)))
157
+ atm = atm.view(-1, 64)
158
+ atm = torch.sigmoid(self.dense_2(atm))
159
+
160
+ return beta, atm
161
+
162
+
163
+ class JNet(torch.nn.Module):
164
+ def __init__(self, num=64):
165
+ super().__init__()
166
+ self.conv1 = torch.nn.Sequential(
167
+ torch.nn.ReflectionPad2d(1),
168
+ torch.nn.Conv2d(3, num, 3, 1, 0),
169
+ torch.nn.InstanceNorm2d(num),
170
+ torch.nn.ReLU()
171
+ )
172
+ self.conv2 = torch.nn.Sequential(
173
+ torch.nn.ReflectionPad2d(1),
174
+ torch.nn.Conv2d(num, num, 3, 1, 0),
175
+ torch.nn.InstanceNorm2d(num),
176
+ torch.nn.ReLU()
177
+ )
178
+ self.conv3 = torch.nn.Sequential(
179
+ torch.nn.ReflectionPad2d(1),
180
+ torch.nn.Conv2d(num, num, 3, 1, 0),
181
+ torch.nn.InstanceNorm2d(num),
182
+ torch.nn.ReLU()
183
+ )
184
+ self.conv4 = torch.nn.Sequential(
185
+ torch.nn.ReflectionPad2d(1),
186
+ torch.nn.Conv2d(num, num, 3, 1, 0),
187
+ torch.nn.InstanceNorm2d(num),
188
+ torch.nn.ReLU()
189
+ )
190
+ self.final = torch.nn.Sequential(
191
+ torch.nn.Conv2d(num, 3, 1, 1, 0),
192
+ torch.nn.Sigmoid()
193
+ )
194
+
195
+ def forward(self, data):
196
+ data = self.conv1(data)
197
+ data = self.conv2(data)
198
+ data = self.conv3(data)
199
+ data = self.conv4(data)
200
+ data1 = self.final(data)
201
+
202
+ return data1
203
+
204
+ class TNet(torch.nn.Module):
205
+ def __init__(self, num=64):
206
+ super().__init__()
207
+ self.conv1 = torch.nn.Sequential(
208
+ torch.nn.ReflectionPad2d(1),
209
+ torch.nn.Conv2d(3, num, 3, 1, 0),
210
+ torch.nn.InstanceNorm2d(num),
211
+ torch.nn.ReLU()
212
+ )
213
+ self.conv2 = torch.nn.Sequential(
214
+ torch.nn.ReflectionPad2d(1),
215
+ torch.nn.Conv2d(num, num, 3, 1, 0),
216
+ torch.nn.InstanceNorm2d(num),
217
+ torch.nn.ReLU()
218
+ )
219
+ self.conv3 = torch.nn.Sequential(
220
+ torch.nn.ReflectionPad2d(1),
221
+ torch.nn.Conv2d(num, num, 3, 1, 0),
222
+ torch.nn.InstanceNorm2d(num),
223
+ torch.nn.ReLU()
224
+ )
225
+ self.conv4 = torch.nn.Sequential(
226
+ torch.nn.ReflectionPad2d(1),
227
+ torch.nn.Conv2d(num, num, 3, 1, 0),
228
+ torch.nn.InstanceNorm2d(num),
229
+ torch.nn.ReLU()
230
+ )
231
+ self.final = torch.nn.Sequential(
232
+ torch.nn.Conv2d(num, 6, 1, 1, 0),
233
+ torch.nn.Sigmoid()
234
+ )
235
+
236
+ def forward(self, data):
237
+ data = self.conv1(data)
238
+ data = self.conv2(data)
239
+ data = self.conv3(data)
240
+ data = self.conv4(data)
241
+ data1 = self.final(data)
242
+
243
+ return data1
244
+
245
+ class MainModel(nn.Module):
246
+ def __init__(self):
247
+ super().__init__()
248
+
249
+ self.estimation = estimation()
250
+ self.Jnet = JNet()
251
+ # self.unet_J = UNet(n_channels=3, n_classes=3, bilinear=True)
252
+ # self.Tnet = TNet()
253
+
254
+ def forward(self, img):
255
+
256
+ beta, A = self.estimation(img)
257
+ beta_d = beta[:, :3, :, :]
258
+ beta_b = beta[:, 3:, :, :]
259
+ J = self.Jnet(img)
260
+ A = torch.unsqueeze(torch.unsqueeze(A, 2), 2)
261
+ A = A.expand_as(J)
262
+
263
+ return [beta_d, beta_b], J, A
264
+
265
+
266
+
267
+ def weights_init(m):
268
+ classname = m.__class__.__name__
269
+ if classname.find('Conv2d') != -1:
270
+ m.weight.data.normal_(0.0, 0.001)
271
+ if classname.find('Linear') != -1:
272
+ m.weight.data.normal_(0.0, 0.001)
273
+
src/weights/depth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:185fdc2788a039352584f942fbd7c47e70eb32472ee92770e4eb90c9ee8f3cd7
3
+ size 12621521
src/weights/encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1c1df619da0aed60d1bb68b4a0012b0d6f541836f64ffe6a10aa098ef4c0732
3
+ size 76780611
src/weights/pose.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31360cc629502594dd329c756ad71b2ce6e2c42ae580b52bf07f399b3d9a2322
3
+ size 5260687
src/weights/pose_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c6b0a44f764c42ea39d272abf6a072ba13430da7efea43dfdbd6a2e73a0562e
3
+ size 46875213
src/weights/uie_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b13897691126aa5aa806aa95dc1599186546e5c585c26fad9c30e60d6b2d7a5f
3
+ size 2300389
test_simple.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import os
8
+ import time
9
+ import json
10
+ import argparse
11
+ import numpy as np
12
+ from PIL import Image
13
+ import matplotlib as mpl
14
+ import matplotlib.cm as cm
15
+
16
+ import torch
17
+ from torch import Tensor
18
+ import torchvision
19
+ from torchvision import transforms
20
+ import torch.nn.functional as F
21
+
22
+ from src.networks import *
23
+ from utils import transformation_from_parameters, disp_to_depth, line
24
+
25
+
26
+ def load_and_preprocess_image(image, resize_width, resize_height):
27
+ image_ori = image.convert('RGB')
28
+ W, H = image_ori.size
29
+ W_resized = W - W % 32
30
+ H_resized = H - H % 32
31
+ img_ori_npy = np.array(image_ori)[0:H_resized, 0:W_resized]
32
+
33
+ image = image_ori.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
34
+ image = transforms.ToTensor()(image)
35
+ image_ori = transforms.ToTensor()(img_ori_npy).unsqueeze(0)
36
+ image = line(image).unsqueeze(0)
37
+ if torch.cuda.is_available():
38
+ return image_ori.cuda(), image.cuda(), (H, W)
39
+ return image_ori, image, (H, W)
40
+
41
+ def load_and_preprocess_intrinsics(intrinsics_path, resize_width, resize_height):
42
+ K = np.eye(4)
43
+ with open(intrinsics_path, 'r') as f:
44
+ K[:3, :3] = np.array(json.load(f))
45
+
46
+ # Convert normalised intrinsics to 1/4 size unnormalised intrinsics.
47
+ # (The cost volume construction expects the intrinsics corresponding to 1/4 size images)
48
+ K[0, :] *= resize_width // 4
49
+ K[1, :] *= resize_height // 4
50
+
51
+ invK = torch.Tensor(np.linalg.pinv(K)).unsqueeze(0)
52
+ K = torch.Tensor(K).unsqueeze(0)
53
+
54
+ if torch.cuda.is_available():
55
+ return K.cuda(), invK.cuda()
56
+ return K, invK
57
+
58
+ def tensor2img(img: Tensor) -> np.ndarray:
59
+ return (255.0 * img.permute(1, 2, 0).cpu().detach().numpy()).astype(np.uint8)
60
+
61
+
62
+ def test_simple(image: Image):
63
+ """Function to predict for a single image or folder of images
64
+ """
65
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
66
+
67
+ # Loading pretrained model
68
+ encoder_dict = torch.load("src/weights/encoder.pth", map_location=device)
69
+ encoder = ResnetEncoderMatching(18, False,
70
+ input_width=encoder_dict['width'],
71
+ input_height=encoder_dict['height'],
72
+ adaptive_bins=True,
73
+ min_depth_bin=encoder_dict['min_depth_bin'],
74
+ max_depth_bin=encoder_dict['max_depth_bin'],
75
+ depth_binning='linear',
76
+ num_depth_bins=96)
77
+
78
+ filtered_dict_enc = {k: v for k, v in encoder_dict.items() if k in encoder.state_dict()}
79
+ encoder.load_state_dict(filtered_dict_enc)
80
+
81
+ depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
82
+
83
+ loaded_dict = torch.load("src/weights/depth.pth", map_location=device)
84
+ depth_decoder.load_state_dict(loaded_dict)
85
+
86
+ pose_enc_dict = torch.load("src/weights/pose_encoder.pth", map_location=device)
87
+ pose_dec_dict = torch.load("src/weights/pose.pth", map_location=device)
88
+
89
+ pose_enc = ResnetEncoder(18, False, num_input_images=2)
90
+ pose_dec = PoseDecoder(pose_enc.num_ch_enc,
91
+ num_input_features=1,
92
+ num_frames_to_predict_for=2)
93
+
94
+ pose_enc.load_state_dict(pose_enc_dict, strict=True)
95
+ pose_dec.load_state_dict(pose_dec_dict, strict=True)
96
+
97
+ restoration_dict = torch.load("src/weights/uie_model.pth", map_location=device)
98
+ uie_model = MainModel()
99
+ uie_model.load_state_dict(restoration_dict, strict=False)
100
+
101
+ # Setting states of networks
102
+ encoder.eval()
103
+ depth_decoder.eval()
104
+ pose_enc.eval()
105
+ pose_dec.eval()
106
+ uie_model.eval()
107
+ if torch.cuda.is_available():
108
+ encoder.cuda()
109
+ depth_decoder.cuda()
110
+ pose_enc.cuda()
111
+ pose_dec.cuda()
112
+ uie_model.cuda()
113
+
114
+ # Load input data
115
+ input_image_ori, input_image, original_size = load_and_preprocess_image(image,
116
+ resize_width=encoder_dict['width'],
117
+ resize_height=encoder_dict['height'])
118
+ source_image_ori, source_image, _ = load_and_preprocess_image(image,
119
+ resize_width=encoder_dict['width'],
120
+ resize_height=encoder_dict['height'])
121
+
122
+ K, invK = load_and_preprocess_intrinsics('canyons_intrinsics.json',
123
+ resize_width=encoder_dict['width'],
124
+ resize_height=encoder_dict['height'])
125
+
126
+ with torch.no_grad():
127
+
128
+ # Estimate poses
129
+ pose_inputs = [source_image, input_image]
130
+ pose_inputs = [pose_enc(torch.cat(pose_inputs, 1))]
131
+ axisangle, translation = pose_dec(pose_inputs)
132
+ pose = transformation_from_parameters(axisangle[:, 0], translation[:, 0], invert=True)
133
+
134
+ pose *= 0 # zero poses are a signal to the encoder not to construct a cost volume
135
+ source_image *= 0
136
+
137
+ # Estimate depth
138
+ output, lowest_cost, _ = encoder(current_image=input_image,
139
+ lookup_images=source_image.unsqueeze(1),
140
+ poses=pose.unsqueeze(1),
141
+ K=K,
142
+ invK=invK,
143
+ min_depth_bin=encoder_dict['min_depth_bin'],
144
+ max_depth_bin=encoder_dict['max_depth_bin'])
145
+
146
+ output = depth_decoder(output)
147
+
148
+ sigmoid_output = output[("disp", 0)]
149
+ _, depth_output = disp_to_depth(sigmoid_output, min_depth=0.1, max_depth=20)
150
+ sigmoid_output_resized = F.interpolate(
151
+ sigmoid_output, original_size, mode="bilinear", align_corners=False)
152
+ sigmoid_output_resized = sigmoid_output_resized.cpu().numpy()[:, 0]
153
+ depth = F.interpolate(
154
+ depth_output, input_image_ori.shape[2:], mode="bilinear", align_corners=False)
155
+
156
+ beta, J, A = uie_model(input_image_ori)
157
+
158
+ beta[0] = 5.0 * beta[0]
159
+ beta[1] = 5.0 * beta[1]
160
+
161
+ t1 = torch.exp(-beta[0] * depth)
162
+ D1 = J * t1
163
+ B1 = (1 - torch.exp(-beta[1] * depth)) * A
164
+ I_rec = D1 + B1
165
+
166
+ J_out = Image.open(tensor2img(J[0]))
167
+
168
+ return J_out
169
+
utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2021. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the ManyDepth licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def disp_to_depth(disp, min_depth=0.1, max_depth=100):
15
+ """Convert network's sigmoid output into depth prediction
16
+ The formula for this conversion is given in the 'additional considerations'
17
+ section of the paper.
18
+ """
19
+ min_disp = 1 / max_depth # 0.05
20
+ max_disp = 1 / min_depth # 10
21
+ scaled_disp = min_disp + (max_disp - min_disp) * disp
22
+ depth = 1 / scaled_disp
23
+ return scaled_disp, depth
24
+
25
+
26
+ def transformation_from_parameters(axisangle, translation, invert=False):
27
+ """Convert the network's (axisangle, translation) output into a 4x4 matrix
28
+ """
29
+ R = rot_from_axisangle(axisangle)
30
+ t = translation.clone()
31
+
32
+ if invert:
33
+ R = R.transpose(1, 2)
34
+ t *= -1
35
+
36
+ T = get_translation_matrix(t)
37
+
38
+ if invert:
39
+ M = torch.matmul(R, T)
40
+ else:
41
+ M = torch.matmul(T, R)
42
+
43
+ return M
44
+
45
+
46
+ def get_translation_matrix(translation_vector):
47
+ """Convert a translation vector into a 4x4 transformation matrix
48
+ """
49
+ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
50
+
51
+ t = translation_vector.contiguous().view(-1, 3, 1)
52
+
53
+ T[:, 0, 0] = 1
54
+ T[:, 1, 1] = 1
55
+ T[:, 2, 2] = 1
56
+ T[:, 3, 3] = 1
57
+ T[:, :3, 3, None] = t
58
+
59
+ return T
60
+
61
+
62
+ def rot_from_axisangle(vec):
63
+ """Convert an axisangle rotation into a 4x4 transformation matrix
64
+ (adapted from https://github.com/Wallacoloo/printipi)
65
+ Input 'vec' has to be Bx1x3
66
+ """
67
+ angle = torch.norm(vec, 2, 2, True)
68
+ axis = vec / (angle + 1e-7)
69
+
70
+ ca = torch.cos(angle)
71
+ sa = torch.sin(angle)
72
+ C = 1 - ca
73
+
74
+ x = axis[..., 0].unsqueeze(1)
75
+ y = axis[..., 1].unsqueeze(1)
76
+ z = axis[..., 2].unsqueeze(1)
77
+
78
+ xs = x * sa
79
+ ys = y * sa
80
+ zs = z * sa
81
+ xC = x * C
82
+ yC = y * C
83
+ zC = z * C
84
+ xyC = x * yC
85
+ yzC = y * zC
86
+ zxC = z * xC
87
+
88
+ rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
89
+
90
+ rot[:, 0, 0] = torch.squeeze(x * xC + ca)
91
+ rot[:, 0, 1] = torch.squeeze(xyC - zs)
92
+ rot[:, 0, 2] = torch.squeeze(zxC + ys)
93
+ rot[:, 1, 0] = torch.squeeze(xyC + zs)
94
+ rot[:, 1, 1] = torch.squeeze(y * yC + ca)
95
+ rot[:, 1, 2] = torch.squeeze(yzC - xs)
96
+ rot[:, 2, 0] = torch.squeeze(zxC - ys)
97
+ rot[:, 2, 1] = torch.squeeze(yzC + xs)
98
+ rot[:, 2, 2] = torch.squeeze(z * zC + ca)
99
+ rot[:, 3, 3] = 1
100
+
101
+ return rot
102
+
103
+ def normalize(img):
104
+ return (img - img.min()) / (img.max() - img.min())
105
+
106
+ def line(img):
107
+ img = img.unsqueeze(0)
108
+ if img.shape[1] == 1:
109
+ q5, q95 = torch.quantile(img.flatten(), q=torch.tensor((0.05, 0.95), device=img.device))
110
+
111
+ img[img < q5] = q5
112
+ img[img > q95] = q95
113
+
114
+ return normalize(img)
115
+ elif img.shape[1] == 3:
116
+ for c in range(3):
117
+ img[:, c:c+1] = line(img[:, c:c+1])
118
+ return img.squeeze()