hugoycj commited on
Commit
3d78bd9
1 Parent(s): 22ec042

fix: Fix mast3r

Browse files
Files changed (1) hide show
  1. mast3r/model.py +9 -7
mast3r/model.py CHANGED
@@ -13,15 +13,17 @@ from mast3r.catmlp_dpt_head import mast3r_head_factory
13
  import mast3r.utils.path_to_dust3r # noqa
14
  from dust3r.model import AsymmetricCroCo3DStereo # noqa
15
  from dust3r.utils.misc import transpose_to_landscape # noqa
16
-
17
 
18
  inf = float('inf')
19
 
20
 
21
- def load_model(model_path, device, verbose=True):
22
  if verbose:
23
- print('... loading model from', model_path)
24
- ckpt = torch.load(model_path, map_location='cpu')
 
 
25
  args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
26
  if 'landscape_only' not in args:
27
  args = args[:-1] + ', landscape_only=False)'
@@ -46,10 +48,10 @@ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
46
 
47
  @classmethod
48
  def from_pretrained(cls, pretrained_model_name_or_path, **kw):
49
- if os.path.isfile(pretrained_model_name_or_path):
50
- return load_model(pretrained_model_name_or_path, device='cpu')
51
  else:
52
- return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
53
 
54
  def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
55
  assert img_size[0] % patch_size == 0 and img_size[
 
13
  import mast3r.utils.path_to_dust3r # noqa
14
  from dust3r.model import AsymmetricCroCo3DStereo # noqa
15
  from dust3r.utils.misc import transpose_to_landscape # noqa
16
+ import urllib
17
 
18
  inf = float('inf')
19
 
20
 
21
+ def load_model(model_url, device, landscape_only=False, verbose=True):
22
  if verbose:
23
+ print('... loading model from', model_url)
24
+
25
+ ckpt = torch.hub.load_state_dict_from_url(model_url, map_location='cpu', progress=verbose)
26
+
27
  args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
28
  if 'landscape_only' not in args:
29
  args = args[:-1] + ', landscape_only=False)'
 
48
 
49
  @classmethod
50
  def from_pretrained(cls, pretrained_model_name_or_path, **kw):
51
+ if os.path.isfile(pretrained_model_name_or_path) or urllib.parse.urlparse(pretrained_model_name_or_path).scheme in ('http', 'https'):
52
+ return load_model(pretrained_model_name_or_path, device='cpu', **kw)
53
  else:
54
+ return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw)
55
 
56
  def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
57
  assert img_size[0] % patch_size == 0 and img_size[