import os import sys import argparse import numpy as np import trimesh from pathlib import Path import torch import pytorch_lightning as pl import requests import tempfile # 确保可以找到 P3-SAM 和 XPart 模块 sys.path.append('P3-SAM') from demo.auto_mask import AutoMask sys.path.append('XPart') from partgen.partformer_pipeline import PartFormerPipeline from partgen.utils.misc import get_config_from_file # --- 全局模型初始化 --- # 初始化 P3-SAM 分割器 try: automask = AutoMask() except Exception as e: print(f"Error initializing AutoMask (P3-SAM): {e}") print("Please ensure the P3-SAM submodule and its dependencies are installed correctly.") sys.exit(1) # XPart Pipeline 的加载函数 _PIPELINE = None # 延迟加载,仅在需要时加载 def _load_pipeline(): global _PIPELINE if _PIPELINE is None: print("Loading XPart generation pipeline... This may take a moment.") pl.seed_everything(2026, workers=True) cfg_path = str(Path(__file__).parent / "XPart/partgen/config" / "infer.yaml") config = get_config_from_file(cfg_path) assert hasattr(config, "ckpt") or hasattr( config, "ckpt_path" ), "ckpt or ckpt_path must be specified in config" _PIPELINE = PartFormerPipeline.from_pretrained( model_path="tencent/Hunyuan3D-Part", verbose=True, ) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") _PIPELINE.to(device=device, dtype=torch.float32) return _PIPELINE # --- 核心功能函数 --- def is_supported_3d_file(filename): """检查文件扩展名是否支持""" ext = os.path.splitext(str(filename))[1].lower() return ext in ['.glb', '.ply', '.obj'] def segment(mesh_path, output_dir, postprocess=True, postprocess_threshold=0.95, seed=42): """ 使用 P3-SAM 分割 3D 模型。 """ if not os.path.exists(mesh_path): print(f"Error: Input file not found at {mesh_path}") sys.exit(1) if not is_supported_3d_file(mesh_path): print(f"Error: Unsupported file type. Only .glb, .ply, .obj are supported.") sys.exit(1) print(f"Segmenting mesh: {mesh_path}") mesh = trimesh.load(mesh_path, force='mesh', process=False) # 核心分割逻辑 aabb, face_ids, mesh = automask.predict_aabb( mesh, seed=seed, is_parallel=False, post_process=postprocess, threshold=postprocess_threshold ) # 为不同的部分生成随机颜色 color_map = {} unique_ids = np.unique(face_ids) for i in unique_ids: if i == -1: continue part_color = np.random.randint(0, 256, size=3) color_map[i] = part_color face_colors = [] for i in face_ids: if i == -1: face_colors.append([0, 0, 0]) else: face_colors.append(color_map[i]) face_colors = np.array(face_colors, dtype=np.uint8) mesh_save = mesh.copy() mesh_save.visual.face_colors = face_colors # 保存结果 (使用固定文件名) seg_mesh_path = os.path.join(output_dir, 'output_segmented.glb') mesh_save.export(seg_mesh_path) print(f" -> Saved segmented mesh to: {seg_mesh_path}") face_id_path = os.path.join(output_dir, 'output_face_ids.npy') np.save(face_id_path, face_ids) print(f" -> Saved face IDs to: {face_id_path}") return aabb, mesh_path # 返回后续生成步骤所需的数据 def generate(mesh_path, aabb, output_dir, seed=42): """ 使用 XPart 生成部件。 """ print(f"Generating parts for mesh: {mesh_path}") pipeline = _load_pipeline() # 确保每次请求的行为确定 try: pl.seed_everything(int(seed), workers=True) except Exception: pl.seed_everything(2026, workers=True) additional_params = {"output_type": "trimesh"} obj_mesh, (out_bbox, mesh_gt_bbox, explode_object) = pipeline( mesh_path=mesh_path, aabb=aabb, octree_resolution=512, **additional_params, ) # 导出所有结果 (使用固定文件名) obj_path = os.path.join(output_dir, 'output_gen_parts.glb') out_bbox_path = os.path.join(output_dir, 'output_gen_bbox.glb') explode_path = os.path.join(output_dir, 'output_gen_exploded.glb') obj_mesh.export(obj_path) print(f" -> Saved generated parts to: {obj_path}") out_bbox.export(out_bbox_path) print(f" -> Saved parts with bounding box to: {out_bbox_path}") explode_object.export(explode_path) print(f" -> Saved exploded view to: {explode_path}") def main(): """主函数,用于解析命令行参数并执行相应操作""" parser = argparse.ArgumentParser( description="Command-line tool for 3D model segmentation and part generation using P3-SAM and XPart." ) # --- 主要参数 (input_model 和 prompt 功能相同) --- parser.add_argument( '-i', '--input_model', type=str, help="Path or URL to the 3D model file (.glb, .obj, .ply)." ) parser.add_argument( '-p', '--prompt', type=str, help="Path or URL to the 3D model file (alias for --input_model)." ) # --- 其他参数 --- parser.add_argument( '-o', '--output_dir', type=str, default='results', help="Directory to save the output files. (default: 'results')" ) parser.add_argument( '--seed', type=int, default=42, help="Random seed for segmentation. (default: 42)" ) parser.add_argument( '--no-postprocess', action='store_true', help="Disable post-processing for segmentation." ) parser.add_argument( '--postprocess_threshold', type=float, default=0.95, help="Threshold for merging small parts during post-processing. (default: 0.95)" ) parser.add_argument( '--generate', action='store_true', help="If set, run the part generation step (XPart) after segmentation." ) parser.add_argument( '--gen_seed', type=int, default=42, help="Random seed for part generation. (default: 42)" ) args = parser.parse_args() # 确定输入路径,--input_model 或 --prompt 均可 input_path = args.input_model or args.prompt if not input_path: print("Error: You must provide an input model using either --input_model or --prompt.") sys.exit(1) local_model_path = None temp_file_handle = None try: # 检查输入是 URL 还是本地文件 if input_path.startswith('http://') or input_path.startswith('https://'): print(f"Downloading model from URL: {input_path}") try: response = requests.get(input_path, stream=True) response.raise_for_status() # 如果下载失败 (如 404), 则抛出异常 # 从 URL 获取文件后缀名 file_suffix = Path(input_path).suffix if not file_suffix: file_suffix = '.glb' # 默认为 .glb # 创建一个带正确后缀的临时文件 temp_file_handle = tempfile.NamedTemporaryFile(delete=False, suffix=file_suffix) with temp_file_handle as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) local_model_path = temp_file_handle.name print(f"Model saved temporarily to: {local_model_path}") except requests.exceptions.RequestException as e: print(f"Error downloading file: {e}") sys.exit(1) else: # 输入是本地文件路径 if not os.path.exists(input_path): print(f"Error: Local file not found at '{input_path}'") sys.exit(1) local_model_path = input_path # 创建输出目录 os.makedirs(args.output_dir, exist_ok=True) # 第一步:执行分割 print("-" * 50) aabb, original_mesh_path = segment( mesh_path=local_model_path, output_dir=args.output_dir, postprocess=not args.no_postprocess, postprocess_threshold=args.postprocess_threshold, seed=args.seed ) print("Segmentation finished.") print("-" * 50) # 第二步:如果指定,执行生成 if args.generate: if aabb is None: print("Segmentation failed, cannot proceed to generation.") sys.exit(1) print("\nStarting part generation...") generate( mesh_path=original_mesh_path, aabb=aabb, output_dir=args.output_dir, seed=args.gen_seed ) print("Generation finished.") print("-" * 50) finally: # 清理临时文件 if temp_file_handle: print(f"Cleaning up temporary file: {local_model_path}") os.remove(local_model_path) if __name__ == '__main__': main()