kazuto1011 commited on
Commit
1ee5104
β€’
1 Parent(s): 95e4642
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +76 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: R2dm
3
- emoji: πŸ“ˆ
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
+ title: R2DM
3
+ emoji: πŸš—
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.cm as cm
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ torch.set_grad_enabled(False)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ ddpm, lidar_utils, _ = torch.hub.load(
10
+ "kazuto1011/r2dm", "pretrained_r2dm", device=device
11
+ )
12
+
13
+
14
+ def colorize(tensor, cmap_fn=cm.turbo):
15
+ colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
16
+ colors = torch.from_numpy(colors).to(tensor)
17
+ tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
18
+ ids = (tensor * 256).clamp(0, 255).long()
19
+ tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
20
+ tensor = tensor.mul(255).clamp(0, 255).byte()
21
+ return tensor
22
+
23
+
24
+ @torch.no_grad()
25
+ def generate(num_steps) -> str:
26
+ output = ddpm.sample(batch_size=1, num_steps=int(num_steps))
27
+ output = lidar_utils.denormalize(output.clamp(-1, 1))
28
+ range_image = lidar_utils.revert_depth(output[:, [0]])
29
+ range_image = (range_image / lidar_utils.max_depth).clamp(0, 1)
30
+ reflectance_image = output[:, [1]]
31
+ range_image = colorize(range_image)[0].permute(1, 2, 0)
32
+ reflectance_image = colorize(reflectance_image)[0].permute(1, 2, 0)
33
+ return range_image.cpu().numpy(), reflectance_image.cpu().numpy()
34
+
35
+
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown(
38
+ """
39
+ # R2DM Demo
40
+ **LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
41
+ Kazuto Nakashima, Ryo Kurazume<br>
42
+ [[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
43
+ """
44
+ )
45
+ with gr.Row():
46
+ with gr.Column():
47
+ gr.Text(f"Device: {device}", label="device")
48
+ num_steps = gr.Dropdown(
49
+ choices=[2**i for i in range(3, 11)],
50
+ value=8,
51
+ label="number of sampling steps",
52
+ )
53
+ btn = gr.Button(value="Generate random samples")
54
+
55
+ with gr.Column():
56
+ range_view = gr.Image(
57
+ type="numpy",
58
+ image_mode="RGB",
59
+ label="Range image",
60
+ scale=1,
61
+ )
62
+ rflct_view = gr.Image(
63
+ type="numpy",
64
+ image_mode="RGB",
65
+ label="Reflectance image",
66
+ scale=1,
67
+ )
68
+
69
+ btn.click(
70
+ generate,
71
+ inputs=[num_steps],
72
+ outputs=[range_view, rflct_view],
73
+ )
74
+
75
+
76
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ einops
2
+ kornia
3
+ matplotlib
4
+ numpy
5
+ torch
6
+ torchvision
7
+ tqdm