import torch import torch.nn as nn import torch.nn.functional as F class DenseFeatureExtractionModule(nn.Module): def __init__(self, use_relu=True, use_cuda=True): super(DenseFeatureExtractionModule, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, stride=2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, stride=2), nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), nn.AvgPool2d(2, stride=1), nn.Conv2d(256, 512, 3, padding=2, dilation=2), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3, padding=2, dilation=2), ) self.num_channels = 512 self.use_relu = use_relu if use_cuda: self.model = self.model.cuda() def forward(self, batch): output = self.model(batch) if self.use_relu: output = F.relu(output) return output class D2Net(nn.Module): def __init__(self, model_file=None, use_relu=True, use_cuda=False): super(D2Net, self).__init__() self.dense_feature_extraction = DenseFeatureExtractionModule( use_relu=use_relu, use_cuda=use_cuda ) self.detection = HardDetectionModule() self.localization = HandcraftedLocalizationModule() if model_file is not None: if use_cuda: self.load_state_dict(torch.load(model_file)['model']) else: self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) def forward(self, batch): _, _, h, w = batch.size() dense_features = self.dense_feature_extraction(batch) detections = self.detection(dense_features) displacements = self.localization(dense_features) return { 'dense_features': dense_features, 'detections': detections, 'displacements': displacements } class HardDetectionModule(nn.Module): def __init__(self, edge_threshold=5): super(HardDetectionModule, self).__init__() self.edge_threshold = edge_threshold self.dii_filter = torch.tensor( [[0, 1., 0], [0, -2., 0], [0, 1., 0]] ).view(1, 1, 3, 3) self.dij_filter = 0.25 * torch.tensor( [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] ).view(1, 1, 3, 3) self.djj_filter = torch.tensor( [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] ).view(1, 1, 3, 3) def forward(self, batch): b, c, h, w = batch.size() device = batch.device depth_wise_max = torch.max(batch, dim=1)[0] is_depth_wise_max = (batch == depth_wise_max) del depth_wise_max local_max = F.max_pool2d(batch, 3, stride=1, padding=1) is_local_max = (batch == local_max) del local_max dii = F.conv2d( batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 ).view(b, c, h, w) dij = F.conv2d( batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 ).view(b, c, h, w) djj = F.conv2d( batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 ).view(b, c, h, w) det = dii * djj - dij * dij tr = dii + djj del dii, dij, djj threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) detected = torch.min( is_depth_wise_max, torch.min(is_local_max, is_not_edge) ) del is_depth_wise_max, is_local_max, is_not_edge return detected class HandcraftedLocalizationModule(nn.Module): def __init__(self): super(HandcraftedLocalizationModule, self).__init__() self.di_filter = torch.tensor( [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]] ).view(1, 1, 3, 3) self.dj_filter = torch.tensor( [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]] ).view(1, 1, 3, 3) self.dii_filter = torch.tensor( [[0, 1., 0], [0, -2., 0], [0, 1., 0]] ).view(1, 1, 3, 3) self.dij_filter = 0.25 * torch.tensor( [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] ).view(1, 1, 3, 3) self.djj_filter = torch.tensor( [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] ).view(1, 1, 3, 3) def forward(self, batch): b, c, h, w = batch.size() device = batch.device dii = F.conv2d( batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 ).view(b, c, h, w) dij = F.conv2d( batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 ).view(b, c, h, w) djj = F.conv2d( batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 ).view(b, c, h, w) det = dii * djj - dij * dij inv_hess_00 = djj / det inv_hess_01 = -dij / det inv_hess_11 = dii / det del dii, dij, djj, det di = F.conv2d( batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1 ).view(b, c, h, w) dj = F.conv2d( batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1 ).view(b, c, h, w) step_i = -(inv_hess_00 * di + inv_hess_01 * dj) step_j = -(inv_hess_01 * di + inv_hess_11 * dj) del inv_hess_00, inv_hess_01, inv_hess_11, di, dj return torch.stack([step_i, step_j], dim=1)