TRELLIS.2 / service_runtime.py
choephix's picture
fix remesh export resolution
61287af
from __future__ import annotations
import traceback
import json
from pathlib import Path
from typing import Any
import runtime_env # noqa: F401
import numpy as np
import torch
from PIL import Image
from schemas import ImageToGlbRequest
from trellis2.pipelines import Trellis2ImageTo3DPipeline
PIPELINE_ID = "microsoft/TRELLIS.2-4B"
class ServiceError(Exception):
def __init__(
self,
*,
stage: str,
error_code: str,
message: str,
retryable: bool,
status_code: int = 500,
details: dict[str, Any] | None = None,
):
super().__init__(message)
self.stage = stage
self.error_code = error_code
self.message = message
self.retryable = retryable
self.status_code = status_code
self.details = details or {}
def to_dict(self, job_id: str) -> dict[str, Any]:
return {
"job_id": job_id,
"stage": self.stage,
"error_code": self.error_code,
"retryable": self.retryable,
"message": self.message,
"details": self.details,
}
def is_fatal_cuda_error(error: BaseException) -> bool:
text = str(error).lower()
needles = [
"illegal memory access",
"device-side assert",
"cuda error",
"[cumesh] cuda error",
]
return any(needle in text for needle in needles)
def classify_runtime_error(stage: str, error: BaseException) -> ServiceError:
if isinstance(error, ServiceError):
return error
retryable = stage == "export" or not is_fatal_cuda_error(error)
error_code = f"{stage}_failed"
status_code = 500
if is_fatal_cuda_error(error):
error_code = f"{stage}_cuda_fatal"
return ServiceError(
stage=stage,
error_code=error_code,
message=f"{type(error).__name__}: {error}",
retryable=retryable,
status_code=status_code,
details={"traceback": traceback.format_exc()},
)
class TrellisRuntime:
def __init__(self) -> None:
self.pipeline: Trellis2ImageTo3DPipeline | None = None
self.unhealthy_reason: str | None = None
@property
def is_healthy(self) -> bool:
return self.unhealthy_reason is None
def load(self) -> None:
if self.pipeline is not None:
return
pipeline = Trellis2ImageTo3DPipeline.from_pretrained(PIPELINE_ID)
pipeline.low_vram = False
pipeline.cuda()
self.pipeline = pipeline
def mark_unhealthy(self, reason: str) -> None:
self.unhealthy_reason = reason
def ensure_ready(self) -> Trellis2ImageTo3DPipeline:
if not self.is_healthy:
raise ServiceError(
stage="generate",
error_code="runtime_unhealthy",
message=self.unhealthy_reason or "Runtime unavailable",
retryable=False,
status_code=503,
)
self.load()
assert self.pipeline is not None
return self.pipeline
def preprocess(self, image: Image.Image, request: ImageToGlbRequest) -> Image.Image:
pipeline = self.ensure_ready()
if request.preprocess.background_mode == "none":
if image.mode == "RGBA":
image_np = np.array(image).astype(np.float32) / 255.0
rgb = image_np[:, :, :3] * image_np[:, :, 3:4]
return Image.fromarray((rgb * 255).astype(np.uint8), mode="RGB")
return image.convert("RGB")
try:
return pipeline.preprocess_image(image)
except Exception as error:
raise classify_runtime_error("preprocess", error) from error
def generate_export_payload(
self, image: Image.Image, request: ImageToGlbRequest
) -> dict[str, Any]:
pipeline = self.ensure_ready()
generation = request.generation
pipeline_type = {
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[generation.resolution]
try:
outputs, latents = pipeline.run(
image,
seed=generation.seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": generation.ss_sampling_steps,
"guidance_strength": generation.ss_guidance_strength,
"guidance_rescale": generation.ss_guidance_rescale,
"rescale_t": generation.ss_rescale_t,
},
shape_slat_sampler_params={
"steps": generation.shape_slat_sampling_steps,
"guidance_strength": generation.shape_slat_guidance_strength,
"guidance_rescale": generation.shape_slat_guidance_rescale,
"rescale_t": generation.shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": generation.tex_slat_sampling_steps,
"guidance_strength": generation.tex_slat_guidance_strength,
"guidance_rescale": generation.tex_slat_guidance_rescale,
"rescale_t": generation.tex_slat_rescale_t,
},
pipeline_type=pipeline_type,
return_latent=True,
)
torch.cuda.synchronize()
mesh = outputs[0]
_, _, resolution = latents
payload = self._mesh_to_payload(mesh, resolution)
del outputs
del latents
del mesh
torch.cuda.empty_cache()
return payload
except Exception as error:
if is_fatal_cuda_error(error):
self.mark_unhealthy(f"Fatal CUDA error during generation: {error}")
raise classify_runtime_error("generate", error) from error
@staticmethod
def _mesh_to_payload(mesh: Any, resolution: int) -> dict[str, Any]:
return {
"vertices": mesh.vertices.detach().cpu().numpy().astype(np.float32),
"faces": mesh.faces.detach().cpu().numpy().astype(np.int32),
"attrs": mesh.attrs.detach().cpu().numpy().astype(np.float32),
"coords": mesh.coords.detach().cpu().numpy().astype(np.int32),
"resolution": int(resolution),
"attr_layout": {
key: {"start": value.start, "stop": value.stop}
for key, value in mesh.layout.items()
},
}
def save_input_image(image: Image.Image, path: Path) -> None:
image.save(path)
def save_export_payload(job_dir: Path, payload: dict[str, Any]) -> tuple[Path, Path]:
npz_path = job_dir / "export_payload.npz"
meta_path = job_dir / "export_payload.json"
np.savez_compressed(
npz_path,
vertices=payload["vertices"],
faces=payload["faces"],
attrs=payload["attrs"],
coords=payload["coords"],
)
meta_path.write_text(
json.dumps(
{
"attr_layout": payload["attr_layout"],
"resolution": payload["resolution"],
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
},
indent=2,
sort_keys=True,
),
encoding="utf-8",
)
return npz_path, meta_path