HumanWild / app_mast3r.py
geyongtao's picture
Rename app.py to app_mast3r.py
007db37 verified
raw
history blame contribute delete
No virus
7.89 kB
import gradio as gr
import spaces
import torch
from gradio_rerun import Rerun
import rerun as rr
import rerun.blueprint as rrb
from pathlib import Path
import uuid
from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
from mini_dust3r.model import AsymmetricCroCo3DStereo
from mini_dust3r.utils.misc import (
fill_default_args,
freeze_all_params,
is_symmetrized,
interleave,
transpose_to_landscape,
)
import os
from mini_dust3r.model import load_model
from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess
DEVICE = "cuda" if torch.cuda.is_available() else "CPU"
# model = AsymmetricCroCo3DStereo.from_pretrained(
# "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
# ).to(DEVICE)
from mini_dust3r.heads.linear_head import LinearPts3d
from mini_dust3r.heads.dpt_head import create_dpt_head
def head_factory(head_type, output_mode, net, has_conf=False):
"""" build a prediction head for the decoder
"""
if head_type == 'linear' and output_mode == 'pts3d':
return LinearPts3d(net, has_conf)
elif head_type == 'dpt' and output_mode == 'pts3d':
return create_dpt_head(net, has_conf=has_conf)
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
local_feat_dim = int(output_mode[10:])
assert net.dec_depth > 9
l2 = net.dec_depth
feature_dim = 256
last_dim = feature_dim // 2
out_nchan = 3
ed = net.enc_embed_dim
dd = net.dec_embed_dim
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
num_channels=out_nchan + has_conf,
feature_dim=feature_dim,
last_dim=last_dim,
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
dim_tokens=[ed, dd, dd, dd],
postprocess=postprocess,
depth_mode=net.depth_mode,
conf_mode=net.conf_mode,
head_type='regression')
else:
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
self.desc_mode = desc_mode
self.two_confs = two_confs
self.desc_conf_mode = desc_conf_mode
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
if os.path.isfile(pretrained_model_name_or_path):
return load_model(pretrained_model_name_or_path, device='cpu')
else:
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
assert img_size[0] % patch_size == 0 and img_size[
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
self.output_mode = output_mode
self.head_type = head_type
self.depth_mode = depth_mode
self.conf_mode = conf_mode
if self.desc_conf_mode is None:
self.desc_conf_mode = conf_mode
# allocate heads
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
# magic wrapper
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
model = AsymmetricMASt3R.from_pretrained(
"naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to(DEVICE)
def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
# dont show 2d views if there are more than 4 images as to not clutter the view
if len(image_name_list) > 4:
blueprint = rrb.Blueprint(
rrb.Horizontal(
rrb.Spatial3DView(origin=f"{log_path}"),
),
collapse_panels=True,
)
else:
blueprint = rrb.Blueprint(
rrb.Horizontal(
contents=[
rrb.Spatial3DView(origin=f"{log_path}"),
rrb.Vertical(
contents=[
rrb.Spatial2DView(
origin=f"{log_path}/camera_{i}/pinhole/",
contents=[
"+ $origin/**",
],
)
for i in range(len(image_name_list))
]
),
],
column_shares=[3, 1],
),
collapse_panels=True,
)
return blueprint
@spaces.GPU
def predict(image_name_list: list[str] | str):
# check if is list or string and if not raise error
if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
raise gr.Error(
f"Input must be a list of strings or a string, got: {type(image_name_list)}"
)
uuid_str = str(uuid.uuid4())
filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
rr.init(f"{uuid_str}")
log_path = Path("world")
if isinstance(image_name_list, str):
image_name_list = [image_name_list]
optimized_results: OptimizedResult = inferece_dust3r(
image_dir_or_list=image_name_list,
model=model,
device=DEVICE,
batch_size=1,
)
blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
rr.send_blueprint(blueprint)
rr.set_time_sequence("sequence", 0)
log_optimized_result(optimized_results, log_path)
rr.save(filename.as_posix())
return filename.as_posix()
with gr.Blocks(
css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
title="Mini-DUSt3R Demo",
) as demo:
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
gr.HTML(
'<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
)
gr.HTML(
'<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
)
with gr.Tab(label="Single Image"):
with gr.Column():
single_image = gr.Image(type="filepath", height=300)
run_btn_single = gr.Button("Run")
rerun_viewer_single = Rerun(height=900)
run_btn_single.click(
fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
)
example_single_dir = Path("examples/single_image")
example_single_files = sorted(example_single_dir.glob("*.png"))
examples_single = gr.Examples(
examples=example_single_files,
inputs=[single_image],
outputs=[rerun_viewer_single],
fn=predict,
cache_examples="lazy",
)
with gr.Tab(label="Multi Image"):
with gr.Column():
multi_files = gr.File(file_count="multiple")
run_btn_multi = gr.Button("Run")
rerun_viewer_multi = Rerun(height=900)
run_btn_multi.click(
fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
)
demo.launch()