|
import sys |
|
import gradio as gr |
|
import os |
|
import shutil |
|
import json |
|
import argparse |
|
from PIL import Image |
|
import subprocess |
|
from sparseags.dust3r_utils import infer_dust3r |
|
from run import main |
|
import functools |
|
|
|
sys.path[0] = sys.path[0] + '/dust3r' |
|
from dust3r.model import AsymmetricCroCo3DStereo |
|
|
|
|
|
def info_fn(): |
|
gr.Info("Data preprocessing done!") |
|
|
|
|
|
def get_select_index(evt: gr.SelectData): |
|
index = evt.index |
|
cate_list = ['toy', 'butter', 'robot', 'jordan', 'eagle'] |
|
args.num_views = len(examples_full[index][0]) |
|
args.category = cate_list[index] |
|
|
|
return examples_full[index][0], examples_full[index][0] |
|
|
|
|
|
|
|
def check_img_input(control_image): |
|
if control_image is None: |
|
raise gr.Error("Please select or upload an input image") |
|
|
|
|
|
def preprocess(args, dust3r_model, image_block: list): |
|
if os.path.exists('data/demo/custom'): |
|
shutil.rmtree('data/demo/custom') |
|
|
|
if os.path.exists('output/demo/custom'): |
|
shutil.rmtree('output/demo/custom') |
|
|
|
os.makedirs('data/demo/custom/source') |
|
os.makedirs('data/demo/custom/processed') |
|
|
|
file_names = [] |
|
|
|
for file_path in image_block: |
|
file_name = file_path.split("/")[-1] |
|
img_pil = Image.open(file_path) |
|
|
|
|
|
try: |
|
img_pil.save(os.path.join('data/demo/custom', file_name)) |
|
except OSError: |
|
img_pil = img_pil.convert('RGB') |
|
img_pil.save(os.path.join('data/demo/custom', file_name)) |
|
|
|
file_names.append(os.path.join('data/demo/custom/source', file_name.split('.')[0] + '.png')) |
|
|
|
|
|
print(f"python process.py {os.path.join('data/demo/custom', file_name)}") |
|
subprocess.run(f"python process.py {os.path.join('data/demo/custom', file_name)}", shell=True) |
|
|
|
|
|
camera_data = infer_dust3r(dust3r_model, file_names) |
|
with open(os.path.join('data/demo/custom', 'cameras.json'), "w") as f: |
|
json.dump(camera_data, f) |
|
|
|
args.num_views = len(file_names) |
|
args.category = "custom" |
|
|
|
processed_image_block = [] |
|
for file_path in image_block: |
|
out_base = os.path.basename(file_path).split('.')[0] |
|
out_rgba = os.path.join('data/demo/custom/processed', out_base + '_rgba.png') |
|
processed_image_block.append(out_rgba) |
|
|
|
return processed_image_block |
|
|
|
|
|
def run_single_reconstruction(image_block: list): |
|
args.enable_loop = False |
|
main(args) |
|
|
|
return f'output/demo/{args.category}/round_0/{args.category}.glb' |
|
|
|
|
|
def run_full_reconstruction(image_block: list): |
|
args.enable_loop = True |
|
main(args) |
|
|
|
if os.path.exists(f'output/demo/{args.category}/cameras_final_recovered.json'): |
|
return f'output/demo/{args.category}/check_recovered_poses/{args.category}.glb' |
|
elif os.path.exists(f'output/demo/{args.category}/cameras_final_init.json'): |
|
return f'output/demo/{args.category}/reconsider_init_poses/{args.category}.glb' |
|
else: |
|
return f'output/demo/{args.category}/round_1/{args.category}.glb' |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--output', default='output/demo', type=str, help='Directory where obj files will be saved') |
|
parser.add_argument('--category', default='jordan', type=str, help='Directory where obj files will be saved') |
|
parser.add_argument('--num_pts', default=25000, type=int, help='Number of points at initialization') |
|
parser.add_argument('--num_views', default=8, type=int, help='Number of input images') |
|
parser.add_argument('--mesh_format', default='glb', type=str, help='Format of output mesh') |
|
parser.add_argument('--enable_loop', default=True, help='Enable the loop-based strategy to detect and correct outliers') |
|
parser.add_argument('--config', default='navi.yaml', type=str, help='Path to config file') |
|
args = parser.parse_args() |
|
|
|
_TITLE = '''Sparse-view Pose Estimation and Reconstruction via Analysis by Generative Synthesis''' |
|
|
|
_DESCRIPTION = ''' |
|
<div> |
|
<a style="display:inline-block" href="https://qitaozhao.github.io/SparseAGS"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a> |
|
<a style="display:inline-block; margin-left: .5em" href="https://openreview.net/pdf?id=wgpmDyJgsg"><img src="https://img.shields.io/badge/2309.16653-f9f7f7?logo="></a> |
|
<a style="display:inline-block; margin-left: .5em" href='https://github.com/dreamgaussian/dreamgaussian'><img src='https://img.shields.io/github/stars/dreamgaussian/dreamgaussian?style=social'/></a> |
|
</div> |
|
Given a set of unposed input images, SparseAGS jointly infers the corresponding camera poses and underlying 3D, allowing high-fidelity 3D inference in the wild. |
|
''' |
|
_IMG_USER_GUIDE = "Once you see the preprocessed images, you can click **Run Single 3D Reconstruction**. \ |
|
If the reconstructed 3D looks bad, you can try to click **Outlier Removal & Correction** to run the full method to deal with outliers camera poses." |
|
|
|
|
|
examples_full = [] |
|
|
|
for example in ['toy', 'butter', 'robot', 'jordan', 'eagle']: |
|
example_folder = os.path.join(os.path.dirname(__file__), 'data/demo', example, 'processed') |
|
example_fns = os.listdir(example_folder) |
|
example_fns.sort() |
|
examples = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')] |
|
examples_full.append([examples]) |
|
|
|
dust3r_model = AsymmetricCroCo3DStereo.from_pretrained('naver/DUSt3R_ViTLarge_BaseDecoder_224_linear').to('cuda') |
|
print("Loaded DUSt3R model!") |
|
|
|
preprocess = functools.partial(preprocess, args, dust3r_model) |
|
|
|
|
|
|
|
with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown('# ' + _TITLE) |
|
gr.Markdown(_DESCRIPTION) |
|
|
|
|
|
with gr.Row(variant='panel'): |
|
with gr.Column(scale=5): |
|
|
|
image_block = gr.File(file_count="multiple") |
|
|
|
preprocess_btn = gr.Button("Preprocess Images") |
|
|
|
|
|
gr.Markdown( |
|
"You have two options to run our model! (1) Upload your own images in the block above and then click **Preprocess Images** to initialize camera poses using \ |
|
DUSt3R; (2) Choose one of the preprocessed examples below (no need to click **Preprocess Images**).") |
|
|
|
gallery = gr.Gallery( |
|
value=[example[0][0] for example in examples_full], label="Examples", show_label=True, elem_id="gallery" |
|
, columns=[5], rows=[1], object_fit="contain", height="256", preview=None, allow_preview=None) |
|
|
|
preprocessed_data = gr.Gallery( |
|
label="Preprocessed images", show_label=True, elem_id="gallery" |
|
, columns=[4], rows=[2], object_fit="contain", height="256", preview=None, allow_preview=None) |
|
|
|
with gr.Row(variant='panel'): |
|
run_single_btn = gr.Button("Run Single 3D Reconstruction") |
|
outlier_detect_btn = gr.Button("Outlier Removal & Correction") |
|
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True) |
|
|
|
with gr.Column(scale=5): |
|
obj_single_recon = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Single Reconstruction)") |
|
obj_outlier_detect = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Full Method, w/ Outlier Removal & Correction)") |
|
|
|
|
|
gallery.select(get_select_index, None, outputs=[image_block, preprocessed_data]) |
|
|
|
|
|
preprocess_btn.click(preprocess, inputs=[image_block], outputs=[preprocessed_data], queue=False, show_progress='full').success(info_fn, None, None) |
|
|
|
|
|
run_single_btn.click(check_img_input, inputs=[image_block], queue=False).success(run_single_reconstruction, |
|
inputs=[image_block], |
|
|
|
|
|
outputs=[obj_single_recon]) |
|
|
|
|
|
outlier_detect_btn.click(check_img_input, inputs=[image_block], queue=False).success(run_full_reconstruction, |
|
inputs=[image_block], |
|
|
|
|
|
outputs=[obj_outlier_detect]) |
|
|
|
demo.queue().launch(share=True) |