sjc / app.py
amankishore's picture
Added subpixel rendering!
c255c40
raw
history blame contribute delete
No virus
9.14 kB
import numpy as np
import time
from pathlib import Path
import torch
import imageio
from my.utils import tqdm
from my.utils.seed import seed_everything
from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
from run_sjc import render_one_view, tsr_stats
from highres_final_vis import highres_render_one_view
import gradio as gr
import gc
import os
device_glb = torch.device("cuda")
def vis_routine(y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
return pane, im, depth
css = '''
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
#component-4, #component-3, #component-10{min-height: 0}
.duplicate-button img{margin: 0}
'''
with gr.Blocks(css=css) as demo:
# title
gr.Markdown('# [Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation')
gr.HTML(f'''
<div class="gr-prose" style="max-width: 80%">
<h2>Attention - This Space takes over 30min to run!</h2>
<p>If the Queue is too long you can run locally or duplicate the Space and run it on your own profile using a (paid) private T4 GPU for training. As each T4 costs US$0.60/h, it should cost < US$1 to train most models using default settings!&nbsp;&nbsp;<a style='display:inline-block' href='https://huggingface.co/spaces/MirageML/sjc?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>
</div>
''')
# inputs
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
button = gr.Button('Generate')
# outputs
image = gr.Image(label="image", visible=True)
# depth = gr.Image(label="depth", visible=True)
video = gr.Video(label="video", visible=False)
logs = gr.Textbox(label="logging")
def submit(prompt, iters, seed):
start_t = time.time()
seed_everything(seed)
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
poser = pose.make()
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
model = sd_model.make()
vox = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0)
vox = vox.make()
lr = 0.05
n_steps = iters
emptiness_scale = 10
emptiness_weight = 10000
emptiness_step = 0.5
emptiness_multiplier = 20.0
depth_weight = 0
var_red = True
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar:
for i in range(n_steps):
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
# metric.put_scalars()
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)),
}
# TODO: Output pane, img and depth to Gradio
pbar.update()
pbar.set_description(p)
# TODO: Save Checkpoint
with torch.no_grad():
n_frames=200
factor=4
ckpt = vox.state_dict()
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(n_frames)
del n_frames
poses = poses[60:] # skip the full overhead view; not interesting
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
all_images = []
for i in (pbar := tqdm(range(num_imgs))):
pose = poses[i]
y, depth = highres_render_one_view(vox, aabb, H, W, K, pose, f=factor)
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# Save img to output
all_images.append(img)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: str(tsr_stats(y)),
}
output_video = "/tmp/tmp.mp4"
imageio.mimwrite(output_video, all_images, quality=8, fps=10)
end_t = time.time()
yield {
image: gr.update(value=img, visible=False),
video: gr.update(value=output_video, visible=True),
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
}
button.click(
submit,
[prompt, iters, seed],
[image, video, logs]
)
# concurrency_count: only allow ONE running progress, else GPU will OOM.
demo.queue(concurrency_count=1)
demo.launch()