|
import os |
|
import random |
|
import tempfile |
|
import time |
|
import zipfile |
|
from contextlib import nullcontext |
|
from functools import lru_cache |
|
from typing import Any |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import trimesh |
|
from gradio_litmodel3d import LitModel3D |
|
from gradio_pointcloudeditor import PointCloudEditor |
|
from PIL import Image |
|
from transparent_background import Remover |
|
|
|
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper") |
|
os.system("pip install ./deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl") |
|
|
|
import spar3d.utils as spar3d_utils |
|
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE |
|
from spar3d.system import SPAR3D |
|
|
|
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio") |
|
|
|
bg_remover = Remover() |
|
|
|
COND_WIDTH = 512 |
|
COND_HEIGHT = 512 |
|
COND_DISTANCE = 2.2 |
|
COND_FOVY = 0.591627 |
|
BACKGROUND_COLOR = [0.5, 0.5, 0.5] |
|
|
|
|
|
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE) |
|
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad( |
|
COND_FOVY, COND_HEIGHT, COND_WIDTH |
|
) |
|
|
|
generated_files = [] |
|
|
|
|
|
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): |
|
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}") |
|
import shutil |
|
|
|
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"]) |
|
|
|
device = spar3d_utils.get_device() |
|
|
|
model = SPAR3D.from_pretrained( |
|
"stabilityai/stable-point-aware-3d", |
|
config_name="config.yaml", |
|
weight_name="model.safetensors", |
|
) |
|
model.eval() |
|
model = model.to(device) |
|
|
|
example_files = [ |
|
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") |
|
] |
|
|
|
def create_zip_file(glb_file, pc_file, illumination_file): |
|
if not all([glb_file, pc_file, illumination_file]): |
|
return None |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
zip_path = os.path.join(temp_dir, "spar3d_output.zip") |
|
|
|
with zipfile.ZipFile(zip_path, "w") as zipf: |
|
zipf.write(glb_file, "mesh.glb") |
|
zipf.write(pc_file, "points.ply") |
|
zipf.write(illumination_file, "illumination.hdr") |
|
|
|
generated_files.append(zip_path) |
|
return zip_path |
|
|
|
def forward_model( |
|
batch, |
|
system, |
|
guidance_scale=3.0, |
|
seed=0, |
|
device="cuda", |
|
remesh_option="none", |
|
vertex_count=-1, |
|
texture_resolution=1024, |
|
): |
|
batch_size = batch["rgb_cond"].shape[0] |
|
|
|
|
|
|
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
cond_tokens = system.forward_pdiff_cond(batch) |
|
|
|
if "pc_cond" not in batch: |
|
sample_iter = system.sampler.sample_batch_progressive( |
|
batch_size, |
|
cond_tokens, |
|
guidance_scale=guidance_scale, |
|
device=device, |
|
) |
|
for x in sample_iter: |
|
samples = x["xstart"] |
|
batch["pc_cond"] = samples.permute(0, 2, 1).float() |
|
batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"]) |
|
|
|
|
|
batch["pc_cond"] = batch["pc_cond"][ |
|
:, torch.randperm(batch["pc_cond"].shape[1])[:512] |
|
] |
|
|
|
|
|
xyz = batch["pc_cond"][0, :, :3].cpu().numpy() |
|
color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8) |
|
pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb) |
|
|
|
|
|
trimesh_mesh, _glob_dict = model.generate_mesh( |
|
batch, |
|
texture_resolution, |
|
remesh=remesh_option, |
|
vertex_count=vertex_count, |
|
estimate_illumination=True, |
|
) |
|
trimesh_mesh = trimesh_mesh[0] |
|
illumination = _glob_dict["illumination"] |
|
|
|
return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0] |
|
|
|
def process_model_run( |
|
fr_res, |
|
guidance_scale, |
|
random_seed, |
|
pc_cond, |
|
remesh_option, |
|
vertex_count_type, |
|
vertex_count, |
|
texture_resolution, |
|
): |
|
start = time.time() |
|
with torch.no_grad(): |
|
with ( |
|
torch.autocast(device_type=device, dtype=torch.bfloat16) |
|
if "cuda" in device |
|
else nullcontext() |
|
): |
|
model_batch = create_batch(fr_res) |
|
model_batch = {k: v.to(device) for k, v in model_batch.items()} |
|
|
|
trimesh_mesh, trimesh_pc, illumination_map = forward_model( |
|
model_batch, |
|
model, |
|
guidance_scale=guidance_scale, |
|
seed=random_seed, |
|
device="cuda", |
|
remesh_option=remesh_option.lower(), |
|
vertex_count=vertex_count, |
|
texture_resolution=texture_resolution, |
|
) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
tmp_file = os.path.join(temp_dir, "mesh.glb") |
|
|
|
trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True) |
|
generated_files.append(tmp_file) |
|
|
|
tmp_file_pc = os.path.join(temp_dir, "points.ply") |
|
trimesh_pc.export(tmp_file_pc) |
|
generated_files.append(tmp_file_pc) |
|
|
|
tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr") |
|
cv2.imwrite(tmp_file_illumination, illumination_map) |
|
generated_files.append(tmp_file_illumination) |
|
|
|
print("Generation took:", time.time() - start, "s") |
|
|
|
return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc |
|
|
|
def create_batch(input_image: Image) -> dict[str, Any]: |
|
img_cond = ( |
|
torch.from_numpy( |
|
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) |
|
/ 255.0 |
|
) |
|
.float() |
|
.clip(0, 1) |
|
) |
|
mask_cond = img_cond[:, :, -1:] |
|
rgb_cond = torch.lerp( |
|
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond |
|
) |
|
|
|
batch_elem = { |
|
"rgb_cond": rgb_cond, |
|
"mask_cond": mask_cond, |
|
"c2w_cond": c2w_cond.unsqueeze(0), |
|
"intrinsic_cond": intrinsic.unsqueeze(0), |
|
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), |
|
} |
|
|
|
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} |
|
return batched |
|
|
|
def remove_background(input_image: Image) -> Image: |
|
return bg_remover.process(input_image.convert("RGB")) |
|
|
|
def auto_process(input_image): |
|
if input_image is None: |
|
return None, None, None, None |
|
|
|
|
|
guidance_scale = 3.0 |
|
random_seed = 0 |
|
foreground_ratio = 1.3 |
|
remesh_option = "None" |
|
vertex_count_type = "Keep Vertex Count" |
|
vertex_count = 2000 |
|
texture_resolution = 1024 |
|
no_crop = False |
|
pc_cond = None |
|
|
|
|
|
rem_removed = remove_background(input_image) |
|
fr_res = spar3d_utils.foreground_crop( |
|
rem_removed, |
|
crop_ratio=foreground_ratio, |
|
newsize=(COND_WIDTH, COND_HEIGHT), |
|
no_crop=no_crop, |
|
) |
|
|
|
|
|
glb_file, pc_file, illumination_file, pc_list = process_model_run( |
|
fr_res, |
|
guidance_scale, |
|
random_seed, |
|
pc_cond, |
|
remesh_option, |
|
vertex_count_type, |
|
vertex_count, |
|
texture_resolution, |
|
) |
|
|
|
zip_file = create_zip_file(glb_file, pc_file, illumination_file) |
|
|
|
return glb_file, illumination_file, zip_file, pc_list |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images |
|
Upload an image to generate a 3D model. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image( |
|
type="pil", |
|
label="Upload Image", |
|
sources=["upload", "click"], |
|
image_mode="RGBA" |
|
) |
|
|
|
with gr.Column(): |
|
output_3d = LitModel3D( |
|
label="3D Model", |
|
clear_color=[0.0, 0.0, 0.0, 0.0], |
|
tonemapping="aces", |
|
contrast=1.0, |
|
scale=1.0, |
|
) |
|
download_all_btn = gr.File( |
|
label="Download Model (ZIP)", |
|
file_count="single", |
|
visible=True |
|
) |
|
|
|
input_img.upload( |
|
auto_process, |
|
inputs=[input_img], |
|
outputs=[ |
|
output_3d, |
|
gr.State(), |
|
download_all_btn, |
|
gr.State(), |
|
], |
|
) |
|
|
|
demo.queue().launch(share=False) |