import numpy as np import time from pathlib import Path import torch import imageio from my.utils import tqdm from my.utils.seed import seed_everything from run_img_sampling import SD, StableDiffusion from misc import torch_samps_to_imgs from pose import PoseConfig from run_nerf import VoxConfig from voxnerf.utils import every from voxnerf.vis import stitch_vis, bad_vis as nerf_vis from run_sjc import render_one_view, tsr_stats import gradio as gr import gc import os device_glb = torch.device("cuda") os.system("git clone --depth 1 https://github.com/CompVis/taming-transformers.git && pip install -e taming-transformers") def vis_routine(y, depth): pane = nerf_vis(y, depth, final_H=256) im = torch_samps_to_imgs(y)[0] depth = depth.cpu().numpy() return pane, im, depth with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo: # title gr.Markdown('[Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation') # inputs prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger") iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100) seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True) button = gr.Button('Generate') # outputs image = gr.Image(label="image", visible=True) # depth = gr.Image(label="depth", visible=True) video = gr.Video(label="video", visible=False) logs = gr.Textbox(label="logging") def submit(prompt, iters, seed): start_t = time.time() seed_everything(seed) # cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True} pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) poser = pose.make() sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast') model = sd_model.make() vox = VoxConfig( model_type="V_SD", grid_size=100, density_shift=-1.0, c=4, blend_bg_texture=True, bg_texture_hw=4, bbox_len=1.0) vox = vox.make() lr = 0.05 n_steps = iters emptiness_scale = 10 emptiness_weight = 10000 emptiness_step = 0.5 emptiness_multiplier = 20.0 depth_weight = 0 var_red = True assert model.samps_centered() _, target_H, target_W = model.data_shape() bs = 1 aabb = vox.aabb.T.cpu().numpy() vox = vox.to(device_glb) opt = torch.optim.Adamax(vox.opt_params(), lr=lr) H, W = poser.H, poser.W Ks, poses, prompt_prefixes = poser.sample_train(n_steps) ts = model.us[30:-10] same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1) with tqdm(total=n_steps) as pbar: for i in range(n_steps): p = f"{prompt_prefixes[i]} {model.prompt}" score_conds = model.prompts_emb([p]) y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True) if isinstance(model, StableDiffusion): pass else: y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear') opt.zero_grad() with torch.no_grad(): chosen_σs = np.random.choice(ts, bs, replace=False) chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) # chosen_σs = us[i] noise = torch.randn(bs, *y.shape[1:], device=model.device) zs = y + chosen_σs * noise Ds = model.denoise(zs, chosen_σs, **score_conds) if var_red: grad = (Ds - y) / chosen_σs else: grad = (Ds - zs) / chosen_σs grad = grad.mean(0, keepdim=True) y.backward(-grad, retain_graph=True) if depth_weight > 0: center_depth = depth[7:-7, 7:-7] border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) center_depth_mean = center_depth.mean() depth_diff = center_depth_mean - border_depth_mean depth_loss = - torch.log(depth_diff + 1e-12) depth_loss = depth_weight * depth_loss depth_loss.backward(retain_graph=True) emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() emptiness_loss = emptiness_weight * emptiness_loss if emptiness_step * n_steps <= i: emptiness_loss *= emptiness_multiplier emptiness_loss.backward() opt.step() # metric.put_scalars() if every(pbar, percent=1): with torch.no_grad(): if isinstance(model, StableDiffusion): y = model.decode(y) pane, img, depth = vis_routine(y, depth) yield { image: gr.update(value=img, visible=True), video: gr.update(visible=False), logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)), } # TODO: Output pane, img and depth to Gradio pbar.update() pbar.set_description(p) # TODO: Save Checkpoint with torch.no_grad(): ckpt = vox.state_dict() H, W = poser.H, poser.W vox.eval() K, poses = poser.sample_test(100) aabb = vox.aabb.T.cpu().numpy() vox = vox.to(device_glb) num_imgs = len(poses) all_images = [] for i in (pbar := tqdm(range(num_imgs))): pose = poses[i] y, depth = render_one_view(vox, aabb, H, W, K, pose) if isinstance(model, StableDiffusion): y = model.decode(y) pane, img, depth = vis_routine(y, depth) # Save img to output all_images.append(img) yield { image: gr.update(value=img, visible=True), video: gr.update(visible=False), logs: str(tsr_stats(y)), } output_video = "/tmp/tmp.mp4" imageio.mimwrite(output_video, all_images, quality=8, fps=10) end_t = time.time() yield { image: gr.update(value=img, visible=False), video: gr.update(value=output_video, visible=True), logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!", } button.click( submit, [prompt, iters, seed], [image, video, logs] ) # concurrency_count: only allow ONE running progress, else GPU will OOM. demo.queue(concurrency_count=1) demo.launch()