et-PartCrafter / handler.py
staswrs
handler new
df89953
import base64
import tempfile
import torch
import numpy as np
import trimesh
from PIL import Image
from huggingface_hub import snapshot_download
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.models.briarmbg import BriaRMBG
from inference_partcrafter import run_triposg
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
# Load model weights from snapshot
self.partcrafter_path = snapshot_download("wgsxm/PartCrafter", local_dir="/tmp/partcrafter")
self.rmbg_path = snapshot_download("briaai/RMBG-1.4", local_dir="/tmp/rmbg")
self.pipe = PartCrafterPipeline.from_pretrained(self.partcrafter_path).to(self.device, self.dtype)
self.rmbg_net = BriaRMBG.from_pretrained(self.rmbg_path).to(self.device).eval()
def __call__(self, data):
inputs = data.get("inputs", [])
if not inputs:
return {"error": "Missing inputs"}
image_b64 = inputs[0]
num_parts = int(inputs[1]) if len(inputs) > 1 else 6
guidance_scale = float(inputs[2]) if len(inputs) > 2 else 7.0
num_steps = int(inputs[3]) if len(inputs) > 3 else 50
seed = int(inputs[4]) if len(inputs) > 4 else 42
rmbg = bool(inputs[5]) if len(inputs) > 5 else False
image_bytes = base64.b64decode(image_b64.split(",")[-1])
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(image_bytes)
image_path = f.name
meshes, _ = run_triposg(
pipe=self.pipe,
image_input=image_path,
num_parts=num_parts,
rmbg_net=self.rmbg_net,
seed=seed,
num_tokens=1024,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
rmbg=rmbg,
dtype=self.dtype,
device=self.device,
)
merged = trimesh.util.concatenate(meshes)
merged.apply_translation(-merged.center_mass)
merged.apply_scale(1.0 / np.max(np.linalg.norm(merged.vertices, axis=1)))
glb_data = merged.export(file_type="glb")
glb_b64 = base64.b64encode(glb_data).decode("utf-8")
return {"glb": glb_b64}