sjc / run_sjc.py
amankishore's picture
Updated app.py
7a11626
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from imageio import imwrite
from pydantic import validator
from my.utils import (
tqdm, EventStorage, HeartBeat, EarlyLoopBreak,
get_event_storage, get_heartbeat, read_stats
)
from my.config import BaseConf, dispatch, optional_load_config
from my.utils.seed import seed_everything
from adapt import ScoreAdapter, karras_t_schedule
from run_img_sampling import GDDPM, SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.render import (
as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle
)
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
device_glb = torch.device("cuda")
def tsr_stats(tsr):
return {
"mean": tsr.mean().item(),
"std": tsr.std().item(),
"max": tsr.max().item(),
}
class SJC(BaseConf):
family: str = "sd"
gddpm: GDDPM = GDDPM()
sd: SD = SD(
variant="v1",
prompt="A high quality photo of a delicious burger",
scale=100.0
)
lr: float = 0.05
n_steps: int = 10000
vox: VoxConfig = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=3,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0
)
pose: PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
emptiness_scale: int = 10
emptiness_weight: int = 1e4
emptiness_step: float = 0.5
emptiness_multiplier: float = 20.0
depth_weight: int = 0
var_red: bool = True
@validator("vox")
def check_vox(cls, vox_cfg, values):
family = values['family']
if family == "sd":
vox_cfg.c = 4
return vox_cfg
def run(self):
cfgs = self.dict()
family = cfgs.pop("family")
model = getattr(self, family).make()
cfgs.pop("vox")
vox = self.vox.make()
cfgs.pop("pose")
poser = self.pose.make()
sjc_3d(**cfgs, poser=poser, model=model, vox=vox)
def sjc_3d(
poser, vox, model: ScoreAdapter,
lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier,
depth_weight, var_red, **kwargs
):
del kwargs
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
fuse = EarlyLoopBreak(5)
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar, \
HeartBeat(pbar) as hbeat, \
EventStorage() as metric:
for i in range(n_steps):
if fuse.on_break():
break
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
metric.put_scalars(**tsr_stats(y))
if every(pbar, percent=1):
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
vis_routine(metric, y, depth)
# if every(pbar, step=2500):
# metric.put_artifact(
# "ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn)
# )
# with EventStorage("test"):
# evaluate(model, vox, poser)
metric.step()
pbar.update()
pbar.set_description(p)
hbeat.beat()
metric.put_artifact(
"ckpt", ".pt", lambda fn: torch.save(vox.state_dict(), fn)
)
with EventStorage("test"):
evaluate(model, vox, poser)
metric.step()
hbeat.done()
@torch.no_grad()
def evaluate(score_model, vox, poser):
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(100)
fuse = EarlyLoopBreak(5)
metric = get_event_storage()
hbeat = get_heartbeat()
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
for i in (pbar := tqdm(range(num_imgs))):
if fuse.on_break():
break
pose = poses[i]
y, depth = render_one_view(vox, aabb, H, W, K, pose)
if isinstance(score_model, StableDiffusion):
y = score_model.decode(y)
vis_routine(metric, y, depth)
metric.step()
hbeat.beat()
metric.flush_history()
metric.put_artifact(
"view_seq", ".mp4",
lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1])
)
metric.step()
def render_one_view(vox, aabb, H, W, K, pose, return_w=False):
N = H * W
ro, rd = rays_from_img(H, W, K, pose)
ro, rd, t_min, t_max = scene_box_filter(ro, rd, aabb)
assert len(ro) == N, "for now all pixels must be in"
ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max)
rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max)
rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W)
depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W)
if return_w:
return rgbs, depth, weights
else:
return rgbs, depth
def scene_box_filter(ro, rd, aabb):
_, t_min, t_max = ray_box_intersect(ro, rd, aabb)
# do not render what's behind the ray origin
t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0)
return ro, rd, t_min, t_max
def vis_routine(metric, y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
metric.put_artifact("view", ".png", lambda fn: imwrite(fn, pane))
metric.put_artifact("img", ".png", lambda fn: imwrite(fn, im))
metric.put_artifact("depth", ".npy", lambda fn: np.save(fn, depth))
def evaluate_ckpt():
cfg = optional_load_config(fname="full_config.yml")
assert len(cfg) > 0, "can't find cfg file"
mod = SJC(**cfg)
family = cfg.pop("family")
model: ScoreAdapter = getattr(mod, family).make()
vox = mod.vox.make()
poser = mod.pose.make()
pbar = tqdm(range(1))
with EventStorage(), HeartBeat(pbar):
ckpt_fname = latest_ckpt()
state = torch.load(ckpt_fname, map_location="cpu")
vox.load_state_dict(state)
vox.to(device_glb)
with EventStorage("test"):
evaluate(model, vox, poser)
def latest_ckpt():
ts, ys = read_stats("./", "ckpt")
assert len(ys) > 0
return ys[-1]
if __name__ == "__main__":
seed_everything(0)
dispatch(SJC)
# evaluate_ckpt()