mkalia commited on
Commit
a50312e
·
verified ·
1 Parent(s): bf4a223

Upload 4 files

Browse files
Files changed (4) hide show
  1. depth_decoder.py +80 -0
  2. pose_cnn.py +52 -0
  3. pose_decoder.py +54 -0
  4. resnet_encoder.py +114 -0
depth_decoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from collections import OrderedDict
14
+ from layers import *
15
+
16
+
17
+ class DepthDecoder(nn.Module):
18
+ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
19
+ super(DepthDecoder, self).__init__()
20
+
21
+ self.num_output_channels = num_output_channels
22
+ self.use_skips = use_skips
23
+ self.upsample_mode = 'nearest'
24
+ self.scales = scales
25
+ self.batch_norm = batch_norm
26
+
27
+ self.num_ch_enc = num_ch_enc
28
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
29
+
30
+ self.convs = OrderedDict()
31
+ self.bn = {}
32
+ for i in range(4, -1, -1):
33
+ self.convs[("deconv", i, 0)] = nn.ConvTranspose2d(self.num_ch_dec[i], self.num_ch_dec[i], 3, stride=2, padding = 1, output_padding = 1)
34
+ if self.batch_norm:
35
+ self.bn[('bn', i)] = batchNorm(self.num_ch_dec[i])
36
+
37
+ # decoder
38
+
39
+ for i in range(4, -1, -1):
40
+ # upconv_0
41
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
42
+ num_ch_out = self.num_ch_dec[i]
43
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
44
+
45
+
46
+ # upconv_1
47
+ num_ch_in = self.num_ch_dec[i]
48
+ if self.use_skips and i > 0:
49
+ num_ch_in += self.num_ch_enc[i - 1]
50
+ num_ch_out = self.num_ch_dec[i]
51
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
52
+
53
+ for s in self.scales:
54
+ self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
55
+
56
+ self.decoder = nn.ModuleList(list(self.convs.values()))
57
+ self.sigmoid = nn.Sigmoid()
58
+
59
+ def forward(self, input_features):
60
+ self.outputs = {}
61
+
62
+ # decoder
63
+ x = input_features[-1]
64
+ for i in range(4, -1, -1):
65
+ x = self.convs[("upconv", i, 0)](x)
66
+ x = [upsample(x)]
67
+ # x = [self.convs[("deconv", i, 0)](x)]
68
+ if self.use_skips and i > 0:
69
+ x += [input_features[i - 1]]
70
+ x = torch.cat(x, 1)
71
+ x = self.convs[("upconv", i, 1)](x)
72
+ if self.batch_norm:
73
+ x = self.bn[('bn', i)].cuda()(x)
74
+
75
+
76
+ # batchnorm
77
+ if i in self.scales:
78
+ self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
79
+
80
+ return self.outputs
pose_cnn.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class PoseCNN(nn.Module):
14
+ def __init__(self, num_input_frames):
15
+ super(PoseCNN, self).__init__()
16
+
17
+ self.num_input_frames = num_input_frames
18
+
19
+ self.convs = {}
20
+ self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
21
+ self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
22
+ self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
23
+ self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
24
+ self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
25
+ self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
26
+ self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
27
+
28
+ self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
29
+
30
+ self.num_convs = len(self.convs)
31
+
32
+ self.relu = nn.ReLU(True)
33
+
34
+ self.net = nn.ModuleList(list(self.convs.values()))
35
+
36
+ def forward(self, out):
37
+
38
+ for i in range(self.num_convs):
39
+ out = self.convs[i](out)
40
+ out = self.relu(out)
41
+
42
+ out = self.pose_conv(out)
43
+ out = out.mean(3).mean(2)
44
+
45
+ # out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6) # original
46
+
47
+ out = out.view(-1, self.num_input_frames - 1, 1, 6)
48
+
49
+ axisangle = out[..., :3]
50
+ translation = out[..., 3:]
51
+
52
+ return axisangle, translation
pose_decoder.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from collections import OrderedDict
12
+
13
+
14
+ class PoseDecoder(nn.Module):
15
+ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
16
+ super(PoseDecoder, self).__init__()
17
+
18
+ self.num_ch_enc = num_ch_enc
19
+ self.num_input_features = num_input_features
20
+
21
+ if num_frames_to_predict_for is None:
22
+ num_frames_to_predict_for = num_input_features - 1
23
+ self.num_frames_to_predict_for = num_frames_to_predict_for
24
+
25
+ self.convs = OrderedDict()
26
+ self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
27
+ self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
28
+ self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
29
+ self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
30
+
31
+ self.relu = nn.ReLU()
32
+
33
+ self.net = nn.ModuleList(list(self.convs.values()))
34
+
35
+ def forward(self, input_features):
36
+ last_features = [f[-1] for f in input_features]
37
+
38
+ cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
39
+ cat_features = torch.cat(cat_features, 1)
40
+
41
+ out = cat_features
42
+ for i in range(3):
43
+ out = self.convs[("pose", i)](out)
44
+ if i != 2:
45
+ out = self.relu(out)
46
+
47
+ out = out.mean(3).mean(2)
48
+
49
+ out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
50
+
51
+ axisangle = out[..., :3]
52
+ translation = out[..., 3:]
53
+
54
+ return axisangle, translation
resnet_encoder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.models as models
14
+ import torch.utils.model_zoo as model_zoo
15
+ from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights
16
+
17
+
18
+ class ResNetMultiImageInput(models.ResNet):
19
+ """Constructs a resnet model with varying number of input images.
20
+ Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
21
+ """
22
+ def __init__(self, block, layers, num_classes=1000, num_input_images=1):
23
+ super(ResNetMultiImageInput, self).__init__(block, layers)
24
+ self.inplanes = 64
25
+ self.conv1 = nn.Conv2d(
26
+ num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(64)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
30
+ self.layer1 = self._make_layer(block, 64, layers[0])
31
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
32
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
33
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
34
+
35
+
36
+
37
+
38
+ for m in self.modules():
39
+ if isinstance(m, nn.Conv2d):
40
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
41
+ elif isinstance(m, nn.BatchNorm2d):
42
+ nn.init.constant_(m.weight, 1)
43
+ nn.init.constant_(m.bias, 0)
44
+
45
+
46
+ def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
47
+ """Constructs a ResNet model.
48
+ Args:
49
+ num_layers (int): Number of resnet layers. Must be 18 or 50
50
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
51
+ num_input_images (int): Number of frames stacked as input
52
+ """
53
+ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
54
+ blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
55
+ block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
56
+ model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
57
+
58
+ if pretrained:
59
+ # loaded = torch.utils.model_zoo.load_url(ResNet50_Weights.IMAGENET1K_V1.url)
60
+ loaded = torch.utils.model_zoo.load_url(ResNet18_Weights.IMAGENET1K_V1.url)
61
+ # loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
62
+ loaded['conv1.weight'] = torch.cat(
63
+ [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
64
+ model.load_state_dict(loaded)
65
+ return model
66
+
67
+
68
+ class ResnetEncoder(nn.Module):
69
+ """Pytorch module for a resnet encoder
70
+ """
71
+ def __init__(self, num_layers, pretrained, num_input_images=1, batch_norm_apply = False):
72
+ super(ResnetEncoder, self).__init__()
73
+
74
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
75
+
76
+ resnets = {18: models.resnet18,
77
+ 34: models.resnet34,
78
+ 50: models.resnet50,
79
+ 101: models.resnet101,
80
+ 152: models.resnet152}
81
+
82
+ if num_layers not in resnets:
83
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
84
+
85
+ if num_input_images > 1:
86
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
87
+ else:
88
+ self.encoder = resnets[num_layers](pretrained)
89
+
90
+ if num_layers > 34:
91
+ self.num_ch_enc[1:] *= 4
92
+
93
+ self.drop = True
94
+ self.dropout = torch.nn.Dropout(p=0.2)
95
+
96
+ def forward(self, input_image):
97
+ self.features = []
98
+ # x = (input_image - 0.45) / 0.225 # ?
99
+ x = input_image
100
+ x = self.encoder.conv1(x)
101
+ x = self.encoder.bn1(x)
102
+ self.features.append(self.encoder.relu(x))
103
+ if self.drop:
104
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
105
+ self.features.append(self.encoder.layer2(self.dropout(self.features[-1])))
106
+ self.features.append(self.encoder.layer3(self.dropout(self.features[-1])))
107
+ self.features.append(self.encoder.layer4(self.dropout(self.features[-1])))
108
+ else:
109
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
110
+ self.features.append(self.encoder.layer2((self.features[-1])))
111
+ self.features.append(self.encoder.layer3((self.features[-1])))
112
+ self.features.append(self.encoder.layer4((self.features[-1])))
113
+
114
+ return self.features