Spaces:
Sleeping
Sleeping
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# MASt3R model class | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn.functional as F | |
import os | |
from mast3r.catmlp_dpt_head import mast3r_head_factory | |
import mast3r.utils.path_to_dust3r # noqa | |
from dust3r.model import AsymmetricCroCo3DStereo # noqa | |
from dust3r.utils.misc import transpose_to_landscape # noqa | |
import urllib | |
inf = float('inf') | |
def load_model(model_url, device, landscape_only=False, verbose=True): | |
if verbose: | |
print('... loading model from', model_url) | |
ckpt = torch.hub.load_state_dict_from_url(model_url, map_location='cpu', progress=verbose) | |
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") | |
if 'landscape_only' not in args: | |
args = args[:-1] + ', landscape_only=False)' | |
else: | |
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') | |
assert "landscape_only=False" in args | |
if verbose: | |
print(f"instantiating : {args}") | |
net = eval(args) | |
s = net.load_state_dict(ckpt['model'], strict=False) | |
if verbose: | |
print(s) | |
return net.to(device) | |
class AsymmetricMASt3R(AsymmetricCroCo3DStereo): | |
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): | |
self.desc_mode = desc_mode | |
self.two_confs = two_confs | |
self.desc_conf_mode = desc_conf_mode | |
super().__init__(**kwargs) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kw): | |
if os.path.isfile(pretrained_model_name_or_path) or urllib.parse.urlparse(pretrained_model_name_or_path).scheme in ('http', 'https'): | |
return load_model(pretrained_model_name_or_path, device='cpu', **kw) | |
else: | |
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) | |
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): | |
assert img_size[0] % patch_size == 0 and img_size[ | |
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' | |
self.output_mode = output_mode | |
self.head_type = head_type | |
self.depth_mode = depth_mode | |
self.conf_mode = conf_mode | |
if self.desc_conf_mode is None: | |
self.desc_conf_mode = conf_mode | |
# allocate heads | |
self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | |
self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | |
# magic wrapper | |
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) | |
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) | |