Spaces:
Runtime error
Runtime error
import os | |
import tyro | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from core.options import AllConfigs, Options | |
from core.gs import GaussianRenderer | |
import dearpygui.dearpygui as dpg | |
import kiui | |
from kiui.cam import OrbitCamera | |
class GUI: | |
def __init__(self, opt: Options): | |
self.opt = opt | |
self.W = opt.output_size | |
self.H = opt.output_size | |
self.cam = OrbitCamera(self.W, self.H, r=opt.cam_radius, fovy=opt.fovy) | |
self.device = torch.device("cuda") | |
self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) | |
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) | |
self.proj_matrix[0, 0] = 1 / self.tan_half_fov | |
self.proj_matrix[1, 1] = 1 / self.tan_half_fov | |
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) | |
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) | |
self.proj_matrix[2, 3] = 1 | |
self.mode = "image" | |
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) | |
self.need_update = True # update buffer_image | |
# renderer | |
self.renderer = GaussianRenderer(opt) | |
self.gaussain_scale_factor = 1 | |
self.gaussians = self.renderer.load_ply(opt.test_path).to(self.device) | |
dpg.create_context() | |
self.register_dpg() | |
self.test_step() | |
def __del__(self): | |
dpg.destroy_context() | |
def test_step(self): | |
# ignore if no need to update | |
if not self.need_update: | |
return | |
starter = torch.cuda.Event(enable_timing=True) | |
ender = torch.cuda.Event(enable_timing=True) | |
starter.record() | |
# should update image | |
if self.need_update: | |
# render image | |
cam_poses = torch.from_numpy(self.cam.pose).unsqueeze(0).to(self.device) | |
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
# cameras needed by gaussian rasterizer | |
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] | |
cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
buffer_image = self.renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=self.gaussain_scale_factor)[self.mode] | |
buffer_image = buffer_image.squeeze(1) # [B, C, H, W] | |
if self.mode in ['alpha']: | |
buffer_image = buffer_image.repeat(1, 3, 1, 1) | |
buffer_image = F.interpolate( | |
buffer_image, | |
size=(self.H, self.W), | |
mode="bilinear", | |
align_corners=False, | |
).squeeze(0) | |
self.buffer_image = ( | |
buffer_image.permute(1, 2, 0) | |
.contiguous() | |
.clamp(0, 1) | |
.contiguous() | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
self.need_update = False | |
ender.record() | |
torch.cuda.synchronize() | |
t = starter.elapsed_time(ender) | |
dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") | |
dpg.set_value( | |
"_texture", self.buffer_image | |
) # buffer must be contiguous, else seg fault! | |
def register_dpg(self): | |
### register texture | |
with dpg.texture_registry(show=False): | |
dpg.add_raw_texture( | |
self.W, | |
self.H, | |
self.buffer_image, | |
format=dpg.mvFormat_Float_rgb, | |
tag="_texture", | |
) | |
### register window | |
# the rendered image, as the primary window | |
with dpg.window( | |
tag="_primary_window", | |
width=self.W, | |
height=self.H, | |
pos=[0, 0], | |
no_move=True, | |
no_title_bar=True, | |
no_scrollbar=True, | |
): | |
# add the texture | |
dpg.add_image("_texture") | |
# dpg.set_primary_window("_primary_window", True) | |
# control window | |
with dpg.window( | |
label="Control", | |
tag="_control_window", | |
width=600, | |
height=self.H, | |
pos=[self.W, 0], | |
no_move=True, | |
no_title_bar=True, | |
): | |
# button theme | |
with dpg.theme() as theme_button: | |
with dpg.theme_component(dpg.mvButton): | |
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) | |
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) | |
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) | |
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) | |
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) | |
# timer stuff | |
with dpg.group(horizontal=True): | |
dpg.add_text("Infer time: ") | |
dpg.add_text("no data", tag="_log_infer_time") | |
# rendering options | |
with dpg.collapsing_header(label="Rendering", default_open=True): | |
# mode combo | |
def callback_change_mode(sender, app_data): | |
self.mode = app_data | |
self.need_update = True | |
dpg.add_combo( | |
("image", "alpha"), | |
label="mode", | |
default_value=self.mode, | |
callback=callback_change_mode, | |
) | |
# fov slider | |
def callback_set_fovy(sender, app_data): | |
self.cam.fovy = np.deg2rad(app_data) | |
self.need_update = True | |
dpg.add_slider_int( | |
label="FoV (vertical)", | |
min_value=1, | |
max_value=120, | |
format="%d deg", | |
default_value=np.rad2deg(self.cam.fovy), | |
callback=callback_set_fovy, | |
) | |
def callback_set_gaussain_scale(sender, app_data): | |
self.gaussain_scale_factor = app_data | |
self.need_update = True | |
dpg.add_slider_float( | |
label="gaussain scale", | |
min_value=0, | |
max_value=1, | |
format="%.2f", | |
default_value=self.gaussain_scale_factor, | |
callback=callback_set_gaussain_scale, | |
) | |
### register camera handler | |
def callback_camera_drag_rotate(sender, app_data): | |
if not dpg.is_item_focused("_primary_window"): | |
return | |
dx = app_data[1] | |
dy = app_data[2] | |
self.cam.orbit(dx, dy) | |
self.need_update = True | |
def callback_camera_wheel_scale(sender, app_data): | |
if not dpg.is_item_focused("_primary_window"): | |
return | |
delta = app_data | |
self.cam.scale(delta) | |
self.need_update = True | |
def callback_camera_drag_pan(sender, app_data): | |
if not dpg.is_item_focused("_primary_window"): | |
return | |
dx = app_data[1] | |
dy = app_data[2] | |
self.cam.pan(dx, dy) | |
self.need_update = True | |
with dpg.handler_registry(): | |
# for camera moving | |
dpg.add_mouse_drag_handler( | |
button=dpg.mvMouseButton_Left, | |
callback=callback_camera_drag_rotate, | |
) | |
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) | |
dpg.add_mouse_drag_handler( | |
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan | |
) | |
dpg.create_viewport( | |
title="Gaussian3D", | |
width=self.W + 600, | |
height=self.H + (45 if os.name == "nt" else 0), | |
resizable=False, | |
) | |
### global theme | |
with dpg.theme() as theme_no_padding: | |
with dpg.theme_component(dpg.mvAll): | |
# set all padding to 0 to avoid scroll bar | |
dpg.add_theme_style( | |
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core | |
) | |
dpg.add_theme_style( | |
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core | |
) | |
dpg.add_theme_style( | |
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core | |
) | |
dpg.bind_item_theme("_primary_window", theme_no_padding) | |
dpg.setup_dearpygui() | |
### register a larger font | |
# get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf | |
if os.path.exists("LXGWWenKai-Regular.ttf"): | |
with dpg.font_registry(): | |
with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: | |
dpg.bind_font(default_font) | |
# dpg.show_metrics() | |
dpg.show_viewport() | |
def render(self): | |
while dpg.is_dearpygui_running(): | |
# update texture every frame | |
self.test_step() | |
dpg.render_dearpygui_frame() | |
opt = tyro.cli(AllConfigs) | |
# load a saved ply and visualize | |
assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py' | |
gui = GUI(opt) | |
gui.render() |