Trellis.2.multiview / app_texturing.py
opsiclear-admin's picture
Upload folder using huggingface_hub
648cdec verified
import gradio as gr
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from datetime import datetime
import shutil
from typing import *
import torch
import numpy as np
import trimesh
from PIL import Image
from trellis2.pipelines import Trellis2TexturingPipeline
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def shapeimage_to_tex(
mesh_file: str,
image: Image.Image,
seed: int,
resolution: str,
texture_size: int,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
mesh = trimesh.load(mesh_file)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.to_mesh()
output = pipeline.run(
mesh,
image,
seed=seed,
preprocess_image=False,
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
resolution=int(resolution),
texture_size=texture_size,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
output.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
with gr.Column(scale=10):
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
# Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
shapeimage_to_tex,
inputs=[
mesh_file, image_prompt, seed, resolution, texture_size,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
pipeline.cuda()
demo.launch()