| from __future__ import annotations |
|
|
| import traceback |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import runtime_env |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from schemas import ImageToGlbRequest |
| from trellis2.pipelines import Trellis2ImageTo3DPipeline |
|
|
|
|
| PIPELINE_ID = "microsoft/TRELLIS.2-4B" |
|
|
|
|
| class ServiceError(Exception): |
| def __init__( |
| self, |
| *, |
| stage: str, |
| error_code: str, |
| message: str, |
| retryable: bool, |
| status_code: int = 500, |
| details: dict[str, Any] | None = None, |
| ): |
| super().__init__(message) |
| self.stage = stage |
| self.error_code = error_code |
| self.message = message |
| self.retryable = retryable |
| self.status_code = status_code |
| self.details = details or {} |
|
|
| def to_dict(self, job_id: str) -> dict[str, Any]: |
| return { |
| "job_id": job_id, |
| "stage": self.stage, |
| "error_code": self.error_code, |
| "retryable": self.retryable, |
| "message": self.message, |
| "details": self.details, |
| } |
|
|
|
|
| def is_fatal_cuda_error(error: BaseException) -> bool: |
| text = str(error).lower() |
| needles = [ |
| "illegal memory access", |
| "device-side assert", |
| "cuda error", |
| "[cumesh] cuda error", |
| ] |
| return any(needle in text for needle in needles) |
|
|
|
|
| def classify_runtime_error(stage: str, error: BaseException) -> ServiceError: |
| if isinstance(error, ServiceError): |
| return error |
| retryable = stage == "export" or not is_fatal_cuda_error(error) |
| error_code = f"{stage}_failed" |
| status_code = 500 |
| if is_fatal_cuda_error(error): |
| error_code = f"{stage}_cuda_fatal" |
| return ServiceError( |
| stage=stage, |
| error_code=error_code, |
| message=f"{type(error).__name__}: {error}", |
| retryable=retryable, |
| status_code=status_code, |
| details={"traceback": traceback.format_exc()}, |
| ) |
|
|
|
|
| class TrellisRuntime: |
| def __init__(self) -> None: |
| self.pipeline: Trellis2ImageTo3DPipeline | None = None |
| self.unhealthy_reason: str | None = None |
|
|
| @property |
| def is_healthy(self) -> bool: |
| return self.unhealthy_reason is None |
|
|
| def load(self) -> None: |
| if self.pipeline is not None: |
| return |
| pipeline = Trellis2ImageTo3DPipeline.from_pretrained(PIPELINE_ID) |
| pipeline.low_vram = False |
| pipeline.cuda() |
| self.pipeline = pipeline |
|
|
| def mark_unhealthy(self, reason: str) -> None: |
| self.unhealthy_reason = reason |
|
|
| def ensure_ready(self) -> Trellis2ImageTo3DPipeline: |
| if not self.is_healthy: |
| raise ServiceError( |
| stage="generate", |
| error_code="runtime_unhealthy", |
| message=self.unhealthy_reason or "Runtime unavailable", |
| retryable=False, |
| status_code=503, |
| ) |
| self.load() |
| assert self.pipeline is not None |
| return self.pipeline |
|
|
| def preprocess(self, image: Image.Image, request: ImageToGlbRequest) -> Image.Image: |
| pipeline = self.ensure_ready() |
| if request.preprocess.background_mode == "none": |
| if image.mode == "RGBA": |
| image_np = np.array(image).astype(np.float32) / 255.0 |
| rgb = image_np[:, :, :3] * image_np[:, :, 3:4] |
| return Image.fromarray((rgb * 255).astype(np.uint8), mode="RGB") |
| return image.convert("RGB") |
| try: |
| return pipeline.preprocess_image(image) |
| except Exception as error: |
| raise classify_runtime_error("preprocess", error) from error |
|
|
| def generate_export_payload( |
| self, image: Image.Image, request: ImageToGlbRequest |
| ) -> dict[str, Any]: |
| pipeline = self.ensure_ready() |
| generation = request.generation |
| pipeline_type = { |
| "512": "512", |
| "1024": "1024_cascade", |
| "1536": "1536_cascade", |
| }[generation.resolution] |
| try: |
| outputs, latents = pipeline.run( |
| image, |
| seed=generation.seed, |
| preprocess_image=False, |
| sparse_structure_sampler_params={ |
| "steps": generation.ss_sampling_steps, |
| "guidance_strength": generation.ss_guidance_strength, |
| "guidance_rescale": generation.ss_guidance_rescale, |
| "rescale_t": generation.ss_rescale_t, |
| }, |
| shape_slat_sampler_params={ |
| "steps": generation.shape_slat_sampling_steps, |
| "guidance_strength": generation.shape_slat_guidance_strength, |
| "guidance_rescale": generation.shape_slat_guidance_rescale, |
| "rescale_t": generation.shape_slat_rescale_t, |
| }, |
| tex_slat_sampler_params={ |
| "steps": generation.tex_slat_sampling_steps, |
| "guidance_strength": generation.tex_slat_guidance_strength, |
| "guidance_rescale": generation.tex_slat_guidance_rescale, |
| "rescale_t": generation.tex_slat_rescale_t, |
| }, |
| pipeline_type=pipeline_type, |
| return_latent=True, |
| ) |
| torch.cuda.synchronize() |
| mesh = outputs[0] |
| _, _, resolution = latents |
| payload = self._mesh_to_payload(mesh, resolution) |
| del outputs |
| del latents |
| del mesh |
| torch.cuda.empty_cache() |
| return payload |
| except Exception as error: |
| if is_fatal_cuda_error(error): |
| self.mark_unhealthy(f"Fatal CUDA error during generation: {error}") |
| raise classify_runtime_error("generate", error) from error |
|
|
| @staticmethod |
| def _mesh_to_payload(mesh: Any, resolution: int) -> dict[str, Any]: |
| return { |
| "vertices": mesh.vertices.detach().cpu().numpy().astype(np.float32), |
| "faces": mesh.faces.detach().cpu().numpy().astype(np.int32), |
| "attrs": mesh.attrs.detach().cpu().numpy().astype(np.float32), |
| "coords": mesh.coords.detach().cpu().numpy().astype(np.int32), |
| "resolution": int(resolution), |
| "attr_layout": { |
| key: {"start": value.start, "stop": value.stop} |
| for key, value in mesh.layout.items() |
| }, |
| } |
|
|
|
|
| def save_input_image(image: Image.Image, path: Path) -> None: |
| image.save(path) |
|
|
|
|
| def save_export_payload(job_dir: Path, payload: dict[str, Any]) -> tuple[Path, Path]: |
| npz_path = job_dir / "export_payload.npz" |
| meta_path = job_dir / "export_payload.json" |
| np.savez_compressed( |
| npz_path, |
| vertices=payload["vertices"], |
| faces=payload["faces"], |
| attrs=payload["attrs"], |
| coords=payload["coords"], |
| ) |
| meta_path.write_text( |
| json.dumps( |
| { |
| "attr_layout": payload["attr_layout"], |
| "resolution": payload["resolution"], |
| "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], |
| }, |
| indent=2, |
| sort_keys=True, |
| ), |
| encoding="utf-8", |
| ) |
| return npz_path, meta_path |
|
|