import numpy as np import torch import imageio from my.utils.tqdm import tqdm from my.utils.event import EventStorage, read_stats, get_event_storage from my.utils.heartbeat import HeartBeat, get_heartbeat from my.utils.debug import EarlyLoopBreak from .utils import PSNR, Scrambler, every, at from .data import load_blender from .render import ( as_torch_tsrs, scene_box_filter, render_ray_bundle, render_one_view, rays_from_img ) from .vis import vis, stitch_vis device_glb = torch.device("cuda") def all_train_rays(scene): imgs, K, poses = load_blender("train", scene) num_imgs = len(imgs) ro, rd, rgbs = [], [], [] for i in tqdm(range(num_imgs)): img, pose = imgs[i], poses[i] H, W = img.shape[:2] _ro, _rd = rays_from_img(H, W, K, pose) ro.append(_ro) rd.append(_rd) rgbs.append(img.reshape(-1, 3)) ro, rd, rgbs = [ np.concatenate(xs, axis=0) for xs in (ro, rd, rgbs) ] return ro, rd, rgbs class OneTestView(): def __init__(self, scene): imgs, K, poses = load_blender("test", scene) self.imgs, self.K, self.poses = imgs, K, poses self.i = 0 def render(self, model): i = self.i img, K, pose = self.imgs[i], self.K, self.poses[i] with torch.no_grad(): aabb = model.aabb.T.cpu().numpy() H, W = img.shape[:2] rgbs, depth = render_one_view(model, aabb, H, W, K, pose) psnr = PSNR.psnr(img, rgbs) self.i = (self.i + 1) % len(self.imgs) return img, rgbs, depth, psnr def train( model, n_epoch=2, bs=4096, lr=0.02, scene="lego" ): fuse = EarlyLoopBreak(500) aabb = model.aabb.T.numpy() model = model.to(device_glb) optim = torch.optim.Adam(model.parameters(), lr=lr) test_view = OneTestView(scene) all_ro, all_rd, all_rgbs = all_train_rays(scene) with tqdm(total=(n_epoch * len(all_ro) // bs)) as pbar, \ HeartBeat(pbar) as hbeat, EventStorage() as metric: ro, rd, t_min, t_max, intsct_inds = scene_box_filter(all_ro, all_rd, aabb) rgbs = all_rgbs[intsct_inds] for epc in range(n_epoch): n = len(ro) scrambler = Scrambler(n) ro, rd, t_min, t_max, rgbs = scrambler.apply(ro, rd, t_min, t_max, rgbs) num_batch = int(np.ceil(n / bs)) for i in range(num_batch): if fuse.on_break(): break s = i * bs e = min(n, s + bs) optim.zero_grad() _ro, _rd, _t_min, _t_max, _rgbs = as_torch_tsrs( model.device, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e], rgbs[s:e] ) pred, _, _ = render_ray_bundle(model, _ro, _rd, _t_min, _t_max) loss = ((pred - _rgbs) ** 2).mean() loss.backward() optim.step() pbar.update() psnr = PSNR.psnr_from_mse(loss.item()) metric.put_scalars(psnr=psnr, d_scale=model.d_scale.item()) if every(pbar, step=50): pbar.set_description(f"TRAIN: psnr {psnr:.2f}") if every(pbar, percent=1): gimg, rimg, depth, psnr = test_view.render(model) pane = vis( gimg, rimg, depth, msg=f"psnr: {psnr:.2f}", return_buffer=True ) metric.put_artifact( "vis", ".png", lambda fn: imageio.imwrite(fn, pane) ) if at(pbar, percent=30): model.make_alpha_mask() if every(pbar, percent=35): target_xyz = (model.grid_size * 1.328).int().tolist() model.resample(target_xyz) optim = torch.optim.Adam(model.parameters(), lr=lr) print(f"resamp the voxel to {model.grid_size}") curr_lr = update_lr(pbar, optim, lr) metric.put_scalars(lr=curr_lr) metric.step() hbeat.beat() metric.put_artifact( "ckpt", ".pt", lambda fn: torch.save(model.state_dict(), fn) ) # metric.step(flush=True) # no need to flush since the test routine directly takes the model metric.put_artifact( "train_seq", ".mp4", lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "vis")[1]) ) with EventStorage("test"): final_psnr = test(model, scene) metric.put("test_psnr", final_psnr) metric.step() hbeat.done() def update_lr(pbar, optimizer, init_lr): i, N = pbar.n, pbar.total factor = 0.1 ** (1 / N) lr = init_lr * (factor ** i) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr def last_ckpt(): ts, ckpts = read_stats("./", "ckpt") if len(ckpts) > 0: fname = ckpts[-1] last = torch.load(fname, map_location="cpu") print(f"loaded ckpt from iter {ts[-1]}") return last def __evaluate_ckpt(model, scene): # this is for external script that needs to evaluate an checkpoint # currently not used metric = get_event_storage() state = last_ckpt() if state is not None: model.load_state_dict(state) model.to(device_glb) with EventStorage("test"): final_psnr = test(model, scene) metric.put("test_psnr", final_psnr) def test(model, scene): fuse = EarlyLoopBreak(5) metric = get_event_storage() hbeat = get_heartbeat() aabb = model.aabb.T.cpu().numpy() model = model.to(device_glb) imgs, K, poses = load_blender("test", scene) num_imgs = len(imgs) stats = [] for i in (pbar := tqdm(range(num_imgs))): if fuse.on_break(): break img, pose = imgs[i], poses[i] H, W = img.shape[:2] rgbs, depth = render_one_view(model, aabb, H, W, K, pose) psnr = PSNR.psnr(img, rgbs) stats.append(psnr) metric.put_scalars(psnr=psnr) pbar.set_description(f"TEST: mean psnr {np.mean(stats):.2f}") plot = vis(img, rgbs, depth, msg=f"PSNR: {psnr:.2f}", return_buffer=True) metric.put_artifact("test_vis", ".png", lambda fn: imageio.imwrite(fn, plot)) metric.step() hbeat.beat() metric.put_artifact( "test_seq", ".mp4", lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "test_vis")[1]) ) final_psnr = np.mean(stats) metric.put("final_psnr", final_psnr) metric.step() return final_psnr