import einops import gradio as gr import matplotlib.cm as cm import numpy as np import plotly.graph_objects as go import torch import torch.nn.functional as F from rendering import estimate_surface_normal DESCRIPTION = """
LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models
Kazuto Nakashima     Ryo Kurazume
Kyushu University
ICRA 2024
Project | Paper | Code

This is a demo of our paper "LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models" presented at ICRA 2024.
We propose R2DM, a continuous-time diffusion model for LiDAR data generation based on the equirectangular range/reflectance image representation.

""" RUN_LOCALLY = """ To run this demo locally: ```bash git clone https://huggingface.co/spaces/kazuto1011/r2dm ``` ```bash cd r2dm ``` ```bash pip install -r requirements.txt ``` ```bash pip install gradio ``` ```bash gradio app.py ``` """ THEME = gr.themes.Default(font=gr.themes.GoogleFont("Titillium Web")) if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = True device = torch.device(device) model_dict = { "KITTI Raw (64x512)": torch.hub.load( "kazuto1011/r2dm", "pretrained_r2dm", config="r2dm-h-kittiraw-300k", device="cpu", show_info=False, ), "KITTI-360 (64x1024)": torch.hub.load( "kazuto1011/r2dm", "pretrained_r2dm", config="r2dm-h-kitti360-300k", device="cpu", show_info=False, ), } def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo): colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] colors = torch.from_numpy(colors).to(tensor) tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor ids = (tensor * 256).clamp(0, 255).long() tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) tensor = tensor.mul(255).clamp(0, 255).byte() return tensor def generate(num_steps: int, sampling_mode: str, dataset: str, progress=gr.Progress()): # model setup model, lidar_utils, _ = model_dict[dataset] model.to(device) lidar_utils.to(device) # sampling num_steps = int(num_steps) x = model.randn(1, *model.sampling_shape, device=model.device) steps = torch.linspace(1.0, 0.0, num_steps + 1, device=model.device)[None] for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"): step_t = steps[:, i] step_s = steps[:, i + 1] x = model.p_step(x, step_t, step_s, mode=sampling_mode.lower()) # rendering point cloud x = lidar_utils.denormalize(x.clamp(-1, 1)) depth = lidar_utils.revert_depth(x[:, [0]]) rflct = x[:, [1]] point = lidar_utils.to_xyz(depth) color = (-estimate_surface_normal(point) + 1) / 2 point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy() color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy() fig = go.Figure( data=[ go.Scatter3d( x=-point[..., 0], y=-point[..., 1], z=point[..., 2], mode="markers", marker=dict(size=1, color=color), ) ], layout=dict( scene=dict( xaxis=dict(showticklabels=False, visible=False), yaxis=dict(showticklabels=False, visible=False), zaxis=dict(showticklabels=False, visible=False), aspectmode="data", ), margin=dict(l=0, r=0, b=0, t=0), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", ), ) depth = depth / lidar_utils.max_depth depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() model.cpu() lidar_utils.cpu() return depth, rflct, fig with gr.Blocks(css="./style.css", theme=THEME) as demo: gr.HTML(DESCRIPTION) with gr.Row(variant="panel"): with gr.Column(): gr.Textbox(device, label="Running device") dataset = gr.Dropdown( choices=list(model_dict.keys()), value=list(model_dict.keys())[0], label="Dataset", ) sampling_mode = gr.Dropdown( choices=["DDPM", "DDIM"], value="DDPM", label="Sampler", ) num_steps = gr.Dropdown( choices=[2**i for i in range(5, 11)], value=32, label="Number of sampling steps (>256 is recommended)", ) btn = gr.Button(value="Generate") with gr.Column(): range_view = gr.Image(type="numpy", label="Range image") rflct_view = gr.Image(type="numpy", label="Reflectance image") point_view = gr.Plot(label="Point cloud") with gr.Row(variant="panel"): gr.Markdown(RUN_LOCALLY) btn.click( generate, inputs=[num_steps, sampling_mode, dataset], outputs=[range_view, rflct_view, point_view], ) demo.queue() demo.launch()