TRELLIS.2 / glb_export.py
choephix's picture
Add safe non-remesh fallback option to GLB export
44b4dd3
"""Shared GLB export logic used by both the Gradio app and FastAPI export worker.
This module owns the remesh=True / remesh=False branching and the
SAFE_NONREMESH_GLB_EXPORT env-flag behaviour so that the two entry-points
stay in lock-step.
"""
from __future__ import annotations
import os
from typing import Any, Dict
import cv2
import numpy as np
import torch
from PIL import Image
import o_voxel
# ---------------------------------------------------------------------------
# Env helpers
# ---------------------------------------------------------------------------
def _env_flag(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
SAFE_NONREMESH_GLB_EXPORT: bool = _env_flag("SAFE_NONREMESH_GLB_EXPORT", True)
# ---------------------------------------------------------------------------
# Logging helpers
# ---------------------------------------------------------------------------
def _cumesh_counts(mesh: Any) -> str:
num_vertices = getattr(mesh, "num_vertices", "?")
num_faces = getattr(mesh, "num_faces", "?")
return f"vertices={num_vertices}, faces={num_faces}"
def _log_cumesh_counts(label: str, mesh: Any) -> None:
print(f"{label}: {_cumesh_counts(mesh)}", flush=True)
# ---------------------------------------------------------------------------
# Safe non-remesh fallback (extracted verbatim from app.py)
# ---------------------------------------------------------------------------
def _to_glb_without_risky_nonremesh_cleanup(
*,
vertices: torch.Tensor,
faces: torch.Tensor,
attr_volume: torch.Tensor,
coords: torch.Tensor,
attr_layout: Dict[str, slice],
aabb: Any,
voxel_size: Any = None,
grid_size: Any = None,
decimation_target: int = 1000000,
texture_size: int = 2048,
mesh_cluster_threshold_cone_half_angle_rad=np.radians(90.0),
mesh_cluster_refine_iterations=0,
mesh_cluster_global_iterations=1,
mesh_cluster_smooth_strength=1,
verbose: bool = False,
use_tqdm: bool = False,
):
postprocess = o_voxel.postprocess
def _try_unify_face_orientations(current_mesh: Any) -> Any:
_log_cumesh_counts("Before face-orientation unification", current_mesh)
try:
current_mesh.unify_face_orientations()
_log_cumesh_counts("After face-orientation unification", current_mesh)
return current_mesh
except RuntimeError as error:
if "[CuMesh] CUDA error" not in str(error):
raise
print(
"Face-orientation unification failed in remesh=False fallback; "
f"retrying once from readback. error={error}",
flush=True,
)
try:
retry_vertices, retry_faces = current_mesh.read()
retry_mesh = postprocess.cumesh.CuMesh()
retry_mesh.init(retry_vertices, retry_faces)
retry_mesh.remove_duplicate_faces()
retry_mesh.remove_small_connected_components(1e-5)
_log_cumesh_counts("Before face-orientation retry", retry_mesh)
retry_mesh.unify_face_orientations()
_log_cumesh_counts("After face-orientation retry", retry_mesh)
return retry_mesh
except RuntimeError as retry_error:
if "[CuMesh] CUDA error" not in str(retry_error):
raise
print(
"Skipping face-orientation unification in remesh=False fallback after "
f"retry failure: {retry_error}",
flush=True,
)
return current_mesh
if isinstance(aabb, (list, tuple)):
aabb = np.array(aabb)
if isinstance(aabb, np.ndarray):
aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device)
assert isinstance(aabb, torch.Tensor)
assert aabb.dim() == 2 and aabb.size(0) == 2 and aabb.size(1) == 3
if voxel_size is not None:
if isinstance(voxel_size, float):
voxel_size = [voxel_size, voxel_size, voxel_size]
if isinstance(voxel_size, (list, tuple)):
voxel_size = np.array(voxel_size)
if isinstance(voxel_size, np.ndarray):
voxel_size = torch.tensor(
voxel_size, dtype=torch.float32, device=coords.device
)
grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int()
else:
assert grid_size is not None, "Either voxel_size or grid_size must be provided"
if isinstance(grid_size, int):
grid_size = [grid_size, grid_size, grid_size]
if isinstance(grid_size, (list, tuple)):
grid_size = np.array(grid_size)
if isinstance(grid_size, np.ndarray):
grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device)
voxel_size = (aabb[1] - aabb[0]) / grid_size
assert isinstance(voxel_size, torch.Tensor)
assert voxel_size.dim() == 1 and voxel_size.size(0) == 3
assert isinstance(grid_size, torch.Tensor)
assert grid_size.dim() == 1 and grid_size.size(0) == 3
pbar = None
if use_tqdm:
pbar = postprocess.tqdm(total=6, desc="Extracting GLB")
vertices = vertices.cuda()
faces = faces.cuda()
mesh = postprocess.cumesh.CuMesh()
mesh.init(vertices, faces)
_log_cumesh_counts("Fallback mesh init", mesh)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.set_description("Building BVH")
bvh = postprocess.cumesh.cuBVH(vertices, faces)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.set_description("Cleaning mesh")
mesh.simplify(decimation_target * 3, verbose=verbose)
_log_cumesh_counts("After fallback coarse simplification", mesh)
mesh.remove_duplicate_faces()
mesh.remove_small_connected_components(1e-5)
_log_cumesh_counts("After fallback initial cleanup", mesh)
mesh.simplify(decimation_target, verbose=verbose)
_log_cumesh_counts("After fallback target simplification", mesh)
mesh.remove_duplicate_faces()
mesh.remove_small_connected_components(1e-5)
_log_cumesh_counts("After fallback final cleanup", mesh)
mesh = _try_unify_face_orientations(mesh)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.set_description("Parameterizing new mesh")
out_vertices, out_faces, out_uvs, out_vmaps = mesh.uv_unwrap(
compute_charts_kwargs={
"threshold_cone_half_angle_rad": mesh_cluster_threshold_cone_half_angle_rad,
"refine_iterations": mesh_cluster_refine_iterations,
"global_iterations": mesh_cluster_global_iterations,
"smooth_strength": mesh_cluster_smooth_strength,
},
return_vmaps=True,
verbose=verbose,
)
out_vertices = out_vertices.cuda()
out_faces = out_faces.cuda()
out_uvs = out_uvs.cuda()
out_vmaps = out_vmaps.cuda()
mesh.compute_vertex_normals()
out_normals = mesh.read_vertex_normals()[out_vmaps]
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.set_description("Sampling attributes")
ctx = postprocess.dr.RasterizeCudaContext()
uvs_rast = torch.cat(
[
out_uvs * 2 - 1,
torch.zeros_like(out_uvs[:, :1]),
torch.ones_like(out_uvs[:, :1]),
],
dim=-1,
).unsqueeze(0)
rast = torch.zeros(
(1, texture_size, texture_size, 4), device="cuda", dtype=torch.float32
)
for i in range(0, out_faces.shape[0], 100000):
rast_chunk, _ = postprocess.dr.rasterize(
ctx,
uvs_rast,
out_faces[i : i + 100000],
resolution=[texture_size, texture_size],
)
mask_chunk = rast_chunk[..., 3:4] > 0
rast_chunk[..., 3:4] += i
rast = torch.where(mask_chunk, rast_chunk, rast)
mask = rast[0, ..., 3] > 0
pos = postprocess.dr.interpolate(out_vertices.unsqueeze(0), rast, out_faces)[0][0]
valid_pos = pos[mask]
_, face_id, uvw = bvh.unsigned_distance(valid_pos, return_uvw=True)
orig_tri_verts = vertices[faces[face_id.long()]]
valid_pos = (orig_tri_verts * uvw.unsqueeze(-1)).sum(dim=1)
attrs = torch.zeros(texture_size, texture_size, attr_volume.shape[1], device="cuda")
attrs[mask] = postprocess.grid_sample_3d(
attr_volume,
torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=-1),
shape=torch.Size([1, attr_volume.shape[1], *grid_size.tolist()]),
grid=((valid_pos - aabb[0]) / voxel_size).reshape(1, -1, 3),
mode="trilinear",
)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.set_description("Finalizing mesh")
mask = mask.cpu().numpy()
base_color = np.clip(
attrs[..., attr_layout["base_color"]].cpu().numpy() * 255, 0, 255
).astype(np.uint8)
metallic = np.clip(
attrs[..., attr_layout["metallic"]].cpu().numpy() * 255, 0, 255
).astype(np.uint8)
roughness = np.clip(
attrs[..., attr_layout["roughness"]].cpu().numpy() * 255, 0, 255
).astype(np.uint8)
alpha = np.clip(
attrs[..., attr_layout["alpha"]].cpu().numpy() * 255, 0, 255
).astype(np.uint8)
mask_inv = (~mask).astype(np.uint8)
base_color = cv2.inpaint(base_color, mask_inv, 3, cv2.INPAINT_TELEA)
metallic = cv2.inpaint(metallic, mask_inv, 1, cv2.INPAINT_TELEA)[..., None]
roughness = cv2.inpaint(roughness, mask_inv, 1, cv2.INPAINT_TELEA)[..., None]
alpha = cv2.inpaint(alpha, mask_inv, 1, cv2.INPAINT_TELEA)[..., None]
material = postprocess.trimesh.visual.material.PBRMaterial(
baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
metallicRoughnessTexture=Image.fromarray(
np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)
),
metallicFactor=1.0,
roughnessFactor=1.0,
alphaMode="OPAQUE",
doubleSided=True,
)
vertices_np = out_vertices.cpu().numpy()
faces_np = out_faces.cpu().numpy()
uvs_np = out_uvs.cpu().numpy()
normals_np = out_normals.cpu().numpy()
vertices_np[:, 1], vertices_np[:, 2] = vertices_np[:, 2], -vertices_np[:, 1]
normals_np[:, 1], normals_np[:, 2] = normals_np[:, 2], -normals_np[:, 1]
uvs_np[:, 1] = 1 - uvs_np[:, 1]
textured_mesh = postprocess.trimesh.Trimesh(
vertices=vertices_np,
faces=faces_np,
vertex_normals=normals_np,
process=False,
visual=postprocess.trimesh.visual.TextureVisuals(uv=uvs_np, material=material),
)
if pbar is not None:
pbar.update(1)
pbar.close()
return textured_mesh
# ---------------------------------------------------------------------------
# Public entry-point -- mirrors the branching in app.py extract_glb()
# ---------------------------------------------------------------------------
def export_glb(
*,
vertices: torch.Tensor,
faces: torch.Tensor,
attr_volume: torch.Tensor,
coords: torch.Tensor,
attr_layout: Dict[str, slice],
grid_size: Any,
aabb: Any,
decimation_target: int,
texture_size: int,
remesh: bool,
safe_nonremesh_fallback: bool | None = None,
use_tqdm: bool = False,
):
"""Export a trimesh GLB scene from decoded mesh data.
Args:
remesh: Whether to rebuild mesh topology during export.
safe_nonremesh_fallback: When ``remesh=False``, selects which
non-remesh path to use. ``True`` = safe fallback (guarded
face-orientation, retry logic). ``False`` = upstream raw
``to_glb(remesh=False)``. ``None`` (default) = fall back to
the ``SAFE_NONREMESH_GLB_EXPORT`` env var (which itself
defaults to ``True``). Ignored when ``remesh=True``.
"""
glb_kwargs = dict(
vertices=vertices,
faces=faces,
attr_volume=attr_volume,
coords=coords,
attr_layout=attr_layout,
grid_size=grid_size,
aabb=aabb,
decimation_target=decimation_target,
texture_size=texture_size,
use_tqdm=use_tqdm,
)
if remesh:
return o_voxel.postprocess.to_glb(
**glb_kwargs,
remesh=True,
remesh_band=1,
remesh_project=0,
)
use_safe = (
safe_nonremesh_fallback
if safe_nonremesh_fallback is not None
else SAFE_NONREMESH_GLB_EXPORT
)
if use_safe:
print(
"Using remesh=False safe GLB export fallback "
f"(safe_nonremesh_fallback={safe_nonremesh_fallback}, "
f"SAFE_NONREMESH_GLB_EXPORT={SAFE_NONREMESH_GLB_EXPORT})",
flush=True,
)
return _to_glb_without_risky_nonremesh_cleanup(
vertices=vertices,
faces=faces,
attr_volume=attr_volume,
coords=coords,
attr_layout=attr_layout,
grid_size=grid_size,
aabb=aabb,
decimation_target=decimation_target,
texture_size=texture_size,
use_tqdm=use_tqdm,
)
print(
"Using upstream remesh=False GLB export path "
f"(safe_nonremesh_fallback={safe_nonremesh_fallback}, "
f"SAFE_NONREMESH_GLB_EXPORT={SAFE_NONREMESH_GLB_EXPORT})",
flush=True,
)
return o_voxel.postprocess.to_glb(
**glb_kwargs,
remesh=False,
remesh_band=1,
remesh_project=0,
)