Spaces:
Sleeping
Sleeping
hugoycj
commited on
Commit
•
3d78bd9
1
Parent(s):
22ec042
fix: Fix mast3r
Browse files- mast3r/model.py +9 -7
mast3r/model.py
CHANGED
@@ -13,15 +13,17 @@ from mast3r.catmlp_dpt_head import mast3r_head_factory
|
|
13 |
import mast3r.utils.path_to_dust3r # noqa
|
14 |
from dust3r.model import AsymmetricCroCo3DStereo # noqa
|
15 |
from dust3r.utils.misc import transpose_to_landscape # noqa
|
16 |
-
|
17 |
|
18 |
inf = float('inf')
|
19 |
|
20 |
|
21 |
-
def load_model(
|
22 |
if verbose:
|
23 |
-
print('... loading model from',
|
24 |
-
|
|
|
|
|
25 |
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
26 |
if 'landscape_only' not in args:
|
27 |
args = args[:-1] + ', landscape_only=False)'
|
@@ -46,10 +48,10 @@ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
|
|
46 |
|
47 |
@classmethod
|
48 |
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
49 |
-
if os.path.isfile(pretrained_model_name_or_path):
|
50 |
-
return load_model(pretrained_model_name_or_path, device='cpu')
|
51 |
else:
|
52 |
-
return super(
|
53 |
|
54 |
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
|
55 |
assert img_size[0] % patch_size == 0 and img_size[
|
|
|
13 |
import mast3r.utils.path_to_dust3r # noqa
|
14 |
from dust3r.model import AsymmetricCroCo3DStereo # noqa
|
15 |
from dust3r.utils.misc import transpose_to_landscape # noqa
|
16 |
+
import urllib
|
17 |
|
18 |
inf = float('inf')
|
19 |
|
20 |
|
21 |
+
def load_model(model_url, device, landscape_only=False, verbose=True):
|
22 |
if verbose:
|
23 |
+
print('... loading model from', model_url)
|
24 |
+
|
25 |
+
ckpt = torch.hub.load_state_dict_from_url(model_url, map_location='cpu', progress=verbose)
|
26 |
+
|
27 |
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
28 |
if 'landscape_only' not in args:
|
29 |
args = args[:-1] + ', landscape_only=False)'
|
|
|
48 |
|
49 |
@classmethod
|
50 |
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
51 |
+
if os.path.isfile(pretrained_model_name_or_path) or urllib.parse.urlparse(pretrained_model_name_or_path).scheme in ('http', 'https'):
|
52 |
+
return load_model(pretrained_model_name_or_path, device='cpu', **kw)
|
53 |
else:
|
54 |
+
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw)
|
55 |
|
56 |
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
|
57 |
assert img_size[0] % patch_size == 0 and img_size[
|