import einops
import gradio as gr
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
torch.set_grad_enabled(False)
device = torch.device(device)
ddpm, lidar_utils, _ = torch.hub.load(
"kazuto1011/r2dm",
"pretrained_r2dm",
device=device,
)
def colorize(tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
def render_point_cloud(output, cmap):
output = lidar_utils.denormalize(output.clamp(-1, 1))
depth = lidar_utils.revert_depth(output[:, [0]])
rflct = output[:, [1]]
point = lidar_utils.to_xyz(depth).cpu().numpy()
point = einops.rearrange(point, "1 c h w -> c (h w)")
# angle = lidar_utils.ray_angles.rad2deg()
fig = go.Figure(
data=[
go.Scatter3d(
x=-point[0],
y=-point[1],
z=point[2],
mode="markers",
marker=dict(
size=1,
color=point[2],
colorscale="viridis",
autocolorscale=False,
cauto=False,
cmin=-2,
cmax=0.5,
),
# text=[
# f"depth: {float(d):.2f}m
"
# + f"reflectance: {float(r):.2f}
"
# + f"elevation: {float(e):.2f}°
"
# + f"azimuth: {float(a):.2f}°"
# for d, r, e, a in zip(
# einops.rearrange(depth, "1 1 h w -> (h w)"),
# einops.rearrange(rflct, "1 1 h w -> (h w)"),
# einops.rearrange(angle[0, 0], "h w -> (h w)"),
# einops.rearrange(angle[0, 1], "h w -> (h w)"),
# )
# ],
# hoverinfo="text",
)
],
layout=dict(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
aspectmode="data",
),
margin=dict(l=0, r=0, b=0, t=0),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
),
)
depth = depth / lidar_utils.max_depth
depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
return depth, rflct, fig
def generate(num_steps, cmap_name, progress=gr.Progress()):
num_steps = int(num_steps)
x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
step_t = steps[:, i]
step_s = steps[:, i + 1]
x = ddpm.p_step(x, step_t, step_s)
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
with gr.Blocks() as demo:
gr.Markdown(
"""
# R2DM
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**
Kazuto Nakashima, Ryo Kurazume
ICRA 2024
[[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
"""
)
with gr.Row():
with gr.Column():
gr.Textbox(device, label="Device")
num_steps = gr.Dropdown(
choices=[2**i for i in range(2, 10)],
value=16,
label="number of sampling steps (>256 is recommended)",
)
cmap_name = gr.Dropdown(
choices=plt.colormaps(),
value="turbo",
label="colormap for range/reflectance images",
)
btn = gr.Button(value="Generate random samples")
with gr.Column():
range_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Range image",
scale=1,
)
rflct_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Reflectance image",
scale=1,
)
point_view = gr.Plot(
label="Point cloud",
scale=1,
)
btn.click(
generate,
inputs=[num_steps, cmap_name],
outputs=[range_view, rflct_view, point_view],
)
demo.queue()
demo.launch()