Kazuto Nakashima commited on
Commit
365f709
1 Parent(s): ffb8a4e
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -1,16 +1,26 @@
 
1
  import gradio as gr
2
  import matplotlib.cm as cm
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
5
  import torch
6
  import torch.nn.functional as F
7
- import einops
8
- import plotly.graph_objects as go
 
 
 
 
 
9
 
10
  torch.set_grad_enabled(False)
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
  ddpm, lidar_utils, _ = torch.hub.load(
13
- "kazuto1011/r2dm", "pretrained_r2dm", device=device
 
 
14
  )
15
 
16
 
@@ -30,19 +40,7 @@ def render_point_cloud(output, cmap):
30
  rflct = output[:, [1]]
31
  point = lidar_utils.to_xyz(depth).cpu().numpy()
32
  point = einops.rearrange(point, "1 c h w -> c (h w)")
33
- angle = lidar_utils.ray_angles.rad2deg()
34
- label = [
35
- f"depth: {float(d):.2f}m<br>"
36
- + f"reflectance: {float(r):.2f}<br>"
37
- + f"elevation: {float(e):.2f}°<br>"
38
- + f"azimuth: {float(a):.2f}°"
39
- for d, r, e, a in zip(
40
- einops.rearrange(depth, "1 1 h w -> (h w)"),
41
- einops.rearrange(rflct, "1 1 h w -> (h w)"),
42
- einops.rearrange(angle[0, 0], "h w -> (h w)"),
43
- einops.rearrange(angle[0, 1], "h w -> (h w)"),
44
- )
45
- ]
46
  fig = go.Figure(
47
  data=[
48
  go.Scatter3d(
@@ -59,8 +57,19 @@ def render_point_cloud(output, cmap):
59
  cmin=-2,
60
  cmax=0.5,
61
  ),
62
- text=label,
63
- hoverinfo="text",
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
  ],
66
  layout=dict(
@@ -88,7 +97,7 @@ def generate(num_steps, cmap_name, progress=gr.Progress()):
88
  for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
89
  step_t = steps[:, i]
90
  step_s = steps[:, i + 1]
91
- x = ddpm.p_sample(x, step_t, step_s)
92
  return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
93
 
94
 
@@ -96,14 +105,17 @@ with gr.Blocks() as demo:
96
  gr.Markdown(
97
  """
98
  # R2DM
99
- R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
100
  > **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
101
  Kazuto Nakashima, Ryo Kurazume<br>
102
- [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
 
 
 
103
  """
104
  )
105
  with gr.Row():
106
  with gr.Column():
 
107
  num_steps = gr.Dropdown(
108
  choices=[2**i for i in range(2, 10)],
109
  value=16,
 
1
+ import einops
2
  import gradio as gr
3
  import matplotlib.cm as cm
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ import plotly.graph_objects as go
7
  import torch
8
  import torch.nn.functional as F
9
+
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ elif torch.backends.mps.is_available():
13
+ device = "mps"
14
+ else:
15
+ device = "cpu"
16
 
17
  torch.set_grad_enabled(False)
18
+ device = torch.device(device)
19
+
20
  ddpm, lidar_utils, _ = torch.hub.load(
21
+ "kazuto1011/r2dm",
22
+ "pretrained_r2dm",
23
+ device=device,
24
  )
25
 
26
 
 
40
  rflct = output[:, [1]]
41
  point = lidar_utils.to_xyz(depth).cpu().numpy()
42
  point = einops.rearrange(point, "1 c h w -> c (h w)")
43
+ # angle = lidar_utils.ray_angles.rad2deg()
 
 
 
 
 
 
 
 
 
 
 
 
44
  fig = go.Figure(
45
  data=[
46
  go.Scatter3d(
 
57
  cmin=-2,
58
  cmax=0.5,
59
  ),
60
+ # text=[
61
+ # f"depth: {float(d):.2f}m<br>"
62
+ # + f"reflectance: {float(r):.2f}<br>"
63
+ # + f"elevation: {float(e):.2f}°<br>"
64
+ # + f"azimuth: {float(a):.2f}°"
65
+ # for d, r, e, a in zip(
66
+ # einops.rearrange(depth, "1 1 h w -> (h w)"),
67
+ # einops.rearrange(rflct, "1 1 h w -> (h w)"),
68
+ # einops.rearrange(angle[0, 0], "h w -> (h w)"),
69
+ # einops.rearrange(angle[0, 1], "h w -> (h w)"),
70
+ # )
71
+ # ],
72
+ # hoverinfo="text",
73
  )
74
  ],
75
  layout=dict(
 
97
  for i in progress.tqdm(range(num_steps), desc="Generating LiDAR data"):
98
  step_t = steps[:, i]
99
  step_s = steps[:, i + 1]
100
+ x = ddpm.p_step(x, step_t, step_s)
101
  return render_point_cloud(x, plt.colormaps.get_cmap(cmap_name))
102
 
103
 
 
105
  gr.Markdown(
106
  """
107
  # R2DM
 
108
  > **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
109
  Kazuto Nakashima, Ryo Kurazume<br>
110
+ ICRA 2024<br>
111
+ [[Project]](https://kazuto1011.github.io/r2dm/) [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
112
+
113
+ R2DM is a denoising diffusion probabilistic model (DDPM) for LiDAR range/reflectance generation based on the equirectangular representation.
114
  """
115
  )
116
  with gr.Row():
117
  with gr.Column():
118
+ gr.Textbox(device, label="Device")
119
  num_steps = gr.Dropdown(
120
  choices=[2**i for i in range(2, 10)],
121
  value=16,