|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
from .nets.backbone import HourglassBackbone, SuperpointBackbone |
|
from .nets.junction_decoder import SuperpointDecoder |
|
from .nets.heatmap_decoder import PixelShuffleDecoder |
|
from .nets.descriptor_decoder import SuperpointDescriptor |
|
|
|
|
|
def get_model(model_cfg=None, loss_weights=None, mode="train"): |
|
"""Get model based on the model configuration.""" |
|
|
|
if model_cfg is None: |
|
raise ValueError("[Error] The model config is required!") |
|
|
|
|
|
print("\n\n\t--------Initializing model----------") |
|
supported_arch = ["simple"] |
|
if not model_cfg["model_architecture"] in supported_arch: |
|
raise ValueError("[Error] The model architecture is not in supported arch!") |
|
|
|
if model_cfg["model_architecture"] == "simple": |
|
model = SOLD2Net(model_cfg) |
|
else: |
|
raise ValueError("[Error] The model architecture is not in supported arch!") |
|
|
|
|
|
if mode == "train": |
|
if loss_weights is not None: |
|
for param_name, param in loss_weights.items(): |
|
if isinstance(param, nn.Parameter): |
|
print( |
|
"\t [Debug] Adding %s with value %f to model" |
|
% (param_name, param.item()) |
|
) |
|
model.register_parameter(param_name, param) |
|
else: |
|
raise ValueError( |
|
"[Error] the loss weights can not be None in dynamic weighting mode during training." |
|
) |
|
|
|
|
|
print("\tModel architecture: %s" % model_cfg["model_architecture"]) |
|
print("\tBackbone: %s" % model_cfg["backbone"]) |
|
print("\tJunction decoder: %s" % model_cfg["junction_decoder"]) |
|
print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"]) |
|
print("\t-------------------------------------") |
|
|
|
return model |
|
|
|
|
|
class SOLD2Net(nn.Module): |
|
"""Full network for SOLD².""" |
|
|
|
def __init__(self, model_cfg): |
|
super(SOLD2Net, self).__init__() |
|
self.name = model_cfg["model_name"] |
|
self.cfg = model_cfg |
|
|
|
|
|
self.supported_backbone = ["lcnn", "superpoint"] |
|
self.backbone_net, self.feat_channel = self.get_backbone() |
|
|
|
|
|
self.supported_junction_decoder = ["superpoint_decoder"] |
|
self.junction_decoder = self.get_junction_decoder() |
|
|
|
|
|
self.supported_heatmap_decoder = ["pixel_shuffle", "pixel_shuffle_single"] |
|
self.heatmap_decoder = self.get_heatmap_decoder() |
|
|
|
|
|
if "descriptor_decoder" in self.cfg: |
|
self.supported_descriptor_decoder = ["superpoint_descriptor"] |
|
self.descriptor_decoder = self.get_descriptor_decoder() |
|
|
|
|
|
self.apply(weight_init) |
|
|
|
def forward(self, input_images): |
|
|
|
features = self.backbone_net(input_images) |
|
|
|
|
|
junctions = self.junction_decoder(features) |
|
|
|
|
|
heatmaps = self.heatmap_decoder(features) |
|
|
|
outputs = {"junctions": junctions, "heatmap": heatmaps} |
|
|
|
|
|
if "descriptor_decoder" in self.cfg: |
|
outputs["descriptors"] = self.descriptor_decoder(features) |
|
|
|
return outputs |
|
|
|
def get_backbone(self): |
|
"""Retrieve the backbone encoder network.""" |
|
if not self.cfg["backbone"] in self.supported_backbone: |
|
raise ValueError("[Error] The backbone selection is not supported.") |
|
|
|
|
|
if self.cfg["backbone"] == "lcnn": |
|
backbone_cfg = self.cfg["backbone_cfg"] |
|
backbone = HourglassBackbone(**backbone_cfg) |
|
feat_channel = 256 |
|
|
|
elif self.cfg["backbone"] == "superpoint": |
|
backbone_cfg = self.cfg["backbone_cfg"] |
|
backbone = SuperpointBackbone() |
|
feat_channel = 128 |
|
|
|
else: |
|
raise ValueError("[Error] The backbone selection is not supported.") |
|
|
|
return backbone, feat_channel |
|
|
|
def get_junction_decoder(self): |
|
"""Get the junction decoder.""" |
|
if not self.cfg["junction_decoder"] in self.supported_junction_decoder: |
|
raise ValueError("[Error] The junction decoder selection is not supported.") |
|
|
|
|
|
if self.cfg["junction_decoder"] == "superpoint_decoder": |
|
decoder = SuperpointDecoder(self.feat_channel, self.cfg["backbone"]) |
|
else: |
|
raise ValueError("[Error] The junction decoder selection is not supported.") |
|
|
|
return decoder |
|
|
|
def get_heatmap_decoder(self): |
|
"""Get the heatmap decoder.""" |
|
if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: |
|
raise ValueError("[Error] The heatmap decoder selection is not supported.") |
|
|
|
|
|
if self.cfg["heatmap_decoder"] == "pixel_shuffle": |
|
if self.cfg["backbone"] == "lcnn": |
|
decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=2) |
|
elif self.cfg["backbone"] == "superpoint": |
|
decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=3) |
|
else: |
|
raise ValueError("[Error] Unknown backbone option.") |
|
|
|
elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": |
|
if self.cfg["backbone"] == "lcnn": |
|
decoder = PixelShuffleDecoder( |
|
self.feat_channel, num_upsample=2, output_channel=1 |
|
) |
|
elif self.cfg["backbone"] == "superpoint": |
|
decoder = PixelShuffleDecoder( |
|
self.feat_channel, num_upsample=3, output_channel=1 |
|
) |
|
else: |
|
raise ValueError("[Error] Unknown backbone option.") |
|
else: |
|
raise ValueError("[Error] The heatmap decoder selection is not supported.") |
|
|
|
return decoder |
|
|
|
def get_descriptor_decoder(self): |
|
"""Get the descriptor decoder.""" |
|
if not self.cfg["descriptor_decoder"] in self.supported_descriptor_decoder: |
|
raise ValueError( |
|
"[Error] The descriptor decoder selection is not supported." |
|
) |
|
|
|
|
|
if self.cfg["descriptor_decoder"] == "superpoint_descriptor": |
|
decoder = SuperpointDescriptor(self.feat_channel) |
|
else: |
|
raise ValueError( |
|
"[Error] The descriptor decoder selection is not supported." |
|
) |
|
|
|
return decoder |
|
|
|
|
|
def weight_init(m): |
|
"""Weight initialization function.""" |
|
|
|
if isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.normal_(m.weight.data, mean=1, std=0.02) |
|
init.constant_(m.bias.data, 0) |
|
|
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight.data) |
|
init.normal_(m.bias.data) |
|
else: |
|
pass |
|
|