import einops
import gradio as gr
import matplotlib.cm as cm
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
from rendering import estimate_surface_normal
DESCRIPTION = """
LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models
Kyushu University
ICRA 2024
This is a demo of our paper "LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models" presented at ICRA 2024.
We propose R2DM, a continuous-time diffusion model for LiDAR data generation based on the equirectangular range/reflectance image representation.
"""
RUN_LOCALLY = """
To run this demo locally:
```bash
git clone https://huggingface.co/spaces/kazuto1011/r2dm
```
```bash
cd r2dm
```
```bash
pip install -r requirements.txt
```
```bash
pip install gradio
```
```bash
gradio app.py
```
"""
THEME = gr.themes.Default(font=gr.themes.GoogleFont("Titillium Web"))
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
device = torch.device(device)
model_dict = {
"KITTI Raw (64x512)": torch.hub.load(
"kazuto1011/r2dm",
"pretrained_r2dm",
config="r2dm-h-kittiraw-300k",
device="cpu",
show_info=False,
),
"KITTI-360 (64x1024)": torch.hub.load(
"kazuto1011/r2dm",
"pretrained_r2dm",
config="r2dm-h-kitti360-300k",
device="cpu",
show_info=False,
),
}
def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
def generate(num_steps: int, sampling_mode: str, dataset: str, progress=gr.Progress()):
# model setup
model, lidar_utils, _ = model_dict[dataset]
model.to(device)
lidar_utils.to(device)
# sampling
num_steps = int(num_steps)
x = model.randn(1, *model.sampling_shape, device=model.device)
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=model.device)[None]
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
step_t = steps[:, i]
step_s = steps[:, i + 1]
x = model.p_step(x, step_t, step_s, mode=sampling_mode.lower())
# rendering point cloud
x = lidar_utils.denormalize(x.clamp(-1, 1))
depth = lidar_utils.revert_depth(x[:, [0]])
rflct = x[:, [1]]
point = lidar_utils.to_xyz(depth)
color = (-estimate_surface_normal(point) + 1) / 2
point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy()
color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy()
fig = go.Figure(
data=[
go.Scatter3d(
x=-point[..., 0],
y=-point[..., 1],
z=point[..., 2],
mode="markers",
marker=dict(size=1, color=color),
)
],
layout=dict(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
aspectmode="data",
),
margin=dict(l=0, r=0, b=0, t=0),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
),
)
depth = depth / lidar_utils.max_depth
depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
model.cpu()
lidar_utils.cpu()
return depth, rflct, fig
with gr.Blocks(css="./style.css", theme=THEME) as demo:
gr.HTML(DESCRIPTION)
with gr.Row(variant="panel"):
with gr.Column():
gr.Textbox(device, label="Running device")
dataset = gr.Dropdown(
choices=list(model_dict.keys()),
value=list(model_dict.keys())[0],
label="Dataset",
)
sampling_mode = gr.Dropdown(
choices=["DDPM", "DDIM"],
value="DDPM",
label="Sampler",
)
num_steps = gr.Dropdown(
choices=[2**i for i in range(5, 11)],
value=32,
label="Number of sampling steps (>256 is recommended)",
)
btn = gr.Button(value="Generate")
with gr.Column():
range_view = gr.Image(type="numpy", label="Range image")
rflct_view = gr.Image(type="numpy", label="Reflectance image")
point_view = gr.Plot(label="Point cloud")
with gr.Row(variant="panel"):
gr.Markdown(RUN_LOCALLY)
btn.click(
generate,
inputs=[num_steps, sampling_mode, dataset],
outputs=[range_view, rflct_view, point_view],
concurrency_limit=1,
)
demo.queue()
demo.launch()