Spaces:
Running
Running
import sys | |
from pathlib import Path | |
from ..utils.base_model import BaseModel | |
import torch | |
from ..utils.base_model import BaseModel | |
from .. import logger | |
import subprocess | |
sold2_path = Path(__file__).parent / "../../third_party/SOLD2" | |
sys.path.append(str(sold2_path)) | |
from sold2.model.line_matcher import LineMatcher | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class SOLD2(BaseModel): | |
default_conf = { | |
"weights": "sold2_wireframe.tar", | |
"match_threshold": 0.2, | |
"checkpoint_dir": sold2_path / "pretrained", | |
"detect_thresh": 0.25, | |
"multiscale": False, | |
"valid_thresh": 1e-3, | |
"num_blocks": 20, | |
"overlap_ratio": 0.5, | |
} | |
required_inputs = [ | |
"image0", | |
"image1", | |
] | |
weight_urls = { | |
"sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download", | |
} | |
# Initialize the line matcher | |
def _init(self, conf): | |
checkpoint_path = conf["checkpoint_dir"] / conf["weights"] | |
# Download the model. | |
if not checkpoint_path.exists(): | |
checkpoint_path.parent.mkdir(exist_ok=True) | |
link = self.weight_urls[conf["weights"]] | |
cmd = ["wget", link, "-O", str(checkpoint_path)] | |
logger.info(f"Downloading the SOLD2 model with `{cmd}`.") | |
subprocess.run(cmd, check=True) | |
mode = "dynamic" # 'dynamic' or 'static' | |
match_config = { | |
"model_cfg": { | |
"model_name": "lcnn_simple", | |
"model_architecture": "simple", | |
# Backbone related config | |
"backbone": "lcnn", | |
"backbone_cfg": { | |
"input_channel": 1, # Use RGB images or grayscale images. | |
"depth": 4, | |
"num_stacks": 2, | |
"num_blocks": 1, | |
"num_classes": 5, | |
}, | |
# Junction decoder related config | |
"junction_decoder": "superpoint_decoder", | |
"junc_decoder_cfg": {}, | |
# Heatmap decoder related config | |
"heatmap_decoder": "pixel_shuffle", | |
"heatmap_decoder_cfg": {}, | |
# Descriptor decoder related config | |
"descriptor_decoder": "superpoint_descriptor", | |
"descriptor_decoder_cfg": {}, | |
# Shared configurations | |
"grid_size": 8, | |
"keep_border_valid": True, | |
# Threshold of junction detection | |
"detection_thresh": 0.0153846, # 1/65 | |
"max_num_junctions": 300, | |
# Threshold of heatmap detection | |
"prob_thresh": 0.5, | |
# Weighting related parameters | |
"weighting_policy": mode, | |
# [Heatmap loss] | |
"w_heatmap": 0.0, | |
"w_heatmap_class": 1, | |
"heatmap_loss_func": "cross_entropy", | |
"heatmap_loss_cfg": {"policy": mode}, | |
# [Heatmap consistency loss] | |
# [Junction loss] | |
"w_junc": 0.0, | |
"junction_loss_func": "superpoint", | |
"junction_loss_cfg": {"policy": mode}, | |
# [Descriptor loss] | |
"w_desc": 0.0, | |
"descriptor_loss_func": "regular_sampling", | |
"descriptor_loss_cfg": { | |
"dist_threshold": 8, | |
"grid_size": 4, | |
"margin": 1, | |
"policy": mode, | |
}, | |
}, | |
"line_detector_cfg": { | |
"detect_thresh": 0.25, # depending on your images, you might need to tune this parameter | |
"num_samples": 64, | |
"sampling_method": "local_max", | |
"inlier_thresh": 0.9, | |
"use_candidate_suppression": True, | |
"nms_dist_tolerance": 3.0, | |
"use_heatmap_refinement": True, | |
"heatmap_refine_cfg": { | |
"mode": "local", | |
"ratio": 0.2, | |
"valid_thresh": 1e-3, | |
"num_blocks": 20, | |
"overlap_ratio": 0.5, | |
}, | |
}, | |
"multiscale": False, | |
"line_matcher_cfg": { | |
"cross_check": True, | |
"num_samples": 5, | |
"min_dist_pts": 8, | |
"top_k_candidates": 10, | |
"grid_size": 4, | |
}, | |
} | |
self.net = LineMatcher( | |
match_config["model_cfg"], | |
checkpoint_path, | |
device, | |
match_config["line_detector_cfg"], | |
match_config["line_matcher_cfg"], | |
match_config["multiscale"], | |
) | |
def _forward(self, data): | |
img0 = data["image0"] | |
img1 = data["image1"] | |
pred = self.net([img0, img1]) | |
line_seg1 = pred["line_segments"][0] | |
line_seg2 = pred["line_segments"][1] | |
matches = pred["matches"] | |
valid_matches = matches != -1 | |
match_indices = matches[valid_matches] | |
matched_lines1 = line_seg1[valid_matches][:, :, ::-1] | |
matched_lines2 = line_seg2[match_indices][:, :, ::-1] | |
pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2 | |
pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2 | |
pred = {**pred, **data} | |
return pred | |