younesbelkada
add first files
7708d0d
import numpy as np
import torch
import torch.nn as nn
class Interpolate(nn.Module):
def __init__(self, scale_factor, mode, align_corners=False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
x = self.interp(
x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners)
return x
class HeadDepth(nn.Module):
def __init__(self, features):
super(HeadDepth, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
# nn.ReLU()
nn.Sigmoid()
)
def forward(self, x):
x = self.head(x)
# x = (x - x.min())/(x.max()-x.min() + 1e-15)
return x
class HeadSeg(nn.Module):
def __init__(self, features, nclasses=2):
super(HeadSeg, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0)
)
def forward(self, x):
x = self.head(x)
return x