File size: 2,354 Bytes
e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d e0004b3 d29737d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import base64
import tempfile
import torch
import numpy as np
import trimesh
from PIL import Image
from huggingface_hub import snapshot_download
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.models.briarmbg import BriaRMBG
from inference_partcrafter import run_triposg
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
# Load model weights from snapshot
self.partcrafter_path = snapshot_download("wgsxm/PartCrafter", local_dir="/tmp/partcrafter")
self.rmbg_path = snapshot_download("briaai/RMBG-1.4", local_dir="/tmp/rmbg")
self.pipe = PartCrafterPipeline.from_pretrained(self.partcrafter_path).to(self.device, self.dtype)
self.rmbg_net = BriaRMBG.from_pretrained(self.rmbg_path).to(self.device).eval()
def __call__(self, data):
inputs = data.get("inputs", [])
if not inputs:
return {"error": "Missing inputs"}
image_b64 = inputs[0]
num_parts = int(inputs[1]) if len(inputs) > 1 else 6
guidance_scale = float(inputs[2]) if len(inputs) > 2 else 7.0
num_steps = int(inputs[3]) if len(inputs) > 3 else 50
seed = int(inputs[4]) if len(inputs) > 4 else 42
rmbg = bool(inputs[5]) if len(inputs) > 5 else False
image_bytes = base64.b64decode(image_b64.split(",")[-1])
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(image_bytes)
image_path = f.name
meshes, _ = run_triposg(
pipe=self.pipe,
image_input=image_path,
num_parts=num_parts,
rmbg_net=self.rmbg_net,
seed=seed,
num_tokens=1024,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
rmbg=rmbg,
dtype=self.dtype,
device=self.device,
)
merged = trimesh.util.concatenate(meshes)
merged.apply_translation(-merged.center_mass)
merged.apply_scale(1.0 / np.max(np.linalg.norm(merged.vertices, axis=1)))
glb_data = merged.export(file_type="glb")
glb_b64 = base64.b64encode(glb_data).decode("utf-8")
return {"glb": glb_b64}
|