|
import math |
|
import torch |
|
import numpy as np |
|
import dearpygui.dearpygui as dpg |
|
from scipy.spatial.transform import Rotation as R |
|
|
|
from nerf.utils import * |
|
|
|
|
|
class OrbitCamera: |
|
def __init__(self, W, H, r=2, fovy=60): |
|
self.W = W |
|
self.H = H |
|
self.radius = r |
|
self.fovy = fovy |
|
self.center = np.array([0, 0, 0], dtype=np.float32) |
|
self.rot = R.from_quat([1, 0, 0, 0]) |
|
self.up = np.array([0, 1, 0], dtype=np.float32) |
|
|
|
|
|
@property |
|
def pose(self): |
|
|
|
res = np.eye(4, dtype=np.float32) |
|
res[2, 3] -= self.radius |
|
|
|
rot = np.eye(4, dtype=np.float32) |
|
rot[:3, :3] = self.rot.as_matrix() |
|
res = rot @ res |
|
|
|
res[:3, 3] -= self.center |
|
return res |
|
|
|
|
|
@property |
|
def intrinsics(self): |
|
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) |
|
return np.array([focal, focal, self.W // 2, self.H // 2]) |
|
|
|
def orbit(self, dx, dy): |
|
|
|
side = self.rot.as_matrix()[:3, 0] |
|
rotvec_x = self.up * np.deg2rad(-0.1 * dx) |
|
rotvec_y = side * np.deg2rad(-0.1 * dy) |
|
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot |
|
|
|
def scale(self, delta): |
|
self.radius *= 1.1 ** (-delta) |
|
|
|
def pan(self, dx, dy, dz=0): |
|
|
|
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) |
|
|
|
|
|
class NeRFGUI: |
|
def __init__(self, opt, trainer, debug=True): |
|
self.opt = opt |
|
self.W = opt.W |
|
self.H = opt.H |
|
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) |
|
self.debug = debug |
|
self.bg_color = torch.ones(3, dtype=torch.float32) |
|
self.training = False |
|
self.step = 0 |
|
|
|
self.trainer = trainer |
|
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) |
|
self.need_update = True |
|
self.spp = 1 |
|
self.light_dir = np.array([opt.light_theta, opt.light_phi]) |
|
self.ambient_ratio = 1.0 |
|
self.mode = 'image' |
|
self.shading = 'albedo' |
|
|
|
self.dynamic_resolution = True |
|
self.downscale = 1 |
|
self.train_steps = 16 |
|
|
|
dpg.create_context() |
|
self.register_dpg() |
|
self.test_step() |
|
|
|
|
|
def __del__(self): |
|
dpg.destroy_context() |
|
|
|
|
|
def train_step(self): |
|
|
|
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
starter.record() |
|
|
|
outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps) |
|
|
|
ender.record() |
|
torch.cuda.synchronize() |
|
t = starter.elapsed_time(ender) |
|
|
|
self.step += self.train_steps |
|
self.need_update = True |
|
|
|
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') |
|
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') |
|
|
|
|
|
|
|
full_t = t / self.train_steps * 16 |
|
train_steps = min(16, max(4, int(16 * 500 / full_t))) |
|
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: |
|
self.train_steps = train_steps |
|
|
|
|
|
def prepare_buffer(self, outputs): |
|
if self.mode == 'image': |
|
return outputs['image'] |
|
else: |
|
return np.expand_dims(outputs['depth'], -1).repeat(3, -1) |
|
|
|
|
|
def test_step(self): |
|
|
|
if self.need_update or self.spp < self.opt.max_spp: |
|
|
|
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
starter.record() |
|
|
|
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading) |
|
|
|
ender.record() |
|
torch.cuda.synchronize() |
|
t = starter.elapsed_time(ender) |
|
|
|
|
|
if self.dynamic_resolution: |
|
|
|
full_t = t / (self.downscale ** 2) |
|
downscale = min(1, max(1/4, math.sqrt(200 / full_t))) |
|
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: |
|
self.downscale = downscale |
|
|
|
if self.need_update: |
|
self.render_buffer = self.prepare_buffer(outputs) |
|
self.spp = 1 |
|
self.need_update = False |
|
else: |
|
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) |
|
self.spp += 1 |
|
|
|
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') |
|
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') |
|
dpg.set_value("_log_spp", self.spp) |
|
dpg.set_value("_texture", self.render_buffer) |
|
|
|
|
|
def register_dpg(self): |
|
|
|
|
|
|
|
with dpg.texture_registry(show=False): |
|
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") |
|
|
|
|
|
|
|
|
|
with dpg.window(tag="_primary_window", width=self.W, height=self.H): |
|
|
|
|
|
dpg.add_image("_texture") |
|
|
|
dpg.set_primary_window("_primary_window", True) |
|
|
|
|
|
with dpg.window(label="Control", tag="_control_window", width=400, height=300): |
|
|
|
|
|
if self.opt.text is not None: |
|
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text") |
|
|
|
|
|
with dpg.theme() as theme_button: |
|
with dpg.theme_component(dpg.mvButton): |
|
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) |
|
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) |
|
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) |
|
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) |
|
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) |
|
|
|
|
|
if not self.opt.test: |
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Train time: ") |
|
dpg.add_text("no data", tag="_log_train_time") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Infer time: ") |
|
dpg.add_text("no data", tag="_log_infer_time") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("SPP: ") |
|
dpg.add_text("1", tag="_log_spp") |
|
|
|
|
|
if not self.opt.test: |
|
with dpg.collapsing_header(label="Train", default_open=True): |
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Train: ") |
|
|
|
def callback_train(sender, app_data): |
|
if self.training: |
|
self.training = False |
|
dpg.configure_item("_button_train", label="start") |
|
else: |
|
self.training = True |
|
dpg.configure_item("_button_train", label="stop") |
|
|
|
dpg.add_button(label="start", tag="_button_train", callback=callback_train) |
|
dpg.bind_item_theme("_button_train", theme_button) |
|
|
|
def callback_reset(sender, app_data): |
|
@torch.no_grad() |
|
def weight_reset(m: nn.Module): |
|
reset_parameters = getattr(m, "reset_parameters", None) |
|
if callable(reset_parameters): |
|
m.reset_parameters() |
|
self.trainer.model.apply(fn=weight_reset) |
|
self.trainer.model.reset_extra_state() |
|
self.need_update = True |
|
|
|
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) |
|
dpg.bind_item_theme("_button_reset", theme_button) |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Checkpoint: ") |
|
|
|
def callback_save(sender, app_data): |
|
self.trainer.save_checkpoint(full=True, best=False) |
|
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) |
|
self.trainer.epoch += 1 |
|
|
|
dpg.add_button(label="save", tag="_button_save", callback=callback_save) |
|
dpg.bind_item_theme("_button_save", theme_button) |
|
|
|
dpg.add_text("", tag="_log_ckpt") |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Marching Cubes: ") |
|
|
|
def callback_mesh(sender, app_data): |
|
self.trainer.save_mesh(resolution=256, threshold=10) |
|
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') |
|
self.trainer.epoch += 1 |
|
|
|
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) |
|
dpg.bind_item_theme("_button_mesh", theme_button) |
|
|
|
dpg.add_text("", tag="_log_mesh") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("", tag="_log_train_log") |
|
|
|
|
|
|
|
with dpg.collapsing_header(label="Options", default_open=True): |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
|
|
def callback_set_dynamic_resolution(sender, app_data): |
|
if self.dynamic_resolution: |
|
self.dynamic_resolution = False |
|
self.downscale = 1 |
|
else: |
|
self.dynamic_resolution = True |
|
self.need_update = True |
|
|
|
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) |
|
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") |
|
|
|
|
|
def callback_change_mode(sender, app_data): |
|
self.mode = app_data |
|
self.need_update = True |
|
|
|
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) |
|
|
|
|
|
def callback_change_bg(sender, app_data): |
|
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) |
|
self.need_update = True |
|
|
|
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) |
|
|
|
|
|
def callback_set_fovy(sender, app_data): |
|
self.cam.fovy = app_data |
|
self.need_update = True |
|
|
|
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) |
|
|
|
|
|
def callback_set_dt_gamma(sender, app_data): |
|
self.opt.dt_gamma = app_data |
|
self.need_update = True |
|
|
|
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) |
|
|
|
|
|
def callback_set_max_steps(sender, app_data): |
|
self.opt.max_steps = app_data |
|
self.need_update = True |
|
|
|
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) |
|
|
|
|
|
def callback_set_aabb(sender, app_data, user_data): |
|
|
|
self.trainer.model.aabb_infer[user_data] = app_data |
|
|
|
|
|
|
|
|
|
self.need_update = True |
|
|
|
dpg.add_separator() |
|
dpg.add_text("Axis-aligned bounding box:") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) |
|
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) |
|
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) |
|
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) |
|
|
|
|
|
def callback_set_light_dir(sender, app_data, user_data): |
|
self.light_dir[user_data] = app_data |
|
self.need_update = True |
|
|
|
dpg.add_separator() |
|
dpg.add_text("Plane Light Direction:") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0) |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1) |
|
|
|
|
|
def callback_set_abm_ratio(sender, app_data): |
|
self.ambient_ratio = app_data |
|
self.need_update = True |
|
|
|
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio) |
|
|
|
|
|
def callback_change_shading(sender, app_data): |
|
self.shading = app_data |
|
self.need_update = True |
|
|
|
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading) |
|
|
|
|
|
|
|
if self.debug: |
|
with dpg.collapsing_header(label="Debug"): |
|
|
|
dpg.add_separator() |
|
dpg.add_text("Camera Pose:") |
|
dpg.add_text(str(self.cam.pose), tag="_log_pose") |
|
|
|
|
|
|
|
|
|
def callback_camera_drag_rotate(sender, app_data): |
|
|
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
dx = app_data[1] |
|
dy = app_data[2] |
|
|
|
self.cam.orbit(dx, dy) |
|
self.need_update = True |
|
|
|
if self.debug: |
|
dpg.set_value("_log_pose", str(self.cam.pose)) |
|
|
|
|
|
def callback_camera_wheel_scale(sender, app_data): |
|
|
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
delta = app_data |
|
|
|
self.cam.scale(delta) |
|
self.need_update = True |
|
|
|
if self.debug: |
|
dpg.set_value("_log_pose", str(self.cam.pose)) |
|
|
|
|
|
def callback_camera_drag_pan(sender, app_data): |
|
|
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
dx = app_data[1] |
|
dy = app_data[2] |
|
|
|
self.cam.pan(dx, dy) |
|
self.need_update = True |
|
|
|
if self.debug: |
|
dpg.set_value("_log_pose", str(self.cam.pose)) |
|
|
|
|
|
with dpg.handler_registry(): |
|
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) |
|
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) |
|
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) |
|
|
|
|
|
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with dpg.theme() as theme_no_padding: |
|
with dpg.theme_component(dpg.mvAll): |
|
|
|
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) |
|
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) |
|
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) |
|
|
|
dpg.bind_item_theme("_primary_window", theme_no_padding) |
|
|
|
dpg.setup_dearpygui() |
|
|
|
|
|
|
|
dpg.show_viewport() |
|
|
|
|
|
def render(self): |
|
|
|
while dpg.is_dearpygui_running(): |
|
|
|
if self.training: |
|
self.train_step() |
|
self.test_step() |
|
dpg.render_dearpygui_frame() |