import os import numpy as np import torch import argparse from PIL import Image from omegaconf import OmegaConf from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen from modules.part_synthesis.process_utils import save_parts_outputs from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser() parser.add_argument("--image-input", type=str, required=True) parser.add_argument("--mask-input", type=str, required=True) parser.add_argument("--output-root", type=str, default="./output") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num-inference-steps", type=int, default=25) parser.add_argument("--guidance-scale", type=float, default=3.5) parser.add_argument("--simplify_ratio", type=float, default=0.3) parser.add_argument("--partfield_encoder_path", type=str, default="ckpt/model_objaverse.ckpt") parser.add_argument("--bbox_gen_ckpt", type=str, default="ckpt/bbox_gen.ckpt") parser.add_argument("--part_synthesis_ckpt", type=str, default="ckpt/part_synthesis") args = parser.parse_args() os.makedirs(args.output_root, exist_ok=True) output_dir = os.path.join(args.output_root, args.image_input.split("/")[-1].split(".")[0]) os.makedirs(output_dir, exist_ok=True) torch.manual_seed(args.seed) # load part_synthesis model part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained(args.part_synthesis_ckpt) part_synthesis_pipeline.to(device) print("[INFO] PartSynthesis model loaded") # load bbox_gen model bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args bbox_gen_config.partfield_encoder_path = args.partfield_encoder_path bbox_gen_model = BboxGen(bbox_gen_config) bbox_gen_model.load_state_dict(torch.load(args.bbox_gen_ckpt), strict=False) bbox_gen_model.to(device) bbox_gen_model.eval().half() print("[INFO] BboxGen model loaded") img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(args.image_input, args.mask_input) img_mask_vis.save(os.path.join(output_dir, "img_mask_vis.png")) voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=args.seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5}) voxel_coords = voxel_coords.cpu().numpy() np.save(os.path.join(output_dir, "voxel_coords.npy"), voxel_coords) voxel_coords_ply = vis_voxel_coords(voxel_coords) voxel_coords_ply.export(os.path.join(output_dir, "voxel_coords_vis.ply")) print("[INFO] Voxel coordinates saved") bbox_gen_input = prepare_bbox_gen_input(os.path.join(output_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input) bbox_gen_output = bbox_gen_model.generate(bbox_gen_input) np.save(os.path.join(output_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0]) bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0]) bboxes_vis.export(os.path.join(output_dir, "bboxes_vis.glb")) print("[INFO] BboxGen output saved") part_synthesis_input = prepare_part_synthesis_input(os.path.join(output_dir, "voxel_coords.npy"), os.path.join(output_dir, "bboxes.npy"), ordered_mask_input) part_synthesis_output = part_synthesis_pipeline.get_slat( img_black_bg, part_synthesis_input['coords'], [part_synthesis_input['part_layouts']], part_synthesis_input['masks'], seed=args.seed, slat_sampler_params={"steps": args.num_inference_steps, "cfg_strength": args.guidance_scale}, formats=['mesh', 'gaussian', 'radiance_field'], preprocess_image=False, ) save_parts_outputs( part_synthesis_output, output_dir=output_dir, simplify_ratio=args.simplify_ratio, save_video=False, save_glb=True, textured=False, ) merge_parts(output_dir) print("[INFO] PartSynthesis output saved")