r2dm / app.py
kazuto1011's picture
update
1ee5104
raw
history blame
No virus
2.4 kB
import gradio as gr
import matplotlib.cm as cm
import numpy as np
import torch
import torch.nn.functional as F
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm, lidar_utils, _ = torch.hub.load(
"kazuto1011/r2dm", "pretrained_r2dm", device=device
)
def colorize(tensor, cmap_fn=cm.turbo):
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
colors = torch.from_numpy(colors).to(tensor)
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
ids = (tensor * 256).clamp(0, 255).long()
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
tensor = tensor.mul(255).clamp(0, 255).byte()
return tensor
@torch.no_grad()
def generate(num_steps) -> str:
output = ddpm.sample(batch_size=1, num_steps=int(num_steps))
output = lidar_utils.denormalize(output.clamp(-1, 1))
range_image = lidar_utils.revert_depth(output[:, [0]])
range_image = (range_image / lidar_utils.max_depth).clamp(0, 1)
reflectance_image = output[:, [1]]
range_image = colorize(range_image)[0].permute(1, 2, 0)
reflectance_image = colorize(reflectance_image)[0].permute(1, 2, 0)
return range_image.cpu().numpy(), reflectance_image.cpu().numpy()
with gr.Blocks() as demo:
gr.Markdown(
"""
# R2DM Demo
**LiDAR Data Synthesis with Denoising Diffusion Probabilistic Models**<br>
Kazuto Nakashima, Ryo Kurazume<br>
[[arXiv]](https://arxiv.org/abs/2309.09256) [[Code]](https://github.com/kazuto1011/r2dm)
"""
)
with gr.Row():
with gr.Column():
gr.Text(f"Device: {device}", label="device")
num_steps = gr.Dropdown(
choices=[2**i for i in range(3, 11)],
value=8,
label="number of sampling steps",
)
btn = gr.Button(value="Generate random samples")
with gr.Column():
range_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Range image",
scale=1,
)
rflct_view = gr.Image(
type="numpy",
image_mode="RGB",
label="Reflectance image",
scale=1,
)
btn.click(
generate,
inputs=[num_steps],
outputs=[range_view, rflct_view],
)
demo.launch()