import math import numpy as np import torch import torch.nn as nn from einops import rearrange from imageio import imwrite from pydantic import validator from my.utils import ( tqdm, EventStorage, HeartBeat, EarlyLoopBreak, get_event_storage, get_heartbeat, read_stats ) from my.config import BaseConf, dispatch, optional_load_config from my.utils.seed import seed_everything from adapt import ScoreAdapter, karras_t_schedule from run_img_sampling import GDDPM, SD, StableDiffusion from misc import torch_samps_to_imgs from pose import PoseConfig from run_nerf import VoxConfig from voxnerf.utils import every from voxnerf.render import ( as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle ) from voxnerf.vis import stitch_vis, bad_vis as nerf_vis device_glb = torch.device("cuda") def tsr_stats(tsr): return { "mean": tsr.mean().item(), "std": tsr.std().item(), "max": tsr.max().item(), } class SJC(BaseConf): family: str = "sd" gddpm: GDDPM = GDDPM() sd: SD = SD( variant="v1", prompt="A high quality photo of a delicious burger", scale=100.0 ) lr: float = 0.05 n_steps: int = 10000 vox: VoxConfig = VoxConfig( model_type="V_SD", grid_size=100, density_shift=-1.0, c=3, blend_bg_texture=True, bg_texture_hw=4, bbox_len=1.0 ) pose: PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) emptiness_scale: int = 10 emptiness_weight: int = 1e4 emptiness_step: float = 0.5 emptiness_multiplier: float = 20.0 depth_weight: int = 0 var_red: bool = True @validator("vox") def check_vox(cls, vox_cfg, values): family = values['family'] if family == "sd": vox_cfg.c = 4 return vox_cfg def run(self): cfgs = self.dict() family = cfgs.pop("family") model = getattr(self, family).make() cfgs.pop("vox") vox = self.vox.make() cfgs.pop("pose") poser = self.pose.make() sjc_3d(**cfgs, poser=poser, model=model, vox=vox) def sjc_3d( poser, vox, model: ScoreAdapter, lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier, depth_weight, var_red, **kwargs ): del kwargs assert model.samps_centered() _, target_H, target_W = model.data_shape() bs = 1 aabb = vox.aabb.T.cpu().numpy() vox = vox.to(device_glb) opt = torch.optim.Adamax(vox.opt_params(), lr=lr) H, W = poser.H, poser.W Ks, poses, prompt_prefixes = poser.sample_train(n_steps) ts = model.us[30:-10] fuse = EarlyLoopBreak(5) same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1) with tqdm(total=n_steps) as pbar, \ HeartBeat(pbar) as hbeat, \ EventStorage() as metric: for i in range(n_steps): if fuse.on_break(): break p = f"{prompt_prefixes[i]} {model.prompt}" score_conds = model.prompts_emb([p]) y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True) if isinstance(model, StableDiffusion): pass else: y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear') opt.zero_grad() with torch.no_grad(): chosen_σs = np.random.choice(ts, bs, replace=False) chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) # chosen_σs = us[i] noise = torch.randn(bs, *y.shape[1:], device=model.device) zs = y + chosen_σs * noise Ds = model.denoise(zs, chosen_σs, **score_conds) if var_red: grad = (Ds - y) / chosen_σs else: grad = (Ds - zs) / chosen_σs grad = grad.mean(0, keepdim=True) y.backward(-grad, retain_graph=True) if depth_weight > 0: center_depth = depth[7:-7, 7:-7] border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) center_depth_mean = center_depth.mean() depth_diff = center_depth_mean - border_depth_mean depth_loss = - torch.log(depth_diff + 1e-12) depth_loss = depth_weight * depth_loss depth_loss.backward(retain_graph=True) emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() emptiness_loss = emptiness_weight * emptiness_loss if emptiness_step * n_steps <= i: emptiness_loss *= emptiness_multiplier emptiness_loss.backward() opt.step() metric.put_scalars(**tsr_stats(y)) if every(pbar, percent=1): with torch.no_grad(): if isinstance(model, StableDiffusion): y = model.decode(y) vis_routine(metric, y, depth) # if every(pbar, step=2500): # metric.put_artifact( # "ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn) # ) # with EventStorage("test"): # evaluate(model, vox, poser) metric.step() pbar.update() pbar.set_description(p) hbeat.beat() metric.put_artifact( "ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn) ) with EventStorage("test"): evaluate(model, vox, poser) metric.step() hbeat.done() @torch.no_grad() def evaluate(score_model, vox, poser): H, W = poser.H, poser.W vox.eval() K, poses = poser.sample_test(100) fuse = EarlyLoopBreak(5) metric = get_event_storage() hbeat = get_heartbeat() aabb = vox.aabb.T.cpu().numpy() vox = vox.to(device_glb) num_imgs = len(poses) for i in (pbar := tqdm(range(num_imgs))): if fuse.on_break(): break pose = poses[i] y, depth = render_one_view(vox, aabb, H, W, K, pose) if isinstance(score_model, StableDiffusion): y = score_model.decode(y) vis_routine(metric, y, depth) metric.step() hbeat.beat() metric.flush_history() metric.put_artifact( "view_seq", ".mp4", lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1]) ) metric.step() def render_one_view(vox, aabb, H, W, K, pose, return_w=False): N = H * W ro, rd = rays_from_img(H, W, K, pose) ro, rd, t_min, t_max = scene_box_filter(ro, rd, aabb) assert len(ro) == N, "for now all pixels must be in" ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max) rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max) rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W) depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W) if return_w: return rgbs, depth, weights else: return rgbs, depth def scene_box_filter(ro, rd, aabb): _, t_min, t_max = ray_box_intersect(ro, rd, aabb) # do not render what's behind the ray origin t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0) return ro, rd, t_min, t_max def vis_routine(metric, y, depth): pane = nerf_vis(y, depth, final_H=256) im = torch_samps_to_imgs(y)[0] depth = depth.cpu().numpy() metric.put_artifact("view", ".png", lambda fn: imwrite(fn, pane)) metric.put_artifact("img", ".png", lambda fn: imwrite(fn, im)) metric.put_artifact("depth", ".npy", lambda fn: np.save(fn, depth)) def evaluate_ckpt(): cfg = optional_load_config(fname="full_config.yml") assert len(cfg) > 0, "can't find cfg file" mod = SJC(**cfg) family = cfg.pop("family") model: ScoreAdapter = getattr(mod, family).make() vox = mod.vox.make() poser = mod.pose.make() pbar = tqdm(range(1)) with EventStorage(), HeartBeat(pbar): ckpt_fname = latest_ckpt() state = torch.load(ckpt_fname, map_location="cpu") vox.load_state_dict(state) vox.to(device_glb) with EventStorage("test"): evaluate(model, vox, poser) def latest_ckpt(): ts, ys = read_stats("./", "ckpt") assert len(ys) > 0 return ys[-1] if __name__ == "__main__": seed_everything(0) dispatch(SJC) # evaluate_ckpt()