File size: 2,354 Bytes
e0004b3
d29737d
 
e0004b3
 
 
d29737d
e0004b3
 
 
 
 
 
 
 
 
 
 
d29737d
 
 
e0004b3
 
d29737d
e0004b3
 
 
d29737d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0004b3
d29737d
 
 
e0004b3
d29737d
 
e0004b3
d29737d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import base64
import tempfile
import torch
import numpy as np
import trimesh
from PIL import Image
from huggingface_hub import snapshot_download

from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.models.briarmbg import BriaRMBG
from inference_partcrafter import run_triposg


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32

        # Load model weights from snapshot
        self.partcrafter_path = snapshot_download("wgsxm/PartCrafter", local_dir="/tmp/partcrafter")
        self.rmbg_path = snapshot_download("briaai/RMBG-1.4", local_dir="/tmp/rmbg")

        self.pipe = PartCrafterPipeline.from_pretrained(self.partcrafter_path).to(self.device, self.dtype)
        self.rmbg_net = BriaRMBG.from_pretrained(self.rmbg_path).to(self.device).eval()

    def __call__(self, data):
        inputs = data.get("inputs", [])
        if not inputs:
            return {"error": "Missing inputs"}

        image_b64 = inputs[0]
        num_parts = int(inputs[1]) if len(inputs) > 1 else 6
        guidance_scale = float(inputs[2]) if len(inputs) > 2 else 7.0
        num_steps = int(inputs[3]) if len(inputs) > 3 else 50
        seed = int(inputs[4]) if len(inputs) > 4 else 42
        rmbg = bool(inputs[5]) if len(inputs) > 5 else False

        image_bytes = base64.b64decode(image_b64.split(",")[-1])
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
            f.write(image_bytes)
            image_path = f.name

        meshes, _ = run_triposg(
            pipe=self.pipe,
            image_input=image_path,
            num_parts=num_parts,
            rmbg_net=self.rmbg_net,
            seed=seed,
            num_tokens=1024,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            rmbg=rmbg,
            dtype=self.dtype,
            device=self.device,
        )

        merged = trimesh.util.concatenate(meshes)
        merged.apply_translation(-merged.center_mass)
        merged.apply_scale(1.0 / np.max(np.linalg.norm(merged.vertices, axis=1)))

        glb_data = merged.export(file_type="glb")
        glb_b64 = base64.b64encode(glb_data).decode("utf-8")

        return {"glb": glb_b64}