|
import os |
|
import cv2 |
|
import json |
|
import torch |
|
import mcubes |
|
import trimesh |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
import imageio.v2 as imageio |
|
import pytorch_lightning as pl |
|
from omegaconf import OmegaConf |
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.models.diffusion.plms import PLMSSampler |
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler |
|
|
|
from utility.initialize import instantiate_from_config, get_obj_from_str |
|
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes |
|
from utility.triplane_renderer.renderer import get_rays, to8b |
|
from safetensors.torch import load_file |
|
from huggingface_hub import hf_hub_download |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
def add_text(rgb, caption): |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
|
gap = 30 |
|
org = (gap, gap) |
|
|
|
fontScale = 0.6 |
|
|
|
color = (255, 0, 0) |
|
|
|
thickness = 1 |
|
break_caption = [] |
|
for i in range(len(caption) // 30 + 1): |
|
break_caption_i = caption[i*30:(i+1)*30] |
|
break_caption.append(break_caption_i) |
|
for i, bci in enumerate(break_caption): |
|
cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) |
|
return rgb |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default='configs/default.yaml') |
|
parser.add_argument("--ckpt", type=str, default=None) |
|
parser.add_argument("--test_folder", type=str, default="stage1") |
|
parser.add_argument("--seed", type=int, default=None) |
|
parser.add_argument("--sampler", type=str, default="ddpm") |
|
parser.add_argument("--samples", type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--steps", type=int, default=1000) |
|
parser.add_argument("--text", nargs='+', default='a robot') |
|
parser.add_argument("--text_file", type=str, default=None) |
|
parser.add_argument("--no_video", action='store_true', default=False) |
|
parser.add_argument("--render_res", type=int, default=128) |
|
parser.add_argument("--no_mcubes", action='store_true', default=False) |
|
parser.add_argument("--mcubes_res", type=int, default=128) |
|
parser.add_argument("--cfg_scale", type=float, default=1) |
|
args = parser.parse_args() |
|
|
|
if args.text is not None: |
|
text = [' '.join(args.text),] |
|
elif args.text_file is not None: |
|
if args.text_file.endswith('.json'): |
|
with open(args.text_file, 'r') as f: |
|
json_file = json.load(f) |
|
text = json_file |
|
text = [l.strip('.') for l in text] |
|
else: |
|
with open(args.text_file, 'r') as f: |
|
text = f.readlines() |
|
text = [l.strip() for l in text] |
|
else: |
|
raise NotImplementedError |
|
|
|
print(text) |
|
|
|
configs = OmegaConf.load(args.config) |
|
if args.seed is not None: |
|
pl.seed_everything(args.seed) |
|
|
|
log_dir = os.path.join('results', args.config.split('/')[-1].split('.')[0], args.test_folder) |
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
if args.ckpt == None: |
|
ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors") |
|
else: |
|
ckpt = args.ckpt |
|
|
|
if ckpt.endswith(".ckpt"): |
|
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params) |
|
elif ckpt.endswith(".safetensors"): |
|
model = get_obj_from_str(configs.model["target"])(**configs.model.params) |
|
model_ckpt = load_file(ckpt) |
|
model.load_state_dict(model_ckpt) |
|
else: |
|
raise NotImplementedError |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model = model.to(device) |
|
|
|
class DummySampler: |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def sample(self, S, batch_size, shape, verbose, conditioning=None, *args, **kwargs): |
|
return self.model.sample( |
|
conditioning, batch_size, shape=[batch_size, ] + shape, *args, **kwargs |
|
), None |
|
|
|
if args.sampler == 'dpm': |
|
raise NotImplementedError |
|
|
|
elif args.sampler == 'plms': |
|
raise NotImplementedError |
|
|
|
elif args.sampler == 'ddim': |
|
sampler = DDIMSampler(model) |
|
elif args.sampler == 'ddpm': |
|
sampler = DummySampler(model) |
|
else: |
|
raise NotImplementedError |
|
|
|
img_size = configs.model.params.unet_config.params.image_size |
|
channels = configs.model.params.unet_config.params.in_channels |
|
shape = [channels, img_size, img_size * 3] |
|
plane_axes = generate_planes() |
|
|
|
pose_folder = 'assets/sample_data/pose' |
|
poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)]) |
|
batch_rays_list = [] |
|
H = args.render_res |
|
ratio = 512 // H |
|
for p in poses_fname: |
|
c2w = np.loadtxt(p).reshape(4, 4) |
|
c2w[:3, 3] *= 2.2 |
|
c2w = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 0, -1, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1] |
|
]) @ c2w |
|
|
|
k = np.array([ |
|
[560 / ratio, 0, H * 0.5], |
|
[0, 560 / ratio, H * 0.5], |
|
[0, 0, 1] |
|
]) |
|
|
|
rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4])) |
|
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1) |
|
coords = torch.reshape(coords, [-1,2]).long() |
|
rays_o = rays_o[coords[:, 0], coords[:, 1]] |
|
rays_d = rays_d[coords[:, 0], coords[:, 1]] |
|
batch_rays = torch.stack([rays_o, rays_d], 0) |
|
batch_rays_list.append(batch_rays) |
|
batch_rays_list = torch.stack(batch_rays_list, 0) |
|
|
|
for text_idx, text_i in enumerate(text): |
|
text_connect = '_'.join(text_i.split(' ')) |
|
for s in range(args.samples): |
|
batch_size = args.batch_size |
|
with torch.no_grad(): |
|
|
|
noise = None |
|
c = model.get_learned_conditioning([text_i]) |
|
unconditional_c = torch.zeros_like(c) |
|
if args.cfg_scale != 1: |
|
assert args.sampler == 'ddim' |
|
sample, _ = sampler.sample( |
|
S=args.steps, |
|
batch_size=batch_size, |
|
shape=shape, |
|
verbose=False, |
|
x_T = noise, |
|
conditioning = c.repeat(batch_size, 1, 1), |
|
unconditional_guidance_scale=args.cfg_scale, |
|
unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1) |
|
) |
|
else: |
|
sample, _ = sampler.sample( |
|
S=args.steps, |
|
batch_size=batch_size, |
|
shape=shape, |
|
verbose=False, |
|
x_T = noise, |
|
conditioning = c.repeat(batch_size, 1, 1), |
|
) |
|
decode_res = model.decode_first_stage(sample) |
|
|
|
for b in range(batch_size): |
|
def render_img(v): |
|
rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder( |
|
decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device), |
|
) |
|
rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] |
|
rgb_sample = np.stack( |
|
[rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1 |
|
) |
|
|
|
return rgb_sample |
|
|
|
if not args.no_mcubes: |
|
|
|
res = args.mcubes_res |
|
c_list = torch.linspace(-1.2, 1.2, steps=res) |
|
grid_x, grid_y, grid_z = torch.meshgrid( |
|
c_list, c_list, c_list, indexing='ij' |
|
) |
|
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) |
|
plane_axes = generate_planes() |
|
feats = sample_from_planes( |
|
plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 |
|
) |
|
fake_dirs = torch.zeros_like(coords) |
|
fake_dirs[..., 0] = 1 |
|
out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs) |
|
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() |
|
del out |
|
|
|
|
|
vertices, triangles = mcubes.marching_cubes(u, 10) |
|
min_bound = np.array([-1.2, -1.2, -1.2]) |
|
max_bound = np.array([1.2, 1.2, 1.2]) |
|
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] |
|
pt_vertices = torch.from_numpy(vertices).to(device) |
|
|
|
|
|
res_triplane = 256 |
|
render_kwargs = { |
|
'depth_resolution': 128, |
|
'disparity_space_sampling': False, |
|
'box_warp': 2.4, |
|
'depth_resolution_importance': 128, |
|
'clamp_mode': 'softplus', |
|
'white_back': True, |
|
'det': True |
|
} |
|
rays_o_list = [ |
|
np.array([0, 0, 2]), |
|
np.array([0, 0, -2]), |
|
np.array([0, 2, 0]), |
|
np.array([0, -2, 0]), |
|
np.array([2, 0, 0]), |
|
np.array([-2, 0, 0]), |
|
] |
|
rgb_final = None |
|
diff_final = None |
|
for rays_o in tqdm(rays_o_list): |
|
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) |
|
rays_d = pt_vertices.reshape(-1, 3) - rays_o |
|
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) |
|
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) |
|
|
|
render_out = model.first_stage_model.triplane_decoder( |
|
decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane), |
|
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, |
|
whole_img=False, tvloss=False |
|
) |
|
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() |
|
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() |
|
depth_diff = np.abs(dist - depth) |
|
|
|
if rgb_final is None: |
|
rgb_final = rgb.copy() |
|
diff_final = depth_diff.copy() |
|
|
|
else: |
|
ind = diff_final > depth_diff |
|
rgb_final[ind] = rgb[ind] |
|
diff_final[ind] = depth_diff[ind] |
|
|
|
|
|
|
|
rgb_final = np.stack([ |
|
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] |
|
], -1) |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) |
|
trimesh.exchange.export.export_mesh(mesh, os.path.join(log_dir, f"{text_connect}_{s}_{b}.ply"), file_type='ply') |
|
|
|
if not args.no_video: |
|
view_num = len(batch_rays_list) |
|
video_list = [] |
|
for v in tqdm(range(view_num//4, view_num//4 * 3, 2)): |
|
rgb_sample = render_img(v) |
|
video_list.append(rgb_sample) |
|
imageio.mimwrite(os.path.join(log_dir, "{}_{}_{}.mp4".format(text_connect, s, b)), np.stack(video_list, 0)) |
|
else: |
|
rgb_sample = render_img(104) |
|
imageio.imwrite(os.path.join(log_dir, "{}_{}_{}.jpg".format(text_connect, s, b)), rgb_sample) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|