Spaces:
Running
Running
import einops | |
import gradio as gr | |
import matplotlib.cm as cm | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import plotly.graph_objects as go | |
import torch | |
import torch.nn.functional as F | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
torch.set_grad_enabled(False) | |
device = torch.device(device) | |
ddpm, lidar_utils, _ = torch.hub.load( | |
"kazuto1011/r2dm", | |
"pretrained_r2dm", | |
device=device, | |
) | |
def colorize(tensor, cmap_fn=cm.turbo): | |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
colors = torch.from_numpy(colors).to(tensor) | |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
ids = (tensor * 256).clamp(0, 255).long() | |
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
tensor = tensor.mul(255).clamp(0, 255).byte() | |
return tensor | |
def render_point_cloud(output, cmap): | |
output = lidar_utils.denormalize(output.clamp(-1, 1)) | |
depth = lidar_utils.revert_depth(output[:, [0]]) | |
rflct = output[:, [1]] | |
point = lidar_utils.to_xyz(depth).cpu().numpy() | |
point = einops.rearrange(point, "1 c h w -> c (h w)") | |
# angle = lidar_utils.ray_angles.rad2deg() | |
fig = go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=-point[0], | |
y=-point[1], | |
z=point[2], | |
mode="markers", | |
marker=dict( | |
size=1, | |
color=point[2], | |
colorscale="viridis", | |
autocolorscale=False, | |
cauto=False, | |
cmin=-2, | |
cmax=0.5, | |
), | |
# text=[ | |
# f"depth: {float(d):.2f}m<br>" | |
# + f"reflectance: {float(r):.2f}<br>" | |
# + f"elevation: {float(e):.2f}°<br>" | |
# + f"azimuth: {float(a):.2f}°" | |
# for d, r, e, a in zip( | |
# einops.rearrange(depth, "1 1 h w -> (h w)"), | |
# einops.rearrange(rflct, "1 1 h w -> (h w)"), | |
# einops.rearrange(angle[0, 0], "h w -> (h w)"), | |
# einops.rearrange(angle[0, 1], "h w -> (h w)"), | |
# ) | |
# ], | |
# hoverinfo="text", | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(showticklabels=False, visible=False), | |
yaxis=dict(showticklabels=False, visible=False), | |
zaxis=dict(showticklabels=False, visible=False), | |
aspectmode="data", | |
), | |
margin=dict(l=0, r=0, b=0, t=0), | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
), | |
) | |
depth = depth / lidar_utils.max_depth | |
depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy() | |
rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy() | |
return depth, rflct, fig | |
def generate(num_steps, cmap_name, progress=gr.Progress()): | |
num_steps = int(num_steps) | |
x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device) | |
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None] | |
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"): | |
step_t = steps[:, i] | |
step_s = steps[:, i + 1] | |
x = ddpm.p_step(x, step_t, step_s) | |
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name)) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# R2DM | |
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br> | |
Kazuto Nakashima, Ryo Kurazume<br> | |
ICRA 2024<br> | |
[[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm) | |
R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Textbox(device, label="Device") | |
num_steps = gr.Dropdown( | |
choices=[2**i for i in range(2, 10)], | |
value=16, | |
label="number of sampling steps (>256 is recommended)", | |
) | |
cmap_name = gr.Dropdown( | |
choices=plt.colormaps(), | |
value="turbo", | |
label="colormap for range/reflectance images", | |
) | |
btn = gr.Button(value="Generate random samples") | |
with gr.Column(): | |
range_view = gr.Image( | |
type="numpy", | |
image_mode="RGB", | |
label="Range image", | |
scale=1, | |
) | |
rflct_view = gr.Image( | |
type="numpy", | |
image_mode="RGB", | |
label="Reflectance image", | |
scale=1, | |
) | |
point_view = gr.Plot( | |
label="Point cloud", | |
scale=1, | |
) | |
btn.click( | |
generate, | |
inputs=[num_steps, cmap_name], | |
outputs=[range_view, rflct_view, point_view], | |
) | |
demo.queue() | |
demo.launch() | |