|
import os |
|
import cv2 |
|
import time |
|
import json |
|
import torch |
|
import mcubes |
|
import trimesh |
|
import datetime |
|
import argparse |
|
import subprocess |
|
import numpy as np |
|
import gradio as gr |
|
from tqdm import tqdm |
|
import imageio.v2 as imageio |
|
import pytorch_lightning as pl |
|
from omegaconf import OmegaConf |
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.models.diffusion.plms import PLMSSampler |
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler |
|
|
|
from utility.initialize import instantiate_from_config, get_obj_from_str |
|
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes |
|
from utility.triplane_renderer.renderer import get_rays, to8b |
|
from safetensors.torch import load_file |
|
from huggingface_hub import hf_hub_download |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
def add_text(rgb, caption): |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
|
gap = 10 |
|
org = (gap, gap) |
|
|
|
fontScale = 0.3 |
|
|
|
color = (255, 0, 0) |
|
|
|
thickness = 1 |
|
break_caption = [] |
|
for i in range(len(caption) // 30 + 1): |
|
break_caption_i = caption[i*30:(i+1)*30] |
|
break_caption.append(break_caption_i) |
|
for i, bci in enumerate(break_caption): |
|
cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) |
|
return rgb |
|
|
|
config = "configs/default.yaml" |
|
|
|
ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors") |
|
configs = OmegaConf.load(config) |
|
os.makedirs("tmp", exist_ok=True) |
|
|
|
if ckpt.endswith(".ckpt"): |
|
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params) |
|
elif ckpt.endswith(".safetensors"): |
|
model = get_obj_from_str(configs.model["target"])(**configs.model.params) |
|
model_ckpt = load_file(ckpt) |
|
model.load_state_dict(model_ckpt) |
|
else: |
|
raise NotImplementedError |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model = model.to(device) |
|
sampler = DDIMSampler(model) |
|
|
|
img_size = configs.model.params.unet_config.params.image_size |
|
channels = configs.model.params.unet_config.params.in_channels |
|
shape = [channels, img_size, img_size * 3] |
|
|
|
pose_folder = 'assets/sample_data/pose' |
|
poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)]) |
|
batch_rays_list = [] |
|
H = 128 |
|
ratio = 512 // H |
|
for p in poses_fname: |
|
c2w = np.loadtxt(p).reshape(4, 4) |
|
c2w[:3, 3] *= 2.2 |
|
c2w = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 0, -1, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1] |
|
]) @ c2w |
|
|
|
k = np.array([ |
|
[560 / ratio, 0, H * 0.5], |
|
[0, 560 / ratio, H * 0.5], |
|
[0, 0, 1] |
|
]) |
|
|
|
rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4])) |
|
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1) |
|
coords = torch.reshape(coords, [-1,2]).long() |
|
rays_o = rays_o[coords[:, 0], coords[:, 1]] |
|
rays_d = rays_d[coords[:, 0], coords[:, 1]] |
|
batch_rays = torch.stack([rays_o, rays_d], 0) |
|
batch_rays_list.append(batch_rays) |
|
batch_rays_list = torch.stack(batch_rays_list, 0) |
|
|
|
def marching_cube(b, text, global_info): |
|
|
|
res = 128 |
|
assert 'decode_res' in global_info |
|
decode_res = global_info['decode_res'] |
|
c_list = torch.linspace(-1.2, 1.2, steps=res) |
|
grid_x, grid_y, grid_z = torch.meshgrid( |
|
c_list, c_list, c_list, indexing='ij' |
|
) |
|
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) |
|
plane_axes = generate_planes() |
|
feats = sample_from_planes( |
|
plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 |
|
) |
|
fake_dirs = torch.zeros_like(coords) |
|
fake_dirs[..., 0] = 1 |
|
out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs) |
|
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() |
|
del out |
|
|
|
|
|
vertices, triangles = mcubes.marching_cubes(u, 10) |
|
min_bound = np.array([-1.2, -1.2, -1.2]) |
|
max_bound = np.array([1.2, 1.2, 1.2]) |
|
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] |
|
pt_vertices = torch.from_numpy(vertices).to(device) |
|
|
|
|
|
res_triplane = 256 |
|
render_kwargs = { |
|
'depth_resolution': 128, |
|
'disparity_space_sampling': False, |
|
'box_warp': 2.4, |
|
'depth_resolution_importance': 128, |
|
'clamp_mode': 'softplus', |
|
'white_back': True, |
|
'det': True |
|
} |
|
rays_o_list = [ |
|
np.array([0, 0, 2]), |
|
np.array([0, 0, -2]), |
|
np.array([0, 2, 0]), |
|
np.array([0, -2, 0]), |
|
np.array([2, 0, 0]), |
|
np.array([-2, 0, 0]), |
|
] |
|
rgb_final = None |
|
diff_final = None |
|
for rays_o in tqdm(rays_o_list): |
|
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) |
|
rays_d = pt_vertices.reshape(-1, 3) - rays_o |
|
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) |
|
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) |
|
|
|
render_out = model.first_stage_model.triplane_decoder( |
|
decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane), |
|
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, |
|
whole_img=False, tvloss=False |
|
) |
|
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() |
|
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() |
|
depth_diff = np.abs(dist - depth) |
|
|
|
if rgb_final is None: |
|
rgb_final = rgb.copy() |
|
diff_final = depth_diff.copy() |
|
|
|
else: |
|
ind = diff_final > depth_diff |
|
rgb_final[ind] = rgb[ind] |
|
diff_final[ind] = depth_diff[ind] |
|
|
|
|
|
rgb_final = np.stack([ |
|
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] |
|
], -1) |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) |
|
path = os.path.join('tmp', f"{text.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.ply") |
|
trimesh.exchange.export.export_mesh(mesh, path, file_type='ply') |
|
|
|
del vertices, triangles, rgb_final |
|
torch.cuda.empty_cache() |
|
|
|
return path |
|
|
|
def infer(prompt, samples, steps, scale, seed, global_info): |
|
prompt = prompt.replace('/', '') |
|
pl.seed_everything(seed) |
|
batch_size = samples |
|
with torch.no_grad(): |
|
noise = None |
|
c = model.get_learned_conditioning([prompt]) |
|
unconditional_c = torch.zeros_like(c) |
|
sample, _ = sampler.sample( |
|
S=steps, |
|
batch_size=batch_size, |
|
shape=shape, |
|
verbose=False, |
|
x_T = noise, |
|
conditioning = c.repeat(batch_size, 1, 1), |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1) |
|
) |
|
decode_res = model.decode_first_stage(sample) |
|
|
|
big_video_list = [] |
|
|
|
global_info['decode_res'] = decode_res |
|
|
|
for b in range(batch_size): |
|
def render_img(v): |
|
rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder( |
|
decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device), |
|
) |
|
rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] |
|
rgb_sample = np.stack( |
|
[rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1 |
|
) |
|
rgb_sample = add_text(rgb_sample, str(b)) |
|
return rgb_sample |
|
|
|
view_num = len(batch_rays_list) |
|
video_list = [] |
|
for v in tqdm(range(view_num//8*3, view_num//8*5, 2)): |
|
rgb_sample = render_img(v) |
|
video_list.append(rgb_sample) |
|
big_video_list.append(video_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(4 - batch_size): |
|
big_video_list.append( |
|
[np.zeros_like(f) + 255 for f in big_video_list[0]] |
|
) |
|
cat_video_list = [ |
|
np.concatenate([ |
|
np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1), |
|
np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1), |
|
], 0) \ |
|
for i in range(len(big_video_list[0])) |
|
] |
|
|
|
path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" |
|
imageio.mimwrite(path, np.stack(cat_video_list, 0)) |
|
|
|
return global_info, path |
|
|
|
def infer_stage2(prompt, selection, seed, global_info): |
|
prompt = prompt.replace('/', '') |
|
mesh_path = marching_cube(int(selection), prompt, global_info) |
|
mesh_name = mesh_path.split('/')[-1][:-4] |
|
|
|
if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y" |
|
print(if2_cmd) |
|
|
|
subprocess.Popen(if2_cmd, shell=True).wait() |
|
torch.cuda.empty_cache() |
|
|
|
video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" |
|
render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256" |
|
print(render_cmd) |
|
|
|
subprocess.Popen(render_cmd, shell=True).wait() |
|
torch.cuda.empty_cache() |
|
|
|
return video_path, os.path.join('tmp', mesh_name + '_if2.glb') |
|
|
|
block = gr.Blocks() |
|
|
|
with block: |
|
global_info = gr.State(dict()) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
text = gr.Textbox( |
|
label = "Enter your prompt", |
|
max_lines = 1, |
|
placeholder = "Enter your prompt", |
|
container = False, |
|
) |
|
btn = gr.Button("Generate 3D") |
|
gallery = gr.Video(height=512) |
|
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") |
|
with gr.Row(elem_id="advanced-options"): |
|
samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1) |
|
steps = gr.Slider(label="Steps", minimum=1, maximum=500, value=50, step=1) |
|
scale = gr.Slider( |
|
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 |
|
) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True, |
|
) |
|
gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed, global_info], outputs=[global_info, gallery]) |
|
advanced_button.click( |
|
None, |
|
[], |
|
text, |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
dropdown = gr.Dropdown( |
|
['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0' |
|
) |
|
btn_stage2 = gr.Button("Start Refinement") |
|
gallery = gr.Video(height=512) |
|
download = gr.File(label="Download Mesh", file_count="single", height=100) |
|
gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info], outputs=[gallery, download]) |
|
|
|
block.launch(share=True) |
|
|