Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
from .nets.backbone import HourglassBackbone, SuperpointBackbone | |
from .nets.junction_decoder import SuperpointDecoder | |
from .nets.heatmap_decoder import PixelShuffleDecoder | |
from .nets.descriptor_decoder import SuperpointDescriptor | |
def get_model(model_cfg=None, loss_weights=None, mode="train"): | |
"""Get model based on the model configuration.""" | |
# Check dataset config is given | |
if model_cfg is None: | |
raise ValueError("[Error] The model config is required!") | |
# List the supported options here | |
print("\n\n\t--------Initializing model----------") | |
supported_arch = ["simple"] | |
if not model_cfg["model_architecture"] in supported_arch: | |
raise ValueError("[Error] The model architecture is not in supported arch!") | |
if model_cfg["model_architecture"] == "simple": | |
model = SOLD2Net(model_cfg) | |
else: | |
raise ValueError("[Error] The model architecture is not in supported arch!") | |
# Optionally register loss weights to the model | |
if mode == "train": | |
if loss_weights is not None: | |
for param_name, param in loss_weights.items(): | |
if isinstance(param, nn.Parameter): | |
print( | |
"\t [Debug] Adding %s with value %f to model" | |
% (param_name, param.item()) | |
) | |
model.register_parameter(param_name, param) | |
else: | |
raise ValueError( | |
"[Error] the loss weights can not be None in dynamic weighting mode during training." | |
) | |
# Display some summary info. | |
print("\tModel architecture: %s" % model_cfg["model_architecture"]) | |
print("\tBackbone: %s" % model_cfg["backbone"]) | |
print("\tJunction decoder: %s" % model_cfg["junction_decoder"]) | |
print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"]) | |
print("\t-------------------------------------") | |
return model | |
class SOLD2Net(nn.Module): | |
"""Full network for SOLD².""" | |
def __init__(self, model_cfg): | |
super(SOLD2Net, self).__init__() | |
self.name = model_cfg["model_name"] | |
self.cfg = model_cfg | |
# List supported network options | |
self.supported_backbone = ["lcnn", "superpoint"] | |
self.backbone_net, self.feat_channel = self.get_backbone() | |
# List supported junction decoder options | |
self.supported_junction_decoder = ["superpoint_decoder"] | |
self.junction_decoder = self.get_junction_decoder() | |
# List supported heatmap decoder options | |
self.supported_heatmap_decoder = ["pixel_shuffle", "pixel_shuffle_single"] | |
self.heatmap_decoder = self.get_heatmap_decoder() | |
# List supported descriptor decoder options | |
if "descriptor_decoder" in self.cfg: | |
self.supported_descriptor_decoder = ["superpoint_descriptor"] | |
self.descriptor_decoder = self.get_descriptor_decoder() | |
# Initialize the model weights | |
self.apply(weight_init) | |
def forward(self, input_images): | |
# The backbone | |
features = self.backbone_net(input_images) | |
# junction decoder | |
junctions = self.junction_decoder(features) | |
# heatmap decoder | |
heatmaps = self.heatmap_decoder(features) | |
outputs = {"junctions": junctions, "heatmap": heatmaps} | |
# Descriptor decoder | |
if "descriptor_decoder" in self.cfg: | |
outputs["descriptors"] = self.descriptor_decoder(features) | |
return outputs | |
def get_backbone(self): | |
"""Retrieve the backbone encoder network.""" | |
if not self.cfg["backbone"] in self.supported_backbone: | |
raise ValueError("[Error] The backbone selection is not supported.") | |
# lcnn backbone (stacked hourglass) | |
if self.cfg["backbone"] == "lcnn": | |
backbone_cfg = self.cfg["backbone_cfg"] | |
backbone = HourglassBackbone(**backbone_cfg) | |
feat_channel = 256 | |
elif self.cfg["backbone"] == "superpoint": | |
backbone_cfg = self.cfg["backbone_cfg"] | |
backbone = SuperpointBackbone() | |
feat_channel = 128 | |
else: | |
raise ValueError("[Error] The backbone selection is not supported.") | |
return backbone, feat_channel | |
def get_junction_decoder(self): | |
"""Get the junction decoder.""" | |
if not self.cfg["junction_decoder"] in self.supported_junction_decoder: | |
raise ValueError("[Error] The junction decoder selection is not supported.") | |
# superpoint decoder | |
if self.cfg["junction_decoder"] == "superpoint_decoder": | |
decoder = SuperpointDecoder(self.feat_channel, self.cfg["backbone"]) | |
else: | |
raise ValueError("[Error] The junction decoder selection is not supported.") | |
return decoder | |
def get_heatmap_decoder(self): | |
"""Get the heatmap decoder.""" | |
if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: | |
raise ValueError("[Error] The heatmap decoder selection is not supported.") | |
# Pixel_shuffle decoder | |
if self.cfg["heatmap_decoder"] == "pixel_shuffle": | |
if self.cfg["backbone"] == "lcnn": | |
decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=2) | |
elif self.cfg["backbone"] == "superpoint": | |
decoder = PixelShuffleDecoder(self.feat_channel, num_upsample=3) | |
else: | |
raise ValueError("[Error] Unknown backbone option.") | |
# Pixel_shuffle decoder with single channel output | |
elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": | |
if self.cfg["backbone"] == "lcnn": | |
decoder = PixelShuffleDecoder( | |
self.feat_channel, num_upsample=2, output_channel=1 | |
) | |
elif self.cfg["backbone"] == "superpoint": | |
decoder = PixelShuffleDecoder( | |
self.feat_channel, num_upsample=3, output_channel=1 | |
) | |
else: | |
raise ValueError("[Error] Unknown backbone option.") | |
else: | |
raise ValueError("[Error] The heatmap decoder selection is not supported.") | |
return decoder | |
def get_descriptor_decoder(self): | |
"""Get the descriptor decoder.""" | |
if not self.cfg["descriptor_decoder"] in self.supported_descriptor_decoder: | |
raise ValueError( | |
"[Error] The descriptor decoder selection is not supported." | |
) | |
# SuperPoint descriptor | |
if self.cfg["descriptor_decoder"] == "superpoint_descriptor": | |
decoder = SuperpointDescriptor(self.feat_channel) | |
else: | |
raise ValueError( | |
"[Error] The descriptor decoder selection is not supported." | |
) | |
return decoder | |
def weight_init(m): | |
"""Weight initialization function.""" | |
# Conv2D | |
if isinstance(m, nn.Conv2d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
# Batchnorm | |
elif isinstance(m, nn.BatchNorm2d): | |
init.normal_(m.weight.data, mean=1, std=0.02) | |
init.constant_(m.bias.data, 0) | |
# Linear | |
elif isinstance(m, nn.Linear): | |
init.xavier_normal_(m.weight.data) | |
init.normal_(m.bias.data) | |
else: | |
pass | |