Spaces:
Running
Running
kazuto1011
commited on
Commit
•
059842e
1
Parent(s):
365f709
update
Browse files- app.py +127 -93
- rendering.py +96 -0
- requirements.txt +0 -1
- style.css +15 -0
app.py
CHANGED
@@ -1,12 +1,60 @@
|
|
1 |
import einops
|
2 |
import gradio as gr
|
3 |
import matplotlib.cm as cm
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import plotly.graph_objects as go
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
if torch.cuda.is_available():
|
11 |
device = "cuda"
|
12 |
elif torch.backends.mps.is_available():
|
@@ -15,16 +63,28 @@ else:
|
|
15 |
device = "cpu"
|
16 |
|
17 |
torch.set_grad_enabled(False)
|
|
|
18 |
device = torch.device(device)
|
19 |
|
20 |
-
|
21 |
-
"
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
|
29 |
colors = torch.from_numpy(colors).to(tensor)
|
30 |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
|
@@ -34,42 +94,37 @@ def colorize(tensor, cmap_fn=cm.turbo):
|
|
34 |
return tensor
|
35 |
|
36 |
|
37 |
-
def
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
fig = go.Figure(
|
45 |
data=[
|
46 |
go.Scatter3d(
|
47 |
-
x=-point[0],
|
48 |
-
y=-point[1],
|
49 |
-
z=point[2],
|
50 |
mode="markers",
|
51 |
-
marker=dict(
|
52 |
-
size=1,
|
53 |
-
color=point[2],
|
54 |
-
colorscale="viridis",
|
55 |
-
autocolorscale=False,
|
56 |
-
cauto=False,
|
57 |
-
cmin=-2,
|
58 |
-
cmax=0.5,
|
59 |
-
),
|
60 |
-
# text=[
|
61 |
-
# f"depth: {float(d):.2f}m<br>"
|
62 |
-
# + f"reflectance: {float(r):.2f}<br>"
|
63 |
-
# + f"elevation: {float(e):.2f}°<br>"
|
64 |
-
# + f"azimuth: {float(a):.2f}°"
|
65 |
-
# for d, r, e, a in zip(
|
66 |
-
# einops.rearrange(depth, "1 1 h w -> (h w)"),
|
67 |
-
# einops.rearrange(rflct, "1 1 h w -> (h w)"),
|
68 |
-
# einops.rearrange(angle[0, 0], "h w -> (h w)"),
|
69 |
-
# einops.rearrange(angle[0, 1], "h w -> (h w)"),
|
70 |
-
# )
|
71 |
-
# ],
|
72 |
-
# hoverinfo="text",
|
73 |
)
|
74 |
],
|
75 |
layout=dict(
|
@@ -85,71 +140,50 @@ def render_point_cloud(output, cmap):
|
|
85 |
),
|
86 |
)
|
87 |
depth = depth / lidar_utils.max_depth
|
88 |
-
depth = colorize(depth,
|
89 |
-
rflct = colorize(rflct,
|
90 |
-
return depth, rflct, fig
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
|
97 |
-
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
|
98 |
-
step_t = steps[:, i]
|
99 |
-
step_s = steps[:, i + 1]
|
100 |
-
x = ddpm.p_step(x, step_t, step_s)
|
101 |
-
return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
|
102 |
|
103 |
|
104 |
-
with gr.Blocks() as demo:
|
105 |
-
gr.
|
106 |
-
"""
|
107 |
-
# R2DM
|
108 |
-
> **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
|
109 |
-
Kazuto Nakashima, Ryo Kurazume<br>
|
110 |
-
ICRA 2024<br>
|
111 |
-
[[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
|
112 |
|
113 |
-
|
114 |
-
"""
|
115 |
-
)
|
116 |
-
with gr.Row():
|
117 |
with gr.Column():
|
118 |
-
gr.Textbox(device, label="
|
119 |
-
|
120 |
-
choices=
|
121 |
-
value=
|
122 |
-
label="
|
123 |
)
|
124 |
-
|
125 |
-
choices=
|
126 |
-
value="
|
127 |
-
label="
|
128 |
)
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
with gr.Column():
|
132 |
-
range_view = gr.Image(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
rflct_view = gr.Image(
|
139 |
-
type="numpy",
|
140 |
-
image_mode="RGB",
|
141 |
-
label="Reflectance image",
|
142 |
-
scale=1,
|
143 |
-
)
|
144 |
-
point_view = gr.Plot(
|
145 |
-
label="Point cloud",
|
146 |
-
scale=1,
|
147 |
-
)
|
148 |
|
149 |
btn.click(
|
150 |
generate,
|
151 |
-
inputs=[num_steps,
|
152 |
outputs=[range_view, rflct_view, point_view],
|
|
|
153 |
)
|
154 |
|
155 |
|
|
|
1 |
import einops
|
2 |
import gradio as gr
|
3 |
import matplotlib.cm as cm
|
|
|
4 |
import numpy as np
|
5 |
import plotly.graph_objects as go
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
|
9 |
+
from rendering import estimate_surface_normal
|
10 |
+
|
11 |
+
DESCRIPTION = """
|
12 |
+
<div class="head">
|
13 |
+
<div class="title">LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models</div>
|
14 |
+
<div class="authors">
|
15 |
+
<a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a>
|
16 |
+
|
17 |
+
<a href="https://robotics.ait.kyushu-u.ac.jp/kurazume/en/" target="_blank" rel="noopener"> Ryo Kurazume</a>
|
18 |
+
</div>
|
19 |
+
<div class="affiliations">Kyushu University</div>
|
20 |
+
<div class="conference">ICRA 2024</div>
|
21 |
+
<div class="materials">
|
22 |
+
<a href="https://kazuto1011.github.io/r2dm">Project</a> |
|
23 |
+
<a href="https://arxiv.org/abs/2309.09256">Paper</a> |
|
24 |
+
<a href="https://github.com/kazuto1011/r2dm">Code</a>
|
25 |
+
</div>
|
26 |
+
<br>
|
27 |
+
<div class="description">
|
28 |
+
This is a demo of our paper "LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models" presented at ICRA 2024.<br>
|
29 |
+
We propose <strong>R2DM</strong>, a continuous-time diffusion model for LiDAR data generation based on the equirectangular range/reflectance image representation.<br>
|
30 |
+
</div>
|
31 |
+
<br>
|
32 |
+
</div>
|
33 |
+
"""
|
34 |
+
|
35 |
+
RUN_LOCALLY = """
|
36 |
+
To run this demo locally:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
git clone https://huggingface.co/spaces/kazuto1011/r2dm
|
40 |
+
```
|
41 |
+
```bash
|
42 |
+
cd r2dm
|
43 |
+
```
|
44 |
+
```bash
|
45 |
+
pip install -r requirements.txt
|
46 |
+
```
|
47 |
+
```bash
|
48 |
+
pip install gradio
|
49 |
+
```
|
50 |
+
```bash
|
51 |
+
gradio app.py
|
52 |
+
```
|
53 |
+
"""
|
54 |
+
|
55 |
+
THEME = gr.themes.Default(font=gr.themes.GoogleFont("Titillium Web"))
|
56 |
+
|
57 |
+
|
58 |
if torch.cuda.is_available():
|
59 |
device = "cuda"
|
60 |
elif torch.backends.mps.is_available():
|
|
|
63 |
device = "cpu"
|
64 |
|
65 |
torch.set_grad_enabled(False)
|
66 |
+
torch.backends.cudnn.benchmark = True
|
67 |
device = torch.device(device)
|
68 |
|
69 |
+
model_dict = {
|
70 |
+
"KITTI Raw (64x512)": torch.hub.load(
|
71 |
+
"kazuto1011/r2dm",
|
72 |
+
"pretrained_r2dm",
|
73 |
+
config="r2dm-h-kittiraw-300k",
|
74 |
+
device="cpu",
|
75 |
+
show_info=False,
|
76 |
+
),
|
77 |
+
"KITTI-360 (64x1024)": torch.hub.load(
|
78 |
+
"kazuto1011/r2dm",
|
79 |
+
"pretrained_r2dm",
|
80 |
+
config="r2dm-h-kitti360-300k",
|
81 |
+
device="cpu",
|
82 |
+
show_info=False,
|
83 |
+
),
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo):
|
88 |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
|
89 |
colors = torch.from_numpy(colors).to(tensor)
|
90 |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
|
|
|
94 |
return tensor
|
95 |
|
96 |
|
97 |
+
def generate(num_steps: int, sampling_mode: str, dataset: str, progress=gr.Progress()):
|
98 |
+
# model setup
|
99 |
+
model, lidar_utils, _ = model_dict[dataset]
|
100 |
+
model.to(device)
|
101 |
+
lidar_utils.to(device)
|
102 |
+
|
103 |
+
# sampling
|
104 |
+
num_steps = int(num_steps)
|
105 |
+
x = model.randn(1, *model.sampling_shape, device=model.device)
|
106 |
+
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=model.device)[None]
|
107 |
+
for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
|
108 |
+
step_t = steps[:, i]
|
109 |
+
step_s = steps[:, i + 1]
|
110 |
+
x = model.p_step(x, step_t, step_s, mode=sampling_mode.lower())
|
111 |
+
|
112 |
+
# rendering point cloud
|
113 |
+
x = lidar_utils.denormalize(x.clamp(-1, 1))
|
114 |
+
depth = lidar_utils.revert_depth(x[:, [0]])
|
115 |
+
rflct = x[:, [1]]
|
116 |
+
point = lidar_utils.to_xyz(depth)
|
117 |
+
color = (-estimate_surface_normal(point) + 1) / 2
|
118 |
+
point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy()
|
119 |
+
color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy()
|
120 |
fig = go.Figure(
|
121 |
data=[
|
122 |
go.Scatter3d(
|
123 |
+
x=-point[..., 0],
|
124 |
+
y=-point[..., 1],
|
125 |
+
z=point[..., 2],
|
126 |
mode="markers",
|
127 |
+
marker=dict(size=1, color=color),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
)
|
129 |
],
|
130 |
layout=dict(
|
|
|
140 |
),
|
141 |
)
|
142 |
depth = depth / lidar_utils.max_depth
|
143 |
+
depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
|
144 |
+
rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
|
|
|
|
|
145 |
|
146 |
+
model.cpu()
|
147 |
+
lidar_utils.cpu()
|
148 |
+
return depth, rflct, fig
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
+
with gr.Blocks(css="./style.css", theme=THEME) as demo:
|
152 |
+
gr.HTML(DESCRIPTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
with gr.Row(variant="panel"):
|
|
|
|
|
|
|
155 |
with gr.Column():
|
156 |
+
gr.Textbox(device, label="Running device")
|
157 |
+
dataset = gr.Dropdown(
|
158 |
+
choices=list(model_dict.keys()),
|
159 |
+
value=list(model_dict.keys())[0],
|
160 |
+
label="Dataset",
|
161 |
)
|
162 |
+
sampling_mode = gr.Dropdown(
|
163 |
+
choices=["DDPM", "DDIM"],
|
164 |
+
value="DDPM",
|
165 |
+
label="Sampler",
|
166 |
)
|
167 |
+
num_steps = gr.Dropdown(
|
168 |
+
choices=[2**i for i in range(5, 11)],
|
169 |
+
value=32,
|
170 |
+
label="Number of sampling steps (>256 is recommended)",
|
171 |
+
)
|
172 |
+
btn = gr.Button(value="Generate")
|
173 |
|
174 |
with gr.Column():
|
175 |
+
range_view = gr.Image(type="numpy", label="Range image")
|
176 |
+
rflct_view = gr.Image(type="numpy", label="Reflectance image")
|
177 |
+
point_view = gr.Plot(label="Point cloud")
|
178 |
+
|
179 |
+
with gr.Row(variant="panel"):
|
180 |
+
gr.Markdown(RUN_LOCALLY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
btn.click(
|
183 |
generate,
|
184 |
+
inputs=[num_steps, sampling_mode, dataset],
|
185 |
outputs=[range_view, rflct_view, point_view],
|
186 |
+
concurrency_limit=1,
|
187 |
)
|
188 |
|
189 |
|
rendering.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def estimate_surface_normal(
|
6 |
+
points: torch.Tensor, d: int = 2, mode: str = "closest"
|
7 |
+
) -> torch.Tensor:
|
8 |
+
# estimate surface normal from coordinated point clouds
|
9 |
+
# re-implemented the following codes with pytorch:
|
10 |
+
# https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py
|
11 |
+
# https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py
|
12 |
+
|
13 |
+
assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}"
|
14 |
+
B, C, H, W = points.shape
|
15 |
+
assert C == 3, f"expected C==3, but got {C}"
|
16 |
+
device = points.device
|
17 |
+
|
18 |
+
# points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf"))
|
19 |
+
points = F.pad(points, (0, 0, d, d), mode="replicate")
|
20 |
+
points = F.pad(points, (d, d, 0, 0), mode="circular")
|
21 |
+
points = points.permute(0, 2, 3, 1) # (B,H,W,3)
|
22 |
+
|
23 |
+
# 8 adjacent offsets
|
24 |
+
# -----------
|
25 |
+
# | 7 | 6 | 5 |
|
26 |
+
# -----------
|
27 |
+
# | 0 | | 4 |
|
28 |
+
# -----------
|
29 |
+
# | 1 | 2 | 3 |
|
30 |
+
# -----------
|
31 |
+
offsets = torch.tensor(
|
32 |
+
[
|
33 |
+
# (dh,dw)
|
34 |
+
(-d, 0), # 0
|
35 |
+
(-d, d), # 1
|
36 |
+
(0, d), # 2
|
37 |
+
(d, d), # 3
|
38 |
+
(d, 0), # 4
|
39 |
+
(d, -d), # 5
|
40 |
+
(0, -d), # 6
|
41 |
+
(-d, -d), # 7
|
42 |
+
],
|
43 |
+
device=device,
|
44 |
+
)
|
45 |
+
|
46 |
+
# (B,H,W) indices
|
47 |
+
b = torch.arange(B, device=device)[:, None, None]
|
48 |
+
h = torch.arange(H, device=device)[None, :, None]
|
49 |
+
w = torch.arange(W, device=device)[None, None, :]
|
50 |
+
k = torch.arange(8, device=device)
|
51 |
+
|
52 |
+
# anchor points
|
53 |
+
b1 = b[:, None] # (B,1,1,1)
|
54 |
+
h1 = h[:, None] + d # (1,1,H,1)
|
55 |
+
w1 = w[:, None] + d # (1,1,1,W)
|
56 |
+
anchors = points[b1, h1, w1] # (B,H,W,3) -> (B,1,H,W,3)
|
57 |
+
|
58 |
+
# neighbor points
|
59 |
+
offset = offsets[k] # (8,2)
|
60 |
+
b2 = b1
|
61 |
+
h2 = h1 + offset[None, :, 0, None, None] # (1,8,H,1)
|
62 |
+
w2 = w1 + offset[None, :, 1, None, None] # (1,8,1,W)
|
63 |
+
points1 = points[b2, h2, w2] # (B,8,H,W,3)
|
64 |
+
|
65 |
+
# anothor neighbor points
|
66 |
+
offset = offsets[(k + 2) % 8]
|
67 |
+
b3 = b1
|
68 |
+
h3 = h1 + offset[None, :, 0, None, None]
|
69 |
+
w3 = w1 + offset[None, :, 1, None, None]
|
70 |
+
points2 = points[b3, h3, w3] # (B,8,H,W,3)
|
71 |
+
|
72 |
+
if mode == "closest":
|
73 |
+
# find the closest neighbor pair
|
74 |
+
diff = torch.norm(points1 - anchors, dim=4)
|
75 |
+
diff = diff + torch.norm(points2 - anchors, dim=4)
|
76 |
+
i = torch.argmin(diff, dim=1) # (B,H,W)
|
77 |
+
# get normals by cross product
|
78 |
+
anchors = anchors[b, 0, h, w] # (B,H,W,3)
|
79 |
+
points1 = points1[b, i, h, w] # (B,H,W,3)
|
80 |
+
points2 = points2[b, i, h, w] # (B,H,W,3)
|
81 |
+
vector1 = points1 - anchors
|
82 |
+
vector2 = points2 - anchors
|
83 |
+
normals = torch.cross(vector1, vector2, dim=-1) # (B,H,W,3)
|
84 |
+
elif mode == "mean":
|
85 |
+
# get normals by cross product
|
86 |
+
vector1 = points1 - anchors
|
87 |
+
vector2 = points2 - anchors
|
88 |
+
normals = torch.cross(vector1, vector2, dim=-1) # (B,8,H,W,3)
|
89 |
+
normals = normals.mean(dim=1) # (B,H,W,3)
|
90 |
+
else:
|
91 |
+
raise NotImplementedError(mode)
|
92 |
+
|
93 |
+
normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8)
|
94 |
+
normals = normals.permute(0, 3, 1, 2) # (B,3,H,W)
|
95 |
+
|
96 |
+
return normals
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
einops
|
2 |
-
kornia
|
3 |
matplotlib
|
4 |
numpy
|
5 |
torch
|
|
|
1 |
einops
|
|
|
2 |
matplotlib
|
3 |
numpy
|
4 |
torch
|
style.css
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.head {
|
2 |
+
text-align: center;
|
3 |
+
display: block;
|
4 |
+
font-size: var(--text-xl);
|
5 |
+
}
|
6 |
+
|
7 |
+
.title {
|
8 |
+
font-size: var(--text-xxl);
|
9 |
+
font-weight: bold;
|
10 |
+
margin-top: 2rem;
|
11 |
+
}
|
12 |
+
|
13 |
+
.description {
|
14 |
+
font-size: var(--text-lg);
|
15 |
+
}
|