|
import numpy as np |
|
import torch |
|
import imageio |
|
|
|
from my.utils.tqdm import tqdm |
|
from my.utils.event import EventStorage, read_stats, get_event_storage |
|
from my.utils.heartbeat import HeartBeat, get_heartbeat |
|
from my.utils.debug import EarlyLoopBreak |
|
|
|
from .utils import PSNR, Scrambler, every, at |
|
from .data import load_blender |
|
from .render import ( |
|
as_torch_tsrs, scene_box_filter, render_ray_bundle, render_one_view, rays_from_img |
|
) |
|
from .vis import vis, stitch_vis |
|
|
|
|
|
device_glb = torch.device("cuda") |
|
|
|
|
|
def all_train_rays(scene): |
|
imgs, K, poses = load_blender("train", scene) |
|
num_imgs = len(imgs) |
|
ro, rd, rgbs = [], [], [] |
|
for i in tqdm(range(num_imgs)): |
|
img, pose = imgs[i], poses[i] |
|
H, W = img.shape[:2] |
|
_ro, _rd = rays_from_img(H, W, K, pose) |
|
ro.append(_ro) |
|
rd.append(_rd) |
|
rgbs.append(img.reshape(-1, 3)) |
|
|
|
ro, rd, rgbs = [ |
|
np.concatenate(xs, axis=0) for xs in (ro, rd, rgbs) |
|
] |
|
return ro, rd, rgbs |
|
|
|
|
|
class OneTestView(): |
|
def __init__(self, scene): |
|
imgs, K, poses = load_blender("test", scene) |
|
self.imgs, self.K, self.poses = imgs, K, poses |
|
self.i = 0 |
|
|
|
def render(self, model): |
|
i = self.i |
|
img, K, pose = self.imgs[i], self.K, self.poses[i] |
|
with torch.no_grad(): |
|
aabb = model.aabb.T.cpu().numpy() |
|
H, W = img.shape[:2] |
|
rgbs, depth = render_one_view(model, aabb, H, W, K, pose) |
|
psnr = PSNR.psnr(img, rgbs) |
|
|
|
self.i = (self.i + 1) % len(self.imgs) |
|
|
|
return img, rgbs, depth, psnr |
|
|
|
|
|
def train( |
|
model, n_epoch=2, bs=4096, lr=0.02, scene="lego" |
|
): |
|
fuse = EarlyLoopBreak(500) |
|
|
|
aabb = model.aabb.T.numpy() |
|
model = model.to(device_glb) |
|
optim = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
|
test_view = OneTestView(scene) |
|
all_ro, all_rd, all_rgbs = all_train_rays(scene) |
|
print(n_epoch, len(all_ro), bs) |
|
with tqdm(total=(n_epoch * len(all_ro) // bs)) as pbar, \ |
|
HeartBeat(pbar) as hbeat, EventStorage() as metric: |
|
|
|
ro, rd, t_min, t_max, intsct_inds = scene_box_filter(all_ro, all_rd, aabb) |
|
rgbs = all_rgbs[intsct_inds] |
|
print(len(ro)) |
|
for epc in range(n_epoch): |
|
n = len(ro) |
|
scrambler = Scrambler(n) |
|
ro, rd, t_min, t_max, rgbs = scrambler.apply(ro, rd, t_min, t_max, rgbs) |
|
|
|
num_batch = int(np.ceil(n / bs)) |
|
for i in range(num_batch): |
|
if fuse.on_break(): |
|
break |
|
s = i * bs |
|
e = min(n, s + bs) |
|
|
|
optim.zero_grad() |
|
_ro, _rd, _t_min, _t_max, _rgbs = as_torch_tsrs( |
|
model.device, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e], rgbs[s:e] |
|
) |
|
pred, _, _ = render_ray_bundle(model, _ro, _rd, _t_min, _t_max) |
|
loss = ((pred - _rgbs) ** 2).mean() |
|
loss.backward() |
|
optim.step() |
|
|
|
pbar.update() |
|
|
|
psnr = PSNR.psnr_from_mse(loss.item()) |
|
metric.put_scalars(psnr=psnr, d_scale=model.d_scale.item()) |
|
|
|
if every(pbar, step=50): |
|
pbar.set_description(f"TRAIN: psnr {psnr:.2f}") |
|
|
|
if every(pbar, percent=1): |
|
gimg, rimg, depth, psnr = test_view.render(model) |
|
pane = vis( |
|
gimg, rimg, depth, |
|
msg=f"psnr: {psnr:.2f}", return_buffer=True |
|
) |
|
metric.put_artifact( |
|
"vis", ".png", lambda fn: imageio.imwrite(fn, pane) |
|
) |
|
|
|
if at(pbar, percent=30): |
|
model.make_alpha_mask() |
|
|
|
if every(pbar, percent=35): |
|
target_xyz = (model.grid_size * 1.328).int().tolist() |
|
model.resample(target_xyz) |
|
optim = torch.optim.Adam(model.parameters(), lr=lr) |
|
print(f"resamp the voxel to {model.grid_size}") |
|
|
|
curr_lr = update_lr(pbar, optim, lr) |
|
metric.put_scalars(lr=curr_lr) |
|
|
|
metric.step() |
|
hbeat.beat() |
|
|
|
metric.put_artifact( |
|
"ckpt", ".pt", lambda fn: torch.save(model.state_dict(), fn) |
|
) |
|
|
|
|
|
metric.put_artifact( |
|
"train_seq", ".mp4", |
|
lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "vis")[1]) |
|
) |
|
|
|
with EventStorage("test"): |
|
final_psnr = test(model, scene) |
|
metric.put("test_psnr", final_psnr) |
|
|
|
metric.step() |
|
|
|
hbeat.done() |
|
|
|
|
|
def update_lr(pbar, optimizer, init_lr): |
|
i, N = pbar.n, pbar.total |
|
factor = 0.1 ** (1 / N) |
|
lr = init_lr * (factor ** i) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
return lr |
|
|
|
|
|
def last_ckpt(): |
|
ts, ckpts = read_stats("./", "ckpt") |
|
if len(ckpts) > 0: |
|
fname = ckpts[-1] |
|
last = torch.load(fname, map_location="cpu") |
|
print(f"loaded ckpt from iter {ts[-1]}") |
|
return last |
|
|
|
|
|
def __evaluate_ckpt(model, scene): |
|
|
|
|
|
metric = get_event_storage() |
|
|
|
state = last_ckpt() |
|
if state is not None: |
|
model.load_state_dict(state) |
|
model.to(device_glb) |
|
|
|
with EventStorage("test"): |
|
final_psnr = test(model, scene) |
|
metric.put("test_psnr", final_psnr) |
|
|
|
|
|
def test(model, scene): |
|
fuse = EarlyLoopBreak(5) |
|
metric = get_event_storage() |
|
hbeat = get_heartbeat() |
|
|
|
aabb = model.aabb.T.cpu().numpy() |
|
model = model.to(device_glb) |
|
|
|
imgs, K, poses = load_blender("test", scene) |
|
num_imgs = len(imgs) |
|
|
|
stats = [] |
|
|
|
for i in (pbar := tqdm(range(num_imgs))): |
|
if fuse.on_break(): |
|
break |
|
|
|
img, pose = imgs[i], poses[i] |
|
H, W = img.shape[:2] |
|
rgbs, depth = render_one_view(model, aabb, H, W, K, pose) |
|
psnr = PSNR.psnr(img, rgbs) |
|
|
|
stats.append(psnr) |
|
metric.put_scalars(psnr=psnr) |
|
pbar.set_description(f"TEST: mean psnr {np.mean(stats):.2f}") |
|
|
|
plot = vis(img, rgbs, depth, msg=f"PSNR: {psnr:.2f}", return_buffer=True) |
|
metric.put_artifact("test_vis", ".png", lambda fn: imageio.imwrite(fn, plot)) |
|
metric.step() |
|
hbeat.beat() |
|
|
|
metric.put_artifact( |
|
"test_seq", ".mp4", |
|
lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "test_vis")[1]) |
|
) |
|
|
|
final_psnr = np.mean(stats) |
|
metric.put("final_psnr", final_psnr) |
|
metric.step() |
|
|
|
return final_psnr |
|
|