import numpy as np import torch import torch.nn as nn class Interpolate(nn.Module): def __init__(self, scale_factor, mode, align_corners=False): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class HeadDepth(nn.Module): def __init__(self, features): super(HeadDepth, self).__init__() self.head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), # nn.ReLU() nn.Sigmoid() ) def forward(self, x): x = self.head(x) # x = (x - x.min())/(x.max()-x.min() + 1e-15) return x class HeadSeg(nn.Module): def __init__(self, features, nclasses=2): super(HeadSeg, self).__init__() self.head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0) ) def forward(self, x): x = self.head(x) return x