import os import time import shutil import argparse import functools import torch import torchvision from PIL import Image import gradio as gr import numpy as np import matplotlib.pyplot as plt import trimesh from diffusionsfm.dataset.custom import CustomDataset from diffusionsfm.dataset.co3d_v2 import unnormalize_image from diffusionsfm.inference.load_model import load_model from diffusionsfm.inference.predict import predict_cameras from diffusionsfm.utils.visualization import add_scene_cam def info_fn(): gr.Info("Data preprocessing completed!") def get_select_index(evt: gr.SelectData): selected = evt.index return examples_full[selected][0], selected def check_img_input(control_image): if control_image is None: raise gr.Error("Please select or upload an input image.") def preprocess(args, image_block, selected): cate_name = time.strftime("%m%d_%H%M%S") if selected is None else examples_list[selected] demo_dir = os.path.join(args.output_dir, f'demo/{cate_name}') shutil.rmtree(demo_dir, ignore_errors=True) os.makedirs(os.path.join(demo_dir, 'source'), exist_ok=True) os.makedirs(os.path.join(demo_dir, 'processed'), exist_ok=True) dataset = CustomDataset(image_block) batch = dataset.get_data() batch['cate_name'] = cate_name processed_image_block = [] for i, file_path in enumerate(image_block): file_name = os.path.basename(file_path) raw_img = Image.open(file_path) try: raw_img.save(os.path.join(demo_dir, 'source', file_name)) except OSError: raw_img.convert('RGB').save(os.path.join(demo_dir, 'source', file_name)) batch['image_for_vis'][i].save(os.path.join(demo_dir, 'processed', file_name)) processed_image_block.append(os.path.join(demo_dir, 'processed', file_name)) return processed_image_block, batch def transform_cameras(pred_cameras): num_cameras = pred_cameras.R.shape[0] Rs = pred_cameras.R.transpose(1, 2).detach() ts = pred_cameras.T.unsqueeze(-1).detach() c2ws = torch.zeros(num_cameras, 4, 4) c2ws[:, :3, :3] = Rs c2ws[:, :3, -1:] = ts c2ws[:, 3, 3] = 1 c2ws[:, :2] *= -1 # PyTorch3D to OpenCV c2ws = torch.linalg.inv(c2ws).numpy() return c2ws def run_inference(args, cfg, model, batch): device = args.device images = batch["image"].to(device) crop_parameters = batch["crop_parameters"].to(device) (pred_cameras, pred_rays), _ = predict_cameras( model=model, images=images, device=device, crop_parameters=crop_parameters, stop_iteration=90, num_patches_x=cfg.training.full_num_patches_x, num_patches_y=cfg.training.full_num_patches_y, calculate_intrinsics=True, max_num_images=8, mode="segment", return_rays=True, use_homogeneous=True, seed=0, ) # Unnormalize and resize input images images = unnormalize_image(images, return_numpy=False, return_int=False) images = torchvision.transforms.Resize(256)(images) rgbs = images.permute(0, 2, 3, 1).contiguous().view(-1, 3) xyzs = pred_rays.get_segments().view(-1, 3).cpu() # Create point cloud and scene scene = trimesh.Scene() point_cloud = trimesh.points.PointCloud(xyzs, colors=rgbs) scene.add_geometry(point_cloud) # Add predicted cameras to the scene num_images = images.shape[0] c2ws = transform_cameras(pred_cameras) cmap = plt.get_cmap("hsv") for i, c2w in enumerate(c2ws): color_rgb = (np.array(cmap(i / num_images))[:3] * 255).astype(int) add_scene_cam( scene=scene, c2w=c2w, edge_color=color_rgb, image=None, focal=None, imsize=(256, 256), screen_width=0.1 ) # Export GLB cate_name = batch['cate_name'] output_path = os.path.join(args.output_dir, f'demo/{cate_name}/{cate_name}.glb') scene.export(output_path) return output_path if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--output_dir', default='output/multi_diffusionsfm_dense', type=str, help='Output directory') parser.add_argument('--device', default='cuda', type=str, help='Device to run inference on') args = parser.parse_args() _TITLE = "DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion" _DESCRIPTION = """
DiffusionSfM learns to predict scene geometry and camera poses as pixel-wise ray origins and endpoints using a denoising diffusion model. """ # Load demo examples examples_list = ["kew_gardens_ruined_arch", "jellycat", "kotor_cathedral", "jordan"] examples_full = [] for example in examples_list: folder = os.path.join(os.path.dirname(__file__), "data/demo", example) examples = sorted(os.path.join(folder, x) for x in os.listdir(folder)) examples_full.append([examples]) model, cfg = load_model(args.output_dir, device=args.device) print("Loaded DiffusionSfM model!") preprocess = functools.partial(preprocess, args) run_inference = functools.partial(run_inference, args, cfg, model) with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {_TITLE}") gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=2): image_block = gr.Files(file_count="multiple", label="Upload Images") gr.Markdown( "You can run our model by either: (1) **Uploading images** above " "or (2) selecting a **pre-collected example** below." ) gallery = gr.Gallery( value=[example[0][0] for example in examples_full], label="Examples", show_label=True, columns=[4], rows=[1], object_fit="contain", height="256", ) selected = gr.State() batch = gr.State() preprocessed_data = gr.Gallery( label="Preprocessed Images", show_label=True, columns=[4], rows=[1], object_fit="contain", height="256", ) with gr.Row(variant='panel'): run_inference_btn = gr.Button("Run Inference") with gr.Column(scale=4): output_3D = gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], height=520, zoom_speed=0.5, pan_speed=0.5, label="3D Point Clouds and Recovered Cameras" ) # Link image gallery selection gallery.select( fn=get_select_index, inputs=None, outputs=[image_block, selected] ).success( fn=preprocess, inputs=[image_block, selected], outputs=[preprocessed_data, batch], queue=False, show_progress="full" ) # Handle user uploads image_block.upload( preprocess, inputs=[image_block], outputs=[preprocessed_data, batch], queue=False, show_progress="full" ).success(info_fn, None, None) # Run 3D reconstruction run_inference_btn.click( check_img_input, inputs=[image_block], queue=False ).success( run_inference, inputs=[batch], outputs=[output_3D] ) demo.queue().launch(share=True)