r2dm / app.py
Kazuto Nakashima
update
365f709
raw
history blame
No virus
5.22 kB
import einops
import gradio as gr
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
torch.set_grad_enabled(False)
device = torch.device(device)
ddpm, lidar_utils, _ = torch.hub.load(
"kazuto1011/r2dm",
"pretrained_r2dm",
device=device,
)
def colorize(tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
def render_point_cloud(output, cmap):
output = lidar_utils.denormalize(output.clamp(-1, 1))
depth = lidar_utils.revert_depth(output[:, [0]])
rflct = output[:, [1]]
point = lidar_utils.to_xyz(depth).cpu().numpy()
point = einops.rearrange(point, "1 c h w -> c (h w)")
# angle = lidar_utils.ray_angles.rad2deg()
fig = go.Figure(
data=[
go.Scatter3d(
x=-point[0],
y=-point[1],
z=point[2],
mode="markers",
marker=dict(
size=1,
color=point[2],
colorscale="viridis",
autocolorscale=False,
cauto=False,
cmin=-2,
cmax=0.5,
),
# text=[
# f"depth: {float(d):.2f}m<br>"
# + f"reflectance: {float(r):.2f}<br>"
# + f"elevation: {float(e):.2f}°<br>"
# + f"azimuth: {float(a):.2f}°"
# for d, r, e, a in zip(
# einops.rearrange(depth, "1 1 h w -> (h w)"),
# einops.rearrange(rflct, "1 1 h w -> (h w)"),
# einops.rearrange(angle[0, 0], "h w -> (h w)"),
# einops.rearrange(angle[0, 1], "h w -> (h w)"),
# )
# ],
# hoverinfo="text",
)
],
layout=dict(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
aspectmode="data",
),
margin=dict(l=0, r=0, b=0, t=0),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
),
)
depth = depth / lidar_utils.max_depth
depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
return depth, rflct, fig
def generate(num_steps, cmap_name, progress=gr.Progress()):
num_steps = int(num_steps)
x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
step_t = steps[:, i]
step_s = steps[:, i + 1]
x = ddpm.p_step(x, step_t, step_s)
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
with gr.Blocks() as demo:
gr.Markdown(
"""
# R2DM
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
Kazuto Nakashima, Ryo Kurazume<br>
ICRA 2024<br>
[[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
"""
)
with gr.Row():
with gr.Column():
gr.Textbox(device, label="Device")
num_steps = gr.Dropdown(
choices=[2**i for i in range(2, 10)],
value=16,
label="number of sampling steps (>256 is recommended)",
)
cmap_name = gr.Dropdown(
choices=plt.colormaps(),
value="turbo",
label="colormap for range/reflectance images",
)
btn = gr.Button(value="Generate random samples")
with gr.Column():
range_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Range image",
scale=1,
)
rflct_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Reflectance image",
scale=1,
)
point_view = gr.Plot(
label="Point cloud",
scale=1,
)
btn.click(
generate,
inputs=[num_steps, cmap_name],
outputs=[range_view, rflct_view, point_view],
)
demo.queue()
demo.launch()