import gradio as gr import spaces import torch from gradio_rerun import Rerun import rerun as rr import rerun.blueprint as rrb from pathlib import Path import uuid from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result from mini_dust3r.model import AsymmetricCroCo3DStereo from mini_dust3r.utils.misc import ( fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape, ) import os from mini_dust3r.model import load_model from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess DEVICE = "cuda" if torch.cuda.is_available() else "CPU" # model = AsymmetricCroCo3DStereo.from_pretrained( # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" # ).to(DEVICE) from mini_dust3r.heads.linear_head import LinearPts3d from mini_dust3r.heads.dpt_head import create_dpt_head def head_factory(head_type, output_mode, net, has_conf=False): """" build a prediction head for the decoder """ if head_type == 'linear' and output_mode == 'pts3d': return LinearPts3d(net, has_conf) elif head_type == 'dpt' and output_mode == 'pts3d': return create_dpt_head(net, has_conf=has_conf) if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): local_feat_dim = int(output_mode[10:]) assert net.dec_depth > 9 l2 = net.dec_depth feature_dim = 256 last_dim = feature_dim // 2 out_nchan = 3 ed = net.enc_embed_dim dd = net.dec_embed_dim return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, num_channels=out_nchan + has_conf, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], dim_tokens=[ed, dd, dd, dd], postprocess=postprocess, depth_mode=net.depth_mode, conf_mode=net.conf_mode, head_type='regression') else: raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") class AsymmetricMASt3R(AsymmetricCroCo3DStereo): def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): self.desc_mode = desc_mode self.two_confs = two_confs self.desc_conf_mode = desc_conf_mode super().__init__(**kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device='cpu') else: return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): assert img_size[0] % patch_size == 0 and img_size[ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode if self.desc_conf_mode is None: self.desc_conf_mode = conf_mode # allocate heads self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) # magic wrapper self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) model = AsymmetricMASt3R.from_pretrained( "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to(DEVICE) def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint: # dont show 2d views if there are more than 4 images as to not clutter the view if len(image_name_list) > 4: blueprint = rrb.Blueprint( rrb.Horizontal( rrb.Spatial3DView(origin=f"{log_path}"), ), collapse_panels=True, ) else: blueprint = rrb.Blueprint( rrb.Horizontal( contents=[ rrb.Spatial3DView(origin=f"{log_path}"), rrb.Vertical( contents=[ rrb.Spatial2DView( origin=f"{log_path}/camera_{i}/pinhole/", contents=[ "+ $origin/**", ], ) for i in range(len(image_name_list)) ] ), ], column_shares=[3, 1], ), collapse_panels=True, ) return blueprint @spaces.GPU def predict(image_name_list: list[str] | str): # check if is list or string and if not raise error if not isinstance(image_name_list, list) and not isinstance(image_name_list, str): raise gr.Error( f"Input must be a list of strings or a string, got: {type(image_name_list)}" ) uuid_str = str(uuid.uuid4()) filename = Path(f"/tmp/gradio/{uuid_str}.rrd") rr.init(f"{uuid_str}") log_path = Path("world") if isinstance(image_name_list, str): image_name_list = [image_name_list] optimized_results: OptimizedResult = inferece_dust3r( image_dir_or_list=image_name_list, model=model, device=DEVICE, batch_size=1, ) blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path) rr.send_blueprint(blueprint) rr.set_time_sequence("sequence", 0) log_optimized_result(optimized_results, log_path) rr.save(filename.as_posix()) return filename.as_posix() with gr.Blocks( css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="Mini-DUSt3R Demo", ) as demo: # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference gr.HTML('
Unofficial DUSt3R demo using the mini-dust3r pip package
' ) gr.HTML( 'More info here
' ) with gr.Tab(label="Single Image"): with gr.Column(): single_image = gr.Image(type="filepath", height=300) run_btn_single = gr.Button("Run") rerun_viewer_single = Rerun(height=900) run_btn_single.click( fn=predict, inputs=[single_image], outputs=[rerun_viewer_single] ) example_single_dir = Path("examples/single_image") example_single_files = sorted(example_single_dir.glob("*.png")) examples_single = gr.Examples( examples=example_single_files, inputs=[single_image], outputs=[rerun_viewer_single], fn=predict, cache_examples="lazy", ) with gr.Tab(label="Multi Image"): with gr.Column(): multi_files = gr.File(file_count="multiple") run_btn_multi = gr.Button("Run") rerun_viewer_multi = Rerun(height=900) run_btn_multi.click( fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi] ) demo.launch()