import base64 import tempfile import torch import numpy as np import trimesh from PIL import Image from huggingface_hub import snapshot_download from src.pipelines.pipeline_partcrafter import PartCrafterPipeline from src.models.briarmbg import BriaRMBG from inference_partcrafter import run_triposg class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 # Load model weights from snapshot self.partcrafter_path = snapshot_download("wgsxm/PartCrafter", local_dir="/tmp/partcrafter") self.rmbg_path = snapshot_download("briaai/RMBG-1.4", local_dir="/tmp/rmbg") self.pipe = PartCrafterPipeline.from_pretrained(self.partcrafter_path).to(self.device, self.dtype) self.rmbg_net = BriaRMBG.from_pretrained(self.rmbg_path).to(self.device).eval() def __call__(self, data): inputs = data.get("inputs", []) if not inputs: return {"error": "Missing inputs"} image_b64 = inputs[0] num_parts = int(inputs[1]) if len(inputs) > 1 else 6 guidance_scale = float(inputs[2]) if len(inputs) > 2 else 7.0 num_steps = int(inputs[3]) if len(inputs) > 3 else 50 seed = int(inputs[4]) if len(inputs) > 4 else 42 rmbg = bool(inputs[5]) if len(inputs) > 5 else False image_bytes = base64.b64decode(image_b64.split(",")[-1]) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(image_bytes) image_path = f.name meshes, _ = run_triposg( pipe=self.pipe, image_input=image_path, num_parts=num_parts, rmbg_net=self.rmbg_net, seed=seed, num_tokens=1024, num_inference_steps=num_steps, guidance_scale=guidance_scale, rmbg=rmbg, dtype=self.dtype, device=self.device, ) merged = trimesh.util.concatenate(meshes) merged.apply_translation(-merged.center_mass) merged.apply_scale(1.0 / np.max(np.linalg.norm(merged.vertices, axis=1))) glb_data = merged.export(file_type="glb") glb_b64 = base64.b64encode(glb_data).decode("utf-8") return {"glb": glb_b64}