import torch import torch.nn as nn import torch.nn.init as init from .nets.backbone import HourglassBackbone, SuperpointBackbone from .nets.junction_decoder import SuperpointDecoder from .nets.heatmap_decoder import PixelShuffleDecoder from .nets.descriptor_decoder import SuperpointDescriptor def get_model(model_cfg=None, loss_weights=None, mode="train"): """Get model based on the model configuration.""" # Check dataset config is given if model_cfg is None: raise ValueError("[Error] The model config is required!") # List the supported options here print("\n\n\t--------Initializing model----------") supported_arch = ["simple"] if not model_cfg["model_architecture"] in supported_arch: raise ValueError("[Error] The model architecture is not in supported arch!") if model_cfg["model_architecture"] == "simple": model = SOLD2Net(model_cfg) else: raise ValueError("[Error] The model architecture is not in supported arch!") # Optionally register loss weights to the model if mode == "train": if loss_weights is not None: for param_name, param in loss_weights.items(): if isinstance(param, nn.Parameter): print( "\t [Debug] Adding %s with value %f to model" % (param_name, param.item()) ) model.register_parameter(param_name, param) else: raise ValueError( "[Error] the loss weights can not be None in dynamic weighting mode during training." ) # Display some summary info. print("\tModel architecture: %s" % model_cfg["model_architecture"]) print("\tBackbone: %s" % model_cfg["backbone"]) print("\tJunction decoder: %s" % model_cfg["junction_decoder"]) print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"]) print("\t-------------------------------------") return model class SOLD2Net(nn.Module): """Full network for SOLDĀ².""" def __init__(self, model_cfg): super(SOLD2Net, self).__init__() self.name = model_cfg["model_name"] self.cfg = model_cfg # List supported network options self.supported_backbone = ["lcnn", "superpoint"] self.backbone_net, self.feat_channel = self.get_backbone() # List supported junction decoder options self.supported_junction_decoder = ["superpoint_decoder"] self.junction_decoder = self.get_junction_decoder() # List supported heatmap decoder options self.supported_heatmap_decoder = ["pixel_shuffle", "pixel_shuffle_single"] self.heatmap_decoder = self.get_heatmap_decoder() # List supported descriptor decoder options if "descriptor_decoder" in self.cfg: self.supported_descriptor_decoder = ["superpoint_descriptor"] self.descriptor_decoder = self.get_descriptor_decoder() # Initialize the model weights self.apply(weight_init) def forward(self, input_images): # The backbone features = self.backbone_net(input_images) # junction decoder junctions = self.junction_decoder(features) # heatmap decoder heatmaps = self.heatmap_decoder(features) outputs = {"junctions": junctions, "heatmap": heatmaps} # Descriptor decoder if "descriptor_decoder" in self.cfg: outputs["descriptors"] = self.descriptor_decoder(features) return outputs def get_backbone(self): """Retrieve the backbone encoder network.""" if not self.cfg["backbone"] in self.supported_backbone: raise ValueError("[Error] The backbone selection is not supported.") # lcnn backbone (stacked hourglass) if self.cfg["backbone"] == "lcnn": backbone_cfg = self.cfg["backbone_cfg"] backbone = HourglassBackbone(**backbone_cfg) feat_channel = 256 elif self.cfg["backbone"] == "superpoint": backbone_cfg = self.cfg["backbone_cfg"] backbone = SuperpointBackbone() feat_channel = 128 else: raise ValueError("[Error] The backbone selection is not supported.") return backbone, feat_channel def get_junction_decoder(self): """Get the junction decoder.""" if not self.cfg["junction_decoder"] in self.supported_junction_decoder: raise ValueError("[Error] The junction decoder selection is not supported.") # superpoint decoder if self.cfg["junction_decoder"] == "superpoint_decoder": decoder = SuperpointDecoder(self.feat_channel, self.cfg["backbone"]) else: raise ValueError("[Error] The junction decoder selection is not supported.") return decoder def get_heatmap_decoder(self): """Get the heatmap decoder.""" if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: raise ValueError("[Error] The heatmap decoder selection is not supported.") # Pixel_shuffle decoder if self.cfg["heatmap_decoder"] == "pixel_shuffle": if self.cfg["backbone"] == "lcnn": decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=2) elif self.cfg["backbone"] == "superpoint": decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=3) else: raise ValueError("[Error] Unknown backbone option.") # Pixel_shuffle decoder with single channel output elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": if self.cfg["backbone"] == "lcnn": decoder = PixelShuffleDecoder( self.feat_channel, num_upsample=2, output_channel=1 ) elif self.cfg["backbone"] == "superpoint": decoder = PixelShuffleDecoder( self.feat_channel, num_upsample=3, output_channel=1 ) else: raise ValueError("[Error] Unknown backbone option.") else: raise ValueError("[Error] The heatmap decoder selection is not supported.") return decoder def get_descriptor_decoder(self): """Get the descriptor decoder.""" if not self.cfg["descriptor_decoder"] in self.supported_descriptor_decoder: raise ValueError( "[Error] The descriptor decoder selection is not supported." ) # SuperPoint descriptor if self.cfg["descriptor_decoder"] == "superpoint_descriptor": decoder = SuperpointDescriptor(self.feat_channel) else: raise ValueError( "[Error] The descriptor decoder selection is not supported." ) return decoder def weight_init(m): """Weight initialization function.""" # Conv2D if isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) # Batchnorm elif isinstance(m, nn.BatchNorm2d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) # Linear elif isinstance(m, nn.Linear): init.xavier_normal_(m.weight.data) init.normal_(m.bias.data) else: pass