# Copyright Niantic 2019. Patent Pending. All rights reserved. # # This software is licensed under the terms of the Monodepth2 licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. from __future__ import absolute_import, division, print_function import numpy as np import torch import torch.nn as nn from collections import OrderedDict from layers import * class DepthDecoder(nn.Module): def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True): super(DepthDecoder, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.num_output_channels = num_output_channels self.use_skips = use_skips self.upsample_mode = 'nearest' self.scales = scales self.batch_norm = batch_norm self.num_ch_enc = num_ch_enc self.num_ch_dec = np.array([16, 32, 64, 128, 256]) self.convs = OrderedDict() self.bn = {} for i in range(4, -1, -1): self.convs[("deconv", i, 0)] = nn.ConvTranspose2d(self.num_ch_dec[i], self.num_ch_dec[i], 3, stride=2, padding = 1, output_padding = 1) if self.batch_norm: self.bn[('bn', i)] = batchNorm(self.num_ch_dec[i]) # decoder for i in range(4, -1, -1): # upconv_0 num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) # upconv_1 num_ch_in = self.num_ch_dec[i] if self.use_skips and i > 0: num_ch_in += self.num_ch_enc[i - 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) for s in self.scales: self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid() def forward(self, input_features): self.outputs = {} # decoder x = input_features[-1] for i in range(4, -1, -1): x = self.convs[("upconv", i, 0)](x) x = [upsample(x)] # x = [self.convs[("deconv", i, 0)](x)] if self.use_skips and i > 0: x += [input_features[i - 1]] x = torch.cat(x, 1) x = self.convs[("upconv", i, 1)](x) if self.batch_norm: x = self.bn[('bn', i)].to(self.device)(x) # batchnorm if i in self.scales: self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) return self.outputs