|
|
import torch |
|
|
import base64 |
|
|
import os |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from trellis.pipelines import TrellisImageTo3DPipeline |
|
|
from trellis.utils import postprocessing_utils |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
|
|
|
self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") |
|
|
self.pipeline.cuda() |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Args: |
|
|
data (:obj:`dict`): |
|
|
- "inputs": The base64 encoded image or URL. |
|
|
- "params": Dictionary of generation parameters (optional). |
|
|
""" |
|
|
inputs = data.pop("inputs", data) |
|
|
params = data.pop("parameters", {}) |
|
|
|
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs))) |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.pipeline( |
|
|
image, |
|
|
num_samples=1, |
|
|
return_flags=["mesh"], |
|
|
**params |
|
|
) |
|
|
|
|
|
|
|
|
mesh = outputs['mesh'][0] |
|
|
glb_io = BytesIO() |
|
|
mesh.export(glb_io, file_type='glb') |
|
|
glb_io.seek(0) |
|
|
|
|
|
|
|
|
return { |
|
|
"mesh_base64": base64.b64encode(glb_io.getvalue()).decode("utf-8"), |
|
|
"format": "glb" |
|
|
} |
|
|
|