|
import argparse
|
|
import os
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from transparent_background import Remover
|
|
|
|
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
|
|
from spar3d.system import SPAR3D
|
|
from spar3d.utils import foreground_crop, get_device, remove_background
|
|
|
|
|
|
def check_positive(value):
|
|
ivalue = int(value)
|
|
if ivalue <= 0:
|
|
raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
|
|
return ivalue
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"image", type=str, nargs="+", help="Path to input image(s) or folder."
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
default=get_device(),
|
|
type=str,
|
|
help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
|
|
)
|
|
parser.add_argument(
|
|
"--pretrained-model",
|
|
default="stabilityai/stable-point-aware-3d",
|
|
type=str,
|
|
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'",
|
|
)
|
|
parser.add_argument(
|
|
"--foreground-ratio",
|
|
default=1.3,
|
|
type=float,
|
|
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
default="output/",
|
|
type=str,
|
|
help="Output directory to save the results. Default: 'output/'",
|
|
)
|
|
parser.add_argument(
|
|
"--texture-resolution",
|
|
default=1024,
|
|
type=int,
|
|
help="Texture atlas resolution. Default: 1024",
|
|
)
|
|
parser.add_argument(
|
|
"--low-vram-mode",
|
|
action="store_true",
|
|
help=(
|
|
"Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. "
|
|
"This mode will reduce the VRAM consumption to roughly 7GB but in exchange "
|
|
"the model will be slower. Default: False"
|
|
),
|
|
)
|
|
|
|
remesh_choices = ["none"]
|
|
if TRIANGLE_REMESH_AVAILABLE:
|
|
remesh_choices.append("triangle")
|
|
if QUAD_REMESH_AVAILABLE:
|
|
remesh_choices.append("quad")
|
|
parser.add_argument(
|
|
"--remesh_option",
|
|
choices=remesh_choices,
|
|
default="none",
|
|
help="Remeshing option",
|
|
)
|
|
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
|
|
parser.add_argument(
|
|
"--reduction_count_type",
|
|
choices=["keep", "vertex", "faces"],
|
|
default="keep",
|
|
help="Vertex count type",
|
|
)
|
|
parser.add_argument(
|
|
"--target_count",
|
|
type=check_positive,
|
|
help="Selected target count.",
|
|
default=2000,
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size", default=1, type=int, help="Batch size for inference"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
|
|
devices = ["cuda", "mps", "cpu"]
|
|
if not any(args.device in device for device in devices):
|
|
raise ValueError("Invalid device. Use cuda, mps or cpu")
|
|
|
|
output_dir = args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
device = args.device
|
|
if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
|
|
device = "cpu"
|
|
|
|
print("Device used: ", device)
|
|
|
|
model = SPAR3D.from_pretrained(
|
|
args.pretrained_model,
|
|
config_name="config.yaml",
|
|
weight_name="model.safetensors",
|
|
low_vram_mode=args.low_vram_mode,
|
|
)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
bg_remover = Remover(device=device)
|
|
images = []
|
|
idx = 0
|
|
for image_path in args.image:
|
|
|
|
def handle_image(image_path, idx):
|
|
image = remove_background(
|
|
Image.open(image_path).convert("RGBA"), bg_remover
|
|
)
|
|
image = foreground_crop(image, args.foreground_ratio)
|
|
os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
|
|
image.save(os.path.join(output_dir, str(idx), "input.png"))
|
|
images.append(image)
|
|
|
|
if os.path.isdir(image_path):
|
|
image_paths = [
|
|
os.path.join(image_path, f)
|
|
for f in os.listdir(image_path)
|
|
if f.endswith((".png", ".jpg", ".jpeg"))
|
|
]
|
|
for image_path in image_paths:
|
|
handle_image(image_path, idx)
|
|
idx += 1
|
|
else:
|
|
handle_image(image_path, idx)
|
|
idx += 1
|
|
|
|
vertex_count = (
|
|
-1
|
|
if args.reduction_count_type == "keep"
|
|
else (
|
|
args.target_count
|
|
if args.reduction_count_type == "vertex"
|
|
else args.target_count // 2
|
|
)
|
|
)
|
|
|
|
for i in tqdm(range(0, len(images), args.batch_size)):
|
|
image = images[i : i + args.batch_size]
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
with torch.no_grad():
|
|
with (
|
|
torch.autocast(device_type=device, dtype=torch.bfloat16)
|
|
if "cuda" in device
|
|
else nullcontext()
|
|
):
|
|
mesh, glob_dict = model.run_image(
|
|
image,
|
|
bake_resolution=args.texture_resolution,
|
|
remesh=args.remesh_option,
|
|
vertex_count=vertex_count,
|
|
return_points=True,
|
|
)
|
|
if torch.cuda.is_available():
|
|
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
|
elif torch.backends.mps.is_available():
|
|
print(
|
|
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
|
|
)
|
|
|
|
if len(image) == 1:
|
|
out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
|
|
mesh.export(out_mesh_path, include_normals=True)
|
|
out_points_path = os.path.join(output_dir, str(i), "points.ply")
|
|
glob_dict["point_clouds"][0].export(out_points_path)
|
|
else:
|
|
for j in range(len(mesh)):
|
|
out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
|
|
mesh[j].export(out_mesh_path, include_normals=True)
|
|
out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
|
|
glob_dict["point_clouds"][j].export(out_points_path)
|
|
|