|
import base64 |
|
import tempfile |
|
import torch |
|
import numpy as np |
|
import trimesh |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download |
|
|
|
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline |
|
from src.models.briarmbg import BriaRMBG |
|
from inference_partcrafter import run_triposg |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
|
|
|
self.partcrafter_path = snapshot_download("wgsxm/PartCrafter", local_dir="/tmp/partcrafter") |
|
self.rmbg_path = snapshot_download("briaai/RMBG-1.4", local_dir="/tmp/rmbg") |
|
|
|
self.pipe = PartCrafterPipeline.from_pretrained(self.partcrafter_path).to(self.device, self.dtype) |
|
self.rmbg_net = BriaRMBG.from_pretrained(self.rmbg_path).to(self.device).eval() |
|
|
|
def __call__(self, data): |
|
inputs = data.get("inputs", []) |
|
if not inputs: |
|
return {"error": "Missing inputs"} |
|
|
|
image_b64 = inputs[0] |
|
num_parts = int(inputs[1]) if len(inputs) > 1 else 6 |
|
guidance_scale = float(inputs[2]) if len(inputs) > 2 else 7.0 |
|
num_steps = int(inputs[3]) if len(inputs) > 3 else 50 |
|
seed = int(inputs[4]) if len(inputs) > 4 else 42 |
|
rmbg = bool(inputs[5]) if len(inputs) > 5 else False |
|
|
|
image_bytes = base64.b64decode(image_b64.split(",")[-1]) |
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: |
|
f.write(image_bytes) |
|
image_path = f.name |
|
|
|
meshes, _ = run_triposg( |
|
pipe=self.pipe, |
|
image_input=image_path, |
|
num_parts=num_parts, |
|
rmbg_net=self.rmbg_net, |
|
seed=seed, |
|
num_tokens=1024, |
|
num_inference_steps=num_steps, |
|
guidance_scale=guidance_scale, |
|
rmbg=rmbg, |
|
dtype=self.dtype, |
|
device=self.device, |
|
) |
|
|
|
merged = trimesh.util.concatenate(meshes) |
|
merged.apply_translation(-merged.center_mass) |
|
merged.apply_scale(1.0 / np.max(np.linalg.norm(merged.vertices, axis=1))) |
|
|
|
glb_data = merged.export(file_type="glb") |
|
glb_b64 = base64.b64encode(glb_data).decode("utf-8") |
|
|
|
return {"glb": glb_b64} |
|
|