Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
from pathlib import Path | |
import torchvision.transforms as tfm | |
import torch.nn.functional as F | |
import urllib.request | |
import numpy as np | |
from ..utils.base_model import BaseModel | |
from .. import logger | |
duster_path = Path(__file__).parent / "../../third_party/dust3r" | |
sys.path.append(str(duster_path)) | |
from dust3r.inference import inference | |
from dust3r.model import load_model | |
from dust3r.image_pairs import make_pairs | |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Duster(BaseModel): | |
default_conf = { | |
"name": "Duster3r", | |
"model_path": duster_path / "model_weights/duster_vit_large.pth", | |
"max_keypoints": 3000, | |
"vit_patch_size": 16, | |
} | |
def _init(self, conf): | |
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
self.model_path = self.conf["model_path"] | |
self.download_weights() | |
self.net = load_model(self.model_path, device) | |
logger.info(f"Loaded Dust3r model") | |
def download_weights(self): | |
url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" | |
self.model_path.parent.mkdir(parents=True, exist_ok=True) | |
if not os.path.isfile(self.model_path): | |
logger.info("Downloading Duster(ViT large)... (takes a while)") | |
urllib.request.urlretrieve(url, self.model_path) | |
def preprocess(self, img): | |
# the super-class already makes sure that img0,img1 have | |
# same resolution and that h == w | |
_, h, _ = img.shape | |
imsize = h | |
if not ((h % self.vit_patch_size) == 0): | |
imsize = int( | |
self.vit_patch_size * round(h / self.vit_patch_size, 0) | |
) | |
img = tfm.functional.resize(img, imsize, antialias=True) | |
_, new_h, new_w = img.shape | |
if not ((new_w % self.vit_patch_size) == 0): | |
safe_w = int( | |
self.vit_patch_size * round(new_w / self.vit_patch_size, 0) | |
) | |
img = tfm.functional.resize(img, (new_h, safe_w), antialias=True) | |
img = self.normalize(img).unsqueeze(0) | |
return img | |
def _forward(self, data): | |
img0, img1 = data["image0"], data["image1"] | |
# img0 = self.preprocess(img0) | |
# img1 = self.preprocess(img1) | |
images = [ | |
{"img": img0, "idx": 0, "instance": 0}, | |
{"img": img1, "idx": 1, "instance": 1}, | |
] | |
pairs = make_pairs( | |
images, scene_graph="complete", prefilter=None, symmetrize=True | |
) | |
output = inference(pairs, self.net, device, batch_size=1) | |
scene = global_aligner( | |
output, device=device, mode=GlobalAlignerMode.PairViewer | |
) | |
batch_size = 1 | |
schedule = "cosine" | |
lr = 0.01 | |
niter = 300 | |
loss = scene.compute_global_alignment( | |
init="mst", niter=niter, schedule=schedule, lr=lr | |
) | |
# retrieve useful values from scene: | |
confidence_masks = scene.get_masks() | |
pts3d = scene.get_pts3d() | |
imgs = scene.imgs | |
pts2d_list, pts3d_list = [], [] | |
for i in range(2): | |
conf_i = confidence_masks[i].cpu().numpy() | |
pts2d_list.append( | |
xy_grid(*imgs[i].shape[:2][::-1])[conf_i] | |
) # imgs[i].shape[:2] = (H, W) | |
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) | |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches( | |
*pts3d_list | |
) | |
logger.info(f"Found {num_matches} matches") | |
mkpts1 = pts2d_list[1][reciprocal_in_P2] | |
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] | |
top_k = self.conf["max_keypoints"] | |
if top_k is not None and len(mkpts0) > top_k: | |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int) | |
mkpts0 = mkpts0[keep] | |
mkpts1 = mkpts1[keep] | |
pred = { | |
"keypoints0": torch.from_numpy(mkpts0), | |
"keypoints1": torch.from_numpy(mkpts1), | |
} | |
return pred | |