Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class DenseFeatureExtractionModule(nn.Module): | |
def __init__(self, use_relu=True, use_cuda=True): | |
super(DenseFeatureExtractionModule, self).__init__() | |
self.model = nn.Sequential( | |
nn.Conv2d(3, 64, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 64, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, stride=2), | |
nn.Conv2d(64, 128, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 128, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, stride=2), | |
nn.Conv2d(128, 256, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.AvgPool2d(2, stride=1), | |
nn.Conv2d(256, 512, 3, padding=2, dilation=2), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, 3, padding=2, dilation=2), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, 3, padding=2, dilation=2), | |
) | |
self.num_channels = 512 | |
self.use_relu = use_relu | |
if use_cuda: | |
self.model = self.model.cuda() | |
def forward(self, batch): | |
output = self.model(batch) | |
if self.use_relu: | |
output = F.relu(output) | |
return output | |
class D2Net(nn.Module): | |
def __init__(self, model_file=None, use_relu=True, use_cuda=True): | |
super(D2Net, self).__init__() | |
self.dense_feature_extraction = DenseFeatureExtractionModule( | |
use_relu=use_relu, use_cuda=use_cuda | |
) | |
self.detection = HardDetectionModule() | |
self.localization = HandcraftedLocalizationModule() | |
if model_file is not None: | |
if use_cuda: | |
self.load_state_dict(torch.load(model_file)['model']) | |
else: | |
self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) | |
def forward(self, batch): | |
_, _, h, w = batch.size() | |
dense_features = self.dense_feature_extraction(batch) | |
detections = self.detection(dense_features) | |
displacements = self.localization(dense_features) | |
return { | |
'dense_features': dense_features, | |
'detections': detections, | |
'displacements': displacements | |
} | |
class HardDetectionModule(nn.Module): | |
def __init__(self, edge_threshold=5): | |
super(HardDetectionModule, self).__init__() | |
self.edge_threshold = edge_threshold | |
self.dii_filter = torch.tensor( | |
[[0, 1., 0], [0, -2., 0], [0, 1., 0]] | |
).view(1, 1, 3, 3) | |
self.dij_filter = 0.25 * torch.tensor( | |
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] | |
).view(1, 1, 3, 3) | |
self.djj_filter = torch.tensor( | |
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]] | |
).view(1, 1, 3, 3) | |
def forward(self, batch): | |
b, c, h, w = batch.size() | |
device = batch.device | |
depth_wise_max = torch.max(batch, dim=1)[0] | |
is_depth_wise_max = (batch == depth_wise_max) | |
del depth_wise_max | |
local_max = F.max_pool2d(batch, 3, stride=1, padding=1) | |
is_local_max = (batch == local_max) | |
del local_max | |
dii = F.conv2d( | |
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
dij = F.conv2d( | |
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
djj = F.conv2d( | |
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
det = dii * djj - dij * dij | |
tr = dii + djj | |
del dii, dij, djj | |
threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold | |
is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) | |
detected = torch.min( | |
is_depth_wise_max, | |
torch.min(is_local_max, is_not_edge) | |
) | |
del is_depth_wise_max, is_local_max, is_not_edge | |
return detected | |
class HandcraftedLocalizationModule(nn.Module): | |
def __init__(self): | |
super(HandcraftedLocalizationModule, self).__init__() | |
self.di_filter = torch.tensor( | |
[[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]] | |
).view(1, 1, 3, 3) | |
self.dj_filter = torch.tensor( | |
[[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]] | |
).view(1, 1, 3, 3) | |
self.dii_filter = torch.tensor( | |
[[0, 1., 0], [0, -2., 0], [0, 1., 0]] | |
).view(1, 1, 3, 3) | |
self.dij_filter = 0.25 * torch.tensor( | |
[[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] | |
).view(1, 1, 3, 3) | |
self.djj_filter = torch.tensor( | |
[[0, 0, 0], [1., -2., 1.], [0, 0, 0]] | |
).view(1, 1, 3, 3) | |
def forward(self, batch): | |
b, c, h, w = batch.size() | |
device = batch.device | |
dii = F.conv2d( | |
batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
dij = F.conv2d( | |
batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
djj = F.conv2d( | |
batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
det = dii * djj - dij * dij | |
inv_hess_00 = djj / det | |
inv_hess_01 = -dij / det | |
inv_hess_11 = dii / det | |
del dii, dij, djj, det | |
di = F.conv2d( | |
batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
dj = F.conv2d( | |
batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1 | |
).view(b, c, h, w) | |
step_i = -(inv_hess_00 * di + inv_hess_01 * dj) | |
step_j = -(inv_hess_01 * di + inv_hess_11 * dj) | |
del inv_hess_00, inv_hess_01, inv_hess_11, di, dj | |
return torch.stack([step_i, step_j], dim=1) | |