Daankular commited on
Commit
14c3d13
·
0 Parent(s):

Initial local files

Browse files
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # System deps
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ git wget curl build-essential cmake ninja-build pkg-config \
6
+ libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev ffmpeg \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # HF user setup
10
+ RUN useradd -m -u 1000 user
11
+ USER user
12
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
13
+ WORKDIR $HOME/app
14
+
15
+ # Upgrade pip first
16
+ RUN pip install --user --upgrade pip setuptools wheel
17
+
18
+ # chumpy must be installed with --no-build-isolation BEFORE everything else
19
+ # (its setup.py does `import pip` which fails in pip's default isolated build env)
20
+ RUN pip install --user --no-build-isolation \
21
+ "chumpy @ git+https://github.com/mattloper/chumpy.git@580566eafc9ac68b2614b64d6f7aaa84eebb70da"
22
+
23
+ # Copy app files
24
+ COPY --chown=user . $HOME/app
25
+
26
+ # Install remaining requirements (chumpy already satisfied above)
27
+ RUN pip install --user --no-cache-dir -r requirements.txt \
28
+ "torch<=2.9.1" \
29
+ "gradio[oauth,mcp]==6.11.0" \
30
+ "uvicorn>=0.14.0" \
31
+ "websockets>=10.4" \
32
+ "spaces==0.48.1"
33
+
34
+ EXPOSE 7860
35
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image2Model
3
+ emoji: 🎭
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ license: apache-2.0
10
+ hardware: zero-a10g
11
+ ---
12
+
13
+ # Image2Model
14
+
15
+ Portrait-to-mesh pipeline on HuggingFace ZeroGPU.
16
+
17
+ Upload a photo → rigged, textured, animation-ready GLB in minutes.
18
+
19
+ **Pipeline stages**
20
+ 1. Background removal — RMBG-2.0
21
+ 2. 3D shape generation — TripoSG (diffusion SDF)
22
+ 3. Multiview texturing — MV-Adapter + SDXL
23
+ 4. Face enhancement — HyperSwap 1A 256 + RealESRGAN x4plus
24
+ 5. Rigging — YOLO-pose → 3D joints → LBS weights
25
+ 6. SKEL anatomy layer — anatomical bone mesh
26
+ 7. MDM animation — text-to-motion
27
+ 8. Surface enhancement — StableNormal normal maps + Depth-Anything displacement
app.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import tempfile
4
+ import shutil
5
+ import traceback
6
+ import json
7
+ import random
8
+ from pathlib import Path
9
+
10
+ import cv2
11
+ import gradio as gr
12
+ import spaces
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+ # ── Paths ─────────────────────────────────────────────────────────────────────
18
+ HERE = Path(__file__).parent
19
+ PIPELINE_DIR = HERE / "pipeline"
20
+ CKPT_DIR = Path(os.environ.get("CKPT_DIR", "/tmp/checkpoints"))
21
+ CKPT_DIR.mkdir(parents=True, exist_ok=True)
22
+
23
+ # Add pipeline dir so local overrides (patched files) take priority
24
+ sys.path.insert(0, str(HERE))
25
+ sys.path.insert(0, str(PIPELINE_DIR))
26
+
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ # Lazy-loaded models (persist between ZeroGPU calls when Space is warm)
30
+ _triposg_pipe = None
31
+ _rmbg_net = None
32
+ _rmbg_version = None
33
+ _last_glb_path = None
34
+ _init_seed = random.randint(0, 2**31 - 1)
35
+
36
+ ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
37
+ [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
38
+ * (256 / 112) + (256 - 112 * (256 / 112)) / 2)
39
+
40
+ VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"]
41
+ VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES]
42
+
43
+
44
+ # ── Weight download helpers ────────────────────────────────────────────────────
45
+
46
+ def _ensure_weight(url: str, dest: Path) -> Path:
47
+ """Download a file if not already cached."""
48
+ if not dest.exists():
49
+ import urllib.request
50
+ dest.parent.mkdir(parents=True, exist_ok=True)
51
+ print(f"[weights] Downloading {dest.name} ...")
52
+ urllib.request.urlretrieve(url, dest)
53
+ print(f"[weights] Saved → {dest}")
54
+ return dest
55
+
56
+
57
+ def _ensure_ckpts():
58
+ """Download all face-enhancement checkpoints to CKPT_DIR."""
59
+ weights = {
60
+ "hyperswap_1a_256.onnx": "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/hyperswap_1a_256.onnx",
61
+ "inswapper_128.onnx": "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx",
62
+ "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth",
63
+ "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
64
+ }
65
+ for name, url in weights.items():
66
+ _ensure_weight(url, CKPT_DIR / name)
67
+
68
+
69
+ # ── Model loaders ─────────────────────────────────────────────────────────────
70
+
71
+ def load_triposg():
72
+ global _triposg_pipe, _rmbg_net, _rmbg_version
73
+ if _triposg_pipe is not None:
74
+ _triposg_pipe.to(DEVICE)
75
+ if _rmbg_net is not None:
76
+ _rmbg_net.to(DEVICE)
77
+ return _triposg_pipe, _rmbg_net
78
+
79
+ print("[load_triposg] Loading TripoSG pipeline...")
80
+ from huggingface_hub import snapshot_download
81
+ weights_path = snapshot_download("VAST-AI/TripoSG")
82
+
83
+ # TripoSG ships its own pipeline — add to path
84
+ triposg_pkg = Path(weights_path)
85
+ if (triposg_pkg / "triposg").exists():
86
+ sys.path.insert(0, str(triposg_pkg))
87
+ else:
88
+ # Try installed package from the cloned repo (if installed with pip -e)
89
+ import importlib.util
90
+ if importlib.util.find_spec("triposg") is None:
91
+ import subprocess
92
+ subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(triposg_pkg), "-q"], check=False)
93
+
94
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
95
+ _triposg_pipe = TripoSGPipeline.from_pretrained(
96
+ weights_path, torch_dtype=torch.float16
97
+ ).to(DEVICE)
98
+
99
+ try:
100
+ from transformers import AutoModelForImageSegmentation
101
+ _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
102
+ "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False
103
+ ).to(DEVICE)
104
+ _rmbg_net.eval()
105
+ _rmbg_version = "2.0"
106
+ print("[load_triposg] TripoSG + RMBG-2.0 loaded.")
107
+ except Exception as e:
108
+ print(f"[load_triposg] RMBG-2.0 failed ({e}). BG removal disabled.")
109
+ _rmbg_net = None
110
+
111
+ return _triposg_pipe, _rmbg_net
112
+
113
+
114
+ # ── Background removal helper ─────────────────────────────────────────────────
115
+
116
+ def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
117
+ if _rmbg_net is None:
118
+ return img_pil
119
+ import torchvision.transforms.functional as TF
120
+ from torchvision import transforms
121
+
122
+ img_tensor = transforms.ToTensor()(img_pil.resize((1024, 1024)))
123
+ if _rmbg_version == "2.0":
124
+ img_tensor = TF.normalize(img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
125
+ else:
126
+ img_tensor = TF.normalize(img_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]).unsqueeze(0)
127
+
128
+ with torch.no_grad():
129
+ result = _rmbg_net(img_tensor)
130
+
131
+ if isinstance(result, (list, tuple)):
132
+ candidate = result[-1] if _rmbg_version == "2.0" else result[0]
133
+ if isinstance(candidate, (list, tuple)):
134
+ candidate = candidate[0]
135
+ else:
136
+ candidate = result
137
+
138
+ mask_tensor = candidate.sigmoid()[0, 0].cpu()
139
+ mask = np.array(transforms.ToPILImage()(mask_tensor).resize(img_pil.size, Image.BILINEAR),
140
+ dtype=np.float32) / 255.0
141
+ mask = (mask >= threshold).astype(np.float32) * mask
142
+ if erode_px > 0:
143
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px * 2 + 1,) * 2)
144
+ mask = cv2.erode((mask * 255).astype(np.uint8), kernel).astype(np.float32) / 255.0
145
+
146
+ rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
147
+ alpha = mask[:, :, np.newaxis]
148
+ comp = (rgb * alpha + 0.5 * (1.0 - alpha) * 255).clip(0, 255).astype(np.uint8)
149
+ return Image.fromarray(comp)
150
+
151
+
152
+ def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
153
+ if input_image is None or not do_remove_bg or _rmbg_net is None:
154
+ return input_image
155
+ try:
156
+ return np.array(_remove_bg_rmbg(Image.fromarray(input_image).convert("RGB"),
157
+ threshold=float(threshold), erode_px=int(erode_px)))
158
+ except Exception:
159
+ return input_image
160
+
161
+
162
+ # ── Stage 1: Shape generation ─────────────────────────────────────────────────
163
+
164
+ @spaces.GPU(duration=180)
165
+ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
166
+ seed, face_count, progress=gr.Progress()):
167
+ if input_image is None:
168
+ return None, "Please upload an image."
169
+ try:
170
+ progress(0.1, desc="Loading TripoSG...")
171
+
172
+ # Add TripoSG scripts to path after model download
173
+ from huggingface_hub import snapshot_download
174
+ weights_path = snapshot_download("VAST-AI/TripoSG")
175
+ sys.path.insert(0, weights_path)
176
+
177
+ pipe, rmbg_net = load_triposg()
178
+
179
+ img = Image.fromarray(input_image).convert("RGB")
180
+ img_path = "/tmp/triposg_input.png"
181
+ img.save(img_path)
182
+
183
+ progress(0.5, desc="Generating shape (SDF diffusion)...")
184
+ from scripts.inference_triposg import run_triposg
185
+ mesh = run_triposg(
186
+ pipe=pipe,
187
+ image_input=img_path,
188
+ rmbg_net=rmbg_net if remove_background else None,
189
+ seed=int(seed),
190
+ num_inference_steps=int(num_steps),
191
+ guidance_scale=float(guidance_scale),
192
+ faces=int(face_count) if int(face_count) > 0 else -1,
193
+ )
194
+
195
+ out_path = "/tmp/triposg_shape.glb"
196
+ mesh.export(out_path)
197
+
198
+ # Offload to CPU before next stage
199
+ _triposg_pipe.to("cpu")
200
+ if _rmbg_net is not None:
201
+ _rmbg_net.to("cpu")
202
+ torch.cuda.empty_cache()
203
+
204
+ return out_path, "Shape generated!"
205
+ except Exception:
206
+ return None, f"Error:\n{traceback.format_exc()}"
207
+
208
+
209
+ # ── Stage 2: Texture ──────────────────────────────────────────────────────────
210
+
211
+ @spaces.GPU(duration=300)
212
+ def apply_texture(glb_path, input_image, remove_background, variant, tex_seed,
213
+ enhance_face, rembg_threshold=0.5, rembg_erode=2,
214
+ progress=gr.Progress()):
215
+ if glb_path is None:
216
+ glb_path = "/tmp/triposg_shape.glb"
217
+ if not os.path.exists(glb_path):
218
+ return None, None, "Generate a shape first."
219
+ if input_image is None:
220
+ return None, None, "Please upload an image."
221
+ try:
222
+ progress(0.1, desc="Preprocessing image...")
223
+ img = Image.fromarray(input_image).convert("RGB")
224
+ face_ref_path = "/tmp/triposg_face_ref.png"
225
+ img.save(face_ref_path)
226
+
227
+ if remove_background and _rmbg_net is not None:
228
+ img = _remove_bg_rmbg(img, threshold=float(rembg_threshold), erode_px=int(rembg_erode))
229
+
230
+ img = img.resize((768, 768), Image.LANCZOS)
231
+ img_path = "/tmp/tex_input_768.png"
232
+ img.save(img_path)
233
+
234
+ out_dir = "/tmp/tex_out"
235
+ os.makedirs(out_dir, exist_ok=True)
236
+
237
+ # ── Run MV-Adapter in-process ─────────────────────────────────────
238
+ progress(0.3, desc="Loading MV-Adapter pipeline...")
239
+ import importlib
240
+ from huggingface_hub import snapshot_download
241
+
242
+ mvadapter_weights = snapshot_download("huanngzh/mv-adapter")
243
+
244
+ # Resolve SD pipeline
245
+ if variant == "sdxl":
246
+ from diffusers import StableDiffusionXLPipeline
247
+ sd_id = "stabilityai/stable-diffusion-xl-base-1.0"
248
+ else:
249
+ from diffusers import StableDiffusionPipeline
250
+ sd_id = "stabilityai/stable-diffusion-2-1-base"
251
+
252
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
253
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
254
+ from mvadapter.utils import get_orthogonal_camera, get_ipadapter_image
255
+ import torchvision.transforms.functional as TF
256
+
257
+ progress(0.4, desc=f"Running MV-Adapter ({variant})...")
258
+
259
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(
260
+ sd_id,
261
+ torch_dtype=torch.float16,
262
+ ).to(DEVICE)
263
+
264
+ pipe.init_adapter(
265
+ image_encoder_path="openai/clip-vit-large-patch14",
266
+ ipa_weight_path=os.path.join(mvadapter_weights, "mvadapter_i2mv_sdxl.safetensors"),
267
+ adapter_tokens=256,
268
+ )
269
+
270
+ ref_pil = Image.open(img_path).convert("RGB")
271
+ cameras = get_orthogonal_camera(
272
+ elevation_deg=[0, 0, 0, 0, 0, 0],
273
+ distance=[1.8] * 6,
274
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
275
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 135, 180, 270]],
276
+ device=DEVICE,
277
+ )
278
+
279
+ with torch.autocast(DEVICE):
280
+ out = pipe(
281
+ image=ref_pil,
282
+ height=768, width=768,
283
+ num_images_per_prompt=6,
284
+ guidance_scale=3.0,
285
+ num_inference_steps=30,
286
+ generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)),
287
+ cameras=cameras,
288
+ )
289
+
290
+ mv_grid = out.images # list of 6 PIL images
291
+ grid_w = mv_grid[0].width * len(mv_grid)
292
+ mv_pil = Image.new("RGB", (grid_w, mv_grid[0].height))
293
+ for i, v in enumerate(mv_grid):
294
+ mv_pil.paste(v, (i * mv_grid[0].width, 0))
295
+ mv_path = os.path.join(out_dir, "multiview.png")
296
+ mv_pil.save(mv_path)
297
+
298
+ # Offload before face-enhance (saves VRAM)
299
+ del pipe
300
+ torch.cuda.empty_cache()
301
+
302
+ # ── Face enhancement ─────────────────────────────────────────────
303
+ if enhance_face:
304
+ progress(0.75, desc="Running face enhancement...")
305
+ _ensure_ckpts()
306
+ try:
307
+ from pipeline.face_enhance import enhance_multiview
308
+ enh_path = os.path.join(out_dir, "multiview_enhanced.png")
309
+ enhance_multiview(
310
+ multiview_path=mv_path,
311
+ reference_path=face_ref_path,
312
+ output_path=enh_path,
313
+ ckpt_dir=str(CKPT_DIR),
314
+ )
315
+ mv_path = enh_path
316
+ except Exception as _fe:
317
+ print(f"[apply_texture] face enhance failed: {_fe}")
318
+
319
+ # ── Bake textures onto mesh ─────────────────────────────────────
320
+ progress(0.85, desc="Baking UV texture onto mesh...")
321
+ from mvadapter.utils.mesh_utils import (
322
+ NVDiffRastContextWrapper, load_mesh, bake_texture,
323
+ )
324
+
325
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
326
+ mesh = load_mesh(glb_path, rescale=True, device=DEVICE)
327
+ tex_pil = Image.open(mv_path)
328
+
329
+ baked = bake_texture(ctx, mesh, tex_pil, cameras=cameras, height=1024, width=1024)
330
+ out_glb = os.path.join(out_dir, "textured_shaded.glb")
331
+ baked.export(out_glb)
332
+
333
+ final_path = "/tmp/triposg_textured.glb"
334
+ shutil.copy(out_glb, final_path)
335
+
336
+ global _last_glb_path
337
+ _last_glb_path = final_path
338
+
339
+ torch.cuda.empty_cache()
340
+ return final_path, mv_path, "Texture applied!"
341
+ except Exception:
342
+ return None, None, f"Error:\n{traceback.format_exc()}"
343
+
344
+
345
+ # ── Stage 3a: SKEL Anatomy ────────────────────────────────────────────────────
346
+
347
+ @spaces.GPU(duration=90)
348
+ def gradio_tpose(glb_state_path, export_skel_flag, progress=gr.Progress()):
349
+ try:
350
+ glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb"
351
+ if not os.path.exists(glb):
352
+ return None, None, "No GLB found — run Generate + Texture first."
353
+
354
+ progress(0.1, desc="YOLO pose detection + rigging...")
355
+ from pipeline.rig_yolo import rig_yolo
356
+ out_dir = "/tmp/rig_out"
357
+ os.makedirs(out_dir, exist_ok=True)
358
+ rigged, _rigged_skel = rig_yolo(glb, os.path.join(out_dir, "anatomy_rigged.glb"), debug_dir=None)
359
+
360
+ bones = None
361
+ if export_skel_flag:
362
+ progress(0.7, desc="Generating SKEL bone mesh...")
363
+ from pipeline.tpose_smpl import export_skel_bones
364
+ bones = export_skel_bones(torch.zeros(10), "/tmp/tposed_bones.glb", gender="male")
365
+
366
+ status = f"Rigged surface: {os.path.getsize(rigged)//1024} KB"
367
+ if bones:
368
+ status += f"\nSKEL bone mesh: {os.path.getsize(bones)//1024} KB"
369
+ elif export_skel_flag:
370
+ status += "\nSKEL bone mesh: failed (check logs)"
371
+
372
+ torch.cuda.empty_cache()
373
+ return rigged, bones, status
374
+ except Exception:
375
+ return None, None, f"Error:\n{traceback.format_exc()}"
376
+
377
+
378
+ # ── Stage 3b: Rig & Export ────────────────────────────────────────────────────
379
+
380
+ @spaces.GPU(duration=180)
381
+ def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames,
382
+ progress=gr.Progress()):
383
+ try:
384
+ from pipeline.rig_yolo import rig_yolo
385
+ from pipeline.rig_stage import export_fbx
386
+
387
+ glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb"
388
+ if not os.path.exists(glb):
389
+ return None, None, None, "No GLB found — run Generate + Texture first.", None, None, None
390
+
391
+ out_dir = "/tmp/rig_out"
392
+ os.makedirs(out_dir, exist_ok=True)
393
+
394
+ progress(0.1, desc="YOLO pose detection + rigging...")
395
+ rigged, rigged_skel = rig_yolo(glb, os.path.join(out_dir, "rigged.glb"),
396
+ debug_dir=os.path.join(out_dir, "debug"))
397
+
398
+ fbx = None
399
+ if export_fbx_flag:
400
+ progress(0.7, desc="Exporting FBX...")
401
+ fbx_path = os.path.join(out_dir, "rigged.fbx")
402
+ fbx = fbx_path if export_fbx(rigged, fbx_path) else None
403
+
404
+ animated = None
405
+ if mdm_prompt.strip():
406
+ progress(0.75, desc="Generating MDM animation...")
407
+ from pipeline.rig_stage import run_rig_pipeline
408
+ mdm_result = run_rig_pipeline(
409
+ glb_path=glb,
410
+ reference_image_path="/tmp/triposg_face_ref.png",
411
+ out_dir=out_dir,
412
+ device=DEVICE,
413
+ export_fbx_flag=False,
414
+ mdm_prompt=mdm_prompt.strip(),
415
+ mdm_n_frames=int(mdm_n_frames),
416
+ )
417
+ animated = mdm_result.get("animated_glb")
418
+
419
+ parts = ["Rigged: " + os.path.basename(rigged)]
420
+ if fbx: parts.append("FBX: " + os.path.basename(fbx))
421
+ if animated: parts.append("Animation: " + os.path.basename(animated))
422
+
423
+ torch.cuda.empty_cache()
424
+ return rigged, animated, fbx, " | ".join(parts), rigged, rigged, rigged_skel
425
+ except Exception:
426
+ return None, None, None, f"Error:\n{traceback.format_exc()}", None, None, None
427
+
428
+
429
+ # ── Stage 4: Surface enhancement ─────────────────────────────────────────────
430
+
431
+ @spaces.GPU(duration=120)
432
+ def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength,
433
+ do_depth, dep_res, disp_scale):
434
+ if not glb_path:
435
+ yield None, None, None, None, "No GLB loaded — run Generate first."
436
+ return
437
+ if ref_img_np is None:
438
+ yield None, None, None, None, "No reference image — run Generate first."
439
+ return
440
+ try:
441
+ from pipeline.enhance_surface import (
442
+ run_stable_normal, run_depth_anything,
443
+ bake_normal_into_glb, bake_depth_as_occlusion,
444
+ )
445
+ import pipeline.enhance_surface as _enh_mod
446
+
447
+ ref_pil = Image.fromarray(ref_img_np.astype(np.uint8))
448
+ out_path = glb_path.replace(".glb", "_enhanced.glb")
449
+ shutil.copy2(glb_path, out_path)
450
+ normal_out = depth_out = None
451
+ log = []
452
+
453
+ if do_normal:
454
+ log.append("[StableNormal] Running...")
455
+ yield None, None, None, None, "\n".join(log)
456
+ normal_out = run_stable_normal(ref_pil, resolution=norm_res)
457
+ out_path = bake_normal_into_glb(out_path, normal_out, out_path,
458
+ normal_strength=norm_strength)
459
+ log.append(f"[StableNormal] Done → normalTexture (strength {norm_strength})")
460
+ yield normal_out, depth_out, None, None, "\n".join(log)
461
+
462
+ if do_depth:
463
+ log.append("[Depth-Anything] Running...")
464
+ yield normal_out, depth_out, None, None, "\n".join(log)
465
+ depth_out = run_depth_anything(ref_pil, resolution=dep_res)
466
+ out_path = bake_depth_as_occlusion(out_path, depth_out, out_path,
467
+ displacement_scale=disp_scale)
468
+ log.append(f"[Depth-Anything] Done → occlusionTexture (scale {disp_scale})")
469
+ yield normal_out, depth_out.convert("L").convert("RGB"), None, None, "\n".join(log)
470
+
471
+ torch.cuda.empty_cache()
472
+ log.append("Enhancement complete.")
473
+ yield normal_out, (depth_out.convert("L").convert("RGB") if depth_out else None), out_path, out_path, "\n".join(log)
474
+
475
+ except Exception:
476
+ yield None, None, None, None, f"Error:\n{traceback.format_exc()}"
477
+
478
+
479
+ # ── Render views ──────────────────────────────────────────────────────────────
480
+
481
+ @spaces.GPU(duration=60)
482
+ def render_views(glb_file):
483
+ if not glb_file:
484
+ return []
485
+ glb_path = glb_file if isinstance(glb_file, str) else (glb_file.get("path") if isinstance(glb_file, dict) else str(glb_file))
486
+ if not glb_path or not os.path.exists(glb_path):
487
+ return []
488
+ try:
489
+ from mvadapter.utils.mesh_utils import (
490
+ NVDiffRastContextWrapper, load_mesh, render, get_orthogonal_camera,
491
+ )
492
+ ctx = NVDiffRastContextWrapper(device="cuda", context_type="cuda")
493
+ mesh = load_mesh(glb_path, rescale=True, device="cuda")
494
+ cams = get_orthogonal_camera(
495
+ elevation_deg=[0]*5, distance=[1.8]*5,
496
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
497
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 315]],
498
+ device="cuda",
499
+ )
500
+ out = render(ctx, mesh, cams, height=1024, width=768, render_attr=True, normal_background=0.0)
501
+ save_dir = os.path.dirname(glb_path)
502
+ results = []
503
+ for i, name in enumerate(VIEW_NAMES):
504
+ arr = (out.attr[i].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
505
+ path = os.path.join(save_dir, f"render_{name}.png")
506
+ Image.fromarray(arr).save(path)
507
+ results.append((path, name))
508
+ torch.cuda.empty_cache()
509
+ return results
510
+ except Exception:
511
+ print(f"render_views FAILED:\n{traceback.format_exc()}")
512
+ return []
513
+
514
+
515
+ # ── Full pipeline ─────────────────────────────────────────────────────────────
516
+
517
+ def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count,
518
+ variant, tex_seed, enhance_face, rembg_threshold, rembg_erode,
519
+ export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
520
+ progress(0.0, desc="Stage 1/3: Generating shape...")
521
+ glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count)
522
+ if not glb:
523
+ return None, None, None, None, None, None, status
524
+
525
+ progress(0.33, desc="Stage 2/3: Applying texture...")
526
+ glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed,
527
+ enhance_face, rembg_threshold, rembg_erode)
528
+ if not glb:
529
+ return None, None, None, None, None, None, status
530
+
531
+ progress(0.66, desc="Stage 3/3: Rigging + animation...")
532
+ rigged, animated, fbx, rig_status, _, _, _ = gradio_rig(glb, export_fbx, mdm_prompt, mdm_n_frames)
533
+
534
+ progress(1.0, desc="Pipeline complete!")
535
+ return glb, glb, mv_img, rigged, animated, fbx, f"[Texture] {status}\n[Rig] {rig_status}"
536
+
537
+
538
+ # ── UI ────────────────────────────────────────────────────────────────────────
539
+ with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
540
+ gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh")
541
+ glb_state = gr.State(None)
542
+
543
+ with gr.Tabs():
544
+
545
+ # ════════════════════════════════════════════════════════════════════
546
+ with gr.Tab("Generate"):
547
+ with gr.Row():
548
+ with gr.Column(scale=1):
549
+ input_image = gr.Image(label="Input Image", type="numpy")
550
+ remove_bg_check = gr.Checkbox(label="Remove Background", value=True)
551
+ with gr.Row():
552
+ rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
553
+ label="BG Threshold")
554
+ rembg_erode = gr.Slider(0, 8, value=2, step=1,
555
+ label="Edge Erode (px)")
556
+
557
+ with gr.Accordion("Shape Settings", open=True):
558
+ num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
559
+ guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.5, label="Guidance Scale")
560
+ seed = gr.Number(value=_init_seed, label="Seed", precision=0)
561
+ face_count = gr.Number(value=0, label="Max Faces (0 = unlimited)", precision=0)
562
+
563
+ with gr.Accordion("Texture Settings", open=True):
564
+ variant = gr.Radio(["sdxl", "sd21"], value="sdxl",
565
+ label="Model (sdxl = quality, sd21 = less VRAM)")
566
+ tex_seed = gr.Number(value=_init_seed, label="Texture Seed", precision=0)
567
+ enhance_face_check = gr.Checkbox(
568
+ label="Enhance Face (HyperSwap + RealESRGAN)", value=True)
569
+
570
+ with gr.Row():
571
+ shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
572
+ texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
573
+ render_btn = gr.Button("Render Views", variant="secondary", scale=1)
574
+ run_all_btn = gr.Button("▶ Run Full Pipeline", variant="primary", interactive=False)
575
+
576
+ with gr.Column(scale=1):
577
+ rembg_preview = gr.Image(label="BG Removed Preview", type="numpy", interactive=False)
578
+ status = gr.Textbox(label="Status", lines=3, interactive=False)
579
+ model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
580
+ download_file = gr.File(label="Download GLB")
581
+ multiview_img = gr.Image(label="Multiview", type="filepath", interactive=False)
582
+
583
+ render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
584
+
585
+ _rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode]
586
+ _pipeline_btns = [shape_btn, run_all_btn]
587
+
588
+ input_image.upload(
589
+ fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)),
590
+ inputs=[], outputs=_pipeline_btns,
591
+ )
592
+ input_image.clear(
593
+ fn=lambda: (gr.update(interactive=False), gr.update(interactive=False)),
594
+ inputs=[], outputs=_pipeline_btns,
595
+ )
596
+ input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
597
+ remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
598
+ rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
599
+ rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
600
+
601
+ shape_btn.click(
602
+ fn=generate_shape,
603
+ inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count],
604
+ outputs=[glb_state, status],
605
+ ).then(
606
+ fn=lambda p: (p, p) if p else (None, None),
607
+ inputs=[glb_state], outputs=[model_3d, download_file],
608
+ )
609
+
610
+ texture_btn.click(
611
+ fn=apply_texture,
612
+ inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed,
613
+ enhance_face_check, rembg_threshold, rembg_erode],
614
+ outputs=[glb_state, multiview_img, status],
615
+ ).then(
616
+ fn=lambda p: (p, p) if p else (None, None),
617
+ inputs=[glb_state], outputs=[model_3d, download_file],
618
+ )
619
+
620
+ render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
621
+
622
+ # ════════════════════════════════════════════════════════════════════
623
+ with gr.Tab("Rig & Export"):
624
+ with gr.Row():
625
+ with gr.Column(scale=1):
626
+ gr.Markdown("### Step 1 — SKEL Anatomy Layer")
627
+ tpose_skel_check = gr.Checkbox(label="Export SKEL bone mesh", value=False)
628
+ tpose_btn = gr.Button("Rig + SKEL Anatomy", variant="secondary")
629
+ tpose_status = gr.Textbox(label="Anatomy Status", lines=3, interactive=False)
630
+ with gr.Row():
631
+ tpose_surface_dl = gr.File(label="Rigged Surface GLB")
632
+ tpose_bones_dl = gr.File(label="SKEL Bone Mesh GLB")
633
+
634
+ gr.Markdown("---")
635
+ gr.Markdown("### Step 2 — Rig & Export")
636
+ export_fbx_check = gr.Checkbox(label="Export FBX (requires Blender)", value=True)
637
+ mdm_prompt_box = gr.Textbox(label="Motion Prompt (MDM)",
638
+ placeholder="a person walks forward", value="")
639
+ mdm_frames_slider = gr.Slider(60, 300, value=120, step=30,
640
+ label="Animation Frames (at 20 fps)")
641
+ rig_btn = gr.Button("Rig Mesh", variant="primary")
642
+
643
+ with gr.Column(scale=2):
644
+ rig_status = gr.Textbox(label="Rig Status", lines=4, interactive=False)
645
+ show_skel_check = gr.Checkbox(label="Show Skeleton", value=False)
646
+ rig_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
647
+ with gr.Row():
648
+ rig_glb_dl = gr.File(label="Download Rigged GLB")
649
+ rig_animated_dl = gr.File(label="Download Animated GLB")
650
+ rig_fbx_dl = gr.File(label="Download FBX")
651
+
652
+ rigged_base_state = gr.State(None)
653
+ skel_glb_state = gr.State(None)
654
+
655
+ tpose_btn.click(
656
+ fn=gradio_tpose,
657
+ inputs=[glb_state, tpose_skel_check],
658
+ outputs=[tpose_surface_dl, tpose_bones_dl, tpose_status],
659
+ ).then(
660
+ fn=lambda p: (p["path"] if isinstance(p, dict) else p) if p else None,
661
+ inputs=[tpose_surface_dl], outputs=[rig_model_3d],
662
+ )
663
+
664
+ rig_btn.click(
665
+ fn=gradio_rig,
666
+ inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
667
+ outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
668
+ rig_model_3d, rigged_base_state, skel_glb_state],
669
+ )
670
+
671
+ show_skel_check.change(
672
+ fn=lambda show, base, skel: skel if (show and skel) else base,
673
+ inputs=[show_skel_check, rigged_base_state, skel_glb_state],
674
+ outputs=[rig_model_3d],
675
+ )
676
+
677
+ # ════════════════════════════════════════════════════════════════════
678
+ with gr.Tab("Enhancement"):
679
+ gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.")
680
+ with gr.Row():
681
+ with gr.Column(scale=1):
682
+ gr.Markdown("### StableNormal")
683
+ run_normal_check = gr.Checkbox(label="Run StableNormal", value=True)
684
+ normal_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution")
685
+ normal_strength = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Normal Strength")
686
+
687
+ gr.Markdown("### Depth-Anything V2")
688
+ run_depth_check = gr.Checkbox(label="Run Depth-Anything V2", value=True)
689
+ depth_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution")
690
+ displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
691
+
692
+ enhance_btn = gr.Button("Run Enhancement", variant="primary")
693
+
694
+ with gr.Column(scale=2):
695
+ enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
696
+ with gr.Row():
697
+ normal_map_img = gr.Image(label="Normal Map", type="pil")
698
+ depth_map_img = gr.Image(label="Depth Map", type="pil")
699
+ enhanced_glb_dl = gr.File(label="Download Enhanced GLB")
700
+ enhanced_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
701
+
702
+ enhance_btn.click(
703
+ fn=gradio_enhance,
704
+ inputs=[glb_state, input_image,
705
+ run_normal_check, normal_res, normal_strength,
706
+ run_depth_check, depth_res, displacement_scale],
707
+ outputs=[normal_map_img, depth_map_img,
708
+ enhanced_glb_dl, enhanced_model_3d, enhance_status],
709
+ )
710
+
711
+ # ── Run All wiring ────────────────────────────────────────────────
712
+ run_all_btn.click(
713
+ fn=run_full_pipeline,
714
+ inputs=[
715
+ input_image, remove_bg_check, num_steps, guidance, seed, face_count,
716
+ variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode,
717
+ export_fbx_check, mdm_prompt_box, mdm_frames_slider,
718
+ ],
719
+ outputs=[glb_state, download_file, multiview_img,
720
+ rig_glb_dl, rig_animated_dl, rig_fbx_dl, status],
721
+ ).then(
722
+ fn=lambda p: (p, p) if p else (None, None),
723
+ inputs=[glb_state], outputs=[model_3d, download_file],
724
+ )
725
+
726
+
727
+ if __name__ == "__main__":
728
+ demo.launch(server_name="0.0.0.0", server_port=7860)
packages.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
3
+ libsm6
4
+ libxext6
5
+ libxrender-dev
6
+ ffmpeg
7
+ cmake
8
+ ninja-build
9
+ build-essential
10
+ pkg-config
pipeline/__init__.py ADDED
File without changes
pipeline/enhance_surface.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surface enhancement for TripoSG GLB outputs.
3
+
4
+ StableNormal — high-quality normal map from portrait reference
5
+ Depth-Anything V2 — metric depth map → displacement intensity
6
+
7
+ Both run on the reference portrait, produce calibrated maps that
8
+ are baked as PBR textures (normalTexture + occlusion/displacement)
9
+ into the output GLB.
10
+ """
11
+
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+
17
+
18
+ STABLE_NORMAL_PATH = "/root/models/stable-normal"
19
+ DEPTH_ANYTHING_PATH = "/root/models/depth-anything-v2"
20
+
21
+ _normal_pipe = None
22
+ _depth_pipe = None
23
+
24
+
25
+ # ── model loading ──────────────────────────────────────────────────────────────
26
+
27
+ def load_normal_model():
28
+ global _normal_pipe
29
+ if _normal_pipe is not None:
30
+ return _normal_pipe
31
+ from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline
32
+ from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler
33
+ import torch
34
+ x_start_pipeline = YOSONormalsPipeline.from_pretrained(
35
+ STABLE_NORMAL_PATH,
36
+ torch_dtype=torch.float16,
37
+ variant="fp16",
38
+ t_start=int(0.3 * 1000),
39
+ ).to("cuda")
40
+ _normal_pipe = YOSONormalsPipeline.from_pretrained(
41
+ STABLE_NORMAL_PATH,
42
+ torch_dtype=torch.float16,
43
+ variant="fp16",
44
+ scheduler=HEURI_DDIMScheduler.from_pretrained(
45
+ STABLE_NORMAL_PATH, subfolder="scheduler",
46
+ ddim_timestep_respacing="ddim10", x_start_pipeline=x_start_pipeline,
47
+ ),
48
+ ).to("cuda")
49
+ _normal_pipe.set_progress_bar_config(disable=True)
50
+ return _normal_pipe
51
+
52
+
53
+ def load_depth_model():
54
+ global _depth_pipe
55
+ if _depth_pipe is not None:
56
+ return _depth_pipe
57
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
58
+ processor = AutoImageProcessor.from_pretrained(DEPTH_ANYTHING_PATH)
59
+ model = AutoModelForDepthEstimation.from_pretrained(
60
+ DEPTH_ANYTHING_PATH, torch_dtype=torch.float16
61
+ ).to("cuda")
62
+ _depth_pipe = (processor, model)
63
+ return _depth_pipe
64
+
65
+
66
+ def unload_models():
67
+ global _normal_pipe, _depth_pipe
68
+ if _normal_pipe is not None:
69
+ del _normal_pipe; _normal_pipe = None
70
+ if _depth_pipe is not None:
71
+ del _depth_pipe; _depth_pipe = None
72
+ torch.cuda.empty_cache()
73
+
74
+
75
+ # ── inference ──────────────────────────────────────────────────────────────────
76
+
77
+ def run_stable_normal(image: Image.Image, resolution: int = 768) -> Image.Image:
78
+ """Returns normal map as RGB PIL image ([-1,1] encoded as [0,255])."""
79
+ pipe = load_normal_model()
80
+ img = image.convert("RGB").resize((resolution, resolution), Image.LANCZOS)
81
+ with torch.inference_mode(), torch.autocast("cuda"):
82
+ result = pipe(img)
83
+ normal_img = result.prediction # numpy [H,W,3] in [-1,1]
84
+ normal_rgb = ((normal_img + 1) / 2 * 255).clip(0, 255).astype(np.uint8)
85
+ return Image.fromarray(normal_rgb)
86
+
87
+
88
+ def run_depth_anything(image: Image.Image, resolution: int = 768) -> Image.Image:
89
+ """Returns depth map as 16-bit grayscale PIL image (normalized 0–65535)."""
90
+ processor, model = load_depth_model()
91
+ img = image.convert("RGB").resize((resolution, resolution), Image.LANCZOS)
92
+ inputs = processor(images=img, return_tensors="pt")
93
+ inputs = {k: v.to("cuda", dtype=torch.float16) for k, v in inputs.items()}
94
+ with torch.inference_mode():
95
+ depth = model(**inputs).predicted_depth[0].float().cpu().numpy()
96
+ # Normalize to 0–1
97
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
98
+ depth_16 = (depth * 65535).astype(np.uint16)
99
+ return Image.fromarray(depth_16, mode="I;16")
100
+
101
+
102
+ # ── GLB baking ─────────────────────────────────────────────────────────────────
103
+
104
+ def bake_normal_into_glb(
105
+ glb_path: str,
106
+ normal_img: Image.Image,
107
+ out_path: str,
108
+ normal_strength: float = 1.0,
109
+ ) -> str:
110
+ """
111
+ Adds normalTexture to the first material of the GLB.
112
+ Normal map is resized to match the existing base color texture resolution.
113
+ """
114
+ import pygltflib, struct, io
115
+
116
+ gltf = pygltflib.GLTF2().load(glb_path)
117
+
118
+ # Find existing base color texture size for matching resolution
119
+ target_size = 1024
120
+ if gltf.materials and gltf.materials[0].pbrMetallicRoughness:
121
+ pbr = gltf.materials[0].pbrMetallicRoughness
122
+ if pbr.baseColorTexture is not None:
123
+ tex_idx = pbr.baseColorTexture.index
124
+ img_idx = gltf.textures[tex_idx].source
125
+ blob = gltf.binary_blob()
126
+ bv = gltf.bufferViews[gltf.images[img_idx].bufferView]
127
+ img_bytes = blob[bv.byteOffset: bv.byteOffset + bv.byteLength]
128
+ existing = Image.open(io.BytesIO(img_bytes))
129
+ target_size = existing.width
130
+
131
+ normal_resized = normal_img.resize((target_size, target_size), Image.LANCZOS)
132
+
133
+ # Encode normal map as PNG and append to binary blob
134
+ buf = io.BytesIO()
135
+ normal_resized.save(buf, format="PNG")
136
+ png_bytes = buf.getvalue()
137
+
138
+ blob = bytearray(gltf.binary_blob() or b"")
139
+ byte_offset = len(blob)
140
+ blob.extend(png_bytes)
141
+
142
+ # Pad to 4-byte alignment
143
+ while len(blob) % 4:
144
+ blob.append(0)
145
+
146
+ # Add bufferView, image, texture
147
+ bv_idx = len(gltf.bufferViews)
148
+ gltf.bufferViews.append(pygltflib.BufferView(
149
+ buffer=0, byteOffset=byte_offset, byteLength=len(png_bytes),
150
+ ))
151
+ img_idx = len(gltf.images)
152
+ gltf.images.append(pygltflib.Image(
153
+ bufferView=bv_idx, mimeType="image/png",
154
+ ))
155
+ tex_idx = len(gltf.textures)
156
+ gltf.textures.append(pygltflib.Texture(source=img_idx))
157
+
158
+ # Update material
159
+ if gltf.materials:
160
+ gltf.materials[0].normalTexture = pygltflib.NormalMaterialTexture(
161
+ index=tex_idx, scale=normal_strength,
162
+ )
163
+
164
+ # Update buffer length
165
+ gltf.buffers[0].byteLength = len(blob)
166
+ gltf.set_binary_blob(bytes(blob))
167
+ gltf.save(out_path)
168
+ return out_path
169
+
170
+
171
+ def bake_depth_as_occlusion(
172
+ glb_path: str,
173
+ depth_img: Image.Image,
174
+ out_path: str,
175
+ displacement_scale: float = 1.0,
176
+ ) -> str:
177
+ """
178
+ Bakes depth map as occlusionTexture (R channel) — approximates displacement
179
+ in PBR renderers. Depth is inverted and normalized for AO-style use.
180
+ """
181
+ import pygltflib, io
182
+
183
+ gltf = pygltflib.GLTF2().load(glb_path)
184
+
185
+ target_size = 1024
186
+ if gltf.materials and gltf.materials[0].pbrMetallicRoughness:
187
+ pbr = gltf.materials[0].pbrMetallicRoughness
188
+ if pbr.baseColorTexture is not None:
189
+ tex_idx = pbr.baseColorTexture.index
190
+ img_idx = gltf.textures[tex_idx].source
191
+ blob = gltf.binary_blob()
192
+ bv = gltf.bufferViews[gltf.images[img_idx].bufferView]
193
+ img_bytes = blob[bv.byteOffset: bv.byteOffset + bv.byteLength]
194
+ existing = Image.open(io.BytesIO(img_bytes))
195
+ target_size = existing.width
196
+
197
+ # Convert 16-bit depth to 8-bit RGB occlusion (inverted, scaled)
198
+ depth_arr = np.array(depth_img).astype(np.float32) / 65535.0
199
+ depth_arr = 1.0 - depth_arr # invert: close = bright
200
+ depth_arr = np.clip(depth_arr * displacement_scale, 0, 1)
201
+ occ_8 = (depth_arr * 255).astype(np.uint8)
202
+ occ_rgb = Image.fromarray(np.stack([occ_8, occ_8, occ_8], axis=-1))
203
+ occ_rgb = occ_rgb.resize((target_size, target_size), Image.LANCZOS)
204
+
205
+ buf = io.BytesIO()
206
+ occ_rgb.save(buf, format="PNG")
207
+ png_bytes = buf.getvalue()
208
+
209
+ blob = bytearray(gltf.binary_blob() or b"")
210
+ byte_offset = len(blob)
211
+ blob.extend(png_bytes)
212
+ while len(blob) % 4:
213
+ blob.append(0)
214
+
215
+ bv_idx = len(gltf.bufferViews)
216
+ gltf.bufferViews.append(pygltflib.BufferView(
217
+ buffer=0, byteOffset=byte_offset, byteLength=len(png_bytes),
218
+ ))
219
+ img_idx = len(gltf.images)
220
+ gltf.images.append(pygltflib.Image(
221
+ bufferView=bv_idx, mimeType="image/png",
222
+ ))
223
+ tex_idx = len(gltf.textures)
224
+ gltf.textures.append(pygltflib.Texture(source=img_idx))
225
+
226
+ if gltf.materials:
227
+ gltf.materials[0].occlusionTexture = pygltflib.OcclusionTextureInfo(
228
+ index=tex_idx, strength=displacement_scale,
229
+ )
230
+
231
+ gltf.buffers[0].byteLength = len(blob)
232
+ gltf.set_binary_blob(bytes(blob))
233
+ gltf.save(out_path)
234
+ return out_path
pipeline/face_enhance.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Face enhancement for MV-Adapter multiview textures.
3
+
4
+ Pipeline per visible-face view:
5
+ 1. InsightFace buffalo_l — detect faces, extract 5-pt landmarks & 512-d embeddings
6
+ 2. HyperSwap 1A 256 — swap reference identity (embedding) onto each view face
7
+ (falls back to inswapper_128 if hyperswap not present)
8
+ 3. RealESRGAN x4plus — upscale face bbox 4x, resize back (real detail,
9
+ identity-preserving). Falls back to GFPGAN v1.4 if weights not present.
10
+
11
+ HyperSwap I/O:
12
+ source [1, 512] — face embedding from recognition model
13
+ target [1, 3, 256, 256] — aligned face crop (float32, RGB, [0,1])
14
+ output [1, 3, 256, 256] — swapped face crop
15
+ mask [1, 1, 256, 256] — alpha mask for seamless paste-back
16
+
17
+ Usage (standalone):
18
+ python -m pipeline.face_enhance \
19
+ --multiview /tmp/user_tex4/result.png \
20
+ --reference /tmp/tex_input_768.png \
21
+ --output /tmp/user_tex4/result_enhanced.png \
22
+ --checkpoints /root/MV-Adapter/checkpoints
23
+ """
24
+
25
+ import argparse
26
+ import os
27
+ import cv2
28
+ import numpy as np
29
+ import onnxruntime as ort
30
+ from PIL import Image
31
+
32
+
33
+ # ── helpers ────────────────────────────────────────────────────────────────────
34
+
35
+ def pil_to_bgr(img: Image.Image) -> np.ndarray:
36
+ return cv2.cvtColor(np.array(img.convert("RGB")), cv2.COLOR_RGB2BGR)
37
+
38
+
39
+ def bgr_to_pil(arr: np.ndarray) -> Image.Image:
40
+ return Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))
41
+
42
+
43
+ def split_multiview(mv: Image.Image, n: int = 6):
44
+ w_each = mv.width // n
45
+ return [mv.crop((i * w_each, 0, (i + 1) * w_each, mv.height)) for i in range(n)]
46
+
47
+
48
+ def stitch_views(views):
49
+ total_w = sum(v.width for v in views)
50
+ out = Image.new("RGB", (total_w, views[0].height))
51
+ x = 0
52
+ for v in views:
53
+ out.paste(v, (x, 0))
54
+ x += v.width
55
+ return out
56
+
57
+
58
+ # ── HyperSwap 1A 256 — custom ONNX wrapper ────────────────────────────────────
59
+
60
+ class HyperSwapper:
61
+ """
62
+ Direct ONNX inference for HyperSwap 1A 256.
63
+ source [1,512] × target [1,3,256,256] → output [1,3,256,256], mask [1,1,256,256]
64
+ """
65
+
66
+ # Standard 5-point face alignment template (112×112 base, scaled to crop_size)
67
+ _TEMPLATE_112 = np.array([
68
+ [38.2946, 51.6963],
69
+ [73.5318, 51.5014],
70
+ [56.0252, 71.7366],
71
+ [41.5493, 92.3655],
72
+ [70.7299, 92.2041],
73
+ ], dtype=np.float32)
74
+
75
+ def __init__(self, ckpt_path: str, providers=None):
76
+ self.crop_size = 256
77
+ self.providers = providers or ["CUDAExecutionProvider", "CPUExecutionProvider"]
78
+ self.sess = ort.InferenceSession(ckpt_path, providers=self.providers)
79
+ print(f"[HyperSwapper] Loaded {os.path.basename(ckpt_path)} "
80
+ f"(providers: {self.sess.get_providers()})")
81
+
82
+ def _get_affine(self, kps: np.ndarray) -> np.ndarray:
83
+ """Estimate affine transform from 5 face keypoints to standard template."""
84
+ template = self._TEMPLATE_112 / 112.0 * self.crop_size
85
+ from cv2 import estimateAffinePartial2D
86
+ M, _ = estimateAffinePartial2D(kps, template, method=cv2.RANSAC)
87
+ return M # [2, 3]
88
+
89
+ def _crop_face(self, img_bgr: np.ndarray, kps: np.ndarray):
90
+ """Crop and align face to crop_size × crop_size."""
91
+ M = self._get_affine(kps)
92
+ crop = cv2.warpAffine(img_bgr, M, (self.crop_size, self.crop_size),
93
+ flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
94
+ return crop, M
95
+
96
+ def _paste_back(self, img_bgr: np.ndarray, crop_bgr: np.ndarray,
97
+ mask: np.ndarray, M: np.ndarray) -> np.ndarray:
98
+ """Paste swapped face crop back into the original frame using the mask."""
99
+ h, w = img_bgr.shape[:2]
100
+ IM = cv2.invertAffineTransform(M)
101
+
102
+ warped = cv2.warpAffine(crop_bgr, IM, (w, h),
103
+ flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
104
+ mask_img = (mask * 255).clip(0, 255).astype(np.uint8)
105
+ mask_warped = cv2.warpAffine(mask_img, IM, (w, h), flags=cv2.INTER_LINEAR)
106
+ mask_f = mask_warped.astype(np.float32)[:, :, np.newaxis] / 255.0
107
+
108
+ result = img_bgr.astype(np.float32) * (1.0 - mask_f) + warped.astype(np.float32) * mask_f
109
+ return result.clip(0, 255).astype(np.uint8)
110
+
111
+ def get(self, img_bgr: np.ndarray, target_face, source_face,
112
+ paste_back: bool = True):
113
+ """
114
+ Swap source_face identity onto target_face in img_bgr.
115
+ face objects are InsightFace Face instances with .embedding and .kps.
116
+ """
117
+ # 1. Source embedding [1, 512]
118
+ emb = source_face.embedding.astype(np.float32)
119
+ emb /= np.linalg.norm(emb) # L2-normalise
120
+ source_input = emb.reshape(1, -1) # [1, 512]
121
+
122
+ # 2. Crop and align target face to 256×256
123
+ kps = target_face.kps.astype(np.float32)
124
+ crop_bgr, M = self._crop_face(img_bgr, kps)
125
+
126
+ # Convert BGR→RGB, normalize to [-1, 1], HWC→CHW, add batch dim
127
+ crop_rgb = crop_bgr[:, :, ::-1].astype(np.float32) / 255.0
128
+ crop_rgb = (crop_rgb - 0.5) / 0.5 # [−1, 1]
129
+ target_input = crop_rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 256, 256]
130
+
131
+ # 3. Inference
132
+ outputs = self.sess.run(None, {"source": source_input, "target": target_input})
133
+ out_tensor = outputs[0][0] # [3, 256, 256] values in [-1, 1]
134
+ mask_tensor = outputs[1][0, 0] # [256, 256]
135
+
136
+ # 4. Convert output back to BGR uint8 ([-1,1] → [0,255])
137
+ out_rgb = ((out_tensor.transpose(1, 2, 0) + 1) / 2 * 255).clip(0, 255).astype(np.uint8)
138
+ out_bgr = out_rgb[:, :, ::-1]
139
+
140
+ if not paste_back:
141
+ return out_bgr, mask_tensor
142
+
143
+ # 5. Paste back into the original frame
144
+ return self._paste_back(img_bgr, out_bgr, mask_tensor, M)
145
+
146
+
147
+ # ── model loading ─────────────────────────────────────────────────────────────
148
+
149
+ _ORT_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"]
150
+
151
+
152
+ def load_face_analyzer():
153
+ from insightface.app import FaceAnalysis
154
+ app = FaceAnalysis(name="buffalo_l", providers=_ORT_PROVIDERS)
155
+ app.prepare(ctx_id=0, det_size=(640, 640))
156
+ return app
157
+
158
+
159
+ def load_swapper(ckpt_dir: str):
160
+ """HyperSwap 1A 256 if present, else fall back to inswapper_128."""
161
+ import insightface.model_zoo as model_zoo
162
+
163
+ hyperswap = os.path.join(ckpt_dir, "hyperswap_1a_256.onnx")
164
+ inswapper = os.path.join(ckpt_dir, "inswapper_128.onnx")
165
+
166
+ if os.path.exists(hyperswap):
167
+ print(f"[face_enhance] Using HyperSwap 1A 256")
168
+ return HyperSwapper(hyperswap, providers=_ORT_PROVIDERS)
169
+
170
+ if os.path.exists(inswapper):
171
+ print(f"[face_enhance] Using inswapper_128 (fallback)")
172
+ return model_zoo.get_model(inswapper, providers=_ORT_PROVIDERS)
173
+
174
+ raise FileNotFoundError(
175
+ f"No swapper model found in {ckpt_dir}. "
176
+ "Add hyperswap_1a_256.onnx or inswapper_128.onnx."
177
+ )
178
+
179
+
180
+ def load_realesrgan(model_path: str, scale: int = 4, half: bool = False):
181
+ """Load RealESRGAN x4plus — full float32 (half=False), no tiling (tile=0)."""
182
+ from basicsr.archs.rrdbnet_arch import RRDBNet
183
+ from realesrgan import RealESRGANer
184
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
185
+ num_block=23, num_grow_ch=32, scale=scale)
186
+ return RealESRGANer(
187
+ scale=scale, model_path=model_path, model=model,
188
+ tile=0, tile_pad=10, pre_pad=0, half=half,
189
+ )
190
+
191
+
192
+ def load_gfpgan(ckpt_dir: str, upscale: int = 1):
193
+ from gfpgan import GFPGANer
194
+ model_path = os.path.join(ckpt_dir, "GFPGANv1.4.pth")
195
+ if not os.path.exists(model_path):
196
+ raise FileNotFoundError(f"GFPGANv1.4.pth not found in {ckpt_dir}")
197
+ return GFPGANer(model_path=model_path, upscale=upscale,
198
+ arch="clean", channel_multiplier=2, bg_upsampler=None)
199
+
200
+
201
+ def load_restorer(ckpt_dir: str):
202
+ """
203
+ Prefer RealESRGAN x4plus (full float32, no tiling, unsharp mask post-pass).
204
+ Falls back to GFPGAN v1.4 if RealESRGAN weights are absent.
205
+ Returns (restorer, 'realesrgan' | 'gfpgan').
206
+ """
207
+ realesr_path = os.path.join(ckpt_dir, "RealESRGAN_x4plus.pth")
208
+ if os.path.exists(realesr_path):
209
+ try:
210
+ r = load_realesrgan(realesr_path, scale=4, half=False)
211
+ print("[face_enhance] Restorer: RealESRGAN x4plus (float32, tile=0)")
212
+ return r, "realesrgan"
213
+ except Exception as e:
214
+ print(f"[face_enhance] RealESRGAN load failed ({e}), falling back to GFPGAN")
215
+ r = load_gfpgan(ckpt_dir, upscale=1)
216
+ print("[face_enhance] Restorer: GFPGAN v1.4 (fallback)")
217
+ return r, "gfpgan"
218
+
219
+
220
+ # ── core enhancement ──────────────────────────────────────────────────────────
221
+
222
+ def get_reference_face(analyzer, ref_bgr: np.ndarray):
223
+ faces = analyzer.get(ref_bgr)
224
+ if not faces:
225
+ raise RuntimeError("No face detected in reference image.")
226
+ faces.sort(key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]), reverse=True)
227
+ return faces[0]
228
+
229
+
230
+ def _enhance_face_bbox(frame_bgr: np.ndarray, faces, restorer, restorer_type: str,
231
+ pad: float = 0.4) -> np.ndarray:
232
+ """
233
+ Crop each face bbox (+ padding), enhance with restorer, blend back.
234
+ RealESRGAN: upscale 4x → resize back → unsharp mask → feathered blend.
235
+ GFPGAN: restore in-place on crop → resize back → hard paste.
236
+ """
237
+ result = frame_bgr.copy()
238
+ h, w = frame_bgr.shape[:2]
239
+
240
+ for face in faces:
241
+ x1, y1, x2, y2 = face.bbox[:4].astype(int)
242
+ bw, bh = x2 - x1, y2 - y1
243
+ px, py = int(bw * pad), int(bh * pad)
244
+ cx1 = max(0, x1 - px); cy1 = max(0, y1 - py)
245
+ cx2 = min(w, x2 + px); cy2 = min(h, y2 + py)
246
+ crop = frame_bgr[cy1:cy2, cx1:cx2].copy()
247
+ if crop.size == 0:
248
+ continue
249
+ cw, ch = cx2 - cx1, cy2 - cy1
250
+
251
+ try:
252
+ if restorer_type == "realesrgan":
253
+ enhanced, _ = restorer.enhance(crop, outscale=4)
254
+ enhanced = cv2.resize(enhanced, (cw, ch), interpolation=cv2.INTER_LANCZOS4)
255
+ # Unsharp mask — strength 1.8
256
+ blur = cv2.GaussianBlur(enhanced, (0, 0), 2)
257
+ enhanced = cv2.addWeighted(enhanced, 1.8, blur, -0.8, 0)
258
+ else:
259
+ _, _, enhanced = restorer.enhance(
260
+ crop, has_aligned=False, only_center_face=True,
261
+ paste_back=True, weight=0.5)
262
+ if enhanced.shape[:2] != (ch, cw):
263
+ enhanced = cv2.resize(enhanced, (cw, ch), interpolation=cv2.INTER_LANCZOS4)
264
+ except Exception as e:
265
+ import traceback as _tb
266
+ print(f"[enhance_view] {restorer_type} failed on face bbox: {e}\n{_tb.format_exc()}")
267
+ continue
268
+
269
+ # Feathered blend at edges
270
+ feather = max(3, int(min(cw, ch) * 0.08))
271
+ mask = np.ones((ch, cw), dtype=np.float32)
272
+ for f in range(feather):
273
+ a = (f + 1) / feather
274
+ mask[f, :] = a; mask[-(f+1), :] = a
275
+ mask[:, f] = np.minimum(mask[:, f], a)
276
+ mask[:, -(f+1)] = np.minimum(mask[:, -(f+1)], a)
277
+ mask = mask[:, :, np.newaxis]
278
+ result[cy1:cy2, cx1:cx2] = (
279
+ result[cy1:cy2, cx1:cx2].astype(np.float32) * (1 - mask) +
280
+ enhanced.astype(np.float32) * mask
281
+ ).clip(0, 255).astype(np.uint8)
282
+
283
+ return result
284
+
285
+
286
+ def enhance_view(view_bgr, analyzer, swapper, restorer, restorer_type,
287
+ source_face) -> np.ndarray:
288
+ target_faces = analyzer.get(view_bgr)
289
+ if not target_faces:
290
+ return view_bgr
291
+
292
+ swapped = view_bgr.copy()
293
+ for face in target_faces:
294
+ swapped = swapper.get(swapped, face, source_face, paste_back=True)
295
+ print(f"[enhance_view] HyperSwap applied to {len(target_faces)} face(s)")
296
+
297
+ # Re-detect in swapped image for accurate bboxes
298
+ swapped_faces = analyzer.get(swapped) or target_faces
299
+ result = _enhance_face_bbox(swapped, swapped_faces, restorer, restorer_type)
300
+ print(f"[enhance_view] {restorer_type} enhanced {len(swapped_faces)} face(s)")
301
+ return result
302
+
303
+
304
+ def enhance_multiview(
305
+ multiview_path: str,
306
+ reference_path: str,
307
+ output_path: str,
308
+ ckpt_dir: str,
309
+ n_views: int = 6,
310
+ gfpgan_upscale: int = 1,
311
+ face_views: tuple = (0, 1, 3, 4),
312
+ ):
313
+ print("[face_enhance] Loading models...")
314
+ analyzer = load_face_analyzer()
315
+ swapper = load_swapper(ckpt_dir)
316
+ restorer, restorer_type = load_restorer(ckpt_dir)
317
+ print("[face_enhance] Models loaded.")
318
+
319
+ ref_pil = Image.open(reference_path).convert("RGB")
320
+ ref_bgr = pil_to_bgr(ref_pil)
321
+ source_face = get_reference_face(analyzer, ref_bgr)
322
+ print(f"[face_enhance] Reference face bbox={source_face.bbox.astype(int)}")
323
+
324
+ mv = Image.open(multiview_path).convert("RGB")
325
+ views = split_multiview(mv, n=n_views)
326
+ enhanced = []
327
+
328
+ for i, view_pil in enumerate(views):
329
+ if i in face_views:
330
+ view_bgr = pil_to_bgr(view_pil)
331
+ result_bgr = enhance_view(view_bgr, analyzer, swapper, restorer,
332
+ restorer_type, source_face)
333
+ enhanced.append(bgr_to_pil(result_bgr))
334
+ n_faces = len(analyzer.get(view_bgr))
335
+ print(f"[face_enhance] View {i}: {n_faces} face(s) processed.")
336
+ else:
337
+ enhanced.append(view_pil)
338
+
339
+ stitch_views(enhanced).save(output_path)
340
+ print(f"[face_enhance] Saved → {output_path}")
341
+ return output_path
342
+
343
+
344
+ # ── CLI ───────────────────────────────────────────────────────────────────────
345
+
346
+ if __name__ == "__main__":
347
+ parser = argparse.ArgumentParser()
348
+ parser.add_argument("--multiview", required=True)
349
+ parser.add_argument("--reference", required=True)
350
+ parser.add_argument("--output", required=True)
351
+ parser.add_argument("--checkpoints", default="./checkpoints")
352
+ parser.add_argument("--n_views", type=int, default=6)
353
+ args = parser.parse_args()
354
+
355
+ enhance_multiview(
356
+ multiview_path=args.multiview,
357
+ reference_path=args.reference,
358
+ output_path=args.output,
359
+ ckpt_dir=args.checkpoints,
360
+ n_views=args.n_views,
361
+ )
pipeline/rig_stage.py ADDED
@@ -0,0 +1,1282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 7 — Multi-view pose estimation + mesh rigging
3
+
4
+ Three progressive phases, each feeding the next:
5
+
6
+ Phase 1 (Easy) — Multi-view beta averaging
7
+ Run HMR 2.0 on front / 3q_front / side renders + reference photo
8
+ Average shape betas weighted by detection confidence
9
+
10
+ Phase 2 (Better) — Silhouette fitting
11
+ Project SMPL mesh orthographically into each of the 5 views
12
+ Optimise betas so the SMPL silhouette matches the TripoSG render mask
13
+ Uses known orthographic camera matrices (exact same params as nvdiffrast)
14
+
15
+ Phase 3 (Best) — Multi-view joint triangulation
16
+ For each view where HMR 2.0 fired, project its 2D keypoints back to 3D
17
+ using the known orthographic camera → set up linear system per joint
18
+ Least-squares triangulation gives world-space joint positions used
19
+ directly as the skeleton, overriding the regressed SMPL joints
20
+
21
+ Output: rigged GLB (SMPL 24-joint skeleton + skin weights) + FBX via Blender
22
+ """
23
+
24
+ import os, sys, json, struct, traceback, subprocess, tempfile
25
+ # Must be set before any OpenGL/pyrender import (triggered by hmr2)
26
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
27
+ import numpy as np
28
+
29
+ # ── SMPL constants ────────────────────────────────────────────────────────────
30
+ SMPL_JOINT_NAMES = [
31
+ "pelvis","left_hip","right_hip","spine1",
32
+ "left_knee","right_knee","spine2",
33
+ "left_ankle","right_ankle","spine3",
34
+ "left_foot","right_foot","neck",
35
+ "left_collar","right_collar","head",
36
+ "left_shoulder","right_shoulder",
37
+ "left_elbow","right_elbow",
38
+ "left_wrist","right_wrist",
39
+ "left_hand","right_hand",
40
+ ]
41
+ SMPL_PARENTS = [-1,0,0,0,1,2,3,4,5,6,7,8,9,9,9,
42
+ 12,13,14,16,17,18,19,20,21]
43
+
44
+ # Orthographic camera parameters — must match render_views in triposg_app.py
45
+ ORTHO_LEFT, ORTHO_RIGHT = -0.55, 0.55
46
+ ORTHO_BOT, ORTHO_TOP = -0.55, 0.55
47
+ RENDER_W, RENDER_H = 768, 1024
48
+
49
+ # Azimuths passed to get_orthogonal_camera: [x-90 for x in [0,45,90,180,315]]
50
+ VIEW_AZIMUTHS_DEG = [-90.0, -45.0, 0.0, 90.0, 225.0]
51
+ VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"]
52
+ VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES]
53
+
54
+ # Views with a clearly visible front body (used for Phase 1 beta averaging)
55
+ FRONT_VIEW_INDICES = [0, 1, 2] # front, 3q_front, side
56
+
57
+
58
+ # ══════════════════════════════════════════════════════════════════════════════
59
+ # Camera utilities
60
+ # ══════════════════════════════════════════════════════════════════════════════
61
+
62
+ def _R_y(deg: float) -> np.ndarray:
63
+ """Rotation matrix around Y axis (right-hand, degrees)."""
64
+ t = np.radians(deg)
65
+ c, s = np.cos(t), np.sin(t)
66
+ return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]], dtype=np.float64)
67
+
68
+
69
+ def world_to_cam(pts: np.ndarray, azimuth_deg: float) -> np.ndarray:
70
+ """
71
+ Orthographic projection: world (N,3) → camera (N,2) in world-unit space.
72
+ Convention: camera right = (cos θ, 0, -sin θ), up = (0,1,0)
73
+ """
74
+ t = np.radians(azimuth_deg)
75
+ right = np.array([np.cos(t), 0.0, -np.sin(t)])
76
+ up = np.array([0.0, 1.0, 0.0 ])
77
+ return np.stack([pts @ right, pts @ up], axis=-1) # (N, 2)
78
+
79
+
80
+ def cam_to_pixel(cam_xy: np.ndarray) -> np.ndarray:
81
+ """Camera world-unit coords → pixel coords (u, v) in 768×1024 image."""
82
+ u = (cam_xy[:, 0] - ORTHO_LEFT) / (ORTHO_RIGHT - ORTHO_LEFT) * RENDER_W
83
+ v = (ORTHO_TOP - cam_xy[:, 1]) / (ORTHO_TOP - ORTHO_BOT ) * RENDER_H
84
+ return np.stack([u, v], axis=-1)
85
+
86
+
87
+ def pixel_to_cam(uv: np.ndarray) -> np.ndarray:
88
+ """Pixel coords → camera world-unit coords."""
89
+ cx = uv[:, 0] / RENDER_W * (ORTHO_RIGHT - ORTHO_LEFT) + ORTHO_LEFT
90
+ cy = ORTHO_TOP - uv[:, 1] / RENDER_H * (ORTHO_TOP - ORTHO_BOT)
91
+ return np.stack([cx, cy], axis=-1)
92
+
93
+
94
+ def triangulate_joint(obs: list[tuple]) -> np.ndarray:
95
+ """
96
+ Triangulate a single joint from multi-view 2D observations.
97
+ obs: list of (azimuth_deg, pixel_u, pixel_v)
98
+ Returns world (x, y, z).
99
+
100
+ For orthographic cameras, Y is directly measured; X and Z satisfy:
101
+ px*cos(θ) - pz*sin(θ) = cx for each view
102
+ → overdetermined linear system solved with lstsq.
103
+ """
104
+ ys, rows_A, rhs = [], [], []
105
+ for az_deg, pu, pv in obs:
106
+ cx, cy = pixel_to_cam(np.array([[pu, pv]]))[0]
107
+ ys.append(cy)
108
+ t = np.radians(az_deg)
109
+ rows_A.append([np.cos(t), -np.sin(t)])
110
+ rhs.append(cx)
111
+
112
+ A = np.array(rows_A, dtype=np.float64)
113
+ b = np.array(rhs, dtype=np.float64)
114
+ wy = float(np.mean(ys))
115
+
116
+ if len(obs) >= 2:
117
+ xz, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
118
+ wx, wz = xz
119
+ else:
120
+ wx, wz = 0.0, 0.0
121
+
122
+ return np.array([wx, wy, wz], dtype=np.float32)
123
+
124
+
125
+ # ══════════════════════════════════════════════════════════════════════════════
126
+ # Phase 1 — Multi-view HMR 2.0 + beta averaging
127
+ # ══════════════════════════════════════════════════════════════════════════════
128
+
129
+ def _load_hmr2(device):
130
+ from hmr2.models import download_models, load_hmr2, DEFAULT_CHECKPOINT
131
+ download_models() # downloads to CACHE_DIR_4DHUMANS (no-op if already done)
132
+ model, cfg = load_hmr2(DEFAULT_CHECKPOINT)
133
+ return model.to(device).eval(), cfg
134
+
135
+
136
+ def _load_detector():
137
+ from detectron2.config import LazyConfig
138
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
139
+ import hmr2
140
+ cfg = LazyConfig.load(str(os.path.join(
141
+ os.path.dirname(hmr2.__file__),
142
+ "configs/cascade_mask_rcnn_vitdet_h_75ep.py")))
143
+ cfg.train.init_checkpoint = (
144
+ "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/"
145
+ "cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl")
146
+ for i in range(3):
147
+ cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
148
+ return DefaultPredictor_Lazy(cfg)
149
+
150
+
151
+ def _run_hmr2_on_image(img_bgr, model, model_cfg, detector, device):
152
+ """
153
+ Run HMR 2.0 on a BGR image. Returns dict or None.
154
+ Keys: betas (10,), body_pose (23,3,3), global_orient (1,3,3),
155
+ kp2d (44,2) in [0,1] normalised, kp3d (44,3), score (float)
156
+ """
157
+ import torch
158
+ from hmr2.utils import recursive_to
159
+ from hmr2.datasets.vitdet_dataset import ViTDetDataset
160
+
161
+ det_out = detector(img_bgr)
162
+ instances = det_out["instances"]
163
+ valid = (instances.pred_classes == 0) & (instances.scores > 0.5)
164
+ if not valid.any():
165
+ return None
166
+
167
+ boxes = instances.pred_boxes.tensor[valid].cpu().numpy()
168
+ score = float(instances.scores[valid].max().cpu())
169
+ best = boxes[np.argmax((boxes[:,2]-boxes[:,0]) * (boxes[:,3]-boxes[:,1]))]
170
+
171
+ ds = ViTDetDataset(model_cfg, img_bgr, [best])
172
+ dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False)
173
+ batch = recursive_to(next(iter(dl)), device)
174
+
175
+ with torch.no_grad():
176
+ out = model(batch)
177
+
178
+ p = out["pred_smpl_params"]
179
+ return {
180
+ "betas": p["betas"][0].cpu().numpy(),
181
+ "body_pose": p["body_pose"][0].cpu().numpy(),
182
+ "global_orient": p["global_orient"][0].cpu().numpy(),
183
+ "kp2d": out["pred_keypoints_2d"][0].cpu().numpy(), # (44,2) [-1,1]
184
+ "kp3d": out.get("pred_keypoints_3d", [None]*1)[0],
185
+ "score": score,
186
+ "detected": True,
187
+ }
188
+
189
+
190
+ def estimate_betas_multiview(view_paths: list[str],
191
+ ref_path: str,
192
+ device: str = "cuda") -> tuple[np.ndarray, list]:
193
+ """
194
+ Phase 1: run HMR 2.0 on reference photo + front/3q/side renders.
195
+ Returns (averaged_betas [10,], list_of_all_results).
196
+ Falls back to zero betas (average body shape) if HMR2 is unavailable.
197
+ """
198
+ import cv2
199
+ print("[rig P1] Loading HMR2 + detector...")
200
+ try:
201
+ model, model_cfg = _load_hmr2(device)
202
+ detector = _load_detector()
203
+ except Exception as e:
204
+ print(f"[rig P1] HMR2 unavailable ({e}) — using zero betas (average body shape)")
205
+ return np.zeros(10, dtype=np.float32), []
206
+
207
+ sources = [(ref_path, None)] # (path, azimuth_deg_or_None)
208
+ for idx in FRONT_VIEW_INDICES:
209
+ if idx < len(view_paths) and os.path.exists(view_paths[idx]):
210
+ sources.append((view_paths[idx], VIEW_AZIMUTHS_DEG[idx]))
211
+
212
+ results = []
213
+ weighted_betas, total_w = np.zeros(10, dtype=np.float64), 0.0
214
+
215
+ for path, az in sources:
216
+ img = cv2.imread(path)
217
+ if img is None:
218
+ continue
219
+ r = _run_hmr2_on_image(img, model, model_cfg, detector, device)
220
+ if r is None:
221
+ print(f"[rig P1] {os.path.basename(path)}: no person detected")
222
+ continue
223
+ r["azimuth_deg"] = az
224
+ r["path"] = path
225
+ results.append(r)
226
+ w = r["score"]
227
+ weighted_betas += r["betas"] * w
228
+ total_w += w
229
+ print(f"[rig P1] {os.path.basename(path)}: detected (score={w:.2f}), "
230
+ f"betas[:3]={r['betas'][:3]}")
231
+
232
+ avg_betas = (weighted_betas / total_w).astype(np.float32) if total_w > 0 \
233
+ else np.zeros(10, dtype=np.float32)
234
+ print(f"[rig P1] Averaged betas over {len(results)} detections.")
235
+ return avg_betas, results
236
+
237
+
238
+ # ═════════════════════════════════════════════════��════════════════════════════
239
+ # SMPL helpers
240
+ # ══════════════════════════════════════════════════════════════════════════════
241
+
242
+ def get_smpl_tpose(betas: np.ndarray, smpl_dir: str = "/root/smpl_models"):
243
+ """Returns (verts [N,3], faces [M,3], joints [24,3], lbs_weights [N,24]).
244
+ Uses smplx if SMPL_NEUTRAL.pkl is available, else falls back to a synthetic
245
+ proxy skeleton with proximity-based skinning weights."""
246
+ import torch
247
+
248
+ model_path = os.path.join(smpl_dir, "SMPL_NEUTRAL.pkl")
249
+ if not os.path.exists(model_path) or os.path.getsize(model_path) < 1000:
250
+ # Try download first, silently fall through to synthetic on failure
251
+ try:
252
+ _download_smpl_neutral(smpl_dir)
253
+ except Exception:
254
+ pass
255
+
256
+ if os.path.exists(model_path) and os.path.getsize(model_path) > 100_000:
257
+ import smplx
258
+ smpl = smplx.create(smpl_dir, model_type="smpl", gender="neutral", num_betas=10)
259
+ betas_t = torch.tensor(betas[:10], dtype=torch.float32).unsqueeze(0)
260
+ with torch.no_grad():
261
+ out = smpl(betas=betas_t, return_verts=True)
262
+ verts = out.vertices[0].numpy().astype(np.float32)
263
+ joints = out.joints[0, :24].numpy().astype(np.float32)
264
+ faces = smpl.faces.astype(np.int32)
265
+ weights = smpl.lbs_weights.numpy().astype(np.float32)
266
+ return verts, faces, joints, weights
267
+
268
+ print("[rig] SMPL_NEUTRAL.pkl unavailable — using synthetic proxy skeleton")
269
+ return _synthetic_smpl_tpose()
270
+
271
+
272
+ def _synthetic_smpl_tpose():
273
+ """Synthetic SMPL substitute: hardcoded T-pose joint positions + proximity weights.
274
+ Gives a rough but functional rig for pipeline testing when SMPL is unavailable.
275
+ For production, provide SMPL_NEUTRAL.pkl from https://smpl.is.tue.mpg.de/."""
276
+ # 24 SMPL T-pose joint positions (metres, Y-up, facing +Z)
277
+ joints = np.array([
278
+ [ 0.00, 0.92, 0.00], # 0 pelvis
279
+ [-0.09, 0.86, 0.00], # 1 left_hip
280
+ [ 0.09, 0.86, 0.00], # 2 right_hip
281
+ [ 0.00, 1.05, 0.00], # 3 spine1
282
+ [-0.09, 0.52, 0.00], # 4 left_knee
283
+ [ 0.09, 0.52, 0.00], # 5 right_knee
284
+ [ 0.00, 1.17, 0.00], # 6 spine2
285
+ [-0.09, 0.10, 0.00], # 7 left_ankle
286
+ [ 0.09, 0.10, 0.00], # 8 right_ankle
287
+ [ 0.00, 1.29, 0.00], # 9 spine3
288
+ [-0.09, 0.00, 0.07], # 10 left_foot
289
+ [ 0.09, 0.00, 0.07], # 11 right_foot
290
+ [ 0.00, 1.46, 0.00], # 12 neck
291
+ [-0.07, 1.42, 0.00], # 13 left_collar
292
+ [ 0.07, 1.42, 0.00], # 14 right_collar
293
+ [ 0.00, 1.62, 0.00], # 15 head
294
+ [-0.17, 1.40, 0.00], # 16 left_shoulder
295
+ [ 0.17, 1.40, 0.00], # 17 right_shoulder
296
+ [-0.42, 1.40, 0.00], # 18 left_elbow
297
+ [ 0.42, 1.40, 0.00], # 19 right_elbow
298
+ [-0.65, 1.40, 0.00], # 20 left_wrist
299
+ [ 0.65, 1.40, 0.00], # 21 right_wrist
300
+ [-0.72, 1.40, 0.00], # 22 left_hand
301
+ [ 0.72, 1.40, 0.00], # 23 right_hand
302
+ ], dtype=np.float32)
303
+
304
+ # Build synthetic proxy vertices: ~300 points clustered around each joint
305
+ rng = np.random.default_rng(42)
306
+ n_per_joint = 300
307
+ proxy_v = []
308
+ proxy_w = []
309
+ for ji, jpos in enumerate(joints):
310
+ pts = jpos + rng.normal(0, 0.06, (n_per_joint, 3)).astype(np.float32)
311
+ proxy_v.append(pts)
312
+ w = np.zeros((n_per_joint, 24), np.float32)
313
+ w[:, ji] = 1.0
314
+ proxy_w.append(w)
315
+
316
+ proxy_v = np.concatenate(proxy_v, axis=0) # (7200, 3)
317
+ proxy_w = np.concatenate(proxy_w, axis=0) # (7200, 24)
318
+ proxy_f = np.zeros((0, 3), dtype=np.int32) # no faces needed for KNN transfer
319
+ return proxy_v, proxy_f, joints, proxy_w
320
+
321
+
322
+ def _download_smpl_neutral(out_dir: str):
323
+ os.makedirs(out_dir, exist_ok=True)
324
+ url = ("https://huggingface.co/spaces/TMElyralab/MusePose/resolve/main"
325
+ "/models/smpl/SMPL_NEUTRAL.pkl")
326
+ dest = os.path.join(out_dir, "SMPL_NEUTRAL.pkl")
327
+ print("[rig] Downloading SMPL_NEUTRAL.pkl...")
328
+ subprocess.run(["wget", "-q", url, "-O", dest], check=True)
329
+
330
+
331
+ def _smpl_to_render_space(verts: np.ndarray, joints: np.ndarray):
332
+ """
333
+ Normalise SMPL vertices to fit inside the [-0.55, 0.55] orthographic
334
+ frustum used by the nvdiffrast renders (same as align_mesh_to_smpl).
335
+ Returns (verts_norm, joints_norm, scale, offset).
336
+ """
337
+ ymin, ymax = verts[:, 1].min(), verts[:, 1].max()
338
+ height = ymax - ymin
339
+ scale = (ORTHO_TOP - ORTHO_BOT) / max(height, 1e-6)
340
+
341
+ # Centre on pelvis (joint 0) horizontally, floor-align vertically
342
+ v = verts * scale
343
+ j = joints * scale
344
+ cx = (v[:, 0].max() + v[:, 0].min()) * 0.5
345
+ cz = (v[:, 2].max() + v[:, 2].min()) * 0.5
346
+ v[:, 0] -= cx; j[:, 0] -= cx
347
+ v[:, 2] -= cz; j[:, 2] -= cz
348
+ v[:, 1] -= v[:, 1].min() + ORTHO_BOT # floor at ORTHO_BOT
349
+ j[:, 1] -= (verts[:, 1].min() * scale) - ORTHO_BOT
350
+ return v, j, scale, np.array([-cx, -v[:,1].min() + ORTHO_BOT, -cz])
351
+
352
+
353
+ # ══════════════════════════════════════════════════════════════════════════════
354
+ # Phase 2 — Silhouette fitting
355
+ # ══════════════════════════════════════════════════════════════════════════════
356
+
357
+ def _extract_silhouette(render_path: str, threshold: int = 20) -> np.ndarray:
358
+ """Binary mask (H×W bool) from a render: foreground = any channel > threshold."""
359
+ import cv2
360
+ img = cv2.imread(render_path)
361
+ if img is None:
362
+ return np.zeros((RENDER_H, RENDER_W), dtype=bool)
363
+ return img.max(axis=2) > threshold
364
+
365
+
366
+ def _render_smpl_silhouette(verts_norm: np.ndarray, faces: np.ndarray,
367
+ azimuth_deg: float) -> np.ndarray:
368
+ """
369
+ Rasterise SMPL mesh silhouette for given azimuth (orthographic).
370
+ Returns binary mask (H×W bool).
371
+ """
372
+ from PIL import Image, ImageDraw
373
+
374
+ cam_xy = world_to_cam(verts_norm, azimuth_deg)
375
+ pix = cam_to_pixel(cam_xy) # (N, 2)
376
+
377
+ img = Image.new("L", (RENDER_W, RENDER_H), 0)
378
+ draw = ImageDraw.Draw(img)
379
+ for f in faces:
380
+ pts = [(float(pix[i, 0]), float(pix[i, 1])) for i in f]
381
+ draw.polygon(pts, fill=255)
382
+ return np.array(img) > 0
383
+
384
+
385
+ def _sil_loss(betas: np.ndarray, target_masks: list,
386
+ valid_views: list[int], faces: np.ndarray) -> float:
387
+ """1 - mean IoU between SMPL silhouettes and TripoSG render masks."""
388
+ try:
389
+ verts, _, _, _ = get_smpl_tpose(betas.astype(np.float32))
390
+ verts_n, _, _, _ = _smpl_to_render_space(verts, verts.copy())
391
+ iou_sum = 0.0
392
+ for i in valid_views:
393
+ pred = _render_smpl_silhouette(verts_n, faces, VIEW_AZIMUTHS_DEG[i])
394
+ tgt = target_masks[i]
395
+ inter = (pred & tgt).sum()
396
+ union = (pred | tgt).sum()
397
+ iou_sum += inter / max(union, 1)
398
+ return 1.0 - iou_sum / len(valid_views)
399
+ except Exception:
400
+ return 1.0
401
+
402
+
403
+ def fit_betas_silhouette(betas_init: np.ndarray, view_paths: list[str],
404
+ max_iter: int = 60) -> np.ndarray:
405
+ """
406
+ Phase 2: optimise SMPL betas to match TripoSG render silhouettes.
407
+ Only uses views whose render file exists.
408
+ """
409
+ from scipy.optimize import minimize
410
+
411
+ valid = [i for i, p in enumerate(view_paths) if os.path.exists(p)]
412
+ if not valid:
413
+ print("[rig P2] No render files found — skipping silhouette fit")
414
+ return betas_init
415
+
416
+ print(f"[rig P2] Extracting silhouettes from {len(valid)} views...")
417
+ masks = [_extract_silhouette(view_paths[i]) if i in valid
418
+ else np.zeros((RENDER_H, RENDER_W), bool)
419
+ for i in range(len(VIEW_NAMES))]
420
+
421
+ # Use only back-facing views for shape, not back (which shows less shape info)
422
+ fit_views = [i for i in valid if i in [0, 1, 2]]
423
+ if not fit_views:
424
+ fit_views = valid
425
+
426
+ # Pre-fetch faces (constant across iterations)
427
+ verts0, faces0, _, _ = get_smpl_tpose(betas_init)
428
+
429
+ loss0 = _sil_loss(betas_init, masks, fit_views, faces0)
430
+ print(f"[rig P2] Initial silhouette loss: {loss0:.4f}")
431
+
432
+ result = minimize(
433
+ fun=lambda b: _sil_loss(b, masks, fit_views, faces0),
434
+ x0=betas_init.astype(np.float64),
435
+ method="L-BFGS-B",
436
+ bounds=[(-3.0, 3.0)] * 10,
437
+ options={"maxiter": max_iter, "ftol": 1e-4, "gtol": 1e-3},
438
+ )
439
+
440
+ refined = result.x.astype(np.float32)
441
+ loss1 = _sil_loss(refined, masks, fit_views, faces0)
442
+ print(f"[rig P2] Silhouette fit done: loss {loss0:.4f} → {loss1:.4f} "
443
+ f"({result.nit} iters, {'converged' if result.success else 'stopped'})")
444
+ return refined
445
+
446
+
447
+ # ══════════════════════════════════════════════════════════════════════════════
448
+ # Phase 3 — Multi-view joint triangulation
449
+ # ══════════════════════════════════════════════════════════════════════════════
450
+
451
+ # HMR 2.0 outputs 44 keypoints; first 24 map to SMPL joints
452
+ HMR2_TO_SMPL = list(range(24))
453
+
454
+ def triangulate_joints_multiview(hmr2_results: list) -> np.ndarray | None:
455
+ """
456
+ Phase 3: triangulate world-space SMPL joints from multi-view HMR 2.0 2D keypoints.
457
+
458
+ hmr2_results: list of dicts from _run_hmr2_on_image, each with
459
+ kp2d (44,2) in [-1,1] normalised NDC and azimuth_deg (float or None).
460
+
461
+ Only uses results from rendered views (azimuth_deg is not None).
462
+ Returns (24,3) world joint positions, or None if < 2 valid views.
463
+ """
464
+ view_results = [r for r in hmr2_results
465
+ if r.get("azimuth_deg") is not None and r.get("kp2d") is not None]
466
+
467
+ if len(view_results) < 2:
468
+ print(f"[rig P3] Only {len(view_results)} render views with detections "
469
+ "— need ≥2 for triangulation, skipping")
470
+ return None
471
+
472
+ print(f"[rig P3] Triangulating from {len(view_results)} views: "
473
+ + ", ".join(os.path.basename(r["path"]) for r in view_results))
474
+
475
+ # Convert HMR2 NDC keypoints → pixel coords
476
+ # kp2d is (44,2) in [-1,1]; pixel = (kp+1)/2 * [W, H]
477
+ joints_world = np.zeros((24, 3), dtype=np.float32)
478
+
479
+ for j in range(24):
480
+ obs = []
481
+ for r in view_results:
482
+ kp = r["kp2d"][j] # (2,) in [-1,1]
483
+ pu = (kp[0] + 1.0) / 2.0 * RENDER_W
484
+ pv = (kp[1] + 1.0) / 2.0 * RENDER_H
485
+ obs.append((r["azimuth_deg"], pu, pv))
486
+ joints_world[j] = triangulate_joint(obs)
487
+
488
+ print(f"[rig P3] Triangulated 24 joints. "
489
+ f"Pelvis: {joints_world[0].round(3)}, "
490
+ f"Head: {joints_world[15].round(3)}")
491
+ return joints_world
492
+
493
+
494
+ # ══════════════════════════════════════════════════════════════════════════════
495
+ # Skinning weight transfer
496
+ # ══════════════════════════════════════════════════════════════════════════════
497
+
498
+ def transfer_skinning(smpl_verts: np.ndarray, smpl_weights: np.ndarray,
499
+ target_verts: np.ndarray, k: int = 4) -> np.ndarray:
500
+ from scipy.spatial import cKDTree
501
+ tree = cKDTree(smpl_verts)
502
+ dists, idxs = tree.query(target_verts, k=k, workers=-1)
503
+ dists = np.maximum(dists, 1e-8)
504
+ inv_d = 1.0 / dists
505
+ inv_d /= inv_d.sum(axis=1, keepdims=True)
506
+ transferred = np.einsum("nk,nkj->nj", inv_d, smpl_weights[idxs])
507
+ row_sums = transferred.sum(axis=1, keepdims=True)
508
+ transferred /= np.where(row_sums > 0, row_sums, 1.0)
509
+ return transferred.astype(np.float32)
510
+
511
+
512
+ def align_mesh_to_smpl(mesh_verts: np.ndarray, smpl_verts: np.ndarray,
513
+ smpl_joints: np.ndarray) -> np.ndarray:
514
+ smpl_h = smpl_verts[:, 1].max() - smpl_verts[:, 1].min()
515
+ mesh_h = mesh_verts[:, 1].max() - mesh_verts[:, 1].min()
516
+ scale = smpl_h / max(mesh_h, 1e-6)
517
+ v = mesh_verts * scale
518
+ cx = (v[:, 0].max() + v[:, 0].min()) * 0.5
519
+ cz = (v[:, 2].max() + v[:, 2].min()) * 0.5
520
+ v[:, 0] += smpl_joints[0, 0] - cx
521
+ v[:, 2] += smpl_joints[0, 2] - cz
522
+ v[:, 1] -= v[:, 1].min()
523
+ return v
524
+
525
+
526
+ # ══════════════════════════════════════════════════════════════════════════════
527
+ # GLB export
528
+ # ══════════════════════════════════════════════════════════════════════════════
529
+
530
+ def export_rigged_glb(verts, faces, uv, texture_img, joints, skin_weights, out_path):
531
+ import pygltflib
532
+ from pygltflib import (GLTF2, Scene, Node, Mesh, Primitive, Accessor,
533
+ BufferView, Buffer, Material, Texture,
534
+ Image as GImage, Sampler, Skin, Asset)
535
+ from pygltflib import (ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER, FLOAT,
536
+ UNSIGNED_INT, UNSIGNED_SHORT, LINEAR,
537
+ LINEAR_MIPMAP_LINEAR, REPEAT, SCALAR, VEC2,
538
+ VEC3, VEC4, MAT4)
539
+
540
+ gltf = GLTF2()
541
+ gltf.asset = Asset(version="2.0", generator="rig_stage.py")
542
+ blobs = []
543
+
544
+ def _add(data: np.ndarray, comp, acc_type, target=None):
545
+ b = data.tobytes()
546
+ pad = (4 - len(b) % 4) % 4
547
+ off = sum(len(x) for x in blobs)
548
+ blobs.append(b + b"\x00" * pad)
549
+ bv = len(gltf.bufferViews)
550
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off,
551
+ byteLength=len(b), target=target))
552
+ ac = len(gltf.accessors)
553
+ flat = data.flatten()
554
+ gltf.accessors.append(Accessor(
555
+ bufferView=bv, byteOffset=0, componentType=comp,
556
+ type=acc_type, count=len(data),
557
+ min=[float(flat.min())], max=[float(flat.max())]))
558
+ return ac
559
+
560
+ pos_acc = _add(verts.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
561
+
562
+ v0,v1,v2 = verts[faces[:,0]], verts[faces[:,1]], verts[faces[:,2]]
563
+ fn = np.cross(v1-v0, v2-v0); fn /= (np.linalg.norm(fn,axis=1,keepdims=True)+1e-8)
564
+ vn = np.zeros_like(verts)
565
+ for i in range(3): np.add.at(vn, faces[:,i], fn)
566
+ vn /= (np.linalg.norm(vn,axis=1,keepdims=True)+1e-8)
567
+ nor_acc = _add(vn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
568
+
569
+ if uv is None: uv = np.zeros((len(verts),2), np.float32)
570
+ uv_acc = _add(uv.astype(np.float32), FLOAT, VEC2, ARRAY_BUFFER)
571
+ idx_acc = _add(faces.astype(np.uint32).flatten(), UNSIGNED_INT, SCALAR, ELEMENT_ARRAY_BUFFER)
572
+
573
+ top4_idx = np.argsort(-skin_weights, axis=1)[:,:4].astype(np.uint16)
574
+ top4_w = np.take_along_axis(skin_weights, top4_idx.astype(np.int64), axis=1).astype(np.float32)
575
+ top4_w /= top4_w.sum(axis=1,keepdims=True).clip(1e-8,None)
576
+ j_acc = _add(top4_idx, UNSIGNED_SHORT, "VEC4", ARRAY_BUFFER)
577
+ w_acc = _add(top4_w, FLOAT, "VEC4", ARRAY_BUFFER)
578
+
579
+ if texture_img is not None:
580
+ import io
581
+ buf = io.BytesIO(); texture_img.save(buf, format="PNG"); ib = buf.getvalue()
582
+ off = sum(len(x) for x in blobs); pad = (4-len(ib)%4)%4
583
+ blobs.append(ib + b"\x00"*pad)
584
+ gltf.bufferViews.append(BufferView(buffer=0,byteOffset=off,byteLength=len(ib)))
585
+ gltf.images.append(GImage(mimeType="image/png",bufferView=len(gltf.bufferViews)-1))
586
+ gltf.samplers.append(Sampler(magFilter=LINEAR,minFilter=LINEAR_MIPMAP_LINEAR,
587
+ wrapS=REPEAT,wrapT=REPEAT))
588
+ gltf.textures.append(Texture(sampler=0,source=0))
589
+ gltf.materials.append(Material(name="body",
590
+ pbrMetallicRoughness={"baseColorTexture":{"index":0},
591
+ "metallicFactor":0.0,"roughnessFactor":0.8},
592
+ doubleSided=True))
593
+ else:
594
+ gltf.materials.append(Material(name="body",doubleSided=True))
595
+
596
+ prim = Primitive(attributes={"POSITION":pos_acc,"NORMAL":nor_acc,
597
+ "TEXCOORD_0":uv_acc,"JOINTS_0":j_acc,"WEIGHTS_0":w_acc},
598
+ indices=idx_acc, material=0)
599
+ gltf.meshes.append(Mesh(name="body",primitives=[prim]))
600
+
601
+ jnodes = []
602
+ for i,(name,parent) in enumerate(zip(SMPL_JOINT_NAMES,SMPL_PARENTS)):
603
+ t = joints[i].tolist() if parent==-1 else (joints[i]-joints[parent]).tolist()
604
+ n = Node(name=name,translation=t,children=[])
605
+ jnodes.append(len(gltf.nodes)); gltf.nodes.append(n)
606
+ for i,p in enumerate(SMPL_PARENTS):
607
+ if p!=-1: gltf.nodes[jnodes[p]].children.append(jnodes[i])
608
+
609
+ ibms = np.stack([np.eye(4,dtype=np.float32) for _ in range(len(joints))])
610
+ for i in range(len(joints)): ibms[i,:3,3] = -joints[i]
611
+ ibm_acc = _add(ibms.astype(np.float32), FLOAT, MAT4)
612
+ skin_idx = len(gltf.skins)
613
+ gltf.skins.append(Skin(name="smpl_skin",skeleton=jnodes[0],
614
+ joints=jnodes,inverseBindMatrices=ibm_acc))
615
+
616
+ mesh_node = len(gltf.nodes)
617
+ gltf.nodes.append(Node(name="body_mesh",mesh=0,skin=skin_idx))
618
+ root_node = len(gltf.nodes)
619
+ gltf.nodes.append(Node(name="root",children=[jnodes[0],mesh_node]))
620
+ gltf.scenes.append(Scene(name="Scene",nodes=[root_node]))
621
+ gltf.scene = 0
622
+
623
+ bin_data = b"".join(blobs)
624
+ gltf.buffers.append(Buffer(byteLength=len(bin_data)))
625
+ gltf.set_binary_blob(bin_data)
626
+ gltf.save_binary(out_path)
627
+ print(f"[rig] Rigged GLB → {out_path} ({os.path.getsize(out_path)//1024} KB)")
628
+
629
+
630
+ # ══════════════════════════════════════════════════════════════════════════════
631
+ # FBX export via Blender headless
632
+ # ══════════════════════════════════════════════════════════════════════════════
633
+
634
+ _BLENDER_SCRIPT = """\
635
+ import bpy, sys
636
+ args = sys.argv[sys.argv.index('--') + 1:]
637
+ glb_in, fbx_out = args[0], args[1]
638
+ bpy.ops.wm.read_factory_settings(use_empty=True)
639
+ bpy.ops.import_scene.gltf(filepath=glb_in)
640
+ bpy.ops.export_scene.fbx(
641
+ filepath=fbx_out, use_selection=False,
642
+ add_leaf_bones=False, bake_anim=False,
643
+ path_mode='COPY', embed_textures=True,
644
+ )
645
+ print('FBX OK:', fbx_out)
646
+ """
647
+
648
+ def export_fbx(rigged_glb: str, out_path: str) -> bool:
649
+ blender = next((c for c in ["/usr/bin/blender","/usr/local/bin/blender"]
650
+ if os.path.exists(c)), None)
651
+ if blender is None:
652
+ r = subprocess.run(["which","blender"],capture_output=True,text=True)
653
+ blender = r.stdout.strip() or None
654
+ if blender is None:
655
+ print("[rig] Blender not found — skipping FBX")
656
+ return False
657
+ try:
658
+ with tempfile.NamedTemporaryFile("w",suffix=".py",delete=False) as f:
659
+ f.write(_BLENDER_SCRIPT); script = f.name
660
+ r = subprocess.run([blender,"--background","--python",script,
661
+ "--",rigged_glb,out_path],
662
+ capture_output=True,text=True,timeout=120)
663
+ ok = os.path.exists(out_path)
664
+ if not ok: print(f"[rig] Blender stderr:\n{r.stderr[-800:]}")
665
+ return ok
666
+ except Exception:
667
+ print(f"[rig] export_fbx:\n{traceback.format_exc()}")
668
+ return False
669
+ finally:
670
+ try: os.unlink(script)
671
+ except: pass
672
+
673
+
674
+ # ══════════════════════════════════════════════════════════════════════════════
675
+ # MDM — Motion Diffusion Model
676
+ # ══════════════════════════════════════════════════════════════════════════════
677
+
678
+ MDM_DIR = "/root/MDM"
679
+ MDM_CKPT = f"{MDM_DIR}/save/humanml_trans_enc_512/model000200000.pt"
680
+
681
+ # HumanML3D 22-joint parent array (matches SMPL joints 0-21)
682
+ _MDM_PARENTS = [-1,0,0,0,1,2,3,4,5,6,7,8,9,9,9,12,13,14,16,17,18,19]
683
+
684
+ def setup_mdm() -> bool:
685
+ """Clone MDM repo, install deps, download checkpoint. Idempotent."""
686
+ if os.path.exists(MDM_CKPT):
687
+ return True
688
+ print("[MDM] First-time setup...")
689
+
690
+ if not os.path.exists(MDM_DIR):
691
+ r = subprocess.run(
692
+ ["git", "clone", "--depth=1",
693
+ "https://github.com/GuyTevet/motion-diffusion-model.git", MDM_DIR],
694
+ capture_output=True, text=True, timeout=120)
695
+ if r.returncode != 0:
696
+ print(f"[MDM] git clone failed:\n{r.stderr}")
697
+ return False
698
+
699
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
700
+ "git+https://github.com/openai/CLIP.git",
701
+ "einops", "rotary-embedding-torch", "gdown"], check=False, timeout=300)
702
+
703
+ # HumanML3D normalisation stats (small .npy files needed for inference)
704
+ stats_dir = f"{MDM_DIR}/dataset/HumanML3D"
705
+ os.makedirs(stats_dir, exist_ok=True)
706
+ base = "https://github.com/EricGuo5513/HumanML3D/raw/main/HumanML3D"
707
+ for fn in ["Mean.npy", "Std.npy"]:
708
+ dest = f"{stats_dir}/{fn}"
709
+ if not os.path.exists(dest):
710
+ subprocess.run(["wget", "-q", f"{base}/{fn}", "-O", dest],
711
+ check=False, timeout=60)
712
+
713
+ # Checkpoint (~1.3 GB) — try HuggingFace mirror first, then gdown
714
+ ckpt_dir = os.path.dirname(MDM_CKPT)
715
+ os.makedirs(ckpt_dir, exist_ok=True)
716
+ hf = ("https://huggingface.co/Mathux/motion-diffusion-model/resolve/main/"
717
+ "humanml_trans_enc_512/model000200000.pt")
718
+ r = subprocess.run(["wget", "-q", "--show-progress", hf, "-O", MDM_CKPT],
719
+ capture_output=True, timeout=3600)
720
+ if r.returncode != 0 or not os.path.exists(MDM_CKPT) or \
721
+ os.path.getsize(MDM_CKPT) < 10_000_000:
722
+ print("[MDM] HF download failed — trying gdown (official Google Drive)...")
723
+ subprocess.run([sys.executable, "-m", "gdown",
724
+ "--id", "1PE0PK8e5a5j-7-Xhs5YET5U5pGh0c821",
725
+ "-O", MDM_CKPT], check=False, timeout=3600)
726
+
727
+ ok = os.path.exists(MDM_CKPT) and os.path.getsize(MDM_CKPT) > 10_000_000
728
+ print(f"[MDM] Setup {'OK' if ok else 'FAILED'}")
729
+ return ok
730
+
731
+
732
+ def generate_motion_mdm(text_prompt: str, n_frames: int = 120,
733
+ fps: int = 20, device: str = "cuda") -> dict | None:
734
+ """
735
+ Run MDM text-to-motion. Returns {'positions': (n_frames,22,3), 'fps': fps}
736
+ or None on failure. First call runs setup_mdm() which may take ~10 min.
737
+ """
738
+ if not setup_mdm():
739
+ return None
740
+
741
+ out_dir = tempfile.mkdtemp(prefix="mdm_")
742
+ motion_len = round(n_frames / fps, 2)
743
+
744
+ # Minimal inline driver — avoids MDM's argparse setup entirely
745
+ driver_src = f"""
746
+ import sys, os
747
+ sys.path.insert(0, {repr(MDM_DIR)})
748
+ os.chdir({repr(MDM_DIR)})
749
+ import numpy as np, torch
750
+
751
+ from utils.fixseed import fixseed
752
+ from utils.model_util import create_model_and_diffusion
753
+ from utils import dist_util
754
+ from data_loaders.humanml.utils.paramUtil import t2m_kinematic_chain
755
+ from data_loaders.humanml.scripts.motion_process import recover_from_ric
756
+ import clip as clip_lib
757
+
758
+ fixseed(42)
759
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
760
+ dist_util.dev = lambda: device
761
+
762
+ import argparse
763
+ args = argparse.Namespace(
764
+ arch='trans_enc', emb_trans_dec=False,
765
+ layers=8, latent_dim=512, ff_size=1024, num_heads=4,
766
+ dropout=0.1, activation='gelu', data_rep='rot6d',
767
+ dataset='humanml', cond_mode='text', cond_mask_prob=0.1,
768
+ lambda_rcxyz=0, lambda_vel=0, lambda_fc=0,
769
+ njoints=263, nfeats=1,
770
+ num_actions=1, translation=True, pose_rep='rot6d',
771
+ glob=True, glob_rot=True, npose=315,
772
+ device=0, seed=42, batch_size=1, num_samples=1,
773
+ num_repetitions=1, motion_length={motion_len!r},
774
+ input_text='', text_prompt='', action_file='', action_name='',
775
+ output_dir={repr(out_dir)}, guidance_param=2.5,
776
+ unconstrained=False,
777
+ # additional args required by get_model_args / create_gaussian_diffusion
778
+ text_encoder_type='clip',
779
+ pos_embed_max_len=5000,
780
+ mask_frames=False,
781
+ pred_len=0,
782
+ context_len=0,
783
+ diffusion_steps=1000,
784
+ noise_schedule='cosine',
785
+ sigma_small=True,
786
+ lambda_target_loc=0,
787
+ )
788
+
789
+ class _MockData:
790
+ class dataset:
791
+ pass
792
+ model, diffusion = create_model_and_diffusion(args, _MockData())
793
+ state = torch.load({repr(MDM_CKPT)}, map_location='cpu', weights_only=False)
794
+ missing, unexpected = model.load_state_dict(state, strict=False)
795
+ model.eval().to(device)
796
+
797
+ max_frames = int({n_frames})
798
+ shape = (1, model.njoints, model.nfeats, max_frames)
799
+ clip_model, _ = clip_lib.load('ViT-B/32', device=device, jit=False)
800
+ clip_model.eval()
801
+ tokens = clip_lib.tokenize([{repr(text_prompt)}]).to(device)
802
+ with torch.no_grad():
803
+ text_emb = clip_model.encode_text(tokens).float()
804
+
805
+ model_kwargs = {{
806
+ 'y': {{
807
+ 'mask': torch.ones(1, 1, 1, max_frames).to(device),
808
+ 'lengths': torch.tensor([max_frames]).to(device),
809
+ 'text': [{repr(text_prompt)}],
810
+ 'tokens': [''],
811
+ 'scale': torch.ones(1).to(device) * 2.5,
812
+ }}
813
+ }}
814
+
815
+ with torch.no_grad():
816
+ sample = diffusion.p_sample_loop(
817
+ model, shape, clip_denoised=False,
818
+ model_kwargs=model_kwargs, skip_timesteps=0,
819
+ init_image=None, progress=False, dump_steps=None,
820
+ noise=None, const_noise=False,
821
+ ) # (1, 263, 1, n_frames)
822
+
823
+ # Convert HumanML3D features → joint XYZ using recover_from_ric (no SMPL needed)
824
+ # sample: (1, 263, 1, n_frames) → (1, n_frames, 263)
825
+ sample_ric = sample[:, :, 0, :].permute(0, 2, 1)
826
+ xyz = recover_from_ric(sample_ric, 22) # (1, n_frames, 22, 3)
827
+ positions = xyz[0].cpu().numpy() # (n_frames, 22, 3)
828
+ np.save(os.path.join({repr(out_dir)}, 'positions.npy'), positions)
829
+ print('MDM_DONE')
830
+ """
831
+ driver_f = None
832
+ try:
833
+ with tempfile.NamedTemporaryFile('w', suffix='.py', delete=False) as f:
834
+ f.write(driver_src)
835
+ driver_f = f.name
836
+
837
+ r = subprocess.run(
838
+ [sys.executable, driver_f],
839
+ capture_output=True, text=True, timeout=600,
840
+ env={**os.environ, "PYTHONPATH": MDM_DIR, "CUDA_VISIBLE_DEVICES": "0"},
841
+ )
842
+ print(f"[MDM] stdout: {r.stdout[-400:]}")
843
+ if r.returncode != 0:
844
+ print(f"[MDM] FAILED:\n{r.stderr[-600:]}")
845
+ return None
846
+
847
+ npy = os.path.join(out_dir, "positions.npy")
848
+ if not os.path.exists(npy):
849
+ print("[MDM] positions.npy not found")
850
+ return None
851
+
852
+ arr = np.load(npy) # (n_frames, 22, 3)
853
+ positions = arr # already (n_frames, 22, 3)
854
+ print(f"[MDM] Motion: {positions.shape}, fps={fps}")
855
+ return {"positions": positions, "fps": fps, "n_frames": positions.shape[0]}
856
+
857
+ except Exception:
858
+ print(f"[MDM] Exception:\n{traceback.format_exc()}")
859
+ return None
860
+ finally:
861
+ if driver_f:
862
+ try: os.unlink(driver_f)
863
+ except: pass
864
+
865
+
866
+ # ══════════════════════════════════════════════════════════════════════════════
867
+ # FK Inversion — joint world-positions → local quaternions per frame
868
+ # ══════════════════════════════════════════════════════════════════════════════
869
+
870
+ def _quat_between(v0: np.ndarray, v1: np.ndarray) -> np.ndarray:
871
+ """Shortest-arc quaternion [x,y,z,w] that rotates unit vector v0 → v1."""
872
+ cross = np.cross(v0, v1)
873
+ dot = float(np.clip(np.dot(v0, v1), -1.0, 1.0))
874
+ cn = np.linalg.norm(cross)
875
+ if cn < 1e-8:
876
+ return np.array([0., 0., 0., 1.], np.float32) if dot > 0 \
877
+ else np.array([1., 0., 0., 0.], np.float32)
878
+ axis = cross / cn
879
+ angle = np.arctan2(cn, dot)
880
+ s = np.sin(angle * 0.5)
881
+ return np.array([axis[0]*s, axis[1]*s, axis[2]*s, np.cos(angle*0.5)], np.float32)
882
+
883
+
884
+ def _quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
885
+ """Hamilton product of two [x,y,z,w] quaternions."""
886
+ x1,y1,z1,w1 = q1; x2,y2,z2,w2 = q2
887
+ return np.array([
888
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
889
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
890
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
891
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
892
+ ], np.float32)
893
+
894
+
895
+ def _quat_inv(q: np.ndarray) -> np.ndarray:
896
+ return np.array([-q[0], -q[1], -q[2], q[3]], np.float32)
897
+
898
+
899
+ def _quat_rotate(q: np.ndarray, v: np.ndarray) -> np.ndarray:
900
+ """Rotate vector v by quaternion q."""
901
+ qv = np.array([v[0], v[1], v[2], 0.], np.float32)
902
+ return _quat_mul(_quat_mul(q, qv), _quat_inv(q))[:3]
903
+
904
+
905
+ def positions_to_local_quats(positions: np.ndarray,
906
+ t_pose_joints: np.ndarray,
907
+ parents: list) -> np.ndarray:
908
+ """
909
+ Derive per-joint local quaternions from world-space joint positions.
910
+ positions : (n_frames, n_joints, 3)
911
+ t_pose_joints : (n_joints, 3) — SMPL T-pose joints in same scale/space
912
+ parents : list of length n_joints, parent index (-1 for root)
913
+ Returns : (n_frames, n_joints, 4) XYZW local quaternions
914
+ """
915
+ n_frames, n_joints, _ = positions.shape
916
+ quats = np.zeros((n_frames, n_joints, 4), np.float32)
917
+ quats[:, :, 3] = 1.0 # default identity
918
+
919
+ # Compute global quats first, then convert to local
920
+ global_quats = np.zeros_like(quats)
921
+ global_quats[:, :, 3] = 1.0
922
+
923
+ for j in range(n_joints):
924
+ p = parents[j]
925
+ if p < 0:
926
+ # Root: no rotation relative to world (translation handles it)
927
+ global_quats[:, j] = [0, 0, 0, 1]
928
+ continue
929
+
930
+ # T-pose parent→child bone direction
931
+ tp_dir = t_pose_joints[j] - t_pose_joints[p]
932
+ tp_len = np.linalg.norm(tp_dir)
933
+ if tp_len < 1e-6:
934
+ continue
935
+ tp_dir /= tp_len
936
+
937
+ for f in range(n_frames):
938
+ an_dir = positions[f, j] - positions[f, p]
939
+ an_len = np.linalg.norm(an_dir)
940
+ if an_len < 1e-6:
941
+ global_quats[f, j] = global_quats[f, p]
942
+ continue
943
+ an_dir /= an_len
944
+ # Global rotation = parent_global ∘ local
945
+ # We want global bone direction to match an_dir
946
+ # global_bone_tpose = rotate(global_parent, tp_dir_in_parent_space)
947
+ # For SMPL T-pose, bone dirs are in world space already
948
+ gq = _quat_between(tp_dir, an_dir)
949
+ global_quats[f, j] = gq
950
+
951
+ # Convert global → local (local = inv_parent_global ∘ global)
952
+ for j in range(n_joints):
953
+ p = parents[j]
954
+ if p < 0:
955
+ quats[:, j] = global_quats[:, j]
956
+ else:
957
+ for f in range(n_frames):
958
+ quats[f, j] = _quat_mul(_quat_inv(global_quats[f, p]),
959
+ global_quats[f, j])
960
+
961
+ return quats
962
+
963
+
964
+ # ══════════════════════════════════════════════════════════════════════════════
965
+ # Animated GLB export
966
+ # ══════════════════════════════════════════════════════════════════════════════
967
+
968
+ def export_animated_glb(verts, faces, uv, texture_img,
969
+ joints, # (24, 3) T-pose joint world positions
970
+ skin_weights, # (N_verts, 24)
971
+ joint_quats, # (n_frames, 24, 4) XYZW local quaternions
972
+ root_trans, # (n_frames, 3) world translation of root
973
+ fps: int,
974
+ out_path: str):
975
+ """
976
+ Export fully animated rigged GLB.
977
+ Skeleton + skin weights identical to export_rigged_glb;
978
+ adds a GLTF animation with per-joint rotation channels + root translation.
979
+ """
980
+ import pygltflib
981
+ from pygltflib import (GLTF2, Scene, Node, Mesh, Primitive, Accessor,
982
+ BufferView, Buffer, Material, Texture,
983
+ Image as GImage, Sampler, Skin, Asset,
984
+ Animation, AnimationChannel, AnimationChannelTarget,
985
+ AnimationSampler)
986
+ from pygltflib import (ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER, FLOAT,
987
+ UNSIGNED_INT, UNSIGNED_SHORT, LINEAR,
988
+ LINEAR_MIPMAP_LINEAR, REPEAT, SCALAR, VEC2,
989
+ VEC3, VEC4, MAT4)
990
+
991
+ n_frames, n_joints_anim, _ = joint_quats.shape
992
+ n_joints = len(joints)
993
+
994
+ gltf = GLTF2()
995
+ gltf.asset = Asset(version="2.0", generator="rig_stage.py/animated")
996
+ blobs = []
997
+
998
+ def _add(data: np.ndarray, comp, acc_type, target=None,
999
+ set_min_max=False):
1000
+ b = data.tobytes()
1001
+ pad = (4 - len(b) % 4) % 4
1002
+ off = sum(len(x) for x in blobs)
1003
+ blobs.append(b + b"\x00" * pad)
1004
+ bv = len(gltf.bufferViews)
1005
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off,
1006
+ byteLength=len(b), target=target))
1007
+ ac = len(gltf.accessors)
1008
+ flat = data.flatten().astype(np.float32)
1009
+ kw = {}
1010
+ if set_min_max:
1011
+ kw = {"min": [float(flat.min())], "max": [float(flat.max())]}
1012
+ gltf.accessors.append(Accessor(
1013
+ bufferView=bv, byteOffset=0, componentType=comp,
1014
+ type=acc_type, count=len(data), **kw))
1015
+ return ac
1016
+
1017
+ # ── Mesh geometry ──────────────────────────────────────────────────────
1018
+ pos_acc = _add(verts.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
1019
+
1020
+ v0,v1,v2 = verts[faces[:,0]], verts[faces[:,1]], verts[faces[:,2]]
1021
+ fn = np.cross(v1-v0, v2-v0)
1022
+ fn /= (np.linalg.norm(fn, axis=1, keepdims=True) + 1e-8)
1023
+ vn = np.zeros_like(verts)
1024
+ for i in range(3): np.add.at(vn, faces[:,i], fn)
1025
+ vn /= (np.linalg.norm(vn, axis=1, keepdims=True) + 1e-8)
1026
+ nor_acc = _add(vn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
1027
+
1028
+ if uv is None: uv = np.zeros((len(verts), 2), np.float32)
1029
+ uv_acc = _add(uv.astype(np.float32), FLOAT, VEC2, ARRAY_BUFFER)
1030
+ idx_acc = _add(faces.astype(np.uint32).flatten(), UNSIGNED_INT,
1031
+ SCALAR, ELEMENT_ARRAY_BUFFER)
1032
+
1033
+ top4_idx = np.argsort(-skin_weights, axis=1)[:, :4].astype(np.uint16)
1034
+ top4_w = np.take_along_axis(skin_weights, top4_idx.astype(np.int64), axis=1).astype(np.float32)
1035
+ top4_w /= top4_w.sum(axis=1, keepdims=True).clip(1e-8, None)
1036
+ j_acc = _add(top4_idx, UNSIGNED_SHORT, "VEC4", ARRAY_BUFFER)
1037
+ w_acc = _add(top4_w, FLOAT, "VEC4", ARRAY_BUFFER)
1038
+
1039
+ # ── Texture ────────────────────────────────────────────────────────────
1040
+ if texture_img is not None:
1041
+ import io
1042
+ buf = io.BytesIO(); texture_img.save(buf, format="PNG"); ib = buf.getvalue()
1043
+ off = sum(len(x) for x in blobs); pad2 = (4 - len(ib) % 4) % 4
1044
+ blobs.append(ib + b"\x00" * pad2)
1045
+ gltf.bufferViews.append(BufferView(buffer=0, byteOffset=off, byteLength=len(ib)))
1046
+ gltf.images.append(GImage(mimeType="image/png", bufferView=len(gltf.bufferViews)-1))
1047
+ gltf.samplers.append(Sampler(magFilter=LINEAR, minFilter=LINEAR_MIPMAP_LINEAR,
1048
+ wrapS=REPEAT, wrapT=REPEAT))
1049
+ gltf.textures.append(Texture(sampler=0, source=0))
1050
+ gltf.materials.append(Material(name="body",
1051
+ pbrMetallicRoughness={"baseColorTexture": {"index": 0},
1052
+ "metallicFactor": 0.0, "roughnessFactor": 0.8},
1053
+ doubleSided=True))
1054
+ else:
1055
+ gltf.materials.append(Material(name="body", doubleSided=True))
1056
+
1057
+ prim = Primitive(
1058
+ attributes={"POSITION": pos_acc, "NORMAL": nor_acc,
1059
+ "TEXCOORD_0": uv_acc, "JOINTS_0": j_acc, "WEIGHTS_0": w_acc},
1060
+ indices=idx_acc, material=0)
1061
+ gltf.meshes.append(Mesh(name="body", primitives=[prim]))
1062
+
1063
+ # ── Skeleton nodes ─────────────────────────────────────────────────────
1064
+ jnodes = []
1065
+ for i, (name, parent) in enumerate(zip(SMPL_JOINT_NAMES, SMPL_PARENTS)):
1066
+ t = joints[i].tolist() if parent == -1 else (joints[i] - joints[parent]).tolist()
1067
+ n = Node(name=name, translation=t, children=[])
1068
+ jnodes.append(len(gltf.nodes)); gltf.nodes.append(n)
1069
+ for i, p in enumerate(SMPL_PARENTS):
1070
+ if p != -1: gltf.nodes[jnodes[p]].children.append(jnodes[i])
1071
+
1072
+ ibms = np.stack([np.eye(4, dtype=np.float32) for _ in range(n_joints)])
1073
+ for i in range(n_joints): ibms[i, :3, 3] = -joints[i]
1074
+ ibm_acc = _add(ibms.astype(np.float32), FLOAT, MAT4)
1075
+ skin_idx = len(gltf.skins)
1076
+ gltf.skins.append(Skin(name="smpl_skin", skeleton=jnodes[0],
1077
+ joints=jnodes, inverseBindMatrices=ibm_acc))
1078
+
1079
+ mesh_node = len(gltf.nodes)
1080
+ gltf.nodes.append(Node(name="body_mesh", mesh=0, skin=skin_idx))
1081
+ root_node = len(gltf.nodes)
1082
+ gltf.nodes.append(Node(name="root", children=[jnodes[0], mesh_node]))
1083
+ gltf.scenes.append(Scene(name="Scene", nodes=[root_node]))
1084
+ gltf.scene = 0
1085
+
1086
+ # ── Animation ──────────────────────────────────────────────────────────
1087
+ dt = 1.0 / fps
1088
+ times = np.arange(n_frames, dtype=np.float32) * dt # (n_frames,)
1089
+ time_acc = _add(times, FLOAT, SCALAR, set_min_max=True)
1090
+
1091
+ channels, samplers = [], []
1092
+
1093
+ # Per-joint rotation tracks
1094
+ for j in range(min(n_joints_anim, n_joints)):
1095
+ q = joint_quats[:, j, :].astype(np.float32) # (n_frames, 4) XYZW
1096
+ q_acc = _add(q, FLOAT, VEC4)
1097
+ si = len(samplers)
1098
+ samplers.append(AnimationSampler(input=time_acc, output=q_acc,
1099
+ interpolation="LINEAR"))
1100
+ channels.append(AnimationChannel(
1101
+ sampler=si,
1102
+ target=AnimationChannelTarget(node=jnodes[j], path="rotation")))
1103
+
1104
+ # Root translation track
1105
+ if root_trans is not None:
1106
+ tr = root_trans.astype(np.float32) # (n_frames, 3)
1107
+ tr_acc = _add(tr, FLOAT, VEC3)
1108
+ si = len(samplers)
1109
+ samplers.append(AnimationSampler(input=time_acc, output=tr_acc,
1110
+ interpolation="LINEAR"))
1111
+ channels.append(AnimationChannel(
1112
+ sampler=si,
1113
+ target=AnimationChannelTarget(node=jnodes[0], path="translation")))
1114
+
1115
+ gltf.animations.append(Animation(name="mdm_motion",
1116
+ channels=channels, samplers=samplers))
1117
+
1118
+ # ── Finalise ───────────────────────────────────────────────────────────
1119
+ bin_data = b"".join(blobs)
1120
+ gltf.buffers.append(Buffer(byteLength=len(bin_data)))
1121
+ gltf.set_binary_blob(bin_data)
1122
+ gltf.save_binary(out_path)
1123
+ dur = times[-1] if len(times) else 0
1124
+ print(f"[rig] Animated GLB → {out_path} "
1125
+ f"({os.path.getsize(out_path)//1024} KB, {n_frames} frames @ {fps}fps = {dur:.1f}s)")
1126
+
1127
+
1128
+ # ══════════════════════════════════════════════════════════════════════════════
1129
+ # Main pipeline
1130
+ # ══════════════════════════════════════════════════════════════════════════════
1131
+
1132
+ def run_rig_pipeline(glb_path: str, reference_image_path: str,
1133
+ out_dir: str, device: str = "cuda",
1134
+ export_fbx_flag: bool = True,
1135
+ mdm_prompt: str = "",
1136
+ mdm_n_frames: int = 120,
1137
+ mdm_fps: int = 20) -> dict:
1138
+ import trimesh
1139
+ os.makedirs(out_dir, exist_ok=True)
1140
+ result = {"rigged_glb": None, "animated_glb": None, "fbx": None,
1141
+ "smpl_params": None, "status": "", "phases": {}}
1142
+
1143
+ try:
1144
+ # ── load TripoSG mesh ─────────────────────────────────────────────
1145
+ print("[rig] Loading TripoSG mesh...")
1146
+ scene = trimesh.load(glb_path, force="scene")
1147
+ if isinstance(scene, trimesh.Scene):
1148
+ geom = list(scene.geometry.values())
1149
+ mesh = trimesh.util.concatenate(geom) if len(geom)>1 else geom[0]
1150
+ else:
1151
+ mesh = scene
1152
+ verts = np.array(mesh.vertices, dtype=np.float32)
1153
+ faces = np.array(mesh.faces, dtype=np.int32)
1154
+
1155
+ # UV + texture: try source geoms before concatenation (more reliable)
1156
+ uv, tex = None, None
1157
+ src_geoms = list(scene.geometry.values()) if isinstance(scene, trimesh.Scene) else [scene]
1158
+ for g in src_geoms:
1159
+ if not hasattr(g.visual, "uv") or g.visual.uv is None:
1160
+ continue
1161
+ try:
1162
+ candidate_uv = np.array(g.visual.uv, dtype=np.float32)
1163
+ if len(candidate_uv) == len(verts):
1164
+ uv = candidate_uv
1165
+ mat = getattr(g.visual, "material", None)
1166
+ if mat is not None:
1167
+ for attr in ("image", "baseColorTexture", "diffuse"):
1168
+ img = getattr(mat, attr, None)
1169
+ if img is not None:
1170
+ from PIL import Image as _PILImage
1171
+ tex = img if isinstance(img, _PILImage.Image) else None
1172
+ break
1173
+ break
1174
+ except Exception:
1175
+ pass
1176
+ if uv is None:
1177
+ print("[rig] WARNING: UV not found or vertex count mismatch — mesh will be untextured")
1178
+ print(f"[rig] Mesh: {len(verts)} verts, {len(faces)} faces, "
1179
+ f"UV={'yes' if uv is not None else 'no'}, "
1180
+ f"texture={'yes' if tex is not None else 'no'}")
1181
+
1182
+ # ── Phase 1: multi-view beta averaging ───────────────────────────
1183
+ print("\n[rig] ── Phase 1: multi-view beta averaging ──")
1184
+ betas, hmr2_results = estimate_betas_multiview(VIEW_PATHS, reference_image_path, device)
1185
+ result["phases"]["p1_betas"] = betas.tolist()
1186
+
1187
+ # ── Phase 2: silhouette fitting ───────────────────────────────────
1188
+ print("\n[rig] ── Phase 2: silhouette fitting ──")
1189
+ betas = fit_betas_silhouette(betas, VIEW_PATHS)
1190
+ result["phases"]["p2_betas"] = betas.tolist()
1191
+
1192
+ # ── Phase 3: multi-view joint triangulation ───────────────────────
1193
+ print("\n[rig] ── Phase 3: multi-view joint triangulation ──")
1194
+ tri_joints = triangulate_joints_multiview(hmr2_results)
1195
+ result["phases"]["p3_triangulated"] = tri_joints is not None
1196
+
1197
+ # ── build SMPL T-pose with refined betas ──────────────────────────
1198
+ print("\n[rig] Building SMPL T-pose...")
1199
+ smpl_v, smpl_f, smpl_j, smpl_w = get_smpl_tpose(betas)
1200
+
1201
+ # Override with triangulated joints if available
1202
+ if tri_joints is not None:
1203
+ # Triangulated joints are in render-normalised space; convert to SMPL scale
1204
+ _, _, scale, _ = _smpl_to_render_space(smpl_v.copy(), smpl_j.copy())
1205
+ smpl_j = tri_joints / scale # back to SMPL metric space
1206
+ print("[rig] Using triangulated skeleton joints.")
1207
+
1208
+ # ── align TripoSG mesh to SMPL ────────────────────────────────────
1209
+ verts_aligned = align_mesh_to_smpl(verts, smpl_v, smpl_j)
1210
+
1211
+ # ── skinning weight transfer ──────────────────────────────────────
1212
+ print("[rig] Transferring skinning weights...")
1213
+ skin_w = transfer_skinning(smpl_v, smpl_w, verts_aligned)
1214
+
1215
+ # ── export rigged GLB ─────────────────────────────────────────────
1216
+ rigged_glb = os.path.join(out_dir, "rigged.glb")
1217
+ export_rigged_glb(verts_aligned, faces, uv, tex, smpl_j, skin_w, rigged_glb)
1218
+ result["rigged_glb"] = rigged_glb
1219
+
1220
+ # ── export FBX ────────────────────────────────────────────────────
1221
+ if export_fbx_flag:
1222
+ fbx = os.path.join(out_dir, "rigged.fbx")
1223
+ result["fbx"] = fbx if export_fbx(rigged_glb, fbx) else None
1224
+
1225
+ # ── MDM animation ─────────────────────────────────────────────────
1226
+ if mdm_prompt.strip():
1227
+ print(f"\n[rig] ── MDM animation: {mdm_prompt!r} ({mdm_n_frames} frames) ──")
1228
+ mdm_out = generate_motion_mdm(mdm_prompt, n_frames=mdm_n_frames,
1229
+ fps=mdm_fps, device=device)
1230
+ if mdm_out is not None:
1231
+ pos = mdm_out["positions"] # (n_frames, 22, 3)
1232
+ actual_frames = pos.shape[0]
1233
+
1234
+ # Align MDM joint positions to SMPL scale/space
1235
+ # MDM outputs in metres roughly matching SMPL metric
1236
+ # Scale so pelvis height matches our SMPL pelvis
1237
+ mdm_pelvis_h = float(np.median(pos[:, 0, 1]))
1238
+ smpl_pelvis_h = float(smpl_j[0, 1])
1239
+ if abs(mdm_pelvis_h) > 1e-4:
1240
+ pos = pos * (smpl_pelvis_h / mdm_pelvis_h)
1241
+
1242
+ # FK inversion: positions → local quaternions for joints 0-21
1243
+ t_pose_22 = smpl_j[:22]
1244
+ quats_22 = positions_to_local_quats(pos, t_pose_22, _MDM_PARENTS)
1245
+ # Pad to 24 joints (SMPL hands = identity)
1246
+ quats_24 = np.zeros((actual_frames, 24, 4), np.float32)
1247
+ quats_24[:, :, 3] = 1.0
1248
+ quats_24[:, :22, :] = quats_22
1249
+
1250
+ # Root translation: MDM root XZ + SMPL Y offset
1251
+ root_trans = pos[:, 0, :].copy() # (n_frames, 3)
1252
+
1253
+ anim_glb = os.path.join(out_dir, "animated.glb")
1254
+ export_animated_glb(
1255
+ verts_aligned, faces, uv, tex,
1256
+ smpl_j, skin_w,
1257
+ quats_24, root_trans, mdm_fps, anim_glb
1258
+ )
1259
+ result["animated_glb"] = anim_glb
1260
+ print(f"[rig] MDM animation complete → {anim_glb}")
1261
+ else:
1262
+ print("[rig] MDM generation failed — static GLB only")
1263
+
1264
+ result["smpl_params"] = {
1265
+ "betas": betas.tolist(),
1266
+ "p1_sources": len(hmr2_results),
1267
+ "p3_triangulated": tri_joints is not None,
1268
+ }
1269
+ p3_note = " + triangulated skeleton" if tri_joints is not None else ""
1270
+ fbx_note = " + FBX" if result["fbx"] else ""
1271
+ anim_note = f" + MDM({mdm_n_frames}f)" if result.get("animated_glb") else ""
1272
+ result["status"] = (
1273
+ f"Rigged ({len(hmr2_results)} views used{p3_note}{fbx_note}{anim_note}). "
1274
+ f"{len(verts)} verts, 24 joints."
1275
+ )
1276
+
1277
+ except Exception:
1278
+ err = traceback.format_exc()
1279
+ print(f"[rig] FAILED:\n{err}")
1280
+ result["status"] = f"Rigging failed:\n{err[-600:]}"
1281
+
1282
+ return result
pipeline/rig_yolo.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rig_yolo.py — Rig a humanoid mesh using YOLO-pose joint detection.
3
+
4
+ Instead of estimating T-pose rotations (which failed), detect where joints
5
+ actually ARE in the mesh's current pose and use those positions as the bind pose.
6
+
7
+ Pipeline:
8
+ 1. Render front view (azimuth=-90, same camera as triposg_app.py views)
9
+ 2. YOLOv8x-pose → COCO-17 2D keypoints
10
+ 3. Unproject to 3D in original mesh coordinate space
11
+ 4. Map COCO-17 → SMPL-24 (interpolate spine, collar, hand, foot joints)
12
+ 5. LBS weights: proximity-based (k=4 nearest joints per vertex)
13
+ 6. Export rigged GLB — bind pose = current pose
14
+
15
+ Usage:
16
+ python rig_yolo.py --body /tmp/triposg_textured.glb \
17
+ --out /tmp/rig_out/rigged.glb \
18
+ [--debug_dir /tmp/rig_debug]
19
+ """
20
+
21
+ import os, sys, argparse, warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+ import numpy as np
25
+ import cv2
26
+ import trimesh
27
+ from scipy.spatial import cKDTree
28
+
29
+ sys.path.insert(0, '/root/MV-Adapter')
30
+
31
+ # ── Camera constants — MUST match triposg_app.py ──────────────────────────────
32
+ ORTHO_LEFT, ORTHO_RIGHT = -0.55, 0.55
33
+ ORTHO_BOT, ORTHO_TOP = -0.55, 0.55
34
+ RENDER_W, RENDER_H = 768, 1024
35
+ FRONT_AZ = -90 # azimuth that gives front view
36
+ # Orthographic proj scale: 2/(right-left) = 1.818...
37
+ PROJ_SCALE = 2.0 / (ORTHO_RIGHT - ORTHO_LEFT)
38
+
39
+ SMPL_PARENTS = [-1,0,0,0,1,2,3,4,5,6,7,8,9,9,9,
40
+ 12,13,14,16,17,18,19,20,21]
41
+ SMPL_JOINT_NAMES = [
42
+ 'pelvis','left_hip','right_hip','spine1',
43
+ 'left_knee','right_knee','spine2',
44
+ 'left_ankle','right_ankle','spine3',
45
+ 'left_foot','right_foot','neck',
46
+ 'left_collar','right_collar','head',
47
+ 'left_shoulder','right_shoulder',
48
+ 'left_elbow','right_elbow',
49
+ 'left_wrist','right_wrist',
50
+ 'left_hand','right_hand',
51
+ ]
52
+
53
+ # COCO-17 order
54
+ COCO_NAMES = ['nose','L_eye','R_eye','L_ear','R_ear',
55
+ 'L_shoulder','R_shoulder','L_elbow','R_elbow','L_wrist','R_wrist',
56
+ 'L_hip','R_hip','L_knee','R_knee','L_ankle','R_ankle']
57
+
58
+
59
+ # ── Step 0: Load mesh directly from GLB (correct UV channel) ─────────────────
60
+
61
+ def load_mesh_from_gltf(body_glb):
62
+ """
63
+ Load mesh from GLB using pygltflib, reading the UV channel the material
64
+ actually references (TEXCOORD_0 or TEXCOORD_1).
65
+ Returns: verts (N,3) float64, faces (F,3) int32,
66
+ uv (N,2) float32 or None, texture_pil PIL.Image or None
67
+ """
68
+ import pygltflib
69
+ from PIL import Image as PILImage
70
+ import io
71
+
72
+ gltf = pygltflib.GLTF2().load(body_glb)
73
+ blob = gltf.binary_blob()
74
+
75
+ # componentType → (numpy dtype, bytes per element)
76
+ _DTYPE = {5120: np.int8, 5121: np.uint8, 5122: np.int16,
77
+ 5123: np.uint16, 5125: np.uint32, 5126: np.float32}
78
+ _NCOMP = {'SCALAR': 1, 'VEC2': 2, 'VEC3': 3, 'VEC4': 4, 'MAT4': 16}
79
+
80
+ def read_accessor(idx):
81
+ if idx is None:
82
+ return None
83
+ acc = gltf.accessors[idx]
84
+ bv = gltf.bufferViews[acc.bufferView]
85
+ dtype = _DTYPE[acc.componentType]
86
+ n_comp = _NCOMP[acc.type]
87
+ bv_off = bv.byteOffset or 0
88
+ acc_off = acc.byteOffset or 0
89
+ elem_bytes = np.dtype(dtype).itemsize * n_comp
90
+ stride = bv.byteStride if (bv.byteStride and bv.byteStride != elem_bytes) else elem_bytes
91
+
92
+ if stride == elem_bytes:
93
+ start = bv_off + acc_off
94
+ size = acc.count * elem_bytes
95
+ arr = np.frombuffer(blob[start:start + size], dtype=dtype)
96
+ else:
97
+ # interleaved buffer
98
+ rows = []
99
+ for i in range(acc.count):
100
+ start = bv_off + acc_off + i * stride
101
+ rows.append(np.frombuffer(blob[start:start + elem_bytes], dtype=dtype))
102
+ arr = np.concatenate(rows)
103
+
104
+ return arr.reshape(acc.count, n_comp) if n_comp > 1 else arr
105
+
106
+ # ── Find which texCoord index the material references ──────────────────────
107
+ texcoord_idx = 0
108
+ if gltf.materials:
109
+ pbr = gltf.materials[0].pbrMetallicRoughness
110
+ if pbr and pbr.baseColorTexture:
111
+ texcoord_idx = getattr(pbr.baseColorTexture, 'texCoord', 0) or 0
112
+ print(f' material uses TEXCOORD_{texcoord_idx}')
113
+
114
+ # ── Read primitive ─────────────────────────────────────────────────────────
115
+ prim = gltf.meshes[0].primitives[0]
116
+ attrs = prim.attributes
117
+
118
+ verts = read_accessor(attrs.POSITION).astype(np.float64)
119
+
120
+ idx_data = read_accessor(prim.indices).flatten()
121
+ faces = idx_data.reshape(-1, 3).astype(np.int32)
122
+
123
+ # Read the correct UV channel; fall back to TEXCOORD_0
124
+ uv_acc_idx = getattr(attrs, f'TEXCOORD_{texcoord_idx}', None)
125
+ if uv_acc_idx is None and texcoord_idx != 0:
126
+ uv_acc_idx = getattr(attrs, 'TEXCOORD_0', None)
127
+ uv_raw = read_accessor(uv_acc_idx)
128
+ uv = uv_raw.astype(np.float32) if uv_raw is not None else None
129
+
130
+ print(f' verts={len(verts)} faces={len(faces)} uv={len(uv) if uv is not None else None}')
131
+
132
+ # ── Extract embedded texture ───────────────────────────────────────────────
133
+ texture_pil = None
134
+ try:
135
+ pbr = gltf.materials[0].pbrMetallicRoughness
136
+ if pbr and pbr.baseColorTexture is not None:
137
+ tex_idx = pbr.baseColorTexture.index
138
+ if tex_idx is not None and tex_idx < len(gltf.textures):
139
+ src_idx = gltf.textures[tex_idx].source
140
+ if src_idx is not None and src_idx < len(gltf.images):
141
+ img_obj = gltf.images[src_idx]
142
+ if img_obj.bufferView is not None:
143
+ bv = gltf.bufferViews[img_obj.bufferView]
144
+ bv_off = bv.byteOffset or 0
145
+ img_bytes = blob[bv_off:bv_off + bv.byteLength]
146
+ texture_pil = PILImage.open(io.BytesIO(img_bytes)).convert('RGBA')
147
+ print(f' texture: {texture_pil.size}')
148
+ except Exception as e:
149
+ print(f' texture extraction failed: {e}')
150
+
151
+ return verts, faces, uv, texture_pil
152
+
153
+
154
+ # ── Step 1: Render front view ─────────────────────────────────────────────────
155
+
156
+ def render_front(body_glb, debug_dir=None):
157
+ """
158
+ Render front view using MV-Adapter.
159
+ Returns (img_bgr, scale_factor) where scale_factor = max_abs / 0.5
160
+ (used to convert std-space back to original mesh space).
161
+ """
162
+ from mvadapter.utils.mesh_utils import (
163
+ NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
164
+ )
165
+ ctx = NVDiffRastContextWrapper(device='cuda', context_type='cuda')
166
+ mesh_mv, _offset, scale_factor = load_mesh(
167
+ body_glb, rescale=True, return_transform=True, device='cuda')
168
+ camera = get_orthogonal_camera(
169
+ elevation_deg=[0], distance=[1.8],
170
+ left=ORTHO_LEFT, right=ORTHO_RIGHT,
171
+ bottom=ORTHO_BOT, top=ORTHO_TOP,
172
+ azimuth_deg=[FRONT_AZ], device='cuda')
173
+ out = render(ctx, mesh_mv, camera,
174
+ height=RENDER_H, width=RENDER_W,
175
+ render_attr=True, render_depth=False, render_normal=False,
176
+ attr_background=0.5)
177
+ img_np = (out.attr[0].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
178
+ img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
179
+ if debug_dir:
180
+ cv2.imwrite(os.path.join(debug_dir, 'front_render.png'), img_bgr)
181
+ print(f' render: {RENDER_W}x{RENDER_H}, scale_factor={scale_factor:.4f}')
182
+ return img_bgr, scale_factor
183
+
184
+
185
+ # ── Step 2: YOLO-pose keypoints ───────────────────────────────────────────────
186
+
187
+ def detect_keypoints(img_bgr, debug_dir=None):
188
+ """
189
+ Run YOLOv8x-pose on the rendered image.
190
+ Returns (17, 3) array: [pixel_x, pixel_y, confidence] for COCO-17 joints.
191
+ Picks the largest detected bounding box (the character body).
192
+ """
193
+ from ultralytics import YOLO
194
+ model = YOLO('yolov8x-pose.pt')
195
+ results = model(img_bgr, verbose=False)
196
+
197
+ if not results or results[0].keypoints is None or len(results[0].boxes) == 0:
198
+ raise RuntimeError('YOLO: no person detected in front render')
199
+
200
+ r = results[0]
201
+ boxes = r.boxes.xyxy.cpu().numpy()
202
+ areas = (boxes[:,2]-boxes[:,0]) * (boxes[:,3]-boxes[:,1])
203
+ idx = int(areas.argmax())
204
+
205
+ kp_xy = r.keypoints[idx].xy[0].cpu().numpy() # (17, 2) pixel
206
+ kp_conf = r.keypoints[idx].conf[0].cpu().numpy() # (17,) confidence
207
+ kp = np.concatenate([kp_xy, kp_conf[:,None]], axis=1) # (17, 3)
208
+
209
+ print(' YOLO detections: %d boxes, using largest' % len(boxes))
210
+ for i, name in enumerate(COCO_NAMES):
211
+ if kp_conf[i] > 0.3:
212
+ print(' [%d] %-14s px=(%.0f, %.0f) conf=%.2f' % (
213
+ i, name, kp_xy[i,0], kp_xy[i,1], kp_conf[i]))
214
+
215
+ if debug_dir:
216
+ vis = img_bgr.copy()
217
+ for i in range(17):
218
+ if kp_conf[i] > 0.3:
219
+ x, y = int(kp_xy[i,0]), int(kp_xy[i,1])
220
+ cv2.circle(vis, (x, y), 6, (0, 255, 0), -1)
221
+ cv2.putText(vis, COCO_NAMES[i][:4], (x+4, y-4),
222
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0,255,0), 1)
223
+ cv2.imwrite(os.path.join(debug_dir, 'yolo_keypoints.png'), vis)
224
+
225
+ return kp
226
+
227
+
228
+ # ── Step 3: Unproject 2D → 3D ────────────────────────────────────────────────
229
+
230
+ def unproject_to_3d(kp_2d_conf, scale_factor, mesh_verts_orig):
231
+ """
232
+ Convert COCO-17 pixel positions to 3D positions in original mesh space.
233
+
234
+ MV-Adapter orthographic camera at azimuth=-90 maps:
235
+ pixel_x → orig_x (character lateral axis)
236
+ pixel_y → orig_y (character height axis, flipped from pixel)
237
+ orig_z estimated from k-nearest mesh vertices in image space
238
+
239
+ Forward projection (for reference):
240
+ std_x = orig_x / scale_factor
241
+ NDC_x = PROJ_SCALE * std_x
242
+ pixel_x = (NDC_x + 1) / 2 * W
243
+
244
+ std_z = orig_y / scale_factor (mesh Y ↔ std Z ↔ image vertical)
245
+ NDC_y = -PROJ_SCALE * std_z (Y-flipped by proj matrix)
246
+ pixel_y = (NDC_y + 1) / 2 * H
247
+
248
+ Inverse:
249
+ orig_x = (2*px/W - 1) / PROJ_SCALE * scale_factor
250
+ orig_y = -(2*py/H - 1) / PROJ_SCALE * scale_factor
251
+ """
252
+ W, H = RENDER_W, RENDER_H
253
+
254
+ # Project all mesh vertices to image space (for Z lookup)
255
+ verts_px_x = ((mesh_verts_orig[:,0] / scale_factor * PROJ_SCALE) + 1.0) / 2.0 * W
256
+ verts_px_y = ((-mesh_verts_orig[:,1] / scale_factor * PROJ_SCALE) + 1.0) / 2.0 * H
257
+
258
+ joints_3d = np.full((17, 3), np.nan)
259
+ for i in range(17):
260
+ px, py, conf = kp_2d_conf[i]
261
+ if conf < 0.15 or px < 1 or py < 1:
262
+ continue
263
+
264
+ orig_x = (2.0*px/W - 1.0) / PROJ_SCALE * scale_factor
265
+ orig_y = -(2.0*py/H - 1.0) / PROJ_SCALE * scale_factor
266
+
267
+ # Z: median of k-nearest mesh vertices in image space
268
+ dist_2d = np.hypot(verts_px_x - px, verts_px_y - py)
269
+ k = 30
270
+ near_idx = np.argpartition(dist_2d, k-1)[:k]
271
+ orig_z = float(np.median(mesh_verts_orig[near_idx, 2]))
272
+
273
+ joints_3d[i] = [orig_x, orig_y, orig_z]
274
+
275
+ return joints_3d
276
+
277
+
278
+ # ── Step 4: COCO-17 → SMPL-24 ────────────────────────────────────────────────
279
+
280
+ def coco17_to_smpl24(coco_3d, mesh_verts):
281
+ """
282
+ Build 24 SMPL joint positions from COCO-17 detections.
283
+ Spine / collar / hand / foot joints are interpolated.
284
+ Low-confidence (NaN) COCO joints fall back to mesh geometry.
285
+ """
286
+ def lerp(a, b, t):
287
+ return a + t * (b - a)
288
+
289
+ def valid(i):
290
+ return not np.any(np.isnan(coco_3d[i]))
291
+
292
+ # Fill NaN joints from mesh geometry (centroid fallback)
293
+ c = coco_3d.copy()
294
+ centroid = mesh_verts.mean(axis=0)
295
+ for i in range(17):
296
+ if not valid(i):
297
+ c[i] = centroid
298
+
299
+ # Key anchor points
300
+ L_shoulder = c[5]
301
+ R_shoulder = c[6]
302
+ L_hip = c[11]
303
+ R_hip = c[12]
304
+
305
+ pelvis = lerp(L_hip, R_hip, 0.5)
306
+ mid_shoulder = lerp(L_shoulder, R_shoulder, 0.5)
307
+ # Neck: midpoint of shoulders, raised slightly (~ collar bone level)
308
+ neck = mid_shoulder + np.array([0.0, 0.04 * (mid_shoulder[1] - pelvis[1]), 0.0])
309
+
310
+ J = np.zeros((24, 3), dtype=np.float64)
311
+
312
+ J[0] = pelvis # pelvis
313
+ J[1] = L_hip # left_hip
314
+ J[2] = R_hip # right_hip
315
+ J[3] = lerp(pelvis, neck, 0.25) # spine1
316
+ J[4] = c[13] # left_knee
317
+ J[5] = c[14] # right_knee
318
+ J[6] = lerp(pelvis, neck, 0.5) # spine2
319
+ J[7] = c[15] # left_ankle
320
+ J[8] = c[16] # right_ankle
321
+ J[9] = lerp(pelvis, neck, 0.75) # spine3
322
+ J[12] = neck # neck
323
+
324
+ # Feet: project ankle downward toward mesh floor
325
+ mesh_floor_y = mesh_verts[:,1].min()
326
+ foot_y = mesh_floor_y + 0.02 * (c[15][1] - mesh_floor_y) # 2% above floor
327
+ J[10] = np.array([c[15][0], foot_y, c[15][2]]) # left_foot
328
+ J[11] = np.array([c[16][0], foot_y, c[16][2]]) # right_foot
329
+
330
+ J[13] = lerp(neck, L_shoulder, 0.5) # left_collar
331
+ J[14] = lerp(neck, R_shoulder, 0.5) # right_collar
332
+ J[15] = c[0] # head (nose as proxy)
333
+ J[16] = L_shoulder # left_shoulder
334
+ J[17] = R_shoulder # right_shoulder
335
+ J[18] = c[7] # left_elbow
336
+ J[19] = c[8] # right_elbow
337
+ J[20] = c[9] # left_wrist
338
+ J[21] = c[10] # right_wrist
339
+
340
+ # Hands: extrapolate one step beyond wrist in elbow→wrist direction
341
+ for side, (elbow_i, wrist_i, hand_i) in enumerate([(7,9,22), (8,10,23)]):
342
+ elbow = c[elbow_i]; wrist = c[wrist_i]
343
+ bone = wrist - elbow
344
+ blen = np.linalg.norm(bone)
345
+ if blen > 1e-3:
346
+ J[hand_i] = wrist + bone / blen * 0.05
347
+ else:
348
+ J[hand_i] = wrist
349
+
350
+ print(' SMPL-24 joints:')
351
+ print(' pelvis : (%.3f, %.3f, %.3f)' % tuple(J[0]))
352
+ print(' L_hip : (%.3f, %.3f, %.3f)' % tuple(J[1]))
353
+ print(' R_hip : (%.3f, %.3f, %.3f)' % tuple(J[2]))
354
+ print(' neck : (%.3f, %.3f, %.3f)' % tuple(J[12]))
355
+ print(' L_shoulder: (%.3f, %.3f, %.3f)' % tuple(J[16]))
356
+ print(' R_shoulder: (%.3f, %.3f, %.3f)' % tuple(J[17]))
357
+ print(' head : (%.3f, %.3f, %.3f)' % tuple(J[15]))
358
+
359
+ return J.astype(np.float32)
360
+
361
+
362
+ # ── Step 5: LBS skinning weights ─────────────────────────────────────────────
363
+
364
+ def compute_skinning_weights(mesh_verts, joints, k=4):
365
+ """
366
+ Proximity-based LBS weights: each vertex gets k-nearest joint weights
367
+ via inverse-distance weighting.
368
+ Returns (N, 24) float32 full weight matrix.
369
+ """
370
+ N = len(mesh_verts)
371
+ tree = cKDTree(joints)
372
+ dists, idxs = tree.query(mesh_verts, k=k, workers=-1)
373
+
374
+ # Clamp minimum distance to avoid division by zero
375
+ inv_d = 1.0 / np.maximum(dists, 1e-6)
376
+ inv_d /= inv_d.sum(axis=1, keepdims=True)
377
+
378
+ W_full = np.zeros((N, 24), dtype=np.float32)
379
+ for ki in range(k):
380
+ W_full[np.arange(N), idxs[:, ki]] += inv_d[:, ki].astype(np.float32)
381
+
382
+ # Normalize (should already be normalized, but just in case)
383
+ row_sum = W_full.sum(axis=1, keepdims=True)
384
+ W_full /= np.where(row_sum > 0, row_sum, 1.0)
385
+
386
+ print(' weights: max_joint=%d mean_support=%.2f joints/vert' % (
387
+ W_full.argmax(axis=1).max(),
388
+ (W_full > 0.01).sum(axis=1).mean()))
389
+
390
+ return W_full
391
+
392
+
393
+ # ── Skeleton mesh builder ─────────────────────────────────────────────────────
394
+
395
+ def make_skeleton_mesh(joints, radius=0.008):
396
+ """
397
+ Build a mesh of hexagonal-prism cylinders connecting parent→child joints.
398
+ Returns (verts, faces) as float32 / int32 numpy arrays.
399
+ """
400
+ SEG = 6 # hexagonal cross-section
401
+ angles = np.linspace(0, 2 * np.pi, SEG, endpoint=False)
402
+ circle = np.stack([np.cos(angles), np.sin(angles)], axis=1) # (SEG, 2)
403
+
404
+ all_verts, all_faces = [], []
405
+ vert_offset = 0
406
+
407
+ for i, parent in enumerate(SMPL_PARENTS):
408
+ if parent == -1:
409
+ continue
410
+ p0 = joints[parent].astype(np.float64)
411
+ p1 = joints[i].astype(np.float64)
412
+ bone_vec = p1 - p0
413
+ length = np.linalg.norm(bone_vec)
414
+ if length < 1e-4:
415
+ continue
416
+
417
+ z_axis = bone_vec / length
418
+ ref = np.array([0., 1., 0.]) if abs(z_axis[1]) < 0.9 else np.array([1., 0., 0.])
419
+ x_axis = np.cross(ref, z_axis)
420
+ x_axis /= np.linalg.norm(x_axis)
421
+ y_axis = np.cross(z_axis, x_axis)
422
+
423
+ # Bottom ring at p0, top ring at p1
424
+ offsets = radius * (circle[:, 0:1] * x_axis + circle[:, 1:2] * y_axis)
425
+ bottom = p0 + offsets # (SEG, 3)
426
+ top = p1 + offsets # (SEG, 3)
427
+
428
+ all_verts.append(np.vstack([bottom, top]).astype(np.float32))
429
+
430
+ for j in range(SEG):
431
+ j1 = (j + 1) % SEG
432
+ b0, b1 = vert_offset + j, vert_offset + j1
433
+ t0, t1 = vert_offset + SEG + j, vert_offset + SEG + j1
434
+ all_faces.extend([[b0, b1, t0], [b1, t1, t0]])
435
+
436
+ vert_offset += 2 * SEG
437
+
438
+ if not all_verts:
439
+ return np.zeros((0, 3), np.float32), np.zeros((0, 3), np.int32)
440
+
441
+ return np.vstack(all_verts), np.array(all_faces, dtype=np.int32)
442
+
443
+
444
+ # ── Step 6: Export rigged GLB ─────────────────────────────────────────────────
445
+
446
+ def export_rigged_glb(verts, faces, uv, texture_pil, joints, skin_weights,
447
+ out_path, skel_verts=None, skel_faces=None):
448
+ """
449
+ Export skinned GLB using pygltflib.
450
+ bind pose = current pose (joints at detected positions).
451
+ IBM[j] = Translation(-J_world[j]) (pure offset, no rotation).
452
+
453
+ If skel_verts/skel_faces are provided, a second mesh (bright green skeleton
454
+ sticks) is embedded alongside the body mesh.
455
+ """
456
+ import pygltflib
457
+ from pygltflib import (GLTF2, Scene, Node, Mesh, Primitive, Accessor,
458
+ BufferView, Buffer, Material, Texture,
459
+ Image as GImage, Sampler, Skin, Asset)
460
+ from pygltflib import (ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER, FLOAT,
461
+ UNSIGNED_INT, UNSIGNED_SHORT, LINEAR,
462
+ LINEAR_MIPMAP_LINEAR, REPEAT, SCALAR, VEC2,
463
+ VEC3, VEC4, MAT4)
464
+
465
+ gltf = GLTF2()
466
+ gltf.asset = Asset(version='2.0', generator='rig_yolo.py')
467
+ blobs = []
468
+
469
+ def _add(data, comp, acc_type, target=None):
470
+ b = data.tobytes()
471
+ pad = (4 - len(b) % 4) % 4
472
+ off = sum(len(x) for x in blobs)
473
+ blobs.append(b + b'\x00' * pad)
474
+ bv = len(gltf.bufferViews)
475
+ gltf.bufferViews.append(BufferView(
476
+ buffer=0, byteOffset=off, byteLength=len(b), target=target))
477
+ ac = len(gltf.accessors)
478
+ flat = data.flatten()
479
+ gltf.accessors.append(Accessor(
480
+ bufferView=bv, byteOffset=0, componentType=comp,
481
+ type=acc_type, count=len(data),
482
+ min=[float(flat.min())], max=[float(flat.max())]))
483
+ return ac
484
+
485
+ # Geometry
486
+ pos_acc = _add(verts.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
487
+
488
+ v0, v1, v2 = verts[faces[:,0]], verts[faces[:,1]], verts[faces[:,2]]
489
+ fn = np.cross(v1-v0, v2-v0)
490
+ fn /= (np.linalg.norm(fn, axis=1, keepdims=True) + 1e-8)
491
+ vn = np.zeros_like(verts)
492
+ for i in range(3):
493
+ np.add.at(vn, faces[:,i], fn)
494
+ vn /= (np.linalg.norm(vn, axis=1, keepdims=True) + 1e-8)
495
+ nor_acc = _add(vn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
496
+
497
+ if uv is None:
498
+ uv = np.zeros((len(verts), 2), np.float32)
499
+ uv_acc = _add(uv.astype(np.float32), FLOAT, VEC2, ARRAY_BUFFER)
500
+ idx_acc = _add(faces.astype(np.uint32).flatten(), UNSIGNED_INT, SCALAR,
501
+ ELEMENT_ARRAY_BUFFER)
502
+
503
+ # Skinning: top-4 joints per vertex
504
+ top4_idx = np.argsort(-skin_weights, axis=1)[:, :4].astype(np.uint16)
505
+ top4_w = np.take_along_axis(skin_weights, top4_idx.astype(np.int64), axis=1)
506
+ top4_w = top4_w.astype(np.float32)
507
+ top4_w /= top4_w.sum(axis=1, keepdims=True).clip(1e-8, None)
508
+ j_acc = _add(top4_idx, UNSIGNED_SHORT, VEC4, ARRAY_BUFFER)
509
+ w_acc = _add(top4_w, FLOAT, VEC4, ARRAY_BUFFER)
510
+
511
+ # Texture
512
+ if texture_pil is not None:
513
+ import io
514
+ buf = io.BytesIO()
515
+ texture_pil.save(buf, format='PNG')
516
+ ib = buf.getvalue()
517
+ off = sum(len(x) for x in blobs)
518
+ pad = (4 - len(ib) % 4) % 4
519
+ blobs.append(ib + b'\x00' * pad)
520
+ gltf.bufferViews.append(
521
+ BufferView(buffer=0, byteOffset=off, byteLength=len(ib)))
522
+ gltf.images.append(
523
+ GImage(mimeType='image/png', bufferView=len(gltf.bufferViews)-1))
524
+ gltf.samplers.append(
525
+ Sampler(magFilter=LINEAR, minFilter=LINEAR_MIPMAP_LINEAR,
526
+ wrapS=REPEAT, wrapT=REPEAT))
527
+ gltf.textures.append(Texture(sampler=0, source=0))
528
+ gltf.materials.append(Material(
529
+ name='body',
530
+ pbrMetallicRoughness={
531
+ 'baseColorTexture': {'index': 0},
532
+ 'metallicFactor': 0.0,
533
+ 'roughnessFactor': 0.8},
534
+ doubleSided=True))
535
+ else:
536
+ gltf.materials.append(Material(name='body', doubleSided=True))
537
+
538
+ body_prim = Primitive(
539
+ attributes={'POSITION': pos_acc, 'NORMAL': nor_acc,
540
+ 'TEXCOORD_0': uv_acc, 'JOINTS_0': j_acc, 'WEIGHTS_0': w_acc},
541
+ indices=idx_acc, material=0)
542
+ gltf.meshes.append(Mesh(name='body', primitives=[body_prim]))
543
+
544
+ # ── Optional skeleton mesh ─────────────────────────────────────────────────
545
+ skel_mesh_idx = None
546
+ if skel_verts is not None and len(skel_verts) > 0:
547
+ sv = skel_verts.astype(np.float32)
548
+ sf = skel_faces.astype(np.int32)
549
+
550
+ sv0, sv1, sv2 = sv[sf[:,0]], sv[sf[:,1]], sv[sf[:,2]]
551
+ sfn = np.cross(sv1-sv0, sv2-sv0)
552
+ sfn /= (np.linalg.norm(sfn, axis=1, keepdims=True) + 1e-8)
553
+ svn = np.zeros_like(sv)
554
+ for i in range(3):
555
+ np.add.at(svn, sf[:,i], sfn)
556
+ svn /= (np.linalg.norm(svn, axis=1, keepdims=True) + 1e-8)
557
+
558
+ s_pos_acc = _add(sv, FLOAT, VEC3, ARRAY_BUFFER)
559
+ s_nor_acc = _add(svn.astype(np.float32), FLOAT, VEC3, ARRAY_BUFFER)
560
+ s_idx_acc = _add(sf.astype(np.uint32).flatten(), UNSIGNED_INT, SCALAR,
561
+ ELEMENT_ARRAY_BUFFER)
562
+
563
+ # Lime-green unlit material for skeleton sticks
564
+ mat_idx = len(gltf.materials)
565
+ gltf.materials.append(Material(
566
+ name='skeleton',
567
+ pbrMetallicRoughness={
568
+ 'baseColorFactor': [0.2, 1.0, 0.3, 1.0],
569
+ 'metallicFactor': 0.0,
570
+ 'roughnessFactor': 0.5},
571
+ doubleSided=True))
572
+
573
+ skel_mesh_idx = len(gltf.meshes)
574
+ skel_prim = Primitive(
575
+ attributes={'POSITION': s_pos_acc, 'NORMAL': s_nor_acc},
576
+ indices=s_idx_acc, material=mat_idx)
577
+ gltf.meshes.append(Mesh(name='skeleton', primitives=[skel_prim]))
578
+
579
+ # ── Skeleton nodes ─────────────────────────────────────────────────────────
580
+ jnodes = []
581
+ for i, (name, parent) in enumerate(zip(SMPL_JOINT_NAMES, SMPL_PARENTS)):
582
+ t = joints[i].tolist() if parent == -1 else (joints[i] - joints[parent]).tolist()
583
+ n = Node(name=name, translation=t, children=[])
584
+ jnodes.append(len(gltf.nodes))
585
+ gltf.nodes.append(n)
586
+ for i, p in enumerate(SMPL_PARENTS):
587
+ if p != -1:
588
+ gltf.nodes[jnodes[p]].children.append(jnodes[i])
589
+
590
+ # Inverse bind matrices: IBM[j] = Translation(-J_world[j])
591
+ # glTF MAT4 is column-major; numpy .tobytes() is row-major.
592
+ # glTF reads the numpy buffer as the TRANSPOSE of what numpy stores.
593
+ # So we set the translation in the last ROW of the numpy matrix — glTF
594
+ # reads that as the last COLUMN (translation column) of a 4x4 mat.
595
+ ibms = np.stack([np.eye(4, dtype=np.float32) for _ in range(len(joints))])
596
+ for i in range(len(joints)):
597
+ ibms[i, 3, :3] = -joints[i]
598
+ ibm_acc = _add(ibms.astype(np.float32), FLOAT, MAT4)
599
+
600
+ skin_idx = len(gltf.skins)
601
+ gltf.skins.append(Skin(
602
+ name='smpl_skin', skeleton=jnodes[0],
603
+ joints=jnodes, inverseBindMatrices=ibm_acc))
604
+
605
+ mesh_node = len(gltf.nodes)
606
+ gltf.nodes.append(Node(name='body_mesh', mesh=0, skin=skin_idx))
607
+
608
+ root_children = [jnodes[0], mesh_node]
609
+
610
+ if skel_mesh_idx is not None:
611
+ skel_node_idx = len(gltf.nodes)
612
+ gltf.nodes.append(Node(name='skeleton_mesh', mesh=skel_mesh_idx))
613
+ root_children.append(skel_node_idx)
614
+
615
+ root_node = len(gltf.nodes)
616
+ gltf.nodes.append(Node(name='root', children=root_children))
617
+ gltf.scenes.append(Scene(name='Scene', nodes=[root_node]))
618
+ gltf.scene = 0
619
+
620
+ bin_data = b''.join(blobs)
621
+ gltf.buffers.append(Buffer(byteLength=len(bin_data)))
622
+ gltf.set_binary_blob(bin_data)
623
+ gltf.save_binary(out_path)
624
+ print(' rigged GLB -> %s (%d KB)' % (out_path, os.path.getsize(out_path) // 1024))
625
+
626
+
627
+ # ── Main ──────────────────────────────────────────────────────────────────────
628
+
629
+ def rig_yolo(body_glb, out_glb, debug_dir=None):
630
+ """
631
+ Rig body_glb and write to out_glb.
632
+ Returns (out_glb, out_skel_glb) where out_skel_glb includes visible
633
+ skeleton bone sticks alongside the body mesh.
634
+ """
635
+ os.makedirs(os.path.dirname(out_glb) or '.', exist_ok=True)
636
+ if debug_dir:
637
+ os.makedirs(debug_dir, exist_ok=True)
638
+
639
+ print('[rig_yolo] Rendering front view ...')
640
+ img_bgr, scale_factor = render_front(body_glb, debug_dir)
641
+
642
+ print('[rig_yolo] Running YOLO-pose ...')
643
+ kp = detect_keypoints(img_bgr, debug_dir)
644
+
645
+ print('[rig_yolo] Loading original mesh (pygltflib, correct UV channel) ...')
646
+ verts, faces, uv, texture_pil = load_mesh_from_gltf(body_glb)
647
+
648
+ print('[rig_yolo] Unprojecting YOLO keypoints to 3D ...')
649
+ coco_3d = unproject_to_3d(kp, scale_factor, verts)
650
+
651
+ print('[rig_yolo] Building SMPL-24 skeleton ...')
652
+ joints = coco17_to_smpl24(coco_3d, verts)
653
+
654
+ print('[rig_yolo] Computing skinning weights ...')
655
+ skin_weights = compute_skinning_weights(verts, joints, k=4)
656
+
657
+ print('[rig_yolo] Exporting rigged GLB (no skeleton) ...')
658
+ export_rigged_glb(verts, faces, uv, texture_pil, joints, skin_weights, out_glb)
659
+
660
+ print('[rig_yolo] Building skeleton mesh ...')
661
+ skel_verts, skel_faces = make_skeleton_mesh(joints)
662
+ out_skel_glb = out_glb.replace('.glb', '_skel.glb')
663
+ print('[rig_yolo] Exporting rigged GLB (with skeleton) ...')
664
+ export_rigged_glb(verts, faces, uv, texture_pil, joints, skin_weights,
665
+ out_skel_glb, skel_verts=skel_verts, skel_faces=skel_faces)
666
+
667
+ print('[rig_yolo] Done.')
668
+ return out_glb, out_skel_glb
669
+
670
+
671
+ if __name__ == '__main__':
672
+ ap = argparse.ArgumentParser()
673
+ ap.add_argument('--body', required=True, help='Input textured GLB')
674
+ ap.add_argument('--out', required=True, help='Output rigged GLB')
675
+ ap.add_argument('--debug_dir', default=None, help='Save debug renders here')
676
+ args = ap.parse_args()
677
+ rigged, rigged_skel = rig_yolo(args.body, args.out, args.debug_dir)
678
+ print('Rigged: ', rigged)
679
+ print('Rigged + skel: ', rigged_skel)
pipeline/tpose_smpl.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tpose_smpl.py -- T-pose a humanoid GLB via inverse Linear Blend Skinning.
3
+
4
+ Pipeline:
5
+ 1. Render front view and run HMR2 -> SMPL body_pose + betas
6
+ 2. Read rigged.glb: mesh verts (rig world space), skinning weights, T-pose joints
7
+ 3. Compute FK transforms in rig world space using HMR2 body_pose
8
+ 4. Apply inverse LBS: v_tpose = (Sum_j W_j * A_j)^-1 * v_posed
9
+ 5. Map T-posed verts back to original mesh coordinate space, preserve UV/texture
10
+ 6. Optionally export SKEL bone mesh in T-pose
11
+
12
+ Usage:
13
+ python tpose_smpl.py --body /tmp/triposg_textured.glb \
14
+ --rig /tmp/rig_out/rigged.glb \
15
+ --out /tmp/tposed_surface.glb \
16
+ [--skel_out /tmp/tposed_bones.glb] \
17
+ [--debug_dir /tmp/tpose_debug]
18
+ """
19
+
20
+ import os, sys, argparse, struct, json, warnings
21
+ warnings.filterwarnings('ignore')
22
+
23
+ import numpy as np
24
+ import cv2
25
+ import torch
26
+ import trimesh
27
+ from trimesh.visual.texture import TextureVisuals
28
+ from trimesh.visual.material import PBRMaterial
29
+ from scipy.spatial.transform import Rotation as R
30
+
31
+ sys.path.insert(0, '/root/MV-Adapter')
32
+ SMPL_NEUTRAL = '/root/body_models/smpl/SMPL_NEUTRAL.pkl'
33
+ SKEL_DIR = '/root/body_models/skel'
34
+
35
+ SMPL_PARENTS = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9,
36
+ 12, 13, 14, 16, 17, 18, 19, 20, 21]
37
+
38
+
39
+ # ---- Step 1: Render front view -----------------------------------------------
40
+
41
+ def render_front(body_glb, H=1024, W=768, device='cuda'):
42
+ from mvadapter.utils.mesh_utils import (
43
+ NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
44
+ )
45
+ ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
46
+ mesh_mv = load_mesh(body_glb, rescale=True, device=device)
47
+ camera = get_orthogonal_camera(
48
+ elevation_deg=[0], distance=[1.8],
49
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
50
+ azimuth_deg=[-90], device=device,
51
+ )
52
+ out = render(ctx, mesh_mv, camera, height=H, width=W,
53
+ render_attr=True, render_depth=False, render_normal=False,
54
+ attr_background=0.5)
55
+ img_np = (out.attr[0].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
56
+ return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
57
+
58
+
59
+ # ---- Step 2: HMR2 pose estimation --------------------------------------------
60
+
61
+ def run_hmr2(img_bgr, device='cuda'):
62
+ from pathlib import Path
63
+ from hmr2.configs import CACHE_DIR_4DHUMANS
64
+ from hmr2.models import load_hmr2, DEFAULT_CHECKPOINT, download_models
65
+ from hmr2.utils import recursive_to
66
+ from hmr2.datasets.vitdet_dataset import ViTDetDataset
67
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
68
+ from detectron2.config import LazyConfig
69
+ import hmr2 as hmr2_pkg
70
+
71
+ download_models(CACHE_DIR_4DHUMANS)
72
+ model, model_cfg = load_hmr2(DEFAULT_CHECKPOINT)
73
+ model = model.to(device).eval()
74
+
75
+ cfg_path = Path(hmr2_pkg.__file__).parent / 'configs' / 'cascade_mask_rcnn_vitdet_h_75ep.py'
76
+ det_cfg = LazyConfig.load(str(cfg_path))
77
+ det_cfg.train.init_checkpoint = (
78
+ 'https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h'
79
+ '/f328730692/model_final_f05665.pkl'
80
+ )
81
+ for i in range(3):
82
+ det_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
83
+ detector = DefaultPredictor_Lazy(det_cfg)
84
+
85
+ det_out = detector(img_bgr)
86
+ instances = det_out['instances']
87
+ valid = (instances.pred_classes == 0) & (instances.scores > 0.5)
88
+ boxes = instances.pred_boxes.tensor[valid].cpu().numpy()
89
+ if len(boxes) == 0:
90
+ raise RuntimeError('HMR2: no person detected in render')
91
+
92
+ areas = (boxes[:,2]-boxes[:,0]) * (boxes[:,3]-boxes[:,1])
93
+ boxes = boxes[areas.argmax():areas.argmax()+1]
94
+
95
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
96
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
97
+ for batch in dataloader:
98
+ batch = recursive_to(batch, device)
99
+ with torch.no_grad():
100
+ out = model(batch)
101
+ sp = out['pred_smpl_params']
102
+ return {
103
+ 'body_pose': sp['body_pose'][0].cpu(), # (23, 3, 3)
104
+ 'betas': sp['betas'][0].cpu(), # (10,)
105
+ }
106
+
107
+
108
+ # ---- Step 3: Read all data from rigged.glb -----------------------------------
109
+
110
+ def read_rigged_glb(rig_glb):
111
+ """
112
+ Returns dict with:
113
+ verts : (N, 3) mesh vertices in rig world space
114
+ j_idx : (N, 4) joint indices
115
+ w_arr : (N, 4) skinning weights
116
+ J_bind : (24, 3) T-pose joint world positions
117
+ """
118
+ with open(rig_glb, 'rb') as fh:
119
+ raw = fh.read()
120
+ ch_len, _ = struct.unpack_from('<II', raw, 12)
121
+ gltf = json.loads(raw[20:20+ch_len])
122
+ bin_data = raw[20+ch_len+8:]
123
+
124
+ def _read(acc_i):
125
+ acc = gltf['accessors'][acc_i]
126
+ bv = gltf['bufferViews'][acc['bufferView']]
127
+ off = bv.get('byteOffset', 0) + acc.get('byteOffset', 0)
128
+ cnt = acc['count']
129
+ n = {'SCALAR':1,'VEC2':2,'VEC3':3,'VEC4':4,'MAT4':16}[acc['type']]
130
+ fmt = {5121:'B',5123:'H',5125:'I',5126:'f'}[acc['componentType']]
131
+ nb = {'B':1,'H':2,'I':4,'f':4}[fmt]
132
+ return np.frombuffer(bin_data[off:off+cnt*n*nb],
133
+ dtype=np.dtype(fmt)).reshape(cnt, n)
134
+
135
+ prim = gltf['meshes'][0]['primitives'][0]['attributes']
136
+ verts = _read(prim['POSITION']).astype(np.float64) # (N, 3)
137
+ j_idx = _read(prim['JOINTS_0']).astype(int) # (N, 4)
138
+ w_arr = _read(prim['WEIGHTS_0']).astype(np.float64) # (N, 4)
139
+ row_sum = w_arr.sum(axis=1, keepdims=True)
140
+ w_arr /= np.where(row_sum > 0, row_sum, 1.0)
141
+
142
+ # Read T-pose joint world positions by accumulating node translations
143
+ nodes = gltf['nodes']
144
+ skin = gltf['skins'][0]
145
+ j_nodes = skin['joints'] # [0, 1, ..., 23]
146
+ J_bind = np.zeros((24, 3), dtype=np.float64)
147
+ for ji, ni in enumerate(j_nodes):
148
+ t_local = np.array(nodes[ni].get('translation', [0, 0, 0]))
149
+ p = SMPL_PARENTS[ji]
150
+ J_bind[ji] = (J_bind[p] if p >= 0 else np.zeros(3)) + t_local
151
+
152
+ print(' Rig verts: %d Y: [%.3f, %.3f] X: [%.3f, %.3f]' % (
153
+ len(verts),
154
+ verts[:,1].min(), verts[:,1].max(),
155
+ verts[:,0].min(), verts[:,0].max()))
156
+ print(' J_bind pelvis: (%.3f, %.3f, %.3f) L_shoulder: (%.3f, %.3f, %.3f)' % (
157
+ *J_bind[0], *J_bind[16]))
158
+ return {'verts': verts, 'j_idx': j_idx, 'w_arr': w_arr, 'J_bind': J_bind}
159
+
160
+
161
+ # ---- Step 4: FK in rig world space -> A matrices -----------------------------
162
+
163
+ _FLIP_X = np.diag([-1.0, 1.0, 1.0]) # X-axis mirror matrix
164
+
165
+
166
+ def _adapt_rotmat_to_flipped_x(R_smpl):
167
+ """
168
+ Convert an SO(3) rotation matrix from SMPL convention (left=+X)
169
+ to rig convention (left=-X). F @ R @ F where F = diag(-1,1,1).
170
+ """
171
+ return _FLIP_X @ R_smpl @ _FLIP_X
172
+
173
+
174
+ def compute_rig_fk_transforms(J_bind, body_pose_rotmats):
175
+ """
176
+ Compute A_j = G_j_posed * IBM_j in rig world space.
177
+ A_j maps T-pose -> posed, so A_j^{-1} maps posed -> T-pose.
178
+
179
+ HMR2 returns rotations in SMPL convention (left shoulder at +X).
180
+ The rig uses the opposite convention (left shoulder at -X).
181
+ We convert by conjugating with the X-flip matrix before building FK.
182
+
183
+ J_bind : (24, 3) T-pose joint world positions from rig
184
+ body_pose_rotmats: (23, 3, 3) HMR2 body pose rotation matrices (joints 1-23)
185
+ Returns A: (24, 4, 4)
186
+ """
187
+ G = [None] * 24
188
+ for j in range(24):
189
+ p = SMPL_PARENTS[j]
190
+ # Convert rotation from SMPL (+X=left) to rig (-X=left) convention
191
+ R_smpl = body_pose_rotmats[j-1].numpy() if j >= 1 else np.eye(3)
192
+ R_j = _adapt_rotmat_to_flipped_x(R_smpl)
193
+
194
+ if p < 0:
195
+ t_j = J_bind[j] # root: absolute world position
196
+ else:
197
+ t_j = J_bind[j] - J_bind[p]
198
+
199
+ L = np.eye(4, dtype=np.float64)
200
+ L[:3, :3] = R_j
201
+ L[:3, 3] = t_j
202
+
203
+ G[j] = L if p < 0 else G[p] @ L
204
+
205
+ G = np.stack(G)
206
+
207
+ A = np.zeros((24, 4, 4), dtype=np.float64)
208
+ for j in range(24):
209
+ IBM = np.eye(4, dtype=np.float64)
210
+ IBM[:3, 3] = -J_bind[j]
211
+ A[j] = G[j] @ IBM
212
+
213
+ return A
214
+
215
+
216
+ # ---- Step 5: Inverse LBS -----------------------------------------------------
217
+
218
+ def inverse_lbs(verts, j_idx, w_arr, A):
219
+ """
220
+ v_tpose = (Sum_j W_j * A_j)^{-1} * v_posed
221
+ All inputs in rig world space.
222
+ Returns (N, 3) T-posed vertices.
223
+ """
224
+ N = len(verts)
225
+ # Blend forward transforms
226
+ T_fwd = np.zeros((N, 4, 4), dtype=np.float64)
227
+ for k in range(4):
228
+ ji = j_idx[:, k]
229
+ w = w_arr[:, k]
230
+ mask = w > 1e-6
231
+ if mask.any():
232
+ T_fwd[mask] += w[mask, None, None] * A[ji[mask]]
233
+
234
+ T_inv = np.linalg.inv(T_fwd)
235
+ v_h = np.concatenate([verts, np.ones((N, 1))], axis=1)
236
+ v_tp = np.einsum('nij,nj->ni', T_inv, v_h)[:, :3]
237
+
238
+ disp = np.linalg.norm(v_tp - verts, axis=1)
239
+ print(' inverse LBS: mean_disp=%.4f max_disp=%.4f' % (disp.mean(), disp.max()))
240
+ return v_tp
241
+
242
+
243
+ # ---- Step 6: Map T-posed rig verts back to original mesh space ---------------
244
+
245
+ def rig_to_original_space(rig_verts_tposed, rig_verts_original, orig_mesh_verts):
246
+ """
247
+ Rig verts are a scaled + translated version of the original mesh verts.
248
+ Recover the (scale, offset) from the mapping:
249
+ rig_vert = orig_vert * scale + offset
250
+
251
+ Estimates scale from height ratio, offset from floor alignment.
252
+ Returns T-posed vertices in original mesh coordinate space.
253
+ """
254
+ rig_h = rig_verts_original[:, 1].max() - rig_verts_original[:, 1].min()
255
+ orig_h = orig_mesh_verts[:, 1].max() - orig_mesh_verts[:, 1].min()
256
+ scale = rig_h / max(orig_h, 1e-6)
257
+
258
+ # The rig aligns: orig * scale, then v[:,1] -= v[:,1].min() (floor at 0)
259
+ # and v[:,0] += smpl_joints[0,0] - cx; v[:,2] += smpl_joints[0,2] - cz
260
+ # We can recover offset from comparing means/floors
261
+ # offset = rig_floor_Y - (orig_floor_Y * scale)
262
+ rig_floor = rig_verts_original[:, 1].min()
263
+ orig_floor = orig_mesh_verts[:, 1].min()
264
+ y_offset = rig_floor - orig_floor * scale
265
+
266
+ # X, Z: center offset
267
+ rig_cx = (rig_verts_original[:, 0].max() + rig_verts_original[:, 0].min()) * 0.5
268
+ orig_cx = (orig_mesh_verts[:, 0].max() + orig_mesh_verts[:, 0].min()) * 0.5
269
+ x_offset = rig_cx - orig_cx * scale
270
+
271
+ rig_cz = (rig_verts_original[:, 2].max() + rig_verts_original[:, 2].min()) * 0.5
272
+ orig_cz = (orig_mesh_verts[:, 2].max() + orig_mesh_verts[:, 2].min()) * 0.5
273
+ z_offset = rig_cz - orig_cz * scale
274
+
275
+ print(' rig->orig: scale=%.4f offset=[%.3f, %.3f, %.3f]' % (scale, x_offset, y_offset, z_offset))
276
+
277
+ # Invert: orig_vert = (rig_vert - offset) / scale
278
+ # For T-posed verts: they're in rig space but T-posed, so same inversion
279
+ tposed_orig = np.zeros_like(rig_verts_tposed)
280
+ tposed_orig[:, 0] = (rig_verts_tposed[:, 0] - x_offset) / scale
281
+ tposed_orig[:, 1] = (rig_verts_tposed[:, 1] - y_offset) / scale
282
+ tposed_orig[:, 2] = (rig_verts_tposed[:, 2] - z_offset) / scale
283
+ return tposed_orig
284
+
285
+
286
+ # ---- SKEL bone geometry ------------------------------------------------------
287
+
288
+ def export_skel_bones(betas, out_path, gender='male'):
289
+ try:
290
+ from skel.skel_model import SKEL
291
+ except ImportError:
292
+ print(' [skel] Not installed')
293
+ return None
294
+ skel_file = os.path.join(SKEL_DIR, 'skel_%s.pkl' % gender)
295
+ if not os.path.exists(skel_file):
296
+ print(' [skel] Weights not found: %s' % skel_file)
297
+ return None
298
+ try:
299
+ skel_model = SKEL(gender=gender, model_path=SKEL_DIR)
300
+ betas_t = betas.unsqueeze(0)[:, :10]
301
+ poses_zero = torch.zeros(1, 46)
302
+ trans_zero = torch.zeros(1, 3)
303
+ with torch.no_grad():
304
+ out = skel_model(poses=poses_zero, betas=betas_t, trans=trans_zero, skelmesh=True)
305
+ bone_verts = out.skel_verts[0].numpy()
306
+ bone_faces = skel_model.skel_f.numpy()
307
+ mesh = trimesh.Trimesh(vertices=bone_verts, faces=bone_faces, process=False)
308
+ mesh.export(out_path)
309
+ print(' [skel] Bone mesh -> %s (%d verts)' % (out_path, len(bone_verts)))
310
+ return out_path
311
+ except Exception as e:
312
+ print(' [skel] Export failed: %s' % e)
313
+ return None
314
+
315
+
316
+ # ---- Main --------------------------------------------------------------------
317
+
318
+ def tpose_smpl(body_glb, out_glb, rig_glb=None, debug_dir=None, skel_out=None):
319
+ device = 'cuda'
320
+
321
+ if not rig_glb or not os.path.exists(rig_glb):
322
+ raise RuntimeError('--rig is required: provide the rigged.glb from the Rig step.')
323
+
324
+ print('[tpose_smpl] Rendering front view ...')
325
+ img_bgr = render_front(body_glb, device=device)
326
+ if debug_dir:
327
+ cv2.imwrite(os.path.join(debug_dir, 'tpose_render.png'), img_bgr)
328
+
329
+ print('[tpose_smpl] Running HMR2 pose estimation ...')
330
+ hmr2_out = run_hmr2(img_bgr, device=device)
331
+ print(' betas: %s' % hmr2_out['betas'].numpy().round(3))
332
+
333
+ print('[tpose_smpl] Reading rigged GLB (rig world space) ...')
334
+ rig_data = read_rigged_glb(rig_glb)
335
+
336
+ print('[tpose_smpl] Loading original mesh for UV/texture ...')
337
+ scene = trimesh.load(body_glb)
338
+ if isinstance(scene, trimesh.Scene):
339
+ geom_name = list(scene.geometry.keys())[0]
340
+ orig_mesh = scene.geometry[geom_name]
341
+ else:
342
+ orig_mesh = scene; geom_name = None
343
+
344
+ orig_verts = np.array(orig_mesh.vertices, dtype=np.float64)
345
+ uvs = np.array(orig_mesh.visual.uv, dtype=np.float64)
346
+ orig_tex = orig_mesh.visual.material.baseColorTexture
347
+ print(' Orig mesh: %d verts Y: [%.3f, %.3f] X: [%.3f, %.3f]' % (
348
+ len(orig_verts),
349
+ orig_verts[:,1].min(), orig_verts[:,1].max(),
350
+ orig_verts[:,0].min(), orig_verts[:,0].max()))
351
+
352
+ print('[tpose_smpl] Computing FK transforms in rig world space ...')
353
+ body_pose_rotmats = hmr2_out['body_pose'] # (23, 3, 3)
354
+ A = compute_rig_fk_transforms(rig_data['J_bind'], body_pose_rotmats)
355
+
356
+ # Verify zero-pose gives identity (sanity check)
357
+ A_zero = compute_rig_fk_transforms(rig_data['J_bind'],
358
+ torch.zeros(23, 3, 3) + torch.eye(3))
359
+ v_test = rig_data['verts'][:3]
360
+ v_h = np.concatenate([v_test, np.ones((3,1))], axis=1)
361
+ T_fwd_test = np.zeros((3, 4, 4))
362
+ for k in range(4):
363
+ ji = rig_data['j_idx'][:3, k]; w = rig_data['w_arr'][:3, k]
364
+ T_fwd_test += w[:, None, None] * A_zero[ji]
365
+ identity_err = np.abs(T_fwd_test - np.eye(4)).max()
366
+ print(' zero-pose identity check: max_err=%.6f (expect ~0)' % identity_err)
367
+
368
+ print('[tpose_smpl] Applying inverse LBS ...')
369
+ rig_verts_tposed = inverse_lbs(
370
+ rig_data['verts'], rig_data['j_idx'], rig_data['w_arr'], A)
371
+
372
+ print('[tpose_smpl] T-posed rig verts: Y: [%.3f, %.3f] X: [%.3f, %.3f]' % (
373
+ rig_verts_tposed[:,1].min(), rig_verts_tposed[:,1].max(),
374
+ rig_verts_tposed[:,0].min(), rig_verts_tposed[:,0].max()))
375
+
376
+ print('[tpose_smpl] Mapping back to original mesh coordinate space ...')
377
+ tposed_orig = rig_to_original_space(
378
+ rig_verts_tposed, rig_data['verts'], orig_verts)
379
+
380
+ print('[tpose_smpl] T-posed orig: Y: [%.3f, %.3f] X: [%.3f, %.3f]' % (
381
+ tposed_orig[:,1].min(), tposed_orig[:,1].max(),
382
+ tposed_orig[:,0].min(), tposed_orig[:,0].max()))
383
+
384
+ orig_mesh.vertices = tposed_orig
385
+ orig_mesh.visual = TextureVisuals(uv=uvs,
386
+ material=PBRMaterial(baseColorTexture=orig_tex))
387
+
388
+ if geom_name and isinstance(scene, trimesh.Scene):
389
+ scene.geometry[geom_name] = orig_mesh
390
+ scene.export(out_glb)
391
+ else:
392
+ orig_mesh.export(out_glb)
393
+
394
+ print('[tpose_smpl] Saved: %s (%d KB)' % (out_glb, os.path.getsize(out_glb)//1024))
395
+
396
+ if skel_out:
397
+ print('[tpose_smpl] Exporting SKEL bone geometry ...')
398
+ export_skel_bones(hmr2_out['betas'], skel_out)
399
+
400
+ return out_glb
401
+
402
+
403
+ if __name__ == '__main__':
404
+ ap = argparse.ArgumentParser()
405
+ ap.add_argument('--body', required=True)
406
+ ap.add_argument('--out', required=True)
407
+ ap.add_argument('--rig', required=True, help='Rigged GLB from rig step')
408
+ ap.add_argument('--skel_out', default=None, help='SKEL BSM bone mesh output')
409
+ ap.add_argument('--debug_dir', default=None)
410
+ args = ap.parse_args()
411
+ os.makedirs(args.debug_dir, exist_ok=True) if args.debug_dir else None
412
+ tpose_smpl(args.body, args.out, rig_glb=args.rig,
413
+ debug_dir=args.debug_dir, skel_out=args.skel_out)
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace ZeroGPU Space — Docker SDK
2
+ # chumpy is pre-installed in the Dockerfile with --no-build-isolation
3
+ # (its setup.py does `import pip` which breaks in modern pip isolated builds)
4
+ spaces
5
+
6
+ # Git-pinned installs
7
+ hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c
8
+ clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
9
+ mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be
10
+ chumpy @ git+https://github.com/mattloper/chumpy.git@580566eafc9ac68b2614b64d6f7aaa84eebb70da
11
+ skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95
12
+ nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae
13
+ diso @ git+https://github.com/SarahWeiii/diso.git@9792ad928ccb09bdec938779651ee03e395758a6
14
+ detectron2 @ git+https://github.com/facebookresearch/detectron2.git@8a9d885b3d4dcf1bef015f0593b872ed8d32b4ab
15
+
16
+ # Core ML
17
+ accelerate
18
+ diffusers>=0.37.0
19
+ transformers>=5.0.0
20
+ safetensors
21
+ huggingface_hub
22
+ peft
23
+ einops
24
+ timm
25
+ xformers
26
+
27
+ # 3D / Mesh
28
+ trimesh
29
+ open3d
30
+ pymeshlab
31
+ pygltflib
32
+ pyrender
33
+ moderngl
34
+ moderngl-window
35
+
36
+ # Body model
37
+ smplx
38
+ smplpytorch
39
+
40
+ # Pose / Motion
41
+ ultralytics
42
+ pyquaternion
43
+ kornia
44
+
45
+ # Face enhancement
46
+ insightface
47
+ onnxruntime-gpu
48
+ basicsr
49
+ realesrgan
50
+ gfpgan
51
+ facexlib
52
+ face-alignment
53
+
54
+ # Surface enhancement
55
+ stablenormal
56
+ controlnet_aux
57
+
58
+ # CV
59
+ opencv-python-headless
60
+ scikit-image
61
+ albumentations
62
+
63
+ # Scientific
64
+ numpy
65
+ scipy
66
+ scikit-learn
67
+ pandas
68
+
69
+ # Utils
70
+ easydict
71
+ omegaconf
72
+ yacs
73
+ gdown
74
+ pycocotools