Spaces:
Sleeping
Sleeping
import einops | |
import gradio as gr | |
import matplotlib.cm as cm | |
import numpy as np | |
import plotly.graph_objects as go | |
import torch | |
import torch.nn.functional as F | |
from rendering import estimate_surface_normal | |
DESCRIPTION = """ | |
<div class="head"> | |
<div class="title">LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models</div> | |
<div class="authors"> | |
<a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a> | |
| |
<a href="https://robotics.ait.kyushu-u.ac.jp/kurazume/en/" target="_blank" rel="noopener"> Ryo Kurazume</a> | |
</div> | |
<div class="affiliations">Kyushu University</div> | |
<div class="conference">ICRA 2024</div> | |
<div class="materials"> | |
<a href="https://kazuto1011.github.io/r2dm">Project</a> | | |
<a href="https://arxiv.org/abs/2309.09256">Paper</a> | | |
<a href="https://github.com/kazuto1011/r2dm">Code</a> | |
</div> | |
<br> | |
<div class="description"> | |
This is a demo of our paper "LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models" presented at ICRA 2024.<br> | |
We propose <strong>R2DM</strong>, a continuous-time diffusion model for LiDAR data generation based on the equirectangular range/reflectance image representation.<br> | |
</div> | |
<br> | |
</div> | |
""" | |
RUN_LOCALLY = """ | |
To run this demo locally: | |
```bash | |
git clone https://huggingface.co/spaces/kazuto1011/r2dm | |
``` | |
```bash | |
cd r2dm | |
``` | |
```bash | |
pip install -r requirements.txt | |
``` | |
```bash | |
pip install gradio | |
``` | |
```bash | |
gradio app.py | |
``` | |
""" | |
THEME = gr.themes.Default(font=gr.themes.GoogleFont("Titillium Web")) | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
torch.set_grad_enabled(False) | |
torch.backends.cudnn.benchmark = True | |
device = torch.device(device) | |
model_dict = { | |
"KITTI Raw (64x512)": torch.hub.load( | |
"kazuto1011/r2dm", | |
"pretrained_r2dm", | |
config="r2dm-h-kittiraw-300k", | |
device="cpu", | |
show_info=False, | |
), | |
"KITTI-360 (64x1024)": torch.hub.load( | |
"kazuto1011/r2dm", | |
"pretrained_r2dm", | |
config="r2dm-h-kitti360-300k", | |
device="cpu", | |
show_info=False, | |
), | |
} | |
def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo): | |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
colors = torch.from_numpy(colors).to(tensor) | |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
ids = (tensor * 256).clamp(0, 255).long() | |
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
tensor = tensor.mul(255).clamp(0, 255).byte() | |
return tensor | |
def generate(num_steps: int, sampling_mode: str, dataset: str, progress=gr.Progress()): | |
# model setup | |
model, lidar_utils, _ = model_dict[dataset] | |
model.to(device) | |
lidar_utils.to(device) | |
# sampling | |
num_steps = int(num_steps) | |
x = model.randn(1, *model.sampling_shape, device=model.device) | |
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=model.device)[None] | |
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"): | |
step_t = steps[:, i] | |
step_s = steps[:, i + 1] | |
x = model.p_step(x, step_t, step_s, mode=sampling_mode.lower()) | |
# rendering point cloud | |
x = lidar_utils.denormalize(x.clamp(-1, 1)) | |
depth = lidar_utils.revert_depth(x[:, [0]]) | |
rflct = x[:, [1]] | |
point = lidar_utils.to_xyz(depth) | |
color = (-estimate_surface_normal(point) + 1) / 2 | |
point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy() | |
color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy() | |
fig = go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=-point[..., 0], | |
y=-point[..., 1], | |
z=point[..., 2], | |
mode="markers", | |
marker=dict(size=1, color=color), | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(showticklabels=False, visible=False), | |
yaxis=dict(showticklabels=False, visible=False), | |
zaxis=dict(showticklabels=False, visible=False), | |
aspectmode="data", | |
), | |
margin=dict(l=0, r=0, b=0, t=0), | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
), | |
) | |
depth = depth / lidar_utils.max_depth | |
depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy() | |
model.cpu() | |
lidar_utils.cpu() | |
return depth, rflct, fig | |
with gr.Blocks(css="./style.css", theme=THEME) as demo: | |
gr.HTML(DESCRIPTION) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
gr.Textbox(device, label="Running device") | |
dataset = gr.Dropdown( | |
choices=list(model_dict.keys()), | |
value=list(model_dict.keys())[0], | |
label="Dataset", | |
) | |
sampling_mode = gr.Dropdown( | |
choices=["DDPM", "DDIM"], | |
value="DDPM", | |
label="Sampler", | |
) | |
num_steps = gr.Dropdown( | |
choices=[2**i for i in range(5, 11)], | |
value=32, | |
label="Number of sampling steps (>256 is recommended)", | |
) | |
btn = gr.Button(value="Generate") | |
with gr.Column(): | |
range_view = gr.Image(type="numpy", label="Range image") | |
rflct_view = gr.Image(type="numpy", label="Reflectance image") | |
point_view = gr.Plot(label="Point cloud") | |
with gr.Row(variant="panel"): | |
gr.Markdown(RUN_LOCALLY) | |
btn.click( | |
generate, | |
inputs=[num_steps, sampling_mode, dataset], | |
outputs=[range_view, rflct_view, point_view], | |
) | |
demo.queue() | |
demo.launch() | |