LiDAR-Diffusion / app.py
Hancy's picture
init
851751e
raw
history blame
2.59 kB
import gradio as gr
import spaces
import tempfile
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from app_config import CSS, TITLE, DESCRIPTION, DEVICE
import sample_cond
model = sample_cond.load_model()
def create_custom_colormap():
colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
positions = [0, 0.38, 0.6, 0.7, 1]
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
return custom_cmap
def colorize_depth(depth, log_scale):
if log_scale:
depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
mask = depth == 0
colormap = create_custom_colormap()
rgb = colormap(depth)[:, :, :3]
rgb[mask] = 0.
return rgb
@spaces.GPU
@torch.no_grad()
def generate_lidar(model, cond):
img, pcd = sample_cond.sample(model, cond)
return img, pcd
def load_camera(image):
split_per_view = 4
camera = np.array(image).astype(np.float32) / 255.
camera = camera.transpose(2, 0, 1)
camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views
camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
return camera_cond
with gr.Blocks(css=CSS) as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
gr.Markdown("### Camera-to-LiDAR Demo")
# gr.Markdown("You can slide the output to compare the depth prediction with input image")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
output_image = gr.Image(label="Range Map", elem_id='img-display-output')
raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
submit = gr.Button("Submit")
def on_submit(image):
cond = load_camera(image)
img, pcd = generate_lidar(model, cond)
tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
pcd.save(tmp.name)
rgb_img = colorize_depth(img, log_scale=True)
return [rgb_img, tmp.name]
submit.click(on_submit, inputs=[input_image], outputs=[output_image, raw_file])
example_files = sorted(os.listdir('cam_examples'))
example_files = [os.path.join('cam_examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, raw_file],
fn=on_submit, cache_examples=True)
if __name__ == '__main__':
demo.queue().launch()