|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | from mast3r.catmlp_dpt_head import mast3r_head_factory | 
					
						
						|  |  | 
					
						
						|  | import mast3r.utils.path_to_dust3r | 
					
						
						|  | from dust3r.model import AsymmetricCroCo3DStereo | 
					
						
						|  | from dust3r.utils.misc import transpose_to_landscape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | inf = float('inf') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_model(model_path, device, verbose=True): | 
					
						
						|  | if verbose: | 
					
						
						|  | print('... loading model from', model_path) | 
					
						
						|  | ckpt = torch.load(model_path, map_location='cpu') | 
					
						
						|  | args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") | 
					
						
						|  | if 'landscape_only' not in args: | 
					
						
						|  | args = args[:-1] + ', landscape_only=False)' | 
					
						
						|  | else: | 
					
						
						|  | args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') | 
					
						
						|  | assert "landscape_only=False" in args | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"instantiating : {args}") | 
					
						
						|  | net = eval(args) | 
					
						
						|  | s = net.load_state_dict(ckpt['model'], strict=False) | 
					
						
						|  | if verbose: | 
					
						
						|  | print(s) | 
					
						
						|  | return net.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AsymmetricMASt3R(AsymmetricCroCo3DStereo): | 
					
						
						|  | def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): | 
					
						
						|  | self.desc_mode = desc_mode | 
					
						
						|  | self.two_confs = two_confs | 
					
						
						|  | self.desc_conf_mode = desc_conf_mode | 
					
						
						|  | super().__init__(**kwargs) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_pretrained(cls, pretrained_model_name_or_path, **kw): | 
					
						
						|  | if os.path.isfile(pretrained_model_name_or_path): | 
					
						
						|  | return load_model(pretrained_model_name_or_path, device='cpu') | 
					
						
						|  | else: | 
					
						
						|  | return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) | 
					
						
						|  |  | 
					
						
						|  | def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): | 
					
						
						|  | assert img_size[0] % patch_size == 0 and img_size[ | 
					
						
						|  | 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' | 
					
						
						|  | self.output_mode = output_mode | 
					
						
						|  | self.head_type = head_type | 
					
						
						|  | self.depth_mode = depth_mode | 
					
						
						|  | self.conf_mode = conf_mode | 
					
						
						|  | if self.desc_conf_mode is None: | 
					
						
						|  | self.desc_conf_mode = conf_mode | 
					
						
						|  |  | 
					
						
						|  | self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | 
					
						
						|  | self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | 
					
						
						|  |  | 
					
						
						|  | self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) | 
					
						
						|  | self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) | 
					
						
						|  |  |