kazuto1011 commited on
Commit
059842e
1 Parent(s): 365f709
Files changed (4) hide show
  1. app.py +127 -93
  2. rendering.py +96 -0
  3. requirements.txt +0 -1
  4. style.css +15 -0
app.py CHANGED
@@ -1,12 +1,60 @@
1
  import einops
2
  import gradio as gr
3
  import matplotlib.cm as cm
4
- import matplotlib.pyplot as plt
5
  import numpy as np
6
  import plotly.graph_objects as go
7
  import torch
8
  import torch.nn.functional as F
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  if torch.cuda.is_available():
11
  device = "cuda"
12
  elif torch.backends.mps.is_available():
@@ -15,16 +63,28 @@ else:
15
  device = "cpu"
16
 
17
  torch.set_grad_enabled(False)
 
18
  device = torch.device(device)
19
 
20
- ddpm, lidar_utils, _ = torch.hub.load(
21
- "kazuto1011/r2dm",
22
- "pretrained_r2dm",
23
- device=device,
24
- )
25
-
26
-
27
- def colorize(tensor, cmap_fn=cm.turbo):
 
 
 
 
 
 
 
 
 
 
 
28
  colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
29
  colors = torch.from_numpy(colors).to(tensor)
30
  tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
@@ -34,42 +94,37 @@ def colorize(tensor, cmap_fn=cm.turbo):
34
  return tensor
35
 
36
 
37
- def render_point_cloud(output, cmap):
38
- output = lidar_utils.denormalize(output.clamp(-1, 1))
39
- depth = lidar_utils.revert_depth(output[:, [0]])
40
- rflct = output[:, [1]]
41
- point = lidar_utils.to_xyz(depth).cpu().numpy()
42
- point = einops.rearrange(point, "1 c h w -> c (h w)")
43
- # angle = lidar_utils.ray_angles.rad2deg()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  fig = go.Figure(
45
  data=[
46
  go.Scatter3d(
47
- x=-point[0],
48
- y=-point[1],
49
- z=point[2],
50
  mode="markers",
51
- marker=dict(
52
- size=1,
53
- color=point[2],
54
- colorscale="viridis",
55
- autocolorscale=False,
56
- cauto=False,
57
- cmin=-2,
58
- cmax=0.5,
59
- ),
60
- # text=[
61
- # f"depth: {float(d):.2f}m<br>"
62
- # + f"reflectance: {float(r):.2f}<br>"
63
- # + f"elevation: {float(e):.2f}°<br>"
64
- # + f"azimuth: {float(a):.2f}°"
65
- # for d, r, e, a in zip(
66
- # einops.rearrange(depth, "1 1 h w -> (h w)"),
67
- # einops.rearrange(rflct, "1 1 h w -> (h w)"),
68
- # einops.rearrange(angle[0, 0], "h w -> (h w)"),
69
- # einops.rearrange(angle[0, 1], "h w -> (h w)"),
70
- # )
71
- # ],
72
- # hoverinfo="text",
73
  )
74
  ],
75
  layout=dict(
@@ -85,71 +140,50 @@ def render_point_cloud(output, cmap):
85
  ),
86
  )
87
  depth = depth / lidar_utils.max_depth
88
- depth = colorize(depth, cmap)[0].permute(1, 2, 0).cpu().numpy()
89
- rflct = colorize(rflct, cmap)[0].permute(1, 2, 0).cpu().numpy()
90
- return depth, rflct, fig
91
-
92
 
93
- def generate(num_steps, cmap_name, progress=gr.Progress()):
94
- num_steps = int(num_steps)
95
- x = ddpm.randn(1, *ddpm.sampling_shape, device=ddpm.device)
96
- steps = torch.linspace(1.0, 0.0, num_steps + 1, device=ddpm.device)[None]
97
- for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
98
- step_t = steps[:, i]
99
- step_s = steps[:, i + 1]
100
- x = ddpm.p_step(x, step_t, step_s)
101
- return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
102
 
103
 
104
- with gr.Blocks() as demo:
105
- gr.Markdown(
106
- """
107
- # R2DM
108
- > **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
109
- Kazuto Nakashima, Ryo Kurazume<br>
110
- ICRA 2024<br>
111
- [[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
112
 
113
- R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
114
- """
115
- )
116
- with gr.Row():
117
  with gr.Column():
118
- gr.Textbox(device, label="Device")
119
- num_steps = gr.Dropdown(
120
- choices=[2**i for i in range(2, 10)],
121
- value=16,
122
- label="number of sampling steps (>256 is recommended)",
123
  )
124
- cmap_name = gr.Dropdown(
125
- choices=plt.colormaps(),
126
- value="turbo",
127
- label="colormap for range/reflectance images",
128
  )
129
- btn = gr.Button(value="Generate random samples")
 
 
 
 
 
130
 
131
  with gr.Column():
132
- range_view = gr.Image(
133
- type="numpy",
134
- image_mode="RGB",
135
- label="Range image",
136
- scale=1,
137
- )
138
- rflct_view = gr.Image(
139
- type="numpy",
140
- image_mode="RGB",
141
- label="Reflectance image",
142
- scale=1,
143
- )
144
- point_view = gr.Plot(
145
- label="Point cloud",
146
- scale=1,
147
- )
148
 
149
  btn.click(
150
  generate,
151
- inputs=[num_steps, cmap_name],
152
  outputs=[range_view, rflct_view, point_view],
 
153
  )
154
 
155
 
 
1
  import einops
2
  import gradio as gr
3
  import matplotlib.cm as cm
 
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  import torch
7
  import torch.nn.functional as F
8
 
9
+ from rendering import estimate_surface_normal
10
+
11
+ DESCRIPTION = """
12
+ <div class="head">
13
+ <div class="title">LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models</div>
14
+ <div class="authors">
15
+ <a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a>
16
+ &nbsp;&nbsp;&nbsp;
17
+ <a href="https://robotics.ait.kyushu-u.ac.jp/kurazume/en/" target="_blank" rel="noopener"> Ryo Kurazume</a>
18
+ </div>
19
+ <div class="affiliations">Kyushu University</div>
20
+ <div class="conference">ICRA 2024</div>
21
+ <div class="materials">
22
+ <a href="https://kazuto1011.github.io/r2dm">Project</a> |
23
+ <a href="https://arxiv.org/abs/2309.09256">Paper</a> |
24
+ <a href="https://github.com/kazuto1011/r2dm">Code</a>
25
+ </div>
26
+ <br>
27
+ <div class="description">
28
+ This is a demo of our paper "LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models" presented at ICRA 2024.<br>
29
+ We propose <strong>R2DM</strong>, a continuous-time diffusion model for LiDAR data generation based on the equirectangular range/reflectance image representation.<br>
30
+ </div>
31
+ <br>
32
+ </div>
33
+ """
34
+
35
+ RUN_LOCALLY = """
36
+ To run this demo locally:
37
+
38
+ ```bash
39
+ git clone https://huggingface.co/spaces/kazuto1011/r2dm
40
+ ```
41
+ ```bash
42
+ cd r2dm
43
+ ```
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
+ ```bash
48
+ pip install gradio
49
+ ```
50
+ ```bash
51
+ gradio app.py
52
+ ```
53
+ """
54
+
55
+ THEME = gr.themes.Default(font=gr.themes.GoogleFont("Titillium Web"))
56
+
57
+
58
  if torch.cuda.is_available():
59
  device = "cuda"
60
  elif torch.backends.mps.is_available():
 
63
  device = "cpu"
64
 
65
  torch.set_grad_enabled(False)
66
+ torch.backends.cudnn.benchmark = True
67
  device = torch.device(device)
68
 
69
+ model_dict = {
70
+ "KITTI Raw (64x512)": torch.hub.load(
71
+ "kazuto1011/r2dm",
72
+ "pretrained_r2dm",
73
+ config="r2dm-h-kittiraw-300k",
74
+ device="cpu",
75
+ show_info=False,
76
+ ),
77
+ "KITTI-360 (64x1024)": torch.hub.load(
78
+ "kazuto1011/r2dm",
79
+ "pretrained_r2dm",
80
+ config="r2dm-h-kitti360-300k",
81
+ device="cpu",
82
+ show_info=False,
83
+ ),
84
+ }
85
+
86
+
87
+ def colorize(tensor: torch.Tensor, cmap_fn=cm.turbo):
88
  colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
89
  colors = torch.from_numpy(colors).to(tensor)
90
  tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
 
94
  return tensor
95
 
96
 
97
+ def generate(num_steps: int, sampling_mode: str, dataset: str, progress=gr.Progress()):
98
+ # model setup
99
+ model, lidar_utils, _ = model_dict[dataset]
100
+ model.to(device)
101
+ lidar_utils.to(device)
102
+
103
+ # sampling
104
+ num_steps = int(num_steps)
105
+ x = model.randn(1, *model.sampling_shape, device=model.device)
106
+ steps = torch.linspace(1.0, 0.0, num_steps + 1, device=model.device)[None]
107
+ for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
108
+ step_t = steps[:, i]
109
+ step_s = steps[:, i + 1]
110
+ x = model.p_step(x, step_t, step_s, mode=sampling_mode.lower())
111
+
112
+ # rendering point cloud
113
+ x = lidar_utils.denormalize(x.clamp(-1, 1))
114
+ depth = lidar_utils.revert_depth(x[:, [0]])
115
+ rflct = x[:, [1]]
116
+ point = lidar_utils.to_xyz(depth)
117
+ color = (-estimate_surface_normal(point) + 1) / 2
118
+ point = einops.rearrange(point, "1 c h w -> (h w) c").cpu().numpy()
119
+ color = einops.rearrange(color, "1 c h w -> (h w) c").cpu().numpy()
120
  fig = go.Figure(
121
  data=[
122
  go.Scatter3d(
123
+ x=-point[..., 0],
124
+ y=-point[..., 1],
125
+ z=point[..., 2],
126
  mode="markers",
127
+ marker=dict(size=1, color=color),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  )
129
  ],
130
  layout=dict(
 
140
  ),
141
  )
142
  depth = depth / lidar_utils.max_depth
143
+ depth = colorize(depth, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
144
+ rflct = colorize(rflct, cm.turbo)[0].permute(1, 2, 0).cpu().numpy()
 
 
145
 
146
+ model.cpu()
147
+ lidar_utils.cpu()
148
+ return depth, rflct, fig
 
 
 
 
 
 
149
 
150
 
151
+ with gr.Blocks(css="./style.css", theme=THEME) as demo:
152
+ gr.HTML(DESCRIPTION)
 
 
 
 
 
 
153
 
154
+ with gr.Row(variant="panel"):
 
 
 
155
  with gr.Column():
156
+ gr.Textbox(device, label="Running device")
157
+ dataset = gr.Dropdown(
158
+ choices=list(model_dict.keys()),
159
+ value=list(model_dict.keys())[0],
160
+ label="Dataset",
161
  )
162
+ sampling_mode = gr.Dropdown(
163
+ choices=["DDPM", "DDIM"],
164
+ value="DDPM",
165
+ label="Sampler",
166
  )
167
+ num_steps = gr.Dropdown(
168
+ choices=[2**i for i in range(5, 11)],
169
+ value=32,
170
+ label="Number of sampling steps (>256 is recommended)",
171
+ )
172
+ btn = gr.Button(value="Generate")
173
 
174
  with gr.Column():
175
+ range_view = gr.Image(type="numpy", label="Range image")
176
+ rflct_view = gr.Image(type="numpy", label="Reflectance image")
177
+ point_view = gr.Plot(label="Point cloud")
178
+
179
+ with gr.Row(variant="panel"):
180
+ gr.Markdown(RUN_LOCALLY)
 
 
 
 
 
 
 
 
 
 
181
 
182
  btn.click(
183
  generate,
184
+ inputs=[num_steps, sampling_mode, dataset],
185
  outputs=[range_view, rflct_view, point_view],
186
+ concurrency_limit=1,
187
  )
188
 
189
 
rendering.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def estimate_surface_normal(
6
+ points: torch.Tensor, d: int = 2, mode: str = "closest"
7
+ ) -> torch.Tensor:
8
+ # estimate surface normal from coordinated point clouds
9
+ # re-implemented the following codes with pytorch:
10
+ # https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py
11
+ # https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py
12
+
13
+ assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}"
14
+ B, C, H, W = points.shape
15
+ assert C == 3, f"expected C==3, but got {C}"
16
+ device = points.device
17
+
18
+ # points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf"))
19
+ points = F.pad(points, (0, 0, d, d), mode="replicate")
20
+ points = F.pad(points, (d, d, 0, 0), mode="circular")
21
+ points = points.permute(0, 2, 3, 1) # (B,H,W,3)
22
+
23
+ # 8 adjacent offsets
24
+ # -----------
25
+ # | 7 | 6 | 5 |
26
+ # -----------
27
+ # | 0 | | 4 |
28
+ # -----------
29
+ # | 1 | 2 | 3 |
30
+ # -----------
31
+ offsets = torch.tensor(
32
+ [
33
+ # (dh,dw)
34
+ (-d, 0), # 0
35
+ (-d, d), # 1
36
+ (0, d), # 2
37
+ (d, d), # 3
38
+ (d, 0), # 4
39
+ (d, -d), # 5
40
+ (0, -d), # 6
41
+ (-d, -d), # 7
42
+ ],
43
+ device=device,
44
+ )
45
+
46
+ # (B,H,W) indices
47
+ b = torch.arange(B, device=device)[:, None, None]
48
+ h = torch.arange(H, device=device)[None, :, None]
49
+ w = torch.arange(W, device=device)[None, None, :]
50
+ k = torch.arange(8, device=device)
51
+
52
+ # anchor points
53
+ b1 = b[:, None] # (B,1,1,1)
54
+ h1 = h[:, None] + d # (1,1,H,1)
55
+ w1 = w[:, None] + d # (1,1,1,W)
56
+ anchors = points[b1, h1, w1] # (B,H,W,3) -> (B,1,H,W,3)
57
+
58
+ # neighbor points
59
+ offset = offsets[k] # (8,2)
60
+ b2 = b1
61
+ h2 = h1 + offset[None, :, 0, None, None] # (1,8,H,1)
62
+ w2 = w1 + offset[None, :, 1, None, None] # (1,8,1,W)
63
+ points1 = points[b2, h2, w2] # (B,8,H,W,3)
64
+
65
+ # anothor neighbor points
66
+ offset = offsets[(k + 2) % 8]
67
+ b3 = b1
68
+ h3 = h1 + offset[None, :, 0, None, None]
69
+ w3 = w1 + offset[None, :, 1, None, None]
70
+ points2 = points[b3, h3, w3] # (B,8,H,W,3)
71
+
72
+ if mode == "closest":
73
+ # find the closest neighbor pair
74
+ diff = torch.norm(points1 - anchors, dim=4)
75
+ diff = diff + torch.norm(points2 - anchors, dim=4)
76
+ i = torch.argmin(diff, dim=1) # (B,H,W)
77
+ # get normals by cross product
78
+ anchors = anchors[b, 0, h, w] # (B,H,W,3)
79
+ points1 = points1[b, i, h, w] # (B,H,W,3)
80
+ points2 = points2[b, i, h, w] # (B,H,W,3)
81
+ vector1 = points1 - anchors
82
+ vector2 = points2 - anchors
83
+ normals = torch.cross(vector1, vector2, dim=-1) # (B,H,W,3)
84
+ elif mode == "mean":
85
+ # get normals by cross product
86
+ vector1 = points1 - anchors
87
+ vector2 = points2 - anchors
88
+ normals = torch.cross(vector1, vector2, dim=-1) # (B,8,H,W,3)
89
+ normals = normals.mean(dim=1) # (B,H,W,3)
90
+ else:
91
+ raise NotImplementedError(mode)
92
+
93
+ normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8)
94
+ normals = normals.permute(0, 3, 1, 2) # (B,3,H,W)
95
+
96
+ return normals
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  einops
2
- kornia
3
  matplotlib
4
  numpy
5
  torch
 
1
  einops
 
2
  matplotlib
3
  numpy
4
  torch
style.css ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .head {
2
+ text-align: center;
3
+ display: block;
4
+ font-size: var(--text-xl);
5
+ }
6
+
7
+ .title {
8
+ font-size: var(--text-xxl);
9
+ font-weight: bold;
10
+ margin-top: 2rem;
11
+ }
12
+
13
+ .description {
14
+ font-size: var(--text-lg);
15
+ }