import einops import gradio as gr import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go import torch import torch.nn.functional as F if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" torch.set_grad_enabled(False) device = torch.device(device) ddpm, lidar_utils, _ = torch.hub.load( "kazuto1011/r2dm", "pretrained_r2dm", device=device, ) def colorize(tensor, cmap_fn=cm.turbo): colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] colors = torch.from_numpy(colors).to(tensor) tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor ids = (tensor * 256).clamp(0, 255).long() tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) tensor = tensor.mul(255).clamp(0, 255).byte() return tensor def render_point_cloud(output, cmap): output = lidar_utils.denormalize(output.clamp(-1, 1)) depth = lidar_utils.revert_depth(output[:, [0]]) rflct = output[:, [1]] point = lidar_utils.to_xyz(depth).cpu().numpy() point = einops.rearrange(point, "1 c h w -> c (h w)") # angle = lidar_utils.ray_angles.rad2deg() fig = go.Figure( data=[ go.Scatter3d( x=-point[0], y=-point[1], z=point[2], mode="markers", marker=dict( size=1, color=point[2], colorscale="viridis", autocolorscale=False, cauto=False, cmin=-2, cmax=0.5, ), # text=[ # f"depth: {float(d):.2f}m
" # + f"reflectance: {float(r):.2f}
" # + f"elevation: {float(e):.2f}°
" # + f"azimuth: {float(a):.2f}°" # for d, r, e, a in zip( # einops.rearrange(depth, "1 1 h w -> (h w)"), # einops.rearrange(rflct, "1 1 h w -> (h w)"), # einops.rearrange(angle[0, 0], "h w -> (h w)"), # einops.rearrange(angle[0, 1], "h w -> (h w)"), # ) # ], # hoverinfo="text", ) ], layout=dict( scene=dict( xaxis=dict(showticklabels=False, visible=False), yaxis=dict(showticklabels=False, visible=False), zaxis=dict(showticklabels=False, visible=False), aspectmode="data", ), margin=dict(l=0, r=0, b=0, t=0), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", ), ) depth = depth / lidar_utils.max_depth depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy() rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy() return depth, rflct, fig def generate(num_steps, cmap_name, progress=gr.Progress()): num_steps = int(num_steps) x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device) steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None] for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"): step_t = steps[:, i] step_s = steps[:, i + 1] x = ddpm.p_step(x, step_t, step_s) return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name)) with gr.Blocks() as demo: gr.Markdown( """ # R2DM > **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**
Kazuto Nakashima, Ryo Kurazume
ICRA 2024
[[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm) R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation. """ ) with gr.Row(): with gr.Column(): gr.Textbox(device, label="Device") num_steps = gr.Dropdown( choices=[2**i for i in range(2, 10)], value=16, label="number of sampling steps (>256 is recommended)", ) cmap_name = gr.Dropdown( choices=plt.colormaps(), value="turbo", label="colormap for range/reflectance images", ) btn = gr.Button(value="Generate random samples") with gr.Column(): range_view = gr.Image( type="numpy", image_mode="RGB", label="Range image", scale=1, ) rflct_view = gr.Image( type="numpy", image_mode="RGB", label="Reflectance image", scale=1, ) point_view = gr.Plot( label="Point cloud", scale=1, ) btn.click( generate, inputs=[num_steps, cmap_name], outputs=[range_view, rflct_view, point_view], ) demo.queue() demo.launch()