3DTopia-XL / inference.py
FrozenBurning
update inference
670f57e
raw
history blame
16.2 kB
import os
import sys
import io
import torch
import numpy as np
from omegaconf import OmegaConf
import PIL.Image
from PIL import Image
import rembg
from dva.ray_marcher import RayMarcher
from dva.io import load_from_config
from dva.utils import to_device
from dva.visualize import visualize_primvolume, visualize_video_primvolume
from models.diffusion import create_diffusion
import logging
from tqdm import tqdm
import mcubes
import xatlas
import nvdiffrast.torch as dr
import cv2
from scipy.ndimage import binary_dilation, binary_erosion
from sklearn.neighbors import NearestNeighbors
from utils.meshutils import clean_mesh, decimate_mesh
from utils.mesh import Mesh
from utils.uv_unwrap import box_projection_uv_unwrap, compute_vertex_normal
logger = logging.getLogger("inference.py")
glctx = dr.RasterizeCudaContext()
def remove_background(image: PIL.Image.Image,
rembg_session = None,
force: bool = False,
**rembg_kwargs,
) -> PIL.Image.Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def resize_foreground(
image: PIL.Image.Image,
ratio: float,
) -> PIL.Image.Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = PIL.Image.fromarray(new_image)
return new_image
def extract_texmesh(args, model, output_path, device):
# Prepare directory
ins_dir = output_path
# Noise Filter
raw_srt_param = model.srt_param.clone()
raw_feat_param = model.feat_param.clone()
prim_position = raw_srt_param[:, 1:4]
prim_scale = raw_srt_param[:, 0:1]
dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
dist += torch.eye(prim_position.shape[0]).to(raw_srt_param)
min_dist, min_indices = dist.min(1)
dst_prim_scale = prim_scale[min_indices, :]
min_scale_converage = prim_scale * 1. + dst_prim_scale * 1.
prim_mask = min_dist < min_scale_converage[:, 0]
filtered_srt_param = raw_srt_param[prim_mask, :]
filtered_feat_param = raw_feat_param[prim_mask, ...]
model.srt_param.data = filtered_srt_param
model.feat_param.data = filtered_feat_param
print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
# Get SDFs
with torch.no_grad():
xx = torch.linspace(-1, 1, args.mc_resolution, device=device)
pts = torch.stack(torch.meshgrid(xx, xx, xx, indexing='ij'), dim=-1).reshape(-1,3)
chunks = torch.split(pts, args.batch_size)
dists = []
for chunk_pts in tqdm(chunks):
preds = model(chunk_pts)
dists.append(preds['sdf'].detach())
dists = torch.cat(dists, dim=0)
grid = dists.reshape(args.mc_resolution, args.mc_resolution, args.mc_resolution)
# Meshify
vertices, triangles = mcubes.marching_cubes(grid.cpu().numpy(), 0.0)
# Resize + recenter
b_min_np = np.array([-1., -1., -1.])
b_max_np = np.array([ 1., 1., 1.])
vertices = vertices / (args.mc_resolution - 1.0) * (b_max_np - b_min_np) + b_min_np
vertices, triangles = clean_mesh(vertices, triangles, min_f=8, min_d=5, repair=True, remesh=False)
if args.decimate > 0 and triangles.shape[0] > args.decimate:
vertices, triangles = decimate_mesh(vertices, triangles, args.decimate, remesh=args.remesh)
h0 = 1024
w0 = 1024
ssaa = 1
fp16 = True
v_np = vertices.astype(np.float32)
f_np = triangles.astype(np.int64)
v = torch.from_numpy(vertices).float().contiguous().to(device)
f = torch.from_numpy(triangles.astype(np.int64)).to(torch.int64).contiguous().to(device)
if args.fast_unwrap:
print(f'[INFO] running box-based fast unwrapping to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
v_normal = compute_vertex_normal(v, f)
uv, indices = box_projection_uv_unwrap(v, v_normal, f, 0.02)
indv_v = v[f].reshape(-1, 3)
indv_faces = torch.arange(indv_v.shape[0], device=device, dtype=f.dtype).reshape(-1, 3)
uv_flat = uv[indices].reshape((-1, 2))
v = indv_v.contiguous()
f = indv_faces.contiguous()
ft_np = f.cpu().numpy()
vt_np = uv_flat.cpu().numpy()
else:
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
# unwrap uv in contracted space
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
chart_options = xatlas.ChartOptions()
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
pack_options = xatlas.PackOptions()
atlas.generate(chart_options=chart_options, pack_options=pack_options)
_, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device)
ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device)
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
if ssaa > 1:
h = int(h0 * ssaa)
w = int(w0 * ssaa)
else:
h, w = h0, w0
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f.int()) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f.int()) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
feats = torch.zeros(h * w, 6, device=device, dtype=torch.float32)
if mask.any():
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
all_feats = []
head = 0
chunk_size = args.batch_size
while head < xyzs.shape[0]:
tail = min(head + chunk_size, xyzs.shape[0])
with torch.cuda.amp.autocast(enabled=fp16):
preds = model(xyzs[head:tail])
# [R, G, B, NA, roughness, metallic]
all_feats.append(torch.concat([preds['tex'].float(), torch.zeros_like(preds['tex'])[..., 0:1].float(), preds['mat'].float()], dim=-1))
head += chunk_size
feats[mask] = torch.cat(all_feats, dim=0)
feats = feats.view(h, w, -1) # 6 channels
mask = mask.view(h, w)
# quantize [0.0, 1.0] to [0, 255]
feats = feats.cpu().numpy()
feats = (feats * 255)
### NN search as a queer antialiasing ...
mask = mask.cpu().numpy()
inpaint_region = binary_dilation(mask, iterations=32) # pad width
inpaint_region[mask] = 0
search_region = mask.copy()
not_search_region = binary_erosion(search_region, iterations=3)
search_region[not_search_region] = 0
search_coords = np.stack(np.nonzero(search_region), axis=-1)
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
_, indices = knn.kneighbors(inpaint_coords)
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
model.srt_param.data = raw_srt_param
model.feat_param.data = raw_feat_param
def main(config):
logging.basicConfig(level=logging.INFO)
ddim_steps = config.inference.ddim
if ddim_steps > 0:
use_ddim = True
else:
use_ddim = False
cfg_scale = config.inference.get("cfg", 0.0)
inference_dir = f"{config.output_dir}/inference_folder"
os.makedirs(inference_dir, exist_ok=True)
amp = False
precision = config.inference.get("precision", 'fp16')
if precision == 'tf32':
precision_dtype = torch.float32
elif precision == 'fp16':
amp = True
precision_dtype = torch.float16
else:
raise NotImplementedError("{} precision is not supported".format(precision))
device = torch.device(f"cuda:{0}")
seed = config.inference.seed
torch.manual_seed(seed)
torch.cuda.set_device(device)
model = load_from_config(config.model.generator)
vae = load_from_config(config.model.vae)
conditioner = load_from_config(config.model.conditioner)
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
vae.load_state_dict(vae_state_dict['model_state_dict'])
if config.checkpoint_path:
state_dict = torch.load(config.checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict['ema'])
vae = vae.to(device)
conditioner = conditioner.to(device)
model = model.to(device)
config.diffusion.pop("timestep_respacing")
if use_ddim:
respacing = "ddim{}".format(ddim_steps)
else:
respacing = ""
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) # default: 1000 steps, linear noise schedule
if use_ddim:
sample_fn = diffusion.ddim_sample_loop_progressive
else:
sample_fn = diffusion.p_sample_loop_progressive
if cfg_scale > 0:
fwd_fn = model.forward_with_cfg
else:
fwd_fn = model.forward
rm = RayMarcher(
config.image_height,
config.image_width,
**config.rm,
).to(device)
perchannel_norm = False
if "latent_mean" in config.model:
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
assert latent_mean.shape[-1] == config.model.generator.in_channels
perchannel_norm = True
model.eval()
examples_dir = config.inference.input_dir
img_list = os.listdir(examples_dir)
rembg_session = rembg.new_session()
logger.info(f"Starting Inference...")
for img_path in img_list:
full_img_path = os.path.join(examples_dir, img_path)
img_name = img_path[:-4]
current_output_dir = os.path.join(inference_dir, img_name)
os.makedirs(current_output_dir, exist_ok=True)
input_image = Image.open(full_img_path)
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
raw_image = np.array(input_image)
mask = (raw_image[..., -1][..., None] > 0) * 1
raw_image = raw_image[..., :3] * mask
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
with torch.no_grad():
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
batch = {}
inf_bs = 1
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
y = conditioner.encoder(input_cond)
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
if cfg_scale > 0:
model_kwargs['cfg_scale'] = cfg_scale
sampled_count = -1
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
):
sampled_count += 1
if not (sampled_count % 10 == 0 or sampled_count == diffusion.num_timesteps - 1):
continue
else:
recon_param = samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
if perchannel_norm:
recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
recon_srt_param = recon_param[:, :, 0:4]
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
recon_feat_param_list = []
# one-by-one to avoid oom
for inf_bidx in range(inf_bs):
if not perchannel_norm:
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
else:
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
recon_feat_param_list.append(decoded.detach())
recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
# invert normalization
if not perchannel_norm:
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
recon_feat_param[:, 0:1, ...] /= 5.
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
visualize_primvolume("{}/dstep{:04d}_recon.jpg".format(current_output_dir, sampled_count), batch, recon_param, rm, device)
visualize_video_primvolume(current_output_dir, batch, recon_param, 60, rm, device)
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(current_output_dir))
if config.inference.export_glb:
logger.info(f"Starting GLB Mesh Extraction...")
config.model.pop("vae")
config.model.pop("vae_checkpoint_path")
config.model.pop("conditioner")
config.model.pop("generator")
config.model.pop("latent_nf")
config.model.pop("latent_mean")
config.model.pop("latent_std")
model_primx = load_from_config(config.model)
for img_path in img_list:
img_name = img_path[:-4]
output_path = os.path.join(inference_dir, img_name)
denoise_param_path = os.path.join(inference_dir, img_name, 'denoised.pt')
ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
model_primx.load_state_dict(ckpt_weight)
model_primx.to(device)
model_primx.eval()
with torch.no_grad():
model_primx.srt_param[:, 1:4] *= 0.85
extract_texmesh(config.inference, model_primx, output_path, device)
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
# manually enable tf32 to get speedup on A100 GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# set config
config = OmegaConf.load(str(sys.argv[1]))
config_cli = OmegaConf.from_cli(args_list=sys.argv[2:])
if config_cli:
logger.info("overriding with following values from args:")
logger.info(OmegaConf.to_yaml(config_cli))
config = OmegaConf.merge(config, config_cli)
main(config)