Spaces:
Sleeping
Sleeping
File size: 2,917 Bytes
a50312e d7357cf a50312e d7357cf a50312e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
from layers import *
class DepthDecoder(nn.Module):
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
super(DepthDecoder, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.num_output_channels = num_output_channels
self.use_skips = use_skips
self.upsample_mode = 'nearest'
self.scales = scales
self.batch_norm = batch_norm
self.num_ch_enc = num_ch_enc
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
self.convs = OrderedDict()
self.bn = {}
for i in range(4, -1, -1):
self.convs[("deconv", i, 0)] = nn.ConvTranspose2d(self.num_ch_dec[i], self.num_ch_dec[i], 3, stride=2, padding = 1, output_padding = 1)
if self.batch_norm:
self.bn[('bn', i)] = batchNorm(self.num_ch_dec[i])
# decoder
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
for s in self.scales:
self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
self.decoder = nn.ModuleList(list(self.convs.values()))
self.sigmoid = nn.Sigmoid()
def forward(self, input_features):
self.outputs = {}
# decoder
x = input_features[-1]
for i in range(4, -1, -1):
x = self.convs[("upconv", i, 0)](x)
x = [upsample(x)]
# x = [self.convs[("deconv", i, 0)](x)]
if self.use_skips and i > 0:
x += [input_features[i - 1]]
x = torch.cat(x, 1)
x = self.convs[("upconv", i, 1)](x)
if self.batch_norm:
x = self.bn[('bn', i)].to(self.device)(x)
# batchnorm
if i in self.scales:
self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
return self.outputs
|