Spaces:
Build error
Build error
| import spaces | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from pipe.cfgs import load_cfg | |
| from pipe.c2f_recons import Pipeline | |
| from ops.gs.basic import Gaussian_Scene | |
| from datetime import datetime | |
| cfg = load_cfg(f'pipe/cfgs/basic.yaml') | |
| vistadream = Pipeline(cfg) | |
| from ops.visual_check import Check | |
| checkor = Check() | |
| def get_temp_path(): | |
| if not os.path.exists('data/gradio_temp'):os.makedirs('data/gradio_temp') | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = f"data/gradio_temp/{timestamp}/" | |
| return output_path | |
| def scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): | |
| torch.cuda.init() | |
| # coarse | |
| vistadream.scene = Gaussian_Scene(cfg) | |
| # for trajectory genearation | |
| vistadream.traj_type = 'spiral' | |
| vistadream.scene.traj_type = 'spiral' | |
| vistadream.n_sample = num_coarse_views | |
| # for scene generation | |
| vistadream.opt_iters_per_frame = 512 | |
| vistadream.outpaint_extend_times = 0.45 #outpaint_extend_times | |
| vistadream.outpaint_selections = ['Left','Right','Top','Bottom'] | |
| # for scene refinement | |
| vistadream.mcs_n_view = num_mcs_views | |
| vistadream.mcs_rect_w = mcs_rect_w | |
| vistadream.mcs_iterations = mcs_steps | |
| # coarse scene | |
| vistadream._coarse_scene(rgb) | |
| torch.cuda.empty_cache() | |
| def scene_refinement(): | |
| # refinement | |
| vistadream._MCS_Refinement() | |
| output_path = get_temp_path() | |
| torch.cuda.empty_cache() | |
| torch.save(vistadream.scene,output_path+'scene.pth') | |
| return output_path | |
| def render_video(output_path): | |
| scene = vistadream.scene | |
| vistadream.checkor._render_video(scene,save_dir=output_path+'.') | |
| return output_path+'video_rgb.mp4',output_path+'video_dpt.mp4' | |
| def process(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): | |
| scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps) | |
| path = scene_refinement() | |
| rgb.save(output_path+'input.png') | |
| return render_video(path) | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("## VistaDream") | |
| gr.Markdown("### Sampling multiview consistent images for single-view scene reconstruction") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/WHU-USI3DV/VistaDream"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| <a href="https://vistadream-project-page.github.io/"> | |
| <img src='https://img.shields.io/badge/Project-Page-green'> | |
| </a> | |
| <a href="https://arxiv.org/abs/2410.16892"> | |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil") | |
| run_button = gr.Button("Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| num_coarse_views = gr.Slider(label="Coarse-Expand", minimum=5, maximum=25, value=10, step=1) | |
| num_mcs_views = gr.Slider(label="MCS Optimization Views", minimum=4, maximum=10, value=8, step=1) | |
| mcs_rect_w = gr.Slider(label="MCS Rectification Weight", minimum=0.3, maximum=0.8, value=0.7, step=0.1) | |
| mcs_steps = gr.Slider(label="MCS Steps", minimum=8, maximum=15, value=10, step=1) | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| rgb_video = gr.Video("Output RGB renderings") | |
| with gr.Column(): | |
| dpt_video = gr.Video("Output DPT renderings") | |
| examples = gr.Examples( | |
| examples = [ | |
| ], | |
| inputs=[input_image,rgb_video,dpt_video] | |
| ) | |
| ips = [input_image,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps] | |
| run_button.click(fn=process, inputs=ips, outputs=[rgb_video,dpt_video]) | |
| demo.launch(server_name='0.0.0.0') |