sjc / app.py
amankishore's picture
Requirements
65ab1b5
raw history blame
No virus
7.72 kB
import numpy as np
import time
from pathlib import Path
import torch
import imageio
from my.utils import tqdm
from my.utils.seed import seed_everything
from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
from run_sjc import render_one_view, tsr_stats
import gradio as gr
import gc
import os
device_glb = torch.device("cuda")
os.system("git clone --depth 1 https://github.com/CompVis/taming-transformers.git && pip install -e taming-transformers")
def vis_routine(y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
return pane, im, depth
with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
# title
gr.Markdown('[Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation')
# inputs
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
button = gr.Button('Generate')
# outputs
image = gr.Image(label="image", visible=True)
# depth = gr.Image(label="depth", visible=True)
video = gr.Video(label="video", visible=False)
logs = gr.Textbox(label="logging")
def submit(prompt, iters, seed):
start_t = time.time()
seed_everything(seed)
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
poser = pose.make()
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
model = sd_model.make()
vox = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0)
vox = vox.make()
lr = 0.05
n_steps = iters
emptiness_scale = 10
emptiness_weight = 10000
emptiness_step = 0.5
emptiness_multiplier = 20.0
depth_weight = 0
var_red = True
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar:
for i in range(n_steps):
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
# metric.put_scalars()
if every(pbar, percent=1):
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)),
}
# TODO: Output pane, img and depth to Gradio
pbar.update()
pbar.set_description(p)
# TODO: Save Checkpoint
with torch.no_grad():
ckpt = vox.state_dict()
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(100)
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
all_images = []
for i in (pbar := tqdm(range(num_imgs))):
pose = poses[i]
y, depth = render_one_view(vox, aabb, H, W, K, pose)
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# Save img to output
all_images.append(img)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: str(tsr_stats(y)),
}
output_video = "/tmp/tmp.mp4"
imageio.mimwrite(output_video, all_images, quality=8, fps=10)
end_t = time.time()
yield {
image: gr.update(value=img, visible=False),
video: gr.update(value=output_video, visible=True),
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
}
button.click(
submit,
[prompt, iters, seed],
[image, video, logs]
)
# concurrency_count: only allow ONE running progress, else GPU will OOM.
demo.queue(concurrency_count=1)
demo.launch()