Daankular commited on
Commit
8f1bcd9
·
1 Parent(s): 5d73995

Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search

Browse files

New tabs and features from the MeshForge server version, adapted for ZeroGPU:
- Edit tab: FireRed GGUF-quantized image editing (QwenImageEditPlusPipeline)
- Animate tab: HumanML3D motion search + GLB animation via Retarget/
- PSHuman Face tab: HD face transplant via PSHuman service + face_transplant.py
- Settings tab: VRAM management (preload/unload/refresh)
- Generate tab: remove-BG preview controls + FireRed→Generate flow
- Enhancement tab: unload button for VRAM management

New pipeline modules: face_transplant, pshuman_client, render_glb, tpose,
face_inswap_bake, face_project, face_swap_render, head_replace
New Retarget/ directory: motion search, animate, skeleton, SMPL retargeting
New utils/pytorch3d_minimal.py
Updated requirements.txt: bitsandbytes, gradio_client, filterpy, pytorch-lightning,
lightning-utilities, webdataset, hydra-core, matplotlib

Retarget/README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rig_retarget
2
+
3
+ Pure-Python rig retargeting library. No Blender required.
4
+
5
+ Based on **[KeeMap Blender Rig Retargeting Addon](https://github.com/nkeeline/Keemap-Blender-Rig-ReTargeting-Addon)** by [Nick Keeline](https://github.com/nkeeline) (GPL v2).
6
+ All core retargeting math is a direct port of his work. Mapping JSON files are fully compatible with KeeMap.
7
+
8
+ ---
9
+
10
+ ## File layout
11
+
12
+ ```
13
+ rig_retarget/
14
+ ├── math3d.py # Quaternion / matrix math (numpy + scipy), replaces mathutils
15
+ ├── skeleton.py # Armature + PoseBone with FK, replaces bpy armature objects
16
+ ├── retarget.py # Core retargeting logic — faithful port of KeeMapBoneOperators.py
17
+ ├── cli.py # CLI entry point
18
+ └── io/
19
+ ├── bvh.py # BVH mocap reader (source animation)
20
+ ├── gltf_io.py # glTF/GLB reader + animation writer (UniRig destination)
21
+ └── mapping.py # JSON bone mapping — same format as KeeMap
22
+ ```
23
+
24
+ ---
25
+
26
+ ## Install
27
+
28
+ ```bash
29
+ pip install numpy scipy pygltflib
30
+ ```
31
+
32
+ ---
33
+
34
+ ## CLI
35
+
36
+ ```bash
37
+ # Retarget BVH onto UniRig GLB
38
+ python -m rig_retarget.cli \
39
+ --source motion.bvh \
40
+ --dest unirig_character.glb \
41
+ --mapping radical2unirig.json \
42
+ --output animated_character.glb \
43
+ --fps 30 --start 0 --frames 200 --step 1
44
+
45
+ # Auto-calculate bone correction factors and save back to the mapping file
46
+ python -m rig_retarget.cli --calc-corrections \
47
+ --source motion.bvh \
48
+ --dest unirig_character.glb \
49
+ --mapping mymap.json
50
+ ```
51
+
52
+ ---
53
+
54
+ ## Python API
55
+
56
+ ```python
57
+ from rig_retarget.io.bvh import load_bvh
58
+ from rig_retarget.io.gltf_io import load_gltf, write_gltf_animation
59
+ from rig_retarget.io.mapping import load_mapping
60
+ from rig_retarget.retarget import transfer_animation, calc_all_corrections
61
+
62
+ # Load
63
+ settings, bone_items = load_mapping("my_map.json")
64
+ src_anim = load_bvh("motion.bvh")
65
+ dst_arm = load_gltf("unirig_char.glb")
66
+
67
+ # Optional: auto-calc corrections at first frame
68
+ src_anim.apply_frame(0)
69
+ calc_all_corrections(bone_items, src_anim.armature, dst_arm, settings)
70
+
71
+ # Transfer
72
+ settings.number_of_frames_to_apply = src_anim.num_frames
73
+ keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
74
+
75
+ # Write output GLB
76
+ write_gltf_animation("unirig_char.glb", dst_arm, keyframes, "output.glb")
77
+ ```
78
+
79
+ ---
80
+
81
+ ## Mapping JSON format
82
+
83
+ 100% compatible with KeeMap's `.json` files. Use KeeMap in Blender to create
84
+ and tune mappings, then use this library offline for batch processing.
85
+
86
+ Key fields per bone:
87
+
88
+ | Field | Description |
89
+ |---|---|
90
+ | `SourceBoneName` | Bone name in the source rig (BVH joint name) |
91
+ | `DestinationBoneName` | Bone name in the UniRig skeleton (glTF node name) |
92
+ | `set_bone_rotation` | Drive rotation from source |
93
+ | `set_bone_position` | Drive position from source |
94
+ | `bone_rotation_application_axis` | Mask axes: `X` `Y` `Z` `XY` `XZ` `YZ` `XYZ` |
95
+ | `bone_transpose_axis` | Swap axes: `NONE` `ZYX` `ZXY` `XZY` `YZX` `YXZ` |
96
+ | `CorrectionFactorX/Y/Z` | Euler correction (radians) |
97
+ | `postion_type` | `SINGLE_BONE_OFFSET` or `POLE` |
98
+ | `position_pole_distance` | IK pole distance |
99
+
100
+ ---
101
+
102
+ ## Blender → pure-Python mapping
103
+
104
+ | Blender | rig_retarget |
105
+ |---|---|
106
+ | `bpy.data.objects[name]` | `Armature` + `load_gltf()` / `load_bvh()` |
107
+ | `arm.pose.bones[name]` | `arm.get_bone(name)` → `PoseBone` |
108
+ | `bone.matrix` (pose space) | `bone.matrix_armature` |
109
+ | `arm.matrix_world` | `arm.world_matrix` |
110
+ | `arm.convert_space(...)` | `arm.world_matrix @ bone.matrix_armature` |
111
+ | `bone.rotation_quaternion` | `bone.pose_rotation_quat` |
112
+ | `bone.location` | `bone.pose_location` |
113
+ | `bone.keyframe_insert(...)` | returned in `keyframes` list from `transfer_frame()` |
114
+ | `bpy.context.scene.frame_set(i)` | `src_anim.apply_frame(i)` |
115
+ | `mathutils.Quaternion` | `np.ndarray [w,x,y,z]` + `math3d.*` |
116
+ | `mathutils.Matrix` | `np.ndarray (4,4)` |
117
+
118
+ ---
119
+
120
+ ## Limitations / TODO
121
+
122
+ - **glTF source animation** reading not yet implemented (BVH only for now).
123
+ Add `io/gltf_anim_reader.py` reading `gltf.animations[0]` sampler data.
124
+ - FBX source support: use `pyassimp` or `bpy` offline with `--background`.
125
+ - IK solving: pole bone positioning is FK-only; a full IK solver (FABRIK/CCD)
126
+ would improve accuracy for limb targets.
127
+ - Quaternion mode twist bones: parity with Blender not guaranteed for complex twist rigs.
Retarget/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rig_retarget
3
+ ============
4
+ Pure-Python rig retargeting library.
5
+ No Blender dependency. Targets TripoSG meshes auto-rigged by UniRig (SIGGRAPH 2025).
6
+
7
+ Quick start
8
+ -----------
9
+ from rig_retarget.io.bvh import load_bvh
10
+ from rig_retarget.io.gltf_io import load_gltf, write_gltf_animation
11
+ from rig_retarget.io.mapping import load_mapping
12
+ from rig_retarget.retarget import transfer_animation
13
+
14
+ settings, bone_items = load_mapping("my_map.json")
15
+ src_anim = load_bvh("motion.bvh")
16
+ dst_arm = load_gltf("unirig_char.glb")
17
+
18
+ keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
19
+ write_gltf_animation("unirig_char.glb", dst_arm, keyframes, "output.glb")
20
+
21
+ CLI
22
+ ---
23
+ python -m rig_retarget.cli --source motion.bvh --dest char.glb \\
24
+ --mapping map.json --output char_animated.glb
25
+ """
26
+
27
+ from .skeleton import Armature, PoseBone
28
+ from .retarget import (
29
+ get_bone_position_ws,
30
+ get_bone_ws_quat,
31
+ set_bone_position_ws,
32
+ set_bone_rotation,
33
+ set_bone_position,
34
+ set_bone_position_pole,
35
+ set_bone_scale,
36
+ calc_rotation_offset,
37
+ calc_location_offset,
38
+ calc_all_corrections,
39
+ transfer_frame,
40
+ transfer_animation,
41
+ )
42
+
43
+ __all__ = [
44
+ "Armature", "PoseBone",
45
+ "get_bone_position_ws", "get_bone_ws_quat", "set_bone_position_ws",
46
+ "set_bone_rotation", "set_bone_position", "set_bone_position_pole",
47
+ "set_bone_scale", "calc_rotation_offset", "calc_location_offset",
48
+ "calc_all_corrections", "transfer_frame", "transfer_animation",
49
+ ]
Retarget/animate.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ animate.py
3
+ ──────────────────────────────────────────────────────────────────────────────
4
+ Bake SMPL motion (from HumanML3D [T, 263] features) onto a UniRig-rigged GLB.
5
+
6
+ Retargeting method: world-direction matching
7
+ ────────────────────────────────────────────
8
+ Commercial retargeters (Mixamo, Rokoko, MotionBuilder) avoid rest-pose
9
+ convention mismatches by matching WORLD BONE DIRECTIONS, not local rotations.
10
+
11
+ Algorithm (per frame, per bone):
12
+ 1. Run t2m FK with HumanML3D 6D rotations → world bone direction d_t2m
13
+ 2. Flip X axis: t2m +X = character's LEFT; SMPL/UniRig +X = character's RIGHT
14
+ So d_desired = (-d_t2m_x, d_t2m_y, d_t2m_z) in SMPL/UniRig world frame
15
+ 3. d_rest = normalize(ur_pos[bone] - ur_pos[parent]) from GLB inverse bind matrices
16
+ 4. R_world = R_between(d_rest, d_desired) -- minimal rotation in world space
17
+ 5. local_rot = inv(R_world[parent]) @ R_world[bone]
18
+ 6. pose_rot_delta = inv(rest_r) @ local_rot -- composing with glTF rest rotation
19
+
20
+ This avoids all rest-pose convention issues:
21
+ - t2m canonical arms point DOWN: handled automatically
22
+ - t2m canonical hips/shoulders have inverted X: handled by the X-flip
23
+ - UniRig non-identity rest rotations: handled by inv(rest_r) composition
24
+
25
+ Key bugs fixed vs previous version:
26
+ - IBM column-major: glTF IBMs are column-major; was using inv(ibm)[:3,3] (zeros).
27
+ Fixed to inv(ibm.T)[:3,3] which gives correct world-space bone positions.
28
+ - Normalisation: was mixing ur/smpl Y ranges, causing wrong height alignment.
29
+ Fixed with independent per-skeleton Y normalisation.
30
+ - Rotation convention: was applying t2m rotations directly without X-flip.
31
+ Fixed by world-direction matching with coordinate-frame conversion.
32
+ """
33
+ from __future__ import annotations
34
+ import os
35
+ import re
36
+ import numpy as np
37
+ from typing import Union
38
+
39
+ from .smpl import SMPLMotion, hml3d_to_smpl_motion
40
+
41
+
42
+ # ──────────────────────────────────────────────────────────────────────────────
43
+ # T2M (HumanML3D) skeleton constants
44
+ # Source: HumanML3D/common/paramUtil.py
45
+ # ──────────────────────────────────────────────────────────────────────────────
46
+
47
+ T2M_RAW_OFFSETS = np.array([
48
+ [ 0, 0, 0], # 0 Hips (root)
49
+ [ 1, 0, 0], # 1 LeftUpLeg +X = character LEFT in t2m convention
50
+ [-1, 0, 0], # 2 RightUpLeg
51
+ [ 0, 1, 0], # 3 Spine
52
+ [ 0,-1, 0], # 4 LeftLeg
53
+ [ 0,-1, 0], # 5 RightLeg
54
+ [ 0, 1, 0], # 6 Spine1
55
+ [ 0,-1, 0], # 7 LeftFoot
56
+ [ 0,-1, 0], # 8 RightFoot
57
+ [ 0, 1, 0], # 9 Spine2
58
+ [ 0, 0, 1], # 10 LeftToeBase
59
+ [ 0, 0, 1], # 11 RightToeBase
60
+ [ 0, 1, 0], # 12 Neck
61
+ [ 1, 0, 0], # 13 LeftShoulder +X = character LEFT
62
+ [-1, 0, 0], # 14 RightShoulder
63
+ [ 0, 0, 1], # 15 Head
64
+ [ 0,-1, 0], # 16 LeftArm arms hang DOWN in t2m canonical
65
+ [ 0,-1, 0], # 17 RightArm
66
+ [ 0,-1, 0], # 18 LeftForeArm
67
+ [ 0,-1, 0], # 19 RightForeArm
68
+ [ 0,-1, 0], # 20 LeftHand
69
+ [ 0,-1, 0], # 21 RightHand
70
+ ], dtype=np.float64)
71
+
72
+ T2M_KINEMATIC_CHAIN = [
73
+ [0, 2, 5, 8, 11], # Hips -> RightUpLeg -> RightLeg -> RightFoot -> RightToe
74
+ [0, 1, 4, 7, 10], # Hips -> LeftUpLeg -> LeftLeg -> LeftFoot -> LeftToe
75
+ [0, 3, 6, 9, 12, 15], # Hips -> Spine -> Spine1 -> Spine2 -> Neck -> Head
76
+ [9, 14, 17, 19, 21], # Spine2 -> RightShoulder -> RightArm -> RightForeArm -> RightHand
77
+ [9, 13, 16, 18, 20], # Spine2 -> LeftShoulder -> LeftArm -> LeftForeArm -> LeftHand
78
+ ]
79
+
80
+ # Parent joint index for each of the 22 t2m joints
81
+ T2M_PARENTS = [-1] * 22
82
+ for _chain in T2M_KINEMATIC_CHAIN:
83
+ for _k in range(1, len(_chain)):
84
+ T2M_PARENTS[_chain[_k]] = _chain[_k - 1]
85
+
86
+ # ──────────────────────────────────────────────────────────────────────────────
87
+ # SMPL joint names / T-pose (for bone mapping reference)
88
+ # ──────────────────────────────────────────────────────────────────────────────
89
+
90
+ SMPL_NAMES = [
91
+ "Hips", "LeftUpLeg", "RightUpLeg", "Spine",
92
+ "LeftLeg", "RightLeg", "Spine1", "LeftFoot",
93
+ "RightFoot", "Spine2", "LeftToeBase", "RightToeBase",
94
+ "Neck", "LeftShoulder", "RightShoulder","Head",
95
+ "LeftArm", "RightArm", "LeftForeArm", "RightForeArm",
96
+ "LeftHand", "RightHand",
97
+ ]
98
+
99
+ # Approximate T-pose joint world positions in metres (Y-up, facing +Z)
100
+ # +X = character's RIGHT (standard SMPL/UniRig convention)
101
+ SMPL_TPOSE = np.array([
102
+ [ 0.000, 0.920, 0.000], # 0 Hips
103
+ [-0.095, 0.920, 0.000], # 1 LeftUpLeg (character's left = -X)
104
+ [ 0.095, 0.920, 0.000], # 2 RightUpLeg
105
+ [ 0.000, 0.980, 0.000], # 3 Spine
106
+ [-0.095, 0.495, 0.000], # 4 LeftLeg
107
+ [ 0.095, 0.495, 0.000], # 5 RightLeg
108
+ [ 0.000, 1.050, 0.000], # 6 Spine1
109
+ [-0.095, 0.075, 0.000], # 7 LeftFoot
110
+ [ 0.095, 0.075, 0.000], # 8 RightFoot
111
+ [ 0.000, 1.120, 0.000], # 9 Spine2
112
+ [-0.095, 0.000, -0.020], # 10 LeftToeBase
113
+ [ 0.095, 0.000, -0.020], # 11 RightToeBase
114
+ [ 0.000, 1.370, 0.000], # 12 Neck
115
+ [-0.130, 1.290, 0.000], # 13 LeftShoulder
116
+ [ 0.130, 1.290, 0.000], # 14 RightShoulder
117
+ [ 0.000, 1.500, 0.000], # 15 Head
118
+ [-0.330, 1.290, 0.000], # 16 LeftArm
119
+ [ 0.330, 1.290, 0.000], # 17 RightArm
120
+ [-0.630, 1.290, 0.000], # 18 LeftForeArm
121
+ [ 0.630, 1.290, 0.000], # 19 RightForeArm
122
+ [-0.910, 1.290, 0.000], # 20 LeftHand
123
+ [ 0.910, 1.290, 0.000], # 21 RightHand
124
+ ], dtype=np.float32)
125
+
126
+ # Name hint table: lowercase substrings -> SMPL joint index
127
+ _NAME_HINTS: list[tuple[list[str], int]] = [
128
+ (["hips","pelvis","root"], 0),
129
+ (["leftupleg","l_upleg","leftthigh","lefthip","thigh_l"], 1),
130
+ (["rightupleg","r_upleg","rightthigh","righthip","thigh_r"], 2),
131
+ (["spine","spine0","spine_01"], 3),
132
+ (["leftleg","leftknee","lowerleg_l","knee_l"], 4),
133
+ (["rightleg","rightknee","lowerleg_r","knee_r"], 5),
134
+ (["spine1","spine_02"], 6),
135
+ (["leftfoot","l_foot","foot_l"], 7),
136
+ (["rightfoot","r_foot","foot_r"], 8),
137
+ (["spine2","spine_03","chest"], 9),
138
+ (["lefttoebase","lefttoe","l_toe","toe_l"], 10),
139
+ (["righttoebase","righttoe","r_toe","toe_r"], 11),
140
+ (["neck"], 12),
141
+ (["leftshoulder","leftcollar","clavicle_l"], 13),
142
+ (["rightshoulder","rightcollar","clavicle_r"], 14),
143
+ (["head"], 15),
144
+ (["leftarm","upperarm_l","l_arm"], 16),
145
+ (["rightarm","upperarm_r","r_arm"], 17),
146
+ (["leftforearm","lowerarm_l","l_forearm"], 18),
147
+ (["rightforearm","lowerarm_r","r_forearm"], 19),
148
+ (["lefthand","hand_l","l_hand"], 20),
149
+ (["righthand","hand_r","r_hand"], 21),
150
+ ]
151
+
152
+
153
+ # ──────────────────────────────────────────────────────────────────────────────
154
+ # Quaternion helpers (scalar-first WXYZ convention throughout)
155
+ # ──────────────────────────────────────────────────────────────────────────────
156
+
157
+ _ID_QUAT = np.array([1., 0., 0., 0.], dtype=np.float32)
158
+ _ID_MAT3 = np.eye(3, dtype=np.float64)
159
+
160
+ def _qmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
161
+ aw, ax, ay, az = a
162
+ bw, bx, by, bz = b
163
+ return np.array([
164
+ aw*bw - ax*bx - ay*by - az*bz,
165
+ aw*bx + ax*bw + ay*bz - az*by,
166
+ aw*by - ax*bz + ay*bw + az*bx,
167
+ aw*bz + ax*by - ay*bx + az*bw,
168
+ ], dtype=np.float32)
169
+
170
+ def _qnorm(q: np.ndarray) -> np.ndarray:
171
+ n = np.linalg.norm(q)
172
+ return (q / n) if n > 1e-12 else _ID_QUAT.copy()
173
+
174
+ def _qinv(q: np.ndarray) -> np.ndarray:
175
+ """Conjugate = inverse for unit quaternion."""
176
+ return q * np.array([1., -1., -1., -1.], dtype=np.float32)
177
+
178
+ def _quat_to_mat(q: np.ndarray) -> np.ndarray:
179
+ """WXYZ quaternion -> 3x3 rotation matrix (float64)."""
180
+ w, x, y, z = q.astype(np.float64)
181
+ return np.array([
182
+ [1-2*(y*y+z*z), 2*(x*y-w*z), 2*(x*z+w*y)],
183
+ [ 2*(x*y+w*z), 1-2*(x*x+z*z), 2*(y*z-w*x)],
184
+ [ 2*(x*z-w*y), 2*(y*z+w*x), 1-2*(x*x+y*y)],
185
+ ], dtype=np.float64)
186
+
187
+ def _mat_to_quat(m: np.ndarray) -> np.ndarray:
188
+ """3x3 rotation matrix -> WXYZ quaternion (float32, positive-W)."""
189
+ from scipy.spatial.transform import Rotation
190
+ xyzw = Rotation.from_matrix(m.astype(np.float64)).as_quat()
191
+ wxyz = np.array([xyzw[3], xyzw[0], xyzw[1], xyzw[2]], dtype=np.float32)
192
+ if wxyz[0] < 0:
193
+ wxyz = -wxyz
194
+ return wxyz
195
+
196
+ def _r_between(u: np.ndarray, v: np.ndarray) -> np.ndarray:
197
+ """
198
+ Minimal rotation matrix (3x3) that maps unit vector u to unit vector v.
199
+ Uses the Rodrigues formula; handles parallel/antiparallel cases.
200
+ """
201
+ u = u / (np.linalg.norm(u) + 1e-12)
202
+ v = v / (np.linalg.norm(v) + 1e-12)
203
+ c = float(np.dot(u, v))
204
+ if c >= 1.0 - 1e-7:
205
+ return _ID_MAT3.copy()
206
+ if c <= -1.0 + 1e-7:
207
+ # 180 degree rotation: pick any perpendicular axis
208
+ perp = np.array([1., 0., 0.]) if abs(u[0]) < 0.9 else np.array([0., 1., 0.])
209
+ ax = np.cross(u, perp)
210
+ ax /= np.linalg.norm(ax)
211
+ return 2.0 * np.outer(ax, ax) - _ID_MAT3
212
+ ax = np.cross(u, v) # sin(theta) * rotation axis
213
+ s = np.linalg.norm(ax)
214
+ K = np.array([[ 0, -ax[2], ax[1]],
215
+ [ ax[2], 0, -ax[0]],
216
+ [-ax[1], ax[0], 0]], dtype=np.float64)
217
+ return _ID_MAT3 + K + K @ K * ((1.0 - c) / (s * s + 1e-12))
218
+
219
+
220
+ # ──────────────────────────────────────────────────────────────────────────────
221
+ # GLB skin reader
222
+ # ──────────────────────────────────────────────────────────────────────────────
223
+
224
+ def _read_glb_skin(rigged_glb: str):
225
+ """
226
+ Return (gltf, skin, ibm[n,4,4], node_trs{name->(t,r_wxyz,s)},
227
+ bone_names[], bone_parent_map{name->parent_name_or_None}).
228
+
229
+ ibm is stored as-read from the binary blob (column-major from glTF spec).
230
+ Callers must use inv(ibm[i].T)[:3,3] to get correct world positions.
231
+ """
232
+ import base64
233
+ import pygltflib
234
+
235
+ gltf = pygltflib.GLTF2().load(rigged_glb)
236
+ if not gltf.skins:
237
+ raise ValueError(f"No skin found in {rigged_glb}")
238
+ skin = gltf.skins[0]
239
+
240
+ def _raw_bytes(buf):
241
+ if buf.uri is None:
242
+ return bytes(gltf.binary_blob())
243
+ if buf.uri.startswith("data:"):
244
+ return base64.b64decode(buf.uri.split(",", 1)[1])
245
+ from pathlib import Path
246
+ return (Path(rigged_glb).parent / buf.uri).read_bytes()
247
+
248
+ acc = gltf.accessors[skin.inverseBindMatrices]
249
+ bv = gltf.bufferViews[acc.bufferView]
250
+ raw = _raw_bytes(gltf.buffers[bv.buffer])
251
+ start = (bv.byteOffset or 0) + (acc.byteOffset or 0)
252
+ n = acc.count
253
+ ibm = np.frombuffer(raw[start: start + n * 64], dtype=np.float32).reshape(n, 4, 4)
254
+
255
+ # Build node parent map (node_index -> parent_node_index)
256
+ node_parent: dict[int, int] = {}
257
+ for ni, node in enumerate(gltf.nodes):
258
+ for child_idx in (node.children or []):
259
+ node_parent[child_idx] = ni
260
+
261
+ joint_set = set(skin.joints)
262
+ bone_names = []
263
+ node_trs: dict[str, tuple] = {}
264
+ bone_parent_map: dict[str, str | None] = {}
265
+
266
+ for i, j_idx in enumerate(skin.joints):
267
+ node = gltf.nodes[j_idx]
268
+ name = node.name or f"bone_{i}"
269
+ bone_names.append(name)
270
+
271
+ t = np.array(node.translation or [0., 0., 0.], dtype=np.float32)
272
+ r_xyzw = np.array(node.rotation or [0., 0., 0., 1.], dtype=np.float32)
273
+ s = np.array(node.scale or [1., 1., 1.], dtype=np.float32)
274
+ r_wxyz = np.array([r_xyzw[3], r_xyzw[0], r_xyzw[1], r_xyzw[2]], dtype=np.float32)
275
+ node_trs[name] = (t, r_wxyz, s)
276
+
277
+ # Find parent bone (walk up node hierarchy to nearest joint)
278
+ parent_node = node_parent.get(j_idx)
279
+ parent_name: str | None = None
280
+ while parent_node is not None:
281
+ if parent_node in joint_set:
282
+ pnode = gltf.nodes[parent_node]
283
+ parent_name = pnode.name or f"bone_{skin.joints.index(parent_node)}"
284
+ break
285
+ parent_node = node_parent.get(parent_node)
286
+ bone_parent_map[name] = parent_name
287
+
288
+ print(f"[GLB] {len(bone_names)} bones from skin '{skin.name or 'Armature'}'")
289
+ return gltf, skin, ibm, node_trs, bone_names, bone_parent_map
290
+
291
+
292
+ # ──────────────────────────────────────────────────────────────────────────────
293
+ # Bone mapping
294
+ # ──────────────────────────────────────────────────────────────────────────────
295
+
296
+ def _strip_name(name: str) -> str:
297
+ name = re.sub(r'^(mixamorig:|j_bip_[lcr]_|cc_base_|bip01_|rig:|chr:)',
298
+ "", name, flags=re.IGNORECASE)
299
+ return re.sub(r'[_\-\s.]', "", name).lower()
300
+
301
+
302
+ def build_bone_map(
303
+ rigged_glb: str,
304
+ verbose: bool = True,
305
+ ) -> tuple[dict, dict, float, dict, dict]:
306
+ """
307
+ Map UniRig bone names -> SMPL joint index by spatial proximity + name hints.
308
+
309
+ Returns
310
+ -------
311
+ bone_to_smpl : {bone_name: smpl_joint_index}
312
+ node_trs : {bone_name: (t[3], r_wxyz[4], s[3])}
313
+ height_scale : float (UniRig height / SMPL reference height)
314
+ bone_parent_map : {bone_name: parent_bone_name_or_None}
315
+ ur_pos_by_name : {bone_name: world_pos[3]}
316
+ """
317
+ _gltf, _skin, ibm, node_trs, bone_names, bone_parent_map = _read_glb_skin(rigged_glb)
318
+
319
+ # FIX: glTF IBMs are stored column-major.
320
+ # numpy reads as row-major, so the stored data is the TRANSPOSE of the actual matrix.
321
+ # Correct world position = inv(actual_IBM)[:3,3] = inv(ibm[i].T)[:3,3]
322
+ ur_pos = np.array([
323
+ np.linalg.inv(ibm[i].T)[:3, 3] for i in range(len(bone_names))
324
+ ], dtype=np.float32)
325
+
326
+ ur_pos_by_name = {name: ur_pos[i] for i, name in enumerate(bone_names)}
327
+
328
+ # Scale SMPL T-pose to match character height
329
+ ur_h = ur_pos[:, 1].max() - ur_pos[:, 1].min()
330
+ sm_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
331
+ h_sc = (ur_h / sm_h) if sm_h > 1e-6 else 1.0
332
+ sm_pos = SMPL_TPOSE * h_sc
333
+
334
+ # FIX: Normalise ur and smpl Y ranges independently (floor=0, top=1 for each).
335
+ # The old code used a shared reference which caused floor offsets to misalign.
336
+ def _norm_independent(pos, own_range_min, own_range_max, x_range, z_range):
337
+ p = pos.copy().astype(np.float64)
338
+ y_range = (own_range_max - own_range_min) or 1.0
339
+ p[:, 0] /= (x_range or 1.0)
340
+ p[:, 1] = (p[:, 1] - own_range_min) / y_range
341
+ p[:, 2] /= (z_range or 1.0)
342
+ return p
343
+
344
+ # Common X/Z scale (use both skeletons' width for reference)
345
+ x_range = max(
346
+ abs(ur_pos[:, 0].max() - ur_pos[:, 0].min()),
347
+ abs(sm_pos[:, 0].max() - sm_pos[:, 0].min()),
348
+ ) or 1.0
349
+ z_range = max(
350
+ abs(ur_pos[:, 2].max() - ur_pos[:, 2].min()),
351
+ abs(sm_pos[:, 2].max() - sm_pos[:, 2].min()),
352
+ ) or 1.0
353
+
354
+ ur_n = _norm_independent(ur_pos, ur_pos[:, 1].min(), ur_pos[:, 1].max(), x_range, z_range)
355
+ sm_n = _norm_independent(sm_pos, sm_pos[:, 1].min(), sm_pos[:, 1].max(), x_range, z_range)
356
+
357
+ dist = np.linalg.norm(ur_n[:, None] - sm_n[None], axis=-1) # [M, 22]
358
+ d_sc = 1.0 - np.clip(dist / (dist.max() + 1e-9), 0, 1)
359
+
360
+ # Name hint score
361
+ n_sc = np.zeros((len(bone_names), 22), dtype=np.float32)
362
+ for mi, bname in enumerate(bone_names):
363
+ stripped = _strip_name(bname)
364
+ for kws, ji in _NAME_HINTS:
365
+ if any(kw in stripped for kw in kws):
366
+ n_sc[mi, ji] = 1.0
367
+
368
+ combined = 0.6 * d_sc + 0.4 * n_sc # [M, 22]
369
+
370
+ # Greedy assignment
371
+ THRESHOLD = 0.35
372
+ pairs = sorted(
373
+ ((mi, ji, combined[mi, ji])
374
+ for mi in range(len(bone_names))
375
+ for ji in range(22)),
376
+ key=lambda x: -x[2],
377
+ )
378
+ bone_to_smpl: dict[str, int] = {}
379
+ taken: set[int] = set()
380
+ for mi, ji, score in pairs:
381
+ if score < THRESHOLD:
382
+ break
383
+ bname = bone_names[mi]
384
+ if bname in bone_to_smpl or ji in taken:
385
+ continue
386
+ bone_to_smpl[bname] = ji
387
+ taken.add(ji)
388
+
389
+ if verbose:
390
+ n_mapped = len(bone_to_smpl)
391
+ print(f"\n[MAP] {n_mapped}/{len(bone_names)} bones mapped to SMPL joints:")
392
+ for bname, ji in sorted(bone_to_smpl.items(), key=lambda x: x[1]):
393
+ print(f" {bname:<40} -> {SMPL_NAMES[ji]}")
394
+ unmapped = [n for n in bone_names if n not in bone_to_smpl]
395
+ if unmapped:
396
+ preview = ", ".join(unmapped[:8])
397
+ print(f"[MAP] {len(unmapped)} unmapped (identity): {preview}"
398
+ + (" ..." if len(unmapped) > 8 else ""))
399
+ print()
400
+
401
+ return bone_to_smpl, node_trs, h_sc, bone_parent_map, ur_pos_by_name
402
+
403
+
404
+ # ──────────────────────────────────────────────────────────────────────────────
405
+ # T2M forward kinematics (world rotation matrices)
406
+ # ──────────────────────────────────────────────────────────────────────────────
407
+
408
+ def _compute_t2m_world_rots(
409
+ root_rot_wxyz: np.ndarray, # [4] WXYZ
410
+ local_rots_wxyz: np.ndarray, # [21, 4] WXYZ (joints 1-21)
411
+ ) -> np.ndarray:
412
+ """
413
+ Compute accumulated world rotation matrices for all 22 t2m joints at one frame.
414
+ Matches skeleton.py's forward_kinematics_cont6d_np: each chain RESETS to R_root.
415
+
416
+ Returns [22, 3, 3] world rotation matrices.
417
+ """
418
+ R_root = _quat_to_mat(root_rot_wxyz)
419
+ world_rots = np.zeros((22, 3, 3), dtype=np.float64)
420
+ world_rots[0] = R_root
421
+
422
+ for chain in T2M_KINEMATIC_CHAIN:
423
+ R = R_root.copy() # always start from R_root (matches skeleton.py)
424
+ for i in range(1, len(chain)):
425
+ j = chain[i]
426
+ R_local = _quat_to_mat(local_rots_wxyz[j - 1]) # j-1: joints 1-21
427
+ R = R @ R_local
428
+ world_rots[j] = R
429
+
430
+ return world_rots
431
+
432
+
433
+ # ──────────────────────────────────────────────────────────────────────────────
434
+ # Keyframe builder — world-direction matching
435
+ # ──────────────────────────────────────────────────────────────────────────────
436
+
437
+ def build_keyframes(
438
+ motion: SMPLMotion,
439
+ bone_to_smpl: dict[str, int],
440
+ node_trs: dict[str, tuple],
441
+ height_scale: float,
442
+ bone_parent_map: dict[str, str | None],
443
+ ur_pos_by_name: dict[str, np.ndarray],
444
+ ) -> list[dict]:
445
+ """
446
+ Convert SMPLMotion -> List[Dict[bone_name -> (loc, rot_delta, scale)]]
447
+ using world-direction matching retargeting.
448
+ """
449
+ T = motion.num_frames
450
+ zeros3 = np.zeros(3, dtype=np.float32)
451
+ ones3 = np.ones(3, dtype=np.float32)
452
+
453
+ # Topological order: root joints (si==0) first, then by SMPL joint index
454
+ # (parents always have lower SMPL indices in the kinematic chain)
455
+ sorted_bones = sorted(bone_to_smpl.keys(), key=lambda b: bone_to_smpl[b])
456
+
457
+ keyframes: list[dict] = []
458
+
459
+ for ti in range(T):
460
+ frame: dict = {}
461
+
462
+ # T2M world rotation matrices for this frame
463
+ world_rots_t2m = _compute_t2m_world_rots(
464
+ motion.root_rot[ti].astype(np.float64),
465
+ motion.local_rot[ti].astype(np.float64),
466
+ )
467
+
468
+ # Track UniRig world rotations per bone (needed for child local rotations)
469
+ world_rot_ur: dict[str, np.ndarray] = {}
470
+
471
+ for bname in sorted_bones:
472
+ si = bone_to_smpl[bname]
473
+ rest_t, rest_r, _rest_s = node_trs[bname]
474
+ rest_t = rest_t.astype(np.float32)
475
+ rest_r_mat = _quat_to_mat(rest_r)
476
+
477
+ # ── Root bone (si == 0): drive world translation + facing rotation ──
478
+ if si == 0:
479
+ world_pos = motion.root_pos[ti].astype(np.float64) * height_scale
480
+ pose_loc = (world_pos - rest_t.astype(np.float64)).astype(np.float32)
481
+
482
+ # Root world rotation = t2m root rotation (Y-axis only)
483
+ R_world_root = _quat_to_mat(motion.root_rot[ti])
484
+ world_rot_ur[bname] = R_world_root
485
+
486
+ # pose_rot_delta = inv(rest_r) @ target_world_rot
487
+ pose_rot_mat = rest_r_mat.T @ R_world_root
488
+ pose_rot = _mat_to_quat(pose_rot_mat)
489
+ frame[bname] = (pose_loc, pose_rot, ones3)
490
+ continue
491
+
492
+ # ── Non-root bone: world-direction matching ──────────────────────
493
+
494
+ # T2M world bone direction (in t2m coordinate frame)
495
+ raw_dir_t2m = world_rots_t2m[si] @ T2M_RAW_OFFSETS[si] # [3]
496
+
497
+ # COORDINATE FRAME CONVERSION: t2m +X = character LEFT; SMPL +X = character RIGHT
498
+ # Flip X to convert t2m world directions -> SMPL/UniRig world directions
499
+ d_desired = np.array([-raw_dir_t2m[0], raw_dir_t2m[1], raw_dir_t2m[2]])
500
+ d_desired_norm = d_desired / (np.linalg.norm(d_desired) + 1e-12)
501
+
502
+ # UniRig rest bone direction (from inverse bind matrices, world space)
503
+ parent_b = bone_parent_map.get(bname)
504
+ if parent_b and parent_b in ur_pos_by_name:
505
+ d_rest = (ur_pos_by_name[bname] - ur_pos_by_name[parent_b]).astype(np.float64)
506
+ else:
507
+ d_rest = ur_pos_by_name[bname].astype(np.float64)
508
+ d_rest_norm = d_rest / (np.linalg.norm(d_rest) + 1e-12)
509
+
510
+ # Minimal world-space rotation: rest direction -> desired direction
511
+ R_world_desired = _r_between(d_rest_norm, d_desired_norm) # [3, 3]
512
+ world_rot_ur[bname] = R_world_desired
513
+
514
+ # Local rotation = inv(parent_world) @ R_world_desired
515
+ if parent_b and parent_b in world_rot_ur:
516
+ R_parent = world_rot_ur[parent_b]
517
+ else:
518
+ R_parent = _ID_MAT3
519
+
520
+ local_rot_mat = R_parent.T @ R_world_desired # R_parent^-1 @ R_world
521
+
522
+ # pose_rot_delta = inv(rest_r) @ local_rot
523
+ # (glTF applies: final = rest_r @ pose_rot_delta = local_rot)
524
+ pose_rot_mat = rest_r_mat.T @ local_rot_mat
525
+ pose_rot = _mat_to_quat(pose_rot_mat)
526
+
527
+ frame[bname] = (zeros3, pose_rot, ones3)
528
+
529
+ keyframes.append(frame)
530
+
531
+ return keyframes
532
+
533
+
534
+ # ──────────────────────────────────────────────────────────────────────────────
535
+ # Public API
536
+ # ��─────────────────────────────────────────────────────────────────────────────
537
+
538
+ def animate_glb(
539
+ motion: Union[np.ndarray, list, SMPLMotion],
540
+ rigged_glb: str,
541
+ output_glb: str,
542
+ fps: float = 20.0,
543
+ start_frame: int = 0,
544
+ num_frames: int = -1,
545
+ ) -> str:
546
+ """
547
+ Bake a HumanML3D motion clip onto a UniRig-rigged GLB.
548
+
549
+ Parameters
550
+ ----------
551
+ motion : [T, 263] ndarray, list, or pre-parsed SMPLMotion
552
+ rigged_glb : path to UniRig merge output (.glb with a skin)
553
+ output_glb : destination path for animated GLB
554
+ fps : frame rate embedded in the animation track
555
+ start_frame / num_frames : optional clip range (-1 = all frames)
556
+
557
+ Returns str absolute path to output_glb.
558
+ """
559
+ from .io.gltf_io import write_gltf_animation
560
+
561
+ # 1. Parse motion
562
+ if isinstance(motion, SMPLMotion):
563
+ smpl = motion
564
+ else:
565
+ data = np.asarray(motion, dtype=np.float32)
566
+ if data.ndim != 2 or data.shape[1] < 193:
567
+ raise ValueError(f"Expected [T, 263] HumanML3D features, got {data.shape}")
568
+ smpl = hml3d_to_smpl_motion(data, fps=fps)
569
+
570
+ # 2. Slice
571
+ end = (start_frame + num_frames) if num_frames > 0 else smpl.num_frames
572
+ smpl = smpl.slice(start_frame, end)
573
+ print(f"[animate] {smpl.num_frames} frames @ {fps:.0f} fps -> {output_glb}")
574
+
575
+ # 3. Build bone map (now returns parent map and world positions too)
576
+ bone_to_smpl, node_trs, h_sc, bone_parent_map, ur_pos_by_name = \
577
+ build_bone_map(rigged_glb, verbose=True)
578
+ if not bone_to_smpl:
579
+ raise RuntimeError(
580
+ "build_bone_map returned 0 matches. "
581
+ "Ensure the GLB has a valid skin with readable inverse bind matrices."
582
+ )
583
+
584
+ # 4. Build keyframes using world-direction matching
585
+ keyframes = build_keyframes(smpl, bone_to_smpl, node_trs, h_sc,
586
+ bone_parent_map, ur_pos_by_name)
587
+
588
+ # 5. Write GLB
589
+ out_dir = os.path.dirname(os.path.abspath(output_glb))
590
+ if out_dir:
591
+ os.makedirs(out_dir, exist_ok=True)
592
+
593
+ write_gltf_animation(
594
+ source_filepath=rigged_glb,
595
+ dest_armature=None,
596
+ keyframes=keyframes,
597
+ output_filepath=output_glb,
598
+ fps=float(fps),
599
+ )
600
+
601
+ return output_glb
602
+
603
+
604
+ # Backwards-compatibility alias
605
+ def animate_glb_from_hml3d(
606
+ motion, rigged_glb, output_glb, fps=20, start_frame=0, num_frames=-1
607
+ ):
608
+ return animate_glb(
609
+ motion, rigged_glb, output_glb,
610
+ fps=fps, start_frame=start_frame, num_frames=num_frames,
611
+ )
Retarget/cli.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cli.py
3
+ Command-line interface for rig_retarget.
4
+
5
+ Usage:
6
+ python -m rig_retarget.cli \\
7
+ --source walk.bvh \\
8
+ --dest unirig_character.glb \\
9
+ --mapping radical2unirig.json \\
10
+ --output animated_character.glb \\
11
+ [--fps 30] [--start 0] [--frames 100] [--step 1]
12
+
13
+ # Calculate corrections only (no transfer):
14
+ python -m rig_retarget.cli --calc-corrections \\
15
+ --source walk.bvh --dest unirig_character.glb \\
16
+ --mapping mymap.json
17
+ """
18
+ from __future__ import annotations
19
+ import argparse
20
+ import sys
21
+ from pathlib import Path
22
+
23
+
24
+ def _parse_args(argv=None):
25
+ p = argparse.ArgumentParser(
26
+ prog="rig_retarget",
27
+ description="Retarget animation from BVH/glTF source onto UniRig/glTF destination.",
28
+ )
29
+ p.add_argument("--source", required=True, help="Source animation file (.bvh or .glb/.gltf)")
30
+ p.add_argument("--dest", required=True, help="Destination skeleton file (.glb/.gltf, UniRig output)")
31
+ p.add_argument("--mapping", required=True, help="KeeMap-compatible JSON bone mapping file")
32
+ p.add_argument("--output", default=None, help="Output animated .glb (default: dest_retargeted.glb)")
33
+ p.add_argument("--fps", type=float, default=30.0)
34
+ p.add_argument("--start", type=int, default=0, help="Start frame index (0-based)")
35
+ p.add_argument("--frames", type=int, default=None, help="Number of frames to transfer (default: all)")
36
+ p.add_argument("--step", type=int, default=1, help="Keyframe every N source frames")
37
+ p.add_argument("--skin", type=int, default=0, help="Skin index in destination glTF")
38
+ p.add_argument("--calc-corrections", action="store_true",
39
+ help="Auto-calculate bone corrections and update the mapping JSON, then exit.")
40
+ p.add_argument("--verbose", action="store_true")
41
+ return p.parse_args(argv)
42
+
43
+
44
+ def main(argv=None) -> None:
45
+ args = _parse_args(argv)
46
+
47
+ from .io.mapping import load_mapping, save_mapping, KeeMapSettings
48
+ from .io.gltf_io import load_gltf, write_gltf_animation
49
+ from .retarget import (
50
+ calc_all_corrections, transfer_animation,
51
+ )
52
+
53
+ # -----------------------------------------------------------------------
54
+ # Load mapping
55
+ # -----------------------------------------------------------------------
56
+ print(f"[*] Loading mapping : {args.mapping}")
57
+ settings, bone_items = load_mapping(args.mapping)
58
+
59
+ # Override settings from CLI args
60
+ settings.start_frame_to_apply = args.start
61
+ settings.keyframe_every_n_frames = args.step
62
+
63
+ # -----------------------------------------------------------------------
64
+ # Load source animation
65
+ # -----------------------------------------------------------------------
66
+ src_path = Path(args.source)
67
+ print(f"[*] Loading source : {src_path}")
68
+
69
+ if src_path.suffix.lower() == ".bvh":
70
+ from .io.bvh import load_bvh
71
+ src_anim = load_bvh(str(src_path))
72
+ if args.verbose:
73
+ print(f" BVH: {src_anim.num_frames} frames, "
74
+ f"{src_anim.frame_time*1000:.1f} ms/frame, "
75
+ f"{len(src_anim.armature.pose_bones)} joints")
76
+ elif src_path.suffix.lower() in (".glb", ".gltf"):
77
+ # glTF source — load skeleton only; animation reading is TODO
78
+ raise NotImplementedError(
79
+ "glTF source animation reading is not yet implemented. "
80
+ "Use a BVH file for the source animation."
81
+ )
82
+ else:
83
+ print(f"[!] Unsupported source format: {src_path.suffix}", file=sys.stderr)
84
+ sys.exit(1)
85
+
86
+ if args.frames is not None:
87
+ settings.number_of_frames_to_apply = args.frames
88
+ else:
89
+ settings.number_of_frames_to_apply = src_anim.num_frames - args.start
90
+
91
+ # -----------------------------------------------------------------------
92
+ # Load destination skeleton
93
+ # -----------------------------------------------------------------------
94
+ dst_path = Path(args.dest)
95
+ print(f"[*] Loading dest : {dst_path}")
96
+ dst_arm = load_gltf(str(dst_path), skin_index=args.skin)
97
+ if args.verbose:
98
+ print(f" Skeleton: {len(dst_arm.pose_bones)} bones")
99
+
100
+ # -----------------------------------------------------------------------
101
+ # Auto-correct pass (optional)
102
+ # -----------------------------------------------------------------------
103
+ if args.calc_corrections:
104
+ print("[*] Calculating bone corrections ...")
105
+ src_anim.apply_frame(args.start)
106
+ calc_all_corrections(bone_items, src_anim.armature, dst_arm, settings)
107
+ save_mapping(args.mapping, settings, bone_items)
108
+ print(f"[*] Updated mapping saved → {args.mapping}")
109
+ return
110
+
111
+ # -----------------------------------------------------------------------
112
+ # Transfer
113
+ # -----------------------------------------------------------------------
114
+ print(f"[*] Transferring {settings.number_of_frames_to_apply} frames "
115
+ f"(start={settings.start_frame_to_apply}, step={settings.keyframe_every_n_frames}) ...")
116
+ keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
117
+ print(f"[*] Generated {len(keyframes)} keyframes")
118
+
119
+ # -----------------------------------------------------------------------
120
+ # Write output
121
+ # -----------------------------------------------------------------------
122
+ out_path = args.output or str(dst_path.with_name(dst_path.stem + "_retargeted.glb"))
123
+ print(f"[*] Writing output : {out_path}")
124
+ write_gltf_animation(str(dst_path), dst_arm, keyframes, out_path, fps=args.fps, skin_index=args.skin)
125
+ print("[✓] Done")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
Retarget/generate.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generate.py
3
+ ───────────────────────────────────────────────────────────────────────────────
4
+ Text-to-motion generation.
5
+
6
+ Primary backend: MoMask inference server running on the Vast.ai instance.
7
+ Returns [T, 263] HumanML3D features directly — no SMPL
8
+ body mesh required.
9
+
10
+ Fallback backend: HumanML3D dataset keyword search (offline / no GPU needed).
11
+
12
+ Usage
13
+ ─────
14
+ from Retarget.generate import generate_motion
15
+
16
+ # Use MoMask on instance
17
+ motion = generate_motion("a person walks forward",
18
+ backend_url="http://ssh4.vast.ai:8765")
19
+
20
+ # Local fallback (streams HuggingFace dataset)
21
+ motion = generate_motion("a person walks forward")
22
+
23
+ # Returned motion: np.ndarray [T, 263]
24
+ # Feed directly to animate_glb()
25
+ """
26
+ from __future__ import annotations
27
+ import json
28
+ import numpy as np
29
+
30
+
31
+ # ──────────────────────────────────────────────────────────────────────────────
32
+ # Public API
33
+ # ──────────────────────────────────────────────────────────────────────────────
34
+
35
+ def generate_motion(
36
+ prompt: str,
37
+ backend_url: str | None = None,
38
+ num_frames: int = 196,
39
+ fps: float = 20.0,
40
+ seed: int = -1,
41
+ ) -> np.ndarray:
42
+ """
43
+ Generate a HumanML3D [T, 263] motion array from a text prompt.
44
+
45
+ Parameters
46
+ ----------
47
+ prompt
48
+ Natural language description of the desired motion.
49
+ Examples: "a person walks forward", "someone does a jumping jack",
50
+ "a man waves hello with his right hand"
51
+ backend_url
52
+ URL of the MoMask inference server. E.g. "http://ssh4.vast.ai:8765".
53
+ If None or if the server is unreachable, falls back to dataset search.
54
+ num_frames
55
+ Desired clip length in frames (at 20 fps; max ~196 ≈ 9.8 s).
56
+ fps
57
+ Target fps (MoMask natively produces 20 fps).
58
+ seed
59
+ Random seed for reproducibility (-1 = random).
60
+
61
+ Returns
62
+ -------
63
+ np.ndarray shape [T, 263] HumanML3D feature vector.
64
+ """
65
+ if backend_url:
66
+ try:
67
+ return _call_momask(prompt, backend_url, num_frames, seed)
68
+ except Exception as exc:
69
+ print(f"[generate] MoMask unreachable ({exc}) — falling back to dataset search")
70
+
71
+ return _dataset_search_fallback(prompt)
72
+
73
+
74
+ # ──────────────────────────────────────────────────────────────────────────────
75
+ # MoMask backend
76
+ # ──────────────────────────────────────────────────────────────────────────────
77
+
78
+ def _call_momask(
79
+ prompt: str,
80
+ url: str,
81
+ num_frames: int,
82
+ seed: int,
83
+ ) -> np.ndarray:
84
+ """POST to the MoMask inference server; return [T, 263] array."""
85
+ import urllib.request
86
+
87
+ payload = json.dumps({
88
+ "prompt": prompt,
89
+ "num_frames": num_frames,
90
+ "seed": seed,
91
+ }).encode("utf-8")
92
+
93
+ req = urllib.request.Request(
94
+ f"{url.rstrip('/')}/generate",
95
+ data=payload,
96
+ headers={"Content-Type": "application/json"},
97
+ method="POST",
98
+ )
99
+ with urllib.request.urlopen(req, timeout=180) as resp:
100
+ result = json.loads(resp.read())
101
+
102
+ motion = np.array(result["motion"], dtype=np.float32)
103
+ if motion.ndim != 2 or motion.shape[1] < 193:
104
+ raise ValueError(f"Server returned unexpected shape {motion.shape}")
105
+
106
+ print(f"[generate] MoMask: {motion.shape[0]} frames for '{prompt}'")
107
+ return motion
108
+
109
+
110
+ # ──────────────────────────────────────────────────────────────────────────────
111
+ # Dataset search fallback
112
+ # ──────────────────────────────────────────────────────────────────────────────
113
+
114
+ def _dataset_search_fallback(prompt: str) -> np.ndarray:
115
+ """
116
+ Keyword search in TeoGchx/HumanML3D dataset (streaming, HuggingFace).
117
+ Used when no MoMask server is available.
118
+ """
119
+ from .search import search_motions, format_choice_label
120
+
121
+ print(f"[generate] Searching HumanML3D dataset for: '{prompt}'")
122
+ results = search_motions(prompt, top_k=5, split="test", max_scan=500)
123
+ if not results:
124
+ raise RuntimeError(
125
+ f"No motion found in dataset for prompt: {prompt!r}\n"
126
+ "Check your internet connection or deploy MoMask on the instance."
127
+ )
128
+
129
+ best = results[0]
130
+ print(f"[generate] Best match: {format_choice_label(best)}")
131
+ return np.array(best["motion"], dtype=np.float32)
Retarget/humanml3d_to_bvh.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ humanml3d_to_bvh.py
4
+ Convert HumanML3D .npy motion files → BVH animation.
5
+ When a UniRig-rigged GLB (or ASCII FBX) is supplied via --rig, the BVH is
6
+ built using the UniRig skeleton's own bone names and hierarchy, with
7
+ automatic bone-to-SMPL-joint mapping — no Blender required.
8
+
9
+ Dependencies
10
+ numpy (always required)
11
+ pygltflib pip install pygltflib (required for --rig GLB files)
12
+
13
+ Usage
14
+ # SMPL-named BVH (no rig needed)
15
+ python humanml3d_to_bvh.py 000001.npy
16
+
17
+ # Retargeted to UniRig skeleton
18
+ python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb
19
+
20
+ # Explicit output + fps
21
+ python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb -o anim.bvh --fps 20
22
+ """
23
+
24
+ from __future__ import annotations
25
+ import argparse, re, sys
26
+ from dataclasses import dataclass, field
27
+ from pathlib import Path
28
+ from typing import Optional
29
+
30
+ import numpy as np
31
+
32
+ # ══════════════════════════════════════════════════════════════════════════════
33
+ # SMPL 22-joint skeleton definition
34
+ # ══════════════════════════════════════════════════════════════════════════════
35
+
36
+ SMPL_NAMES = [
37
+ "Hips", # 0 pelvis / root
38
+ "LeftUpLeg", # 1 left_hip
39
+ "RightUpLeg", # 2 right_hip
40
+ "Spine", # 3 spine1
41
+ "LeftLeg", # 4 left_knee
42
+ "RightLeg", # 5 right_knee
43
+ "Spine1", # 6 spine2
44
+ "LeftFoot", # 7 left_ankle
45
+ "RightFoot", # 8 right_ankle
46
+ "Spine2", # 9 spine3
47
+ "LeftToeBase", # 10 left_foot
48
+ "RightToeBase", # 11 right_foot
49
+ "Neck", # 12 neck
50
+ "LeftShoulder", # 13 left_collar
51
+ "RightShoulder", # 14 right_collar
52
+ "Head", # 15 head
53
+ "LeftArm", # 16 left_shoulder
54
+ "RightArm", # 17 right_shoulder
55
+ "LeftForeArm", # 18 left_elbow
56
+ "RightForeArm", # 19 right_elbow
57
+ "LeftHand", # 20 left_wrist
58
+ "RightHand", # 21 right_wrist
59
+ ]
60
+
61
+ SMPL_PARENT = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
62
+ NUM_SMPL = 22
63
+
64
+ SMPL_TPOSE = np.array([
65
+ [ 0.000, 0.920, 0.000], # 0 Hips
66
+ [-0.095, 0.920, 0.000], # 1 LeftUpLeg
67
+ [ 0.095, 0.920, 0.000], # 2 RightUpLeg
68
+ [ 0.000, 0.980, 0.000], # 3 Spine
69
+ [-0.095, 0.495, 0.000], # 4 LeftLeg
70
+ [ 0.095, 0.495, 0.000], # 5 RightLeg
71
+ [ 0.000, 1.050, 0.000], # 6 Spine1
72
+ [-0.095, 0.075, 0.000], # 7 LeftFoot
73
+ [ 0.095, 0.075, 0.000], # 8 RightFoot
74
+ [ 0.000, 1.120, 0.000], # 9 Spine2
75
+ [-0.095, 0.000, -0.020], # 10 LeftToeBase
76
+ [ 0.095, 0.000, -0.020], # 11 RightToeBase
77
+ [ 0.000, 1.370, 0.000], # 12 Neck
78
+ [-0.130, 1.290, 0.000], # 13 LeftShoulder
79
+ [ 0.130, 1.290, 0.000], # 14 RightShoulder
80
+ [ 0.000, 1.500, 0.000], # 15 Head
81
+ [-0.330, 1.290, 0.000], # 16 LeftArm
82
+ [ 0.330, 1.290, 0.000], # 17 RightArm
83
+ [-0.630, 1.290, 0.000], # 18 LeftForeArm
84
+ [ 0.630, 1.290, 0.000], # 19 RightForeArm
85
+ [-0.910, 1.290, 0.000], # 20 LeftHand
86
+ [ 0.910, 1.290, 0.000], # 21 RightHand
87
+ ], dtype=np.float32)
88
+
89
+ _SMPL_CHILDREN: list[list[int]] = [[] for _ in range(NUM_SMPL)]
90
+ for _j, _p in enumerate(SMPL_PARENT):
91
+ if _p >= 0:
92
+ _SMPL_CHILDREN[_p].append(_j)
93
+
94
+
95
+ def _smpl_dfs() -> list[int]:
96
+ order, stack = [], [0]
97
+ while stack:
98
+ j = stack.pop()
99
+ order.append(j)
100
+ for c in reversed(_SMPL_CHILDREN[j]):
101
+ stack.append(c)
102
+ return order
103
+
104
+
105
+ SMPL_DFS = _smpl_dfs()
106
+
107
+ # ══════════════════════════════════════════════════════════════════════════════
108
+ # Quaternion helpers (numpy, WXYZ)
109
+ # ══════════════════════════════════════════════════════════════════════════════
110
+
111
+ def qnorm(q: np.ndarray) -> np.ndarray:
112
+ return q / (np.linalg.norm(q, axis=-1, keepdims=True) + 1e-9)
113
+
114
+
115
+ def qmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
116
+ aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
117
+ bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
118
+ return np.stack([
119
+ aw*bw - ax*bx - ay*by - az*bz,
120
+ aw*bx + ax*bw + ay*bz - az*by,
121
+ aw*by - ax*bz + ay*bw + az*bx,
122
+ aw*bz + ax*by - ay*bx + az*bw,
123
+ ], axis=-1)
124
+
125
+
126
+ def qinv(q: np.ndarray) -> np.ndarray:
127
+ return q * np.array([1, -1, -1, -1], dtype=np.float32)
128
+
129
+
130
+ def qrot(q: np.ndarray, v: np.ndarray) -> np.ndarray:
131
+ vq = np.concatenate([np.zeros((*v.shape[:-1], 1), dtype=v.dtype), v], axis=-1)
132
+ return qmul(qmul(q, vq), qinv(q))[..., 1:]
133
+
134
+
135
+ def qbetween(a: np.ndarray, b: np.ndarray) -> np.ndarray:
136
+ """Swing quaternion rotating unit-vectors a to b. [..., 3] to [..., 4]."""
137
+ a = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-9)
138
+ b = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-9)
139
+ dot = np.clip((a * b).sum(axis=-1, keepdims=True), -1.0, 1.0)
140
+ cross = np.cross(a, b)
141
+ w = np.sqrt(np.maximum((1.0 + dot) * 0.5, 0.0))
142
+ xyz = cross / (2.0 * w + 1e-9)
143
+ anti = (dot[..., 0] < -0.9999)
144
+ if anti.any():
145
+ perp = np.where(
146
+ np.abs(a[anti, 0:1]) < 0.9,
147
+ np.tile([1, 0, 0], (anti.sum(), 1)),
148
+ np.tile([0, 1, 0], (anti.sum(), 1)),
149
+ ).astype(np.float32)
150
+ ax_f = np.cross(a[anti], perp)
151
+ ax_f = ax_f / (np.linalg.norm(ax_f, axis=-1, keepdims=True) + 1e-9)
152
+ w[anti] = 0.0
153
+ xyz[anti] = ax_f
154
+ return qnorm(np.concatenate([w, xyz], axis=-1))
155
+
156
+
157
+ def quat_to_euler_ZXY(q: np.ndarray) -> np.ndarray:
158
+ """WXYZ quaternions to ZXY Euler degrees (rz, rx, ry) for BVH."""
159
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
160
+ sin_x = np.clip(2.0*(w*x - y*z), -1.0, 1.0)
161
+ return np.stack([
162
+ np.degrees(np.arctan2(2.0*(w*z + x*y), 1.0 - 2.0*(x*x + z*z))),
163
+ np.degrees(np.arcsin(sin_x)),
164
+ np.degrees(np.arctan2(2.0*(w*y + x*z), 1.0 - 2.0*(x*x + y*y))),
165
+ ], axis=-1)
166
+
167
+ # ══════════════════════════════════════════════════════════════════════════════
168
+ # HumanML3D 263-dim recovery
169
+ #
170
+ # Layout per frame:
171
+ # [0] root Y-axis angular velocity (rad/frame)
172
+ # [1] root height Y (m)
173
+ # [2:4] root XZ velocity in local frame
174
+ # [4:67] local positions of joints 1-21 (21 x 3 = 63)
175
+ # [67:193] 6-D rotations for joints 1-21 (21 x 6 = 126, unused here)
176
+ # [193:259] joint velocities (22 x 3 = 66, unused here)
177
+ # [259:263] foot contact (4, unused here)
178
+ # ══════════════════════════════════════════════════════════════════════════════
179
+
180
+ def _recover_root(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
181
+ T = data.shape[0]
182
+ theta = np.cumsum(data[:, 0])
183
+ half = theta * 0.5
184
+ r_rot = np.zeros((T, 4), dtype=np.float32)
185
+ r_rot[:, 0] = np.cos(half) # W
186
+ r_rot[:, 2] = np.sin(half) # Y
187
+ vel_local = np.stack([data[:, 2], np.zeros(T, dtype=np.float32), data[:, 3]], -1)
188
+ vel_world = qrot(r_rot, vel_local)
189
+ r_pos = np.zeros((T, 3), dtype=np.float32)
190
+ r_pos[:, 0] = np.cumsum(vel_world[:, 0])
191
+ r_pos[:, 1] = data[:, 1]
192
+ r_pos[:, 2] = np.cumsum(vel_world[:, 2])
193
+ return r_rot, r_pos
194
+
195
+
196
+ def recover_from_ric(data: np.ndarray, joints_num: int = 22) -> np.ndarray:
197
+ """263-dim features to world-space positions [T, joints_num, 3]."""
198
+ data = data.astype(np.float32)
199
+ r_rot, r_pos = _recover_root(data)
200
+ loc = data[:, 4:4 + (joints_num-1)*3].reshape(-1, joints_num-1, 3)
201
+ rinv = np.broadcast_to(qinv(r_rot)[:, None], (*loc.shape[:2], 4)).copy()
202
+ wloc = qrot(rinv, loc) + r_pos[:, None]
203
+ return np.concatenate([r_pos[:, None], wloc], axis=1)
204
+
205
+ # ══════════════════════════════════════════════════════════════════════════════
206
+ # SMPL geometry helpers
207
+ # ══════════════════════════════════════════════════════════════════════════════
208
+
209
+ def _scale_smpl_tpose(positions: np.ndarray) -> np.ndarray:
210
+ data_h = positions[:, :, 1].max() - positions[:, :, 1].min()
211
+ ref_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
212
+ scale = (data_h / ref_h) if (ref_h > 1e-6 and data_h > 1e-6) else 1.0
213
+ return SMPL_TPOSE * scale
214
+
215
+
216
+ def _rest_dirs(tpose: np.ndarray, children: list[list[int]],
217
+ parent: list[int]) -> np.ndarray:
218
+ N = tpose.shape[0]
219
+ dirs = np.zeros((N, 3), dtype=np.float32)
220
+ for j in range(N):
221
+ ch = children[j]
222
+ if ch:
223
+ avg = np.stack([tpose[c] - tpose[j] for c in ch]).mean(0)
224
+ dirs[j] = avg / (np.linalg.norm(avg) + 1e-9)
225
+ else:
226
+ v = tpose[j] - tpose[parent[j]]
227
+ dirs[j] = v / (np.linalg.norm(v) + 1e-9)
228
+ return dirs
229
+
230
+
231
+ def positions_to_local_quats(positions: np.ndarray,
232
+ tpose: np.ndarray) -> np.ndarray:
233
+ """World-space joint positions [T, 22, 3] to local quaternions [T, 22, 4]."""
234
+ T = positions.shape[0]
235
+ rd = _rest_dirs(tpose, _SMPL_CHILDREN, SMPL_PARENT)
236
+
237
+ world_q = np.zeros((T, NUM_SMPL, 4), dtype=np.float32)
238
+ world_q[:, :, 0] = 1.0
239
+
240
+ for j in range(NUM_SMPL):
241
+ ch = _SMPL_CHILDREN[j]
242
+ if ch:
243
+ vecs = np.stack([positions[:, c] - positions[:, j] for c in ch], 1).mean(1)
244
+ else:
245
+ vecs = positions[:, j] - positions[:, SMPL_PARENT[j]]
246
+ cur = vecs / (np.linalg.norm(vecs, axis=-1, keepdims=True) + 1e-9)
247
+ rd_b = np.broadcast_to(rd[j], cur.shape).copy()
248
+ world_q[:, j] = qbetween(rd_b, cur)
249
+
250
+ local_q = np.zeros_like(world_q)
251
+ local_q[:, :, 0] = 1.0
252
+ for j in SMPL_DFS:
253
+ p = SMPL_PARENT[j]
254
+ if p < 0:
255
+ local_q[:, j] = world_q[:, j]
256
+ else:
257
+ local_q[:, j] = qmul(qinv(world_q[:, p]), world_q[:, j])
258
+
259
+ return qnorm(local_q)
260
+
261
+ # ══════════════════════════════════════════════════════════════════════════════
262
+ # UniRig skeleton data structure
263
+ # ══════════════════════════════════════════════════════════════════════════════
264
+
265
+ @dataclass
266
+ class Bone:
267
+ name: str
268
+ parent: Optional[str]
269
+ world_rest_pos: np.ndarray
270
+ children: list[str] = field(default_factory=list)
271
+ smpl_idx: Optional[int] = None
272
+
273
+
274
+ class UnirigSkeleton:
275
+ def __init__(self, bones: dict[str, Bone]):
276
+ self.bones = bones
277
+ self.root = next(b for b in bones.values() if b.parent is None)
278
+
279
+ def dfs_order(self) -> list[str]:
280
+ order, stack = [], [self.root.name]
281
+ while stack:
282
+ n = stack.pop()
283
+ order.append(n)
284
+ for c in reversed(self.bones[n].children):
285
+ stack.append(c)
286
+ return order
287
+
288
+ def local_offsets(self) -> dict[str, np.ndarray]:
289
+ offsets = {}
290
+ for name, bone in self.bones.items():
291
+ if bone.parent is None:
292
+ offsets[name] = bone.world_rest_pos.copy()
293
+ else:
294
+ offsets[name] = bone.world_rest_pos - self.bones[bone.parent].world_rest_pos
295
+ return offsets
296
+
297
+ def rest_direction(self, name: str) -> np.ndarray:
298
+ bone = self.bones[name]
299
+ if bone.children:
300
+ vecs = np.stack([self.bones[c].world_rest_pos - bone.world_rest_pos
301
+ for c in bone.children])
302
+ avg = vecs.mean(0)
303
+ return avg / (np.linalg.norm(avg) + 1e-9)
304
+ if bone.parent is None:
305
+ return np.array([0, 1, 0], dtype=np.float32)
306
+ v = bone.world_rest_pos - self.bones[bone.parent].world_rest_pos
307
+ return v / (np.linalg.norm(v) + 1e-9)
308
+
309
+ # ══════════════════════════════════════════════════════════════════════════════
310
+ # GLB skeleton parser
311
+ # ══════════════════════════════════════════════════════════════════════════════
312
+
313
+ def parse_glb_skeleton(path: str) -> UnirigSkeleton:
314
+ """Extract skeleton from a UniRig-rigged GLB (uses pygltflib)."""
315
+ try:
316
+ import pygltflib
317
+ except ImportError:
318
+ sys.exit("[ERROR] pygltflib not installed. pip install pygltflib")
319
+
320
+ import base64
321
+
322
+ gltf = pygltflib.GLTF2().load(path)
323
+ if not gltf.skins:
324
+ sys.exit(f"[ERROR] No skin found in {path}")
325
+
326
+ skin = gltf.skins[0]
327
+ joint_indices = skin.joints
328
+
329
+ def get_buffer_bytes(buf_idx: int) -> bytes:
330
+ buf = gltf.buffers[buf_idx]
331
+ if buf.uri is None:
332
+ return bytes(gltf.binary_blob())
333
+ if buf.uri.startswith("data:"):
334
+ return base64.b64decode(buf.uri.split(",", 1)[1])
335
+ return (Path(path).parent / buf.uri).read_bytes()
336
+
337
+ def read_accessor(acc_idx: int) -> np.ndarray:
338
+ acc = gltf.accessors[acc_idx]
339
+ bv = gltf.bufferViews[acc.bufferView]
340
+ raw = get_buffer_bytes(bv.buffer)
341
+ COMP = {5120: ('b',1), 5121: ('B',1), 5122: ('h',2),
342
+ 5123: ('H',2), 5125: ('I',4), 5126: ('f',4)}
343
+ DIMS = {"SCALAR":1,"VEC2":2,"VEC3":3,"VEC4":4,"MAT2":4,"MAT3":9,"MAT4":16}
344
+ fmt, sz = COMP[acc.componentType]
345
+ dim = DIMS[acc.type]
346
+ start = (bv.byteOffset or 0) + (acc.byteOffset or 0)
347
+ stride = bv.byteStride
348
+ if stride is None or stride == 0 or stride == sz * dim:
349
+ chunk = raw[start: start + acc.count * sz * dim]
350
+ return np.frombuffer(chunk, dtype=fmt).reshape(acc.count, dim).astype(np.float32)
351
+ rows = []
352
+ for i in range(acc.count):
353
+ off = start + i * stride
354
+ rows.append(np.frombuffer(raw[off: off + sz * dim], dtype=fmt))
355
+ return np.stack(rows).astype(np.float32)
356
+
357
+ ibm = read_accessor(skin.inverseBindMatrices).reshape(-1, 4, 4)
358
+ joint_set = set(joint_indices)
359
+ ni_name = {ni: (gltf.nodes[ni].name or f"bone_{ni}") for ni in joint_indices}
360
+
361
+ bones: dict[str, Bone] = {}
362
+ for i, ni in enumerate(joint_indices):
363
+ name = ni_name[ni]
364
+ world_mat = np.linalg.inv(ibm[i])
365
+ bones[name] = Bone(name=name, parent=None,
366
+ world_rest_pos=world_mat[:3, 3].astype(np.float32))
367
+
368
+ for ni in joint_indices:
369
+ for ci in (gltf.nodes[ni].children or []):
370
+ if ci in joint_set:
371
+ p, c = ni_name[ni], ni_name[ci]
372
+ bones[c].parent = p
373
+ bones[p].children.append(c)
374
+
375
+ print(f"[GLB] {len(bones)} bones from skin '{gltf.skins[0].name or 'Armature'}'")
376
+ return UnirigSkeleton(bones)
377
+
378
+ # ══════════════════════════════════════════════════════════════════════════════
379
+ # ASCII FBX skeleton parser
380
+ # ══════════════════════════════════════════════════════════════════════════════
381
+
382
+ def parse_fbx_ascii_skeleton(path: str) -> UnirigSkeleton:
383
+ """Parse ASCII-format FBX for LimbNode / Root bones."""
384
+ raw = Path(path).read_bytes()
385
+ if raw[:4] == b"Kayd":
386
+ sys.exit(
387
+ f"[ERROR] {path} is binary FBX.\n"
388
+ "Convert to GLB first, e.g.:\n"
389
+ " gltf-pipeline -i rigged.fbx -o rigged.glb"
390
+ )
391
+ text = raw.decode("utf-8", errors="replace")
392
+
393
+ model_pat = re.compile(
394
+ r'Model:\s*(\d+),\s*"Model::([^"]+)",\s*"(LimbNode|Root|Null)"'
395
+ r'.*?Properties70:\s*\{(.*?)\}',
396
+ re.DOTALL
397
+ )
398
+ trans_pat = re.compile(
399
+ r'P:\s*"Lcl Translation".*?(-?[\d.e+\-]+),\s*(-?[\d.e+\-]+),\s*(-?[\d.e+\-]+)'
400
+ )
401
+
402
+ uid_name: dict[str, str] = {}
403
+ uid_local: dict[str, np.ndarray] = {}
404
+
405
+ for m in model_pat.finditer(text):
406
+ uid, name = m.group(1), m.group(2)
407
+ uid_name[uid] = name
408
+ tm = trans_pat.search(m.group(4))
409
+ uid_local[uid] = (np.array([float(tm.group(i)) for i in (1,2,3)], dtype=np.float32)
410
+ if tm else np.zeros(3, dtype=np.float32))
411
+
412
+ if not uid_name:
413
+ sys.exit("[ERROR] No LimbNode/Root bones found in FBX")
414
+
415
+ conn_pat = re.compile(r'C:\s*"OO",\s*(\d+),\s*(\d+)')
416
+ uid_parent: dict[str, str] = {}
417
+ for m in conn_pat.finditer(text):
418
+ child, par = m.group(1), m.group(2)
419
+ if child in uid_name and par in uid_name:
420
+ uid_parent[child] = par
421
+
422
+ # Detect cm vs m
423
+ all_y = np.array([t[1] for t in uid_local.values()])
424
+ scale = 0.01 if all_y.max() > 10.0 else 1.0
425
+ if scale != 1.0:
426
+ print(f"[FBX] Centimetre units detected — scaling by {scale}")
427
+ for uid in uid_local:
428
+ uid_local[uid] *= scale
429
+
430
+ # Accumulate world translations (topological order)
431
+ def topo(uid_to_par):
432
+ visited, order = set(), []
433
+ def visit(u):
434
+ if u in visited: return
435
+ visited.add(u)
436
+ if u in uid_to_par: visit(uid_to_par[u])
437
+ order.append(u)
438
+ for u in uid_to_par: visit(u)
439
+ for u in uid_name:
440
+ if u not in visited: order.append(u)
441
+ return order
442
+
443
+ world: dict[str, np.ndarray] = {}
444
+ for uid in topo(uid_parent):
445
+ loc = uid_local.get(uid, np.zeros(3, dtype=np.float32))
446
+ world[uid] = (world.get(uid_parent[uid], np.zeros(3, dtype=np.float32)) + loc
447
+ if uid in uid_parent else loc.copy())
448
+
449
+ bones: dict[str, Bone] = {}
450
+ for uid, name in uid_name.items():
451
+ bones[name] = Bone(name=name, parent=None, world_rest_pos=world[uid])
452
+
453
+ for uid, p_uid in uid_parent.items():
454
+ c, p = uid_name[uid], uid_name[p_uid]
455
+ bones[c].parent = p
456
+ if c not in bones[p].children:
457
+ bones[p].children.append(c)
458
+
459
+ print(f"[FBX] {len(bones)} bones parsed from ASCII FBX")
460
+ return UnirigSkeleton(bones)
461
+
462
+ # ══════════════════════════════════════════════════════════════════════════════
463
+ # Auto bone mapping: UniRig bones to SMPL joints
464
+ # ══════════════════════════════════════════════════════════════════════════════
465
+
466
+ # Keyword table: normalised name fragments -> SMPL joint index
467
+ _NAME_HINTS: list[tuple[list[str], int]] = [
468
+ (["hips","pelvis","root","hip"], 0),
469
+ (["leftupleg","l_upleg","lupleg","leftthigh","lefthip",
470
+ "left_upper_leg","l_thigh","thigh_l","upperleg_l","j_bip_l_upperleg"], 1),
471
+ (["rightupleg","r_upleg","rupleg","rightthigh","righthip",
472
+ "right_upper_leg","r_thigh","thigh_r","upperleg_r","j_bip_r_upperleg"], 2),
473
+ (["spine","spine0","spine_01","j_bip_c_spine"], 3),
474
+ (["leftleg","leftknee","l_leg","lleg","leftlowerleg",
475
+ "left_lower_leg","lowerleg_l","knee_l","j_bip_l_lowerleg"], 4),
476
+ (["rightleg","rightknee","r_leg","rleg","rightlowerleg",
477
+ "right_lower_leg","lowerleg_r","knee_r","j_bip_r_lowerleg"], 5),
478
+ (["spine1","spine_02","j_bip_c_spine1"], 6),
479
+ (["leftfoot","left_foot","l_foot","lfoot","foot_l","j_bip_l_foot"], 7),
480
+ (["rightfoot","right_foot","r_foot","rfoot","foot_r","j_bip_r_foot"], 8),
481
+ (["spine2","spine_03","j_bip_c_spine2","chest"], 9),
482
+ (["lefttoebase","lefttoe","l_toe","ltoe","toe_l"], 10),
483
+ (["righttoebase","righttoe","r_toe","rtoe","toe_r"], 11),
484
+ (["neck","j_bip_c_neck"], 12),
485
+ (["leftshoulder","leftcollar","l_shoulder","leftclavicle",
486
+ "clavicle_l","j_bip_l_shoulder"], 13),
487
+ (["rightshoulder","rightcollar","r_shoulder","rightclavicle",
488
+ "clavicle_r","j_bip_r_shoulder"], 14),
489
+ (["head","j_bip_c_head"], 15),
490
+ (["leftarm","leftupper","l_arm","larm","leftupperarm",
491
+ "upperarm_l","j_bip_l_upperarm"], 16),
492
+ (["rightarm","rightupper","r_arm","rarm","rightupperarm",
493
+ "upperarm_r","j_bip_r_upperarm"], 17),
494
+ (["leftforearm","leftlower","l_forearm","lforearm",
495
+ "lowerarm_l","j_bip_l_lowerarm"], 18),
496
+ (["rightforearm","rightlower","r_forearm","rforearm",
497
+ "lowerarm_r","j_bip_r_lowerarm"], 19),
498
+ (["lefthand","l_hand","lhand","hand_l","j_bip_l_hand"], 20),
499
+ (["righthand","r_hand","rhand","hand_r","j_bip_r_hand"], 21),
500
+ ]
501
+
502
+
503
+ def _strip_name(name: str) -> str:
504
+ """Remove common rig namespace prefixes, then lower-case, remove separators."""
505
+ name = re.sub(r'^(mixamorig:|j_bip_[lcr]_|cc_base_|bip01_|rig:|chr:)',
506
+ "", name, flags=re.IGNORECASE)
507
+ return re.sub(r'[_\-\s.]', "", name).lower()
508
+
509
+
510
+ def _normalise_positions(pos: np.ndarray) -> np.ndarray:
511
+ """Normalise [N, 3] to [0,1] in Y, [-1,1] in X and Z."""
512
+ y_min, y_max = pos[:, 1].min(), pos[:, 1].max()
513
+ h = (y_max - y_min) or 1.0
514
+ xr = (pos[:, 0].max() - pos[:, 0].min()) or 1.0
515
+ zr = (pos[:, 2].max() - pos[:, 2].min()) or 1.0
516
+ out = pos.copy()
517
+ out[:, 0] /= xr
518
+ out[:, 1] = (out[:, 1] - y_min) / h
519
+ out[:, 2] /= zr
520
+ return out
521
+
522
+
523
+ def auto_map(skel: UnirigSkeleton, verbose: bool = True) -> None:
524
+ """
525
+ Assign skel.bones[name].smpl_idx for each UniRig bone that best matches
526
+ an SMPL joint. Score = 0.6 * position_proximity + 0.4 * name_hint.
527
+ Greedy: each SMPL joint taken by at most one UniRig bone.
528
+ Bones with combined score < 0.35 are left unmapped (identity in BVH).
529
+ """
530
+ names = list(skel.bones.keys())
531
+ ur_pos = np.stack([skel.bones[n].world_rest_pos for n in names]) # [M, 3]
532
+
533
+ # Scale SMPL T-pose to match UniRig height
534
+ ur_h = ur_pos[:, 1].max() - ur_pos[:, 1].min()
535
+ sm_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
536
+ sm_pos = SMPL_TPOSE * ((ur_h / sm_h) if sm_h > 1e-6 else 1.0)
537
+
538
+ all_norm = _normalise_positions(np.concatenate([ur_pos, sm_pos]))
539
+ ur_norm = all_norm[:len(names)]
540
+ sm_norm = all_norm[len(names):]
541
+
542
+ # Distance score [M, 22]
543
+ dist = np.linalg.norm(ur_norm[:, None] - sm_norm[None], axis=-1)
544
+ dist_sc = 1.0 - np.clip(dist / (dist.max() + 1e-9), 0, 1)
545
+
546
+ # Name score [M, 22]
547
+ norm_names = [_strip_name(n) for n in names]
548
+ name_sc = np.array(
549
+ [[1.0 if norm in kws else 0.0
550
+ for kws, _ in _NAME_HINTS]
551
+ for norm in norm_names],
552
+ dtype=np.float32,
553
+ ) # [M, 22]
554
+
555
+ combined = 0.6 * dist_sc + 0.4 * name_sc # [M, 22]
556
+
557
+ # Greedy assignment
558
+ THRESHOLD = 0.35
559
+ taken_smpl: set[int] = set()
560
+ pairs = sorted(
561
+ ((i, j, combined[i, j])
562
+ for i in range(len(names)) for j in range(NUM_SMPL)),
563
+ key=lambda x: -x[2],
564
+ )
565
+ for bi, si, score in pairs:
566
+ if score < THRESHOLD:
567
+ break
568
+ name = names[bi]
569
+ if skel.bones[name].smpl_idx is not None or si in taken_smpl:
570
+ continue
571
+ skel.bones[name].smpl_idx = si
572
+ taken_smpl.add(si)
573
+
574
+ if verbose:
575
+ mapped = [(n, b.smpl_idx) for n, b in skel.bones.items() if b.smpl_idx is not None]
576
+ unmapped = [n for n, b in skel.bones.items() if b.smpl_idx is None]
577
+ print(f"\n[MAP] {len(mapped)}/{len(skel.bones)} bones mapped to SMPL joints:")
578
+ for ur_name, si in sorted(mapped, key=lambda x: x[1]):
579
+ sc = combined[names.index(ur_name), si]
580
+ print(f" {ur_name:40s} -> {SMPL_NAMES[si]:16s} score={sc:.2f}")
581
+ if unmapped:
582
+ print(f"[MAP] {len(unmapped)} unmapped (identity rotation): "
583
+ + ", ".join(unmapped[:8])
584
+ + (" ..." if len(unmapped) > 8 else ""))
585
+ print()
586
+
587
+ # ══════════════════════════════════════════════════════════════════════════════
588
+ # BVH writers
589
+ # ══════════════════════════════════════════════════════════════════════════════
590
+
591
+ def _smpl_offsets(tpose: np.ndarray) -> np.ndarray:
592
+ offsets = np.zeros_like(tpose)
593
+ for j, p in enumerate(SMPL_PARENT):
594
+ offsets[j] = tpose[j] if p < 0 else tpose[j] - tpose[p]
595
+ return offsets
596
+
597
+
598
+ def write_bvh_smpl(output_path: str, positions: np.ndarray, fps: int = 20) -> None:
599
+ """BVH with standard SMPL bone names (no rig file needed)."""
600
+ T = positions.shape[0]
601
+ tpose = _scale_smpl_tpose(positions)
602
+ offsets = _smpl_offsets(tpose)
603
+ tp_w = tpose + (positions[0, 0] - tpose[0])
604
+ local_q = positions_to_local_quats(positions, tp_w)
605
+ euler = quat_to_euler_ZXY(local_q)
606
+
607
+ with open(output_path, "w") as f:
608
+ f.write("HIERARCHY\n")
609
+
610
+ def wj(j, ind):
611
+ off = offsets[j]
612
+ f.write(f"{'ROOT' if SMPL_PARENT[j]<0 else ind+'JOINT'} {SMPL_NAMES[j]}\n")
613
+ f.write(f"{ind}{{\n")
614
+ f.write(f"{ind}\tOFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}\n")
615
+ if SMPL_PARENT[j] < 0:
616
+ f.write(f"{ind}\tCHANNELS 6 Xposition Yposition Zposition "
617
+ "Zrotation Xrotation Yrotation\n")
618
+ else:
619
+ f.write(f"{ind}\tCHANNELS 3 Zrotation Xrotation Yrotation\n")
620
+ for c in _SMPL_CHILDREN[j]:
621
+ wj(c, ind + "\t")
622
+ if not _SMPL_CHILDREN[j]:
623
+ f.write(f"{ind}\tEnd Site\n{ind}\t{{\n"
624
+ f"{ind}\t\tOFFSET 0.000000 0.050000 0.000000\n{ind}\t}}\n")
625
+ f.write(f"{ind}}}\n")
626
+
627
+ wj(0, "")
628
+ f.write(f"MOTION\nFrames: {T}\nFrame Time: {1.0/fps:.8f}\n")
629
+ for t in range(T):
630
+ rp = positions[t, 0]
631
+ row = [f"{rp[0]:.6f}", f"{rp[1]:.6f}", f"{rp[2]:.6f}"]
632
+ for j in SMPL_DFS:
633
+ rz, rx, ry = euler[t, j]
634
+ row += [f"{rz:.6f}", f"{rx:.6f}", f"{ry:.6f}"]
635
+ f.write(" ".join(row) + "\n")
636
+
637
+ print(f"[OK] {T} frames @ {fps} fps -> {output_path} (SMPL skeleton)")
638
+
639
+
640
+ def write_bvh_unirig(output_path: str,
641
+ positions: np.ndarray,
642
+ skel: UnirigSkeleton,
643
+ fps: int = 20) -> None:
644
+ """
645
+ BVH using UniRig bone names and hierarchy.
646
+ Mapped bones receive SMPL-derived local rotations with rest-pose correction.
647
+ Unmapped bones (fingers, face bones, etc.) are set to identity.
648
+ """
649
+ T = positions.shape[0]
650
+
651
+ # Compute SMPL local quaternions
652
+ tpose = _scale_smpl_tpose(positions)
653
+ tp_w = tpose + (positions[0, 0] - tpose[0])
654
+ smpl_q = positions_to_local_quats(positions, tp_w) # [T, 22, 4]
655
+ smpl_rd = _rest_dirs(tp_w, _SMPL_CHILDREN, SMPL_PARENT) # [22, 3]
656
+
657
+ # Rest-pose correction quaternions per bone:
658
+ # q_corr = qbetween(unirig_rest_dir, smpl_rest_dir)
659
+ # unirig_local_q = smpl_local_q @ q_corr
660
+ # This ensures: when applied to unirig_rest_dir, the result matches
661
+ # the SMPL animated direction — accounting for any difference in
662
+ # rest-pose bone orientations between the two skeletons.
663
+ corrections: dict[str, np.ndarray] = {}
664
+ for name, bone in skel.bones.items():
665
+ if bone.smpl_idx is None:
666
+ continue
667
+ ur_rd = skel.rest_direction(name).astype(np.float32)
668
+ sm_rd = smpl_rd[bone.smpl_idx].astype(np.float32)
669
+ corrections[name] = qbetween(ur_rd[None], sm_rd[None])[0] # [4]
670
+
671
+ # Scale root translation from SMPL proportions to UniRig proportions
672
+ ur_h = (max(b.world_rest_pos[1] for b in skel.bones.values())
673
+ - min(b.world_rest_pos[1] for b in skel.bones.values()))
674
+ sm_h = tp_w[:, 1].max() - tp_w[:, 1].min()
675
+ pos_sc = (ur_h / sm_h) if sm_h > 1e-6 else 1.0
676
+
677
+ dfs = skel.dfs_order()
678
+ offsets = skel.local_offsets()
679
+
680
+ # Pre-compute euler per bone [T, 3]
681
+ ID_EUL = np.zeros((T, 3), dtype=np.float32)
682
+ bone_euler: dict[str, np.ndarray] = {}
683
+ for name, bone in skel.bones.items():
684
+ if bone.smpl_idx is not None:
685
+ q = smpl_q[:, bone.smpl_idx].copy() # [T, 4]
686
+ c = corrections.get(name)
687
+ if c is not None:
688
+ q = qnorm(qmul(q, np.broadcast_to(c[None], q.shape).copy()))
689
+ bone_euler[name] = quat_to_euler_ZXY(q) # [T, 3]
690
+ else:
691
+ bone_euler[name] = ID_EUL
692
+
693
+ with open(output_path, "w") as f:
694
+ f.write("HIERARCHY\n")
695
+
696
+ def wj(name, ind):
697
+ off = offsets[name]
698
+ bone = skel.bones[name]
699
+ f.write(f"{'ROOT' if bone.parent is None else ind+'JOINT'} {name}\n")
700
+ f.write(f"{ind}{{\n")
701
+ f.write(f"{ind}\tOFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}\n")
702
+ if bone.parent is None:
703
+ f.write(f"{ind}\tCHANNELS 6 Xposition Yposition Zposition "
704
+ "Zrotation Xrotation Yrotation\n")
705
+ else:
706
+ f.write(f"{ind}\tCHANNELS 3 Zrotation Xrotation Yrotation\n")
707
+ for c in bone.children:
708
+ wj(c, ind + "\t")
709
+ if not bone.children:
710
+ f.write(f"{ind}\tEnd Site\n{ind}\t{{\n"
711
+ f"{ind}\t\tOFFSET 0.000000 0.050000 0.000000\n{ind}\t}}\n")
712
+ f.write(f"{ind}}}\n")
713
+
714
+ wj(skel.root.name, "")
715
+ f.write(f"MOTION\nFrames: {T}\nFrame Time: {1.0/fps:.8f}\n")
716
+
717
+ for t in range(T):
718
+ rp = positions[t, 0] * pos_sc
719
+ row = [f"{rp[0]:.6f}", f"{rp[1]:.6f}", f"{rp[2]:.6f}"]
720
+ for name in dfs:
721
+ rz, rx, ry = bone_euler[name][t]
722
+ row += [f"{rz:.6f}", f"{rx:.6f}", f"{ry:.6f}"]
723
+ f.write(" ".join(row) + "\n")
724
+
725
+ n_mapped = sum(1 for b in skel.bones.values() if b.smpl_idx is not None)
726
+ print(f"[OK] {T} frames @ {fps} fps -> {output_path} "
727
+ f"(UniRig: {n_mapped} driven, {len(skel.bones)-n_mapped} identity)")
728
+
729
+ # ══════════════════════════════════════════════════════════════════════════════
730
+ # Motion loader
731
+ # ══════════════════════════════════════════════════════════════════════════════
732
+
733
+ def load_motion(npy_path: str) -> tuple[np.ndarray, int]:
734
+ """Return (positions [T, 22, 3], fps). Auto-detects HumanML3D format."""
735
+ data = np.load(npy_path).astype(np.float32)
736
+ print(f"[INFO] {npy_path} shape={data.shape}")
737
+
738
+ if data.ndim == 3 and data.shape[1] == 22 and data.shape[2] == 3:
739
+ print("[INFO] Format: new_joints [T, 22, 3]")
740
+ return data, 20
741
+
742
+ if data.ndim == 2 and data.shape[1] == 263:
743
+ print("[INFO] Format: new_joint_vecs [T, 263]")
744
+ pos = recover_from_ric(data, 22)
745
+ print(f"[INFO] Recovered positions {pos.shape}")
746
+ return pos, 20
747
+
748
+ if data.ndim == 2 and data.shape[1] == 272:
749
+ print("[INFO] Format: 272-dim (30 fps)")
750
+ return recover_from_ric(data[:, :263], 22), 30
751
+
752
+ if (data.ndim == 2 and data.shape[1] == 251) or \
753
+ (data.ndim == 3 and data.shape[1] == 21):
754
+ sys.exit("[ERROR] KIT-ML (21-joint) format not yet supported.")
755
+
756
+ sys.exit(f"[ERROR] Unrecognised shape {data.shape}. "
757
+ "Expected [T,22,3] or [T,263].")
758
+
759
+ # ══════════════════════════════════════════════════════════════════════════════
760
+ # CLI
761
+ # ══════════════════════════════════════════════════════════════════════════════
762
+
763
+ def main() -> None:
764
+ ap = argparse.ArgumentParser(
765
+ description="HumanML3D .npy -> BVH, optionally retargeted to UniRig skeleton",
766
+ formatter_class=argparse.RawDescriptionHelpFormatter,
767
+ epilog="""
768
+ Examples
769
+ python humanml3d_to_bvh.py 000001.npy
770
+ Standard SMPL-named BVH (no rig file needed)
771
+
772
+ python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb
773
+ BVH retargeted to UniRig bone names, auto-mapped by position + name
774
+
775
+ python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb -o anim.bvh --fps 20
776
+
777
+ Supported --rig formats
778
+ .glb / .gltf UniRig merge.sh output (requires: pip install pygltflib)
779
+ .fbx ASCII FBX only (binary FBX: convert to GLB first)
780
+ """)
781
+ ap.add_argument("input", help="HumanML3D .npy motion file")
782
+ ap.add_argument("--rig", default=None,
783
+ help="UniRig-rigged mesh .glb or ASCII .fbx for auto-mapping")
784
+ ap.add_argument("-o", "--output", default=None, help="Output .bvh path")
785
+ ap.add_argument("--fps", type=int, default=0,
786
+ help="Override FPS (default: auto from format)")
787
+ ap.add_argument("--quiet", action="store_true",
788
+ help="Suppress mapping table")
789
+ args = ap.parse_args()
790
+
791
+ inp = Path(args.input)
792
+ out = Path(args.output) if args.output else inp.with_suffix(".bvh")
793
+
794
+ positions, auto_fps = load_motion(str(inp))
795
+ fps = args.fps if args.fps > 0 else auto_fps
796
+
797
+ if args.rig:
798
+ ext = Path(args.rig).suffix.lower()
799
+ if ext in (".glb", ".gltf"):
800
+ skel = parse_glb_skeleton(args.rig)
801
+ elif ext == ".fbx":
802
+ skel = parse_fbx_ascii_skeleton(args.rig)
803
+ else:
804
+ sys.exit(f"[ERROR] Unsupported rig format: {ext} (use .glb or .fbx)")
805
+
806
+ auto_map(skel, verbose=not args.quiet)
807
+ write_bvh_unirig(str(out), positions, skel, fps=fps)
808
+ else:
809
+ write_bvh_smpl(str(out), positions, fps=fps)
810
+
811
+
812
+ if __name__ == "__main__":
813
+ main()
Retarget/io/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """rig_retarget.io — file format readers / writers."""
Retarget/io/bvh.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ io/bvh.py
3
+ BVH (Biovision Hierarchy) reader.
4
+
5
+ Returns an Armature in rest pose plus an iterator / list of frame states.
6
+ Each frame state sets the bone pose_rotation_quat / pose_location on the
7
+ source armature so that retarget.get_bone_ws_quat / get_bone_position_ws
8
+ return the correct world-space values.
9
+ """
10
+ from __future__ import annotations
11
+ import math
12
+ import re
13
+ from typing import Dict, List, Optional, Tuple
14
+ import numpy as np
15
+
16
+ from ..skeleton import Armature, PoseBone
17
+ from ..math3d import translation_matrix, euler_to_quat, quat_identity, vec3
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Internal BVH data structures
22
+ # ---------------------------------------------------------------------------
23
+
24
+ class _BVHJoint:
25
+ def __init__(self, name: str):
26
+ self.name = name
27
+ self.offset: np.ndarray = vec3()
28
+ self.channels: List[str] = []
29
+ self.children: List["_BVHJoint"] = []
30
+ self.is_end_site: bool = False
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Parser
35
+ # ---------------------------------------------------------------------------
36
+
37
+ def _tokenize(text: str) -> List[str]:
38
+ return re.split(r"[\s]+", text.strip())
39
+
40
+
41
+ def _parse_hierarchy(tokens: List[str], idx: int) -> Tuple[_BVHJoint, int]:
42
+ """Parse one joint block. idx should point to joint name."""
43
+ name = tokens[idx]; idx += 1
44
+ joint = _BVHJoint(name)
45
+ assert tokens[idx] == "{", f"Expected '{{' got '{tokens[idx]}'"
46
+ idx += 1
47
+ while tokens[idx] != "}":
48
+ kw = tokens[idx].upper()
49
+ if kw == "OFFSET":
50
+ joint.offset = np.array([float(tokens[idx+1]), float(tokens[idx+2]), float(tokens[idx+3])])
51
+ idx += 4
52
+ elif kw == "CHANNELS":
53
+ n = int(tokens[idx+1]); idx += 2
54
+ joint.channels = [tokens[idx+i].upper() for i in range(n)]
55
+ idx += n
56
+ elif kw == "JOINT":
57
+ idx += 1
58
+ child, idx = _parse_hierarchy(tokens, idx)
59
+ joint.children.append(child)
60
+ elif kw == "END" and tokens[idx+1].upper() == "SITE":
61
+ # End Site block — just parse and discard
62
+ idx += 2
63
+ assert tokens[idx] == "{"; idx += 1
64
+ while tokens[idx] != "}":
65
+ idx += 1
66
+ idx += 1 # skip '}'
67
+ else:
68
+ idx += 1 # unknown token, skip
69
+ idx += 1 # skip '}'
70
+ return joint, idx
71
+
72
+
73
+ def _collect_joints(joint: _BVHJoint) -> List[_BVHJoint]:
74
+ result = [joint]
75
+ for c in joint.children:
76
+ result.extend(_collect_joints(c))
77
+ return result
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Build Armature from BVH hierarchy
82
+ # ---------------------------------------------------------------------------
83
+
84
+ def _build_armature(root_joint: _BVHJoint) -> Armature:
85
+ arm = Armature("BVH_Source")
86
+
87
+ def add_recursive(j: _BVHJoint, parent_name: Optional[str], parent_world: np.ndarray):
88
+ # rest_matrix_local = T(offset) relative to parent
89
+ rest_local = translation_matrix(j.offset)
90
+ bone = PoseBone(j.name, rest_local)
91
+ arm.add_bone(bone, parent_name)
92
+ world = parent_world @ rest_local
93
+ for child in j.children:
94
+ add_recursive(child, j.name, world)
95
+
96
+ add_recursive(root_joint, None, np.eye(4))
97
+ arm.update_fk()
98
+ return arm
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Frame application
103
+ # ---------------------------------------------------------------------------
104
+
105
+ _CHANNEL_MAP = {
106
+ "XROTATION": ("rx",), "YROTATION": ("ry",), "ZROTATION": ("rz",),
107
+ "XPOSITION": ("tx",), "YPOSITION": ("ty",), "ZPOSITION": ("tz",),
108
+ }
109
+
110
+
111
+ def _apply_frame(arm: Armature, all_joints: List[_BVHJoint], values: List[float]) -> None:
112
+ """Set bone poses for one BVH frame."""
113
+ vi = 0
114
+ for j in all_joints:
115
+ tx = ty = tz = 0.0
116
+ rx = ry = rz = 0.0
117
+ for ch in j.channels:
118
+ key = _CHANNEL_MAP.get(ch, None)
119
+ if key:
120
+ val = values[vi]
121
+ k = key[0]
122
+ if k == "tx": tx = val
123
+ elif k == "ty": ty = val
124
+ elif k == "tz": tz = val
125
+ elif k == "rx": rx = math.radians(val)
126
+ elif k == "ry": ry = math.radians(val)
127
+ elif k == "rz": rz = math.radians(val)
128
+ vi += 1
129
+
130
+ if j.name not in arm.pose_bones:
131
+ continue
132
+ bone = arm.pose_bones[j.name]
133
+
134
+ # BVH rotation order is specified per channel list; rebuild from order
135
+ rot_channels = [c for c in j.channels if "ROTATION" in c]
136
+ order = "".join(c[0] for c in rot_channels) # e.g. "ZXY"
137
+ angles = {"X": rx, "Y": ry, "Z": rz}
138
+ angle_seq = [angles[a] for a in order]
139
+ bone.pose_rotation_quat = euler_to_quat(*angle_seq, order=order)
140
+
141
+ # Translation — only root joints typically have it
142
+ if tx or ty or tz:
143
+ bone.pose_location = np.array([tx, ty, tz])
144
+
145
+ arm.update_fk()
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # Public API
150
+ # ---------------------------------------------------------------------------
151
+
152
+ class BVHAnimation:
153
+ """Loaded BVH file. Iterate frames by calling advance(frame_index)."""
154
+
155
+ def __init__(
156
+ self,
157
+ armature: Armature,
158
+ all_joints: List[_BVHJoint],
159
+ frame_data: List[List[float]],
160
+ frame_time: float,
161
+ ):
162
+ self.armature = armature
163
+ self._all_joints = all_joints
164
+ self._frame_data = frame_data
165
+ self.frame_time = frame_time
166
+ self.num_frames = len(frame_data)
167
+
168
+ def apply_frame(self, frame_index: int) -> None:
169
+ """Advance armature to frame_index and update FK."""
170
+ if frame_index < 0 or frame_index >= self.num_frames:
171
+ raise IndexError(f"Frame {frame_index} out of range [0, {self.num_frames})")
172
+ _apply_frame(self.armature, self._all_joints, self._frame_data[frame_index])
173
+
174
+
175
+ def load_bvh(filepath: str) -> BVHAnimation:
176
+ """
177
+ Parse a BVH file.
178
+ Returns BVHAnimation with an Armature ready for retargeting.
179
+ """
180
+ with open(filepath, "r") as f:
181
+ text = f.read()
182
+
183
+ tokens = _tokenize(text)
184
+ idx = 0
185
+
186
+ # Expect HIERARCHY keyword
187
+ while tokens[idx].upper() != "HIERARCHY":
188
+ idx += 1
189
+ idx += 1
190
+
191
+ root_kw = tokens[idx].upper()
192
+ assert root_kw in ("ROOT", "JOINT"), f"Expected ROOT/JOINT, got '{tokens[idx]}'"
193
+ idx += 1
194
+ root_joint, idx = _parse_hierarchy(tokens, idx)
195
+
196
+ # MOTION section
197
+ while tokens[idx].upper() != "MOTION":
198
+ idx += 1
199
+ idx += 1
200
+
201
+ assert tokens[idx].upper() == "FRAMES:"; idx += 1
202
+ num_frames = int(tokens[idx]); idx += 1
203
+ assert tokens[idx].upper() == "FRAME"; assert tokens[idx+1].upper() == "TIME:"; idx += 2
204
+ frame_time = float(tokens[idx]); idx += 1
205
+
206
+ all_joints = _collect_joints(root_joint)
207
+ total_channels = sum(len(j.channels) for j in all_joints)
208
+
209
+ frame_data: List[List[float]] = []
210
+ for _ in range(num_frames):
211
+ row = [float(tokens[idx + k]) for k in range(total_channels)]
212
+ idx += total_channels
213
+ frame_data.append(row)
214
+
215
+ arm = _build_armature(root_joint)
216
+ return BVHAnimation(arm, all_joints, frame_data, frame_time)
Retarget/io/gltf_io.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ io/gltf_io.py
3
+ Load a glTF/GLB skeleton (e.g. UniRig output) into an Armature.
4
+ Write retargeted animation back into a glTF/GLB file.
5
+
6
+ Requires: pip install pygltflib
7
+ """
8
+ from __future__ import annotations
9
+ import base64
10
+ import json
11
+ import struct
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+ import numpy as np
15
+
16
+ try:
17
+ import pygltflib
18
+ except ImportError:
19
+ raise ImportError("pip install pygltflib")
20
+
21
+ from ..skeleton import Armature, PoseBone
22
+ from ..math3d import (
23
+ quat_identity, quat_normalize, matrix4_to_quat, matrix4_to_trs,
24
+ trs_to_matrix4, vec3,
25
+ )
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Helpers
30
+ # ---------------------------------------------------------------------------
31
+
32
+ def _node_local_trs(node: "pygltflib.Node"):
33
+ """Extract TRS from a glTF node. Returns (t[3], r_wxyz[4], s[3])."""
34
+ t = np.array(node.translation or [0.0, 0.0, 0.0])
35
+ r_xyzw = np.array(node.rotation or [0.0, 0.0, 0.0, 1.0])
36
+ s = np.array(node.scale or [1.0, 1.0, 1.0])
37
+ # Convert glTF (x,y,z,w) → our (w,x,y,z)
38
+ r_wxyz = np.array([r_xyzw[3], r_xyzw[0], r_xyzw[1], r_xyzw[2]])
39
+ return t, r_wxyz, s
40
+
41
+
42
+ def _node_local_matrix(node: "pygltflib.Node") -> np.ndarray:
43
+ if node.matrix:
44
+ # glTF stores column-major; convert to row-major
45
+ m = np.array(node.matrix, dtype=float).reshape(4, 4).T
46
+ return m
47
+ t, r, s = _node_local_trs(node)
48
+ return trs_to_matrix4(t, r, s)
49
+
50
+
51
+ def _read_accessor(gltf: "pygltflib.GLTF2", accessor_idx: int) -> np.ndarray:
52
+ """Read a glTF accessor into a numpy array."""
53
+ acc = gltf.accessors[accessor_idx]
54
+ bv = gltf.bufferViews[acc.bufferView]
55
+ buf = gltf.buffers[bv.buffer]
56
+
57
+ # Inline base64 data URI
58
+ if buf.uri and buf.uri.startswith("data:"):
59
+ _, b64 = buf.uri.split(",", 1)
60
+ raw = base64.b64decode(b64)
61
+ elif buf.uri:
62
+ base_dir = Path(gltf._path).parent if hasattr(gltf, "_path") and gltf._path else Path(".")
63
+ raw = (base_dir / buf.uri).read_bytes()
64
+ else:
65
+ # Binary GLB — data stored in gltf.binary_blob
66
+ raw = bytes(gltf.binary_blob())
67
+
68
+ start = bv.byteOffset + (acc.byteOffset or 0)
69
+ count = acc.count
70
+
71
+ type_to_components = {
72
+ "SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4,
73
+ "MAT2": 4, "MAT3": 9, "MAT4": 16,
74
+ }
75
+ component_type_to_fmt = {
76
+ 5120: "b", 5121: "B", 5122: "h", 5123: "H",
77
+ 5125: "I", 5126: "f",
78
+ }
79
+ n_comp = type_to_components[acc.type]
80
+ fmt = component_type_to_fmt[acc.componentType]
81
+ item_size = struct.calcsize(fmt) * n_comp
82
+ stride = bv.byteStride or item_size
83
+
84
+ items = []
85
+ for i in range(count):
86
+ offset = start + i * stride
87
+ vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
88
+ items.append(vals)
89
+
90
+ return np.array(items, dtype=float).squeeze()
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Load skeleton from glTF
95
+ # ---------------------------------------------------------------------------
96
+
97
+ def load_gltf(filepath: str, skin_index: int = 0) -> Armature:
98
+ """
99
+ Load the first (or specified) skin from a glTF/GLB file into an Armature.
100
+
101
+ The armature world_matrix is set to identity (typical for UniRig output).
102
+ """
103
+ gltf = pygltflib.GLTF2().load(filepath)
104
+ gltf._path = filepath
105
+
106
+ if not gltf.skins:
107
+ raise ValueError(f"No skins found in '{filepath}'")
108
+ skin = gltf.skins[skin_index]
109
+
110
+ # Read inverse bind matrices
111
+ n_joints = len(skin.joints)
112
+ ibm_array: Optional[np.ndarray] = None
113
+ if skin.inverseBindMatrices is not None:
114
+ raw = _read_accessor(gltf, skin.inverseBindMatrices)
115
+ ibm_array = raw.reshape(n_joints, 4, 4)
116
+
117
+ # Compute bind-pose world matrices: world_bind = inv(ibm)
118
+ joint_world_bind: Dict[int, np.ndarray] = {}
119
+ for i, j_idx in enumerate(skin.joints):
120
+ if ibm_array is not None:
121
+ ibm = ibm_array[i].T # glTF column-major → numpy row-major
122
+ joint_world_bind[j_idx] = np.linalg.inv(ibm)
123
+ else:
124
+ # Fallback: compute from FK over node local matrices
125
+ joint_world_bind[j_idx] = np.eye(4)
126
+
127
+ # Build parent map for nodes
128
+ parent_of: Dict[int, Optional[int]] = {}
129
+ for ni, node in enumerate(gltf.nodes):
130
+ for child_idx in (node.children or []):
131
+ parent_of[child_idx] = ni
132
+
133
+ arm = Armature(skin.name or f"Skin_{skin_index}")
134
+
135
+ # Process joints in order (parent always before child in glTF spec)
136
+ joint_set = set(skin.joints)
137
+ processed: Dict[int, str] = {}
138
+
139
+ for i, j_idx in enumerate(skin.joints):
140
+ node = gltf.nodes[j_idx]
141
+ bone_name = node.name or f"joint_{i}"
142
+
143
+ # Find parent joint node
144
+ parent_node_idx = parent_of.get(j_idx)
145
+ parent_bone_name: Optional[str] = None
146
+ while parent_node_idx is not None:
147
+ if parent_node_idx in joint_set:
148
+ parent_bone_name = processed.get(parent_node_idx)
149
+ break
150
+ parent_node_idx = parent_of.get(parent_node_idx)
151
+
152
+ # rest_matrix_local in parent space
153
+ if parent_bone_name and parent_bone_name in [b for b in processed.values()]:
154
+ parent_world = joint_world_bind.get(
155
+ next(k for k, v in processed.items() if v == parent_bone_name),
156
+ np.eye(4)
157
+ )
158
+ rest_local = np.linalg.inv(parent_world) @ joint_world_bind[j_idx]
159
+ else:
160
+ rest_local = joint_world_bind[j_idx]
161
+
162
+ bone = PoseBone(bone_name, rest_local)
163
+ arm.add_bone(bone, parent_bone_name)
164
+ processed[j_idx] = bone_name
165
+
166
+ arm.update_fk()
167
+ return arm
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Write animation to glTF
172
+ # ---------------------------------------------------------------------------
173
+
174
+ def write_gltf_animation(
175
+ source_filepath: str,
176
+ dest_armature: Armature,
177
+ keyframes: List[Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]],
178
+ output_filepath: str,
179
+ fps: float = 30.0,
180
+ skin_index: int = 0,
181
+ ) -> None:
182
+ """
183
+ Embed animation keyframes into a copy of source_filepath (the UniRig GLB).
184
+
185
+ keyframes: list of dicts, one per frame.
186
+ Each dict maps bone_name → (pose_location, pose_rotation_quat, pose_scale)
187
+ These are LOCAL values (relative to rest pose local matrix).
188
+
189
+ The function adds one glTF Animation with channels for each bone that has data.
190
+ """
191
+ gltf = pygltflib.GLTF2().load(source_filepath)
192
+ gltf._path = source_filepath
193
+
194
+ if not gltf.skins:
195
+ raise ValueError("No skins in source file")
196
+ skin = gltf.skins[skin_index]
197
+
198
+ # Build node_name → node_index map for skin joints
199
+ joint_name_to_node: Dict[str, int] = {}
200
+ for j_idx in skin.joints:
201
+ node = gltf.nodes[j_idx]
202
+ name = node.name or f"joint_{j_idx}"
203
+ joint_name_to_node[name] = j_idx
204
+
205
+ n_frames = len(keyframes)
206
+ times = np.array([i / fps for i in range(n_frames)], dtype=np.float32)
207
+
208
+ # Gather binary data
209
+ binary_chunks: List[bytes] = []
210
+ accessors: List[dict] = []
211
+ buffer_views: List[dict] = []
212
+
213
+ def _add_data(data: np.ndarray, acc_type: str) -> int:
214
+ """Append numpy array to binary, return accessor index."""
215
+ raw = data.astype(np.float32).tobytes()
216
+ bv_offset = sum(len(c) for c in binary_chunks)
217
+ binary_chunks.append(raw)
218
+ bv_idx = len(gltf.bufferViews)
219
+ gltf.bufferViews.append(pygltflib.BufferView(
220
+ buffer=0,
221
+ byteOffset=bv_offset,
222
+ byteLength=len(raw),
223
+ ))
224
+ acc_idx = len(gltf.accessors)
225
+ gltf.accessors.append(pygltflib.Accessor(
226
+ bufferView=bv_idx,
227
+ componentType=pygltflib.FLOAT,
228
+ count=len(data),
229
+ type=acc_type,
230
+ max=data.max(axis=0).tolist() if data.ndim > 1 else [float(data.max())],
231
+ min=data.min(axis=0).tolist() if data.ndim > 1 else [float(data.min())],
232
+ ))
233
+ return acc_idx
234
+
235
+ time_acc_idx = _add_data(times, "SCALAR")
236
+
237
+ channels: List[pygltflib.AnimationChannel] = []
238
+ samplers: List[pygltflib.AnimationSampler] = []
239
+
240
+ bone_names = set()
241
+ for frame in keyframes:
242
+ bone_names |= frame.keys()
243
+
244
+ for bone_name in sorted(bone_names):
245
+ if bone_name not in joint_name_to_node:
246
+ continue
247
+ node_idx = joint_name_to_node[bone_name]
248
+ node = gltf.nodes[node_idx]
249
+
250
+ # Collect TRS arrays across frames
251
+ rot_data = np.zeros((n_frames, 4), dtype=np.float32) # (x,y,z,w)
252
+ trans_data = np.zeros((n_frames, 3), dtype=np.float32)
253
+ scale_data = np.ones((n_frames, 3), dtype=np.float32)
254
+
255
+ rest_t, rest_r, rest_s = _node_local_trs(node)
256
+
257
+ for fi, frame in enumerate(keyframes):
258
+ if bone_name in frame:
259
+ pose_loc, pose_rot, pose_scale = frame[bone_name]
260
+ else:
261
+ pose_loc = vec3()
262
+ pose_rot = quat_identity()
263
+ pose_scale = np.ones(3)
264
+
265
+ # Final local = rest + delta (simple addition for translation, multiply for rotation)
266
+ from ..math3d import quat_mul, trs_to_matrix4
267
+ final_t = rest_t + pose_loc
268
+ final_r = quat_mul(rest_r, pose_rot) # (w,x,y,z)
269
+ final_s = rest_s * pose_scale
270
+
271
+ # Convert rotation to glTF (x,y,z,w)
272
+ w, x, y, z = final_r
273
+ rot_data[fi] = [x, y, z, w]
274
+ trans_data[fi] = final_t
275
+ scale_data[fi] = final_s
276
+
277
+ s_idx = len(samplers)
278
+ rot_acc = _add_data(rot_data, "VEC4")
279
+ samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=rot_acc, interpolation="LINEAR"))
280
+ channels.append(pygltflib.AnimationChannel(
281
+ sampler=s_idx,
282
+ target=pygltflib.AnimationChannelTarget(node=node_idx, path="rotation"),
283
+ ))
284
+
285
+ s_idx = len(samplers)
286
+ trans_acc = _add_data(trans_data, "VEC3")
287
+ samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=trans_acc, interpolation="LINEAR"))
288
+ channels.append(pygltflib.AnimationChannel(
289
+ sampler=s_idx,
290
+ target=pygltflib.AnimationChannelTarget(node=node_idx, path="translation"),
291
+ ))
292
+
293
+ if not channels:
294
+ print("[gltf_io] Warning: no channels written — check bone name mapping.")
295
+ return
296
+
297
+ gltf.animations.append(pygltflib.Animation(
298
+ name="RetargetedAnimation",
299
+ samplers=samplers,
300
+ channels=channels,
301
+ ))
302
+
303
+ # Patch buffer 0 size with our new data
304
+ new_blob = b"".join(binary_chunks)
305
+ existing_blob = bytes(gltf.binary_blob()) if gltf.binary_blob() else b""
306
+ full_blob = existing_blob + new_blob
307
+
308
+ # Update buffer 0 byteOffset of new views
309
+ for bv in gltf.bufferViews[-len(binary_chunks):]:
310
+ bv.byteOffset += len(existing_blob)
311
+
312
+ gltf.set_binary_blob(full_blob)
313
+ gltf.buffers[0].byteLength = len(full_blob)
314
+
315
+ gltf.save(output_filepath)
316
+ print(f"[gltf_io] Saved animated GLB -> {output_filepath}")
Retarget/io/mapping.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ io/mapping.py
3
+ Load / save bone mapping JSON in the exact same format as KeeMap.
4
+ """
5
+ from __future__ import annotations
6
+ import json
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Optional
9
+ import numpy as np
10
+ from ..math3d import quat_identity, vec3
11
+
12
+
13
+ @dataclass
14
+ class BoneMappingItem:
15
+ name: str = ""
16
+ label: str = ""
17
+ description: str = ""
18
+
19
+ source_bone_name: str = ""
20
+ destination_bone_name: str = ""
21
+
22
+ keyframe_this_bone: bool = True
23
+
24
+ # Rotation correction (Euler, radians)
25
+ correction_factor: np.ndarray = field(default_factory=lambda: vec3())
26
+
27
+ # Quaternion correction
28
+ quat_correction_factor: np.ndarray = field(default_factory=quat_identity)
29
+
30
+ has_twist_bone: bool = False
31
+ twist_bone_name: str = ""
32
+
33
+ set_bone_position: bool = False
34
+ set_bone_rotation: bool = True
35
+ set_bone_scale: bool = False
36
+
37
+ # Rotation options
38
+ bone_rotation_application_axis: str = "XYZ" # X Y Z XY XZ YZ XYZ
39
+ bone_transpose_axis: str = "NONE" # NONE ZXY ZYX XZY YZX YXZ
40
+
41
+ # Position options
42
+ postion_type: str = "SINGLE_BONE_OFFSET" # SINGLE_BONE_OFFSET | POLE
43
+ position_correction_factor: np.ndarray = field(default_factory=lambda: vec3())
44
+ position_gain: float = 1.0
45
+ position_pole_distance: float = 0.3
46
+
47
+ # Scale options
48
+ scale_secondary_bone_name: str = ""
49
+ bone_scale_application_axis: str = "Y"
50
+ scale_gain: float = 1.0
51
+ scale_max: float = 1.0
52
+ scale_min: float = 0.5
53
+
54
+
55
+ @dataclass
56
+ class KeeMapSettings:
57
+ source_rig_name: str = ""
58
+ destination_rig_name: str = ""
59
+ bone_mapping_file: str = ""
60
+ bone_rotation_mode: str = "EULER" # EULER | QUATERNION
61
+ start_frame_to_apply: int = 0
62
+ number_of_frames_to_apply: int = 100
63
+ keyframe_every_n_frames: int = 1
64
+ keyframe_test: bool = False
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # Load
69
+ # ---------------------------------------------------------------------------
70
+
71
+ def load_mapping(filepath: str):
72
+ """
73
+ Returns (KeeMapSettings, List[BoneMappingItem]).
74
+ Reads the exact same JSON that KeeMap writes.
75
+ """
76
+ with open(filepath, "r") as f:
77
+ data = json.load(f)
78
+
79
+ settings = KeeMapSettings(
80
+ source_rig_name=data.get("source_rig_name", ""),
81
+ destination_rig_name=data.get("destination_rig_name", ""),
82
+ bone_mapping_file=data.get("bone_mapping_file", ""),
83
+ bone_rotation_mode=data.get("bone_rotation_mode", "EULER"),
84
+ start_frame_to_apply=data.get("start_frame_to_apply", 0),
85
+ number_of_frames_to_apply=data.get("number_of_frames_to_apply", 100),
86
+ keyframe_every_n_frames=data.get("keyframe_every_n_frames", 1),
87
+ )
88
+
89
+ bones: List[BoneMappingItem] = []
90
+ for p in data.get("bones", []):
91
+ item = BoneMappingItem()
92
+ item.name = p.get("name", "")
93
+ item.label = p.get("label", "")
94
+ item.description = p.get("description", "")
95
+ item.source_bone_name = p.get("SourceBoneName", "")
96
+ item.destination_bone_name = p.get("DestinationBoneName", "")
97
+ item.keyframe_this_bone = p.get("keyframe_this_bone", True)
98
+
99
+ item.correction_factor = np.array([
100
+ p.get("CorrectionFactorX", 0.0),
101
+ p.get("CorrectionFactorY", 0.0),
102
+ p.get("CorrectionFactorZ", 0.0),
103
+ ])
104
+
105
+ item.quat_correction_factor = np.array([
106
+ p.get("QuatCorrectionFactorw", 1.0),
107
+ p.get("QuatCorrectionFactorx", 0.0),
108
+ p.get("QuatCorrectionFactory", 0.0),
109
+ p.get("QuatCorrectionFactorz", 0.0),
110
+ ])
111
+
112
+ item.has_twist_bone = p.get("has_twist_bone", False)
113
+ item.twist_bone_name = p.get("TwistBoneName", "")
114
+ item.set_bone_position = p.get("set_bone_position", False)
115
+ item.set_bone_rotation = p.get("set_bone_rotation", True)
116
+ item.set_bone_scale = p.get("set_bone_scale", False)
117
+ item.bone_rotation_application_axis = p.get("bone_rotation_application_axis", "XYZ")
118
+ item.bone_transpose_axis = p.get("bone_transpose_axis", "NONE")
119
+ item.postion_type = p.get("postion_type", "SINGLE_BONE_OFFSET")
120
+
121
+ item.position_correction_factor = np.array([
122
+ p.get("position_correction_factorX", 0.0),
123
+ p.get("position_correction_factorY", 0.0),
124
+ p.get("position_correction_factorZ", 0.0),
125
+ ])
126
+ item.position_gain = p.get("position_gain", 1.0)
127
+ item.position_pole_distance = p.get("position_pole_distance", 0.3)
128
+ item.scale_secondary_bone_name = p.get("scale_secondary_bone_name", "")
129
+ item.bone_scale_application_axis = p.get("bone_scale_application_axis", "Y")
130
+ item.scale_gain = p.get("scale_gain", 1.0)
131
+ item.scale_max = p.get("scale_max", 1.0)
132
+ item.scale_min = p.get("scale_min", 0.5)
133
+ bones.append(item)
134
+
135
+ return settings, bones
136
+
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # Save
140
+ # ---------------------------------------------------------------------------
141
+
142
+ def save_mapping(filepath: str, settings: KeeMapSettings, bones: List[BoneMappingItem]) -> None:
143
+ """Write mapping JSON readable by KeeMap."""
144
+ root = {
145
+ "source_rig_name": settings.source_rig_name,
146
+ "destination_rig_name": settings.destination_rig_name,
147
+ "bone_mapping_file": settings.bone_mapping_file,
148
+ "bone_rotation_mode": settings.bone_rotation_mode,
149
+ "start_frame_to_apply": settings.start_frame_to_apply,
150
+ "number_of_frames_to_apply": settings.number_of_frames_to_apply,
151
+ "keyframe_every_n_frames": settings.keyframe_every_n_frames,
152
+ "bones": [],
153
+ }
154
+ for b in bones:
155
+ root["bones"].append({
156
+ "name": b.name,
157
+ "label": b.label,
158
+ "description": b.description,
159
+ "SourceBoneName": b.source_bone_name,
160
+ "DestinationBoneName": b.destination_bone_name,
161
+ "keyframe_this_bone": b.keyframe_this_bone,
162
+ "CorrectionFactorX": float(b.correction_factor[0]),
163
+ "CorrectionFactorY": float(b.correction_factor[1]),
164
+ "CorrectionFactorZ": float(b.correction_factor[2]),
165
+ "QuatCorrectionFactorw": float(b.quat_correction_factor[0]),
166
+ "QuatCorrectionFactorx": float(b.quat_correction_factor[1]),
167
+ "QuatCorrectionFactory": float(b.quat_correction_factor[2]),
168
+ "QuatCorrectionFactorz": float(b.quat_correction_factor[3]),
169
+ "has_twist_bone": b.has_twist_bone,
170
+ "TwistBoneName": b.twist_bone_name,
171
+ "set_bone_position": b.set_bone_position,
172
+ "set_bone_rotation": b.set_bone_rotation,
173
+ "set_bone_scale": b.set_bone_scale,
174
+ "bone_rotation_application_axis": b.bone_rotation_application_axis,
175
+ "bone_transpose_axis": b.bone_transpose_axis,
176
+ "postion_type": b.postion_type,
177
+ "position_correction_factorX": float(b.position_correction_factor[0]),
178
+ "position_correction_factorY": float(b.position_correction_factor[1]),
179
+ "position_correction_factorZ": float(b.position_correction_factor[2]),
180
+ "position_gain": b.position_gain,
181
+ "position_pole_distance": b.position_pole_distance,
182
+ "scale_secondary_bone_name": b.scale_secondary_bone_name,
183
+ "bone_scale_application_axis": b.bone_scale_application_axis,
184
+ "scale_gain": b.scale_gain,
185
+ "scale_max": b.scale_max,
186
+ "scale_min": b.scale_min,
187
+ })
188
+ with open(filepath, "w") as f:
189
+ json.dump(root, f, indent=2)
Retarget/math3d.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ math3d.py
3
+ Pure numpy / scipy replacement for Blender's mathutils.
4
+ Quaternion convention throughout: (w, x, y, z)
5
+ Matrix convention: row-major, right-multiplied (Numpy default)
6
+ """
7
+ from __future__ import annotations
8
+ import numpy as np
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Quaternion helpers (w, x, y, z)
13
+ # ---------------------------------------------------------------------------
14
+
15
+ def quat_identity() -> np.ndarray:
16
+ return np.array([1.0, 0.0, 0.0, 0.0])
17
+
18
+
19
+ def quat_normalize(q: np.ndarray) -> np.ndarray:
20
+ n = np.linalg.norm(q)
21
+ return q / n if n > 1e-12 else quat_identity()
22
+
23
+
24
+ def quat_conjugate(q: np.ndarray) -> np.ndarray:
25
+ """Conjugate == inverse for unit quaternion."""
26
+ return np.array([q[0], -q[1], -q[2], -q[3]])
27
+
28
+
29
+ def quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
30
+ """Quaternion multiplication (Blender @ operator)."""
31
+ w1, x1, y1, z1 = q1
32
+ w2, x2, y2, z2 = q2
33
+ return np.array([
34
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
35
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
36
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
37
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
38
+ ])
39
+
40
+
41
+ def quat_rotation_difference(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
42
+ """
43
+ Rotation that takes q1 to q2.
44
+ r such that q1 @ r == q2
45
+ r = conj(q1) @ q2
46
+ Matches Blender's Quaternion.rotation_difference()
47
+ """
48
+ return quat_normalize(quat_mul(quat_conjugate(q1), q2))
49
+
50
+
51
+ def quat_dot(q1: np.ndarray, q2: np.ndarray) -> float:
52
+ """Dot product of two quaternions (used for scale retargeting)."""
53
+ return float(np.dot(q1, q2))
54
+
55
+
56
+ def quat_to_matrix4(q: np.ndarray) -> np.ndarray:
57
+ """Unit quaternion (w,x,y,z) → 4×4 rotation matrix."""
58
+ w, x, y, z = q
59
+ m = np.array([
60
+ [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w), 0],
61
+ [ 2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w), 0],
62
+ [ 2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y), 0],
63
+ [ 0, 0, 0, 1],
64
+ ], dtype=float)
65
+ return m
66
+
67
+
68
+ def matrix4_to_quat(m: np.ndarray) -> np.ndarray:
69
+ """4×4 matrix → unit quaternion (w,x,y,z)."""
70
+ r = Rotation.from_matrix(m[:3, :3])
71
+ x, y, z, w = r.as_quat() # scipy uses (x,y,z,w)
72
+ q = np.array([w, x, y, z])
73
+ # Ensure positive w to match Blender convention
74
+ if q[0] < 0:
75
+ q = -q
76
+ return quat_normalize(q)
77
+
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # Euler ↔ Quaternion
81
+ # ---------------------------------------------------------------------------
82
+
83
+ def euler_to_quat(rx: float, ry: float, rz: float, order: str = "XYZ") -> np.ndarray:
84
+ """Euler angles (radians) to quaternion (w,x,y,z)."""
85
+ r = Rotation.from_euler(order, [rx, ry, rz])
86
+ x, y, z, w = r.as_quat()
87
+ return quat_normalize(np.array([w, x, y, z]))
88
+
89
+
90
+ def quat_to_euler(q: np.ndarray, order: str = "XYZ") -> np.ndarray:
91
+ """Quaternion (w,x,y,z) to Euler angles (radians)."""
92
+ w, x, y, z = q
93
+ r = Rotation.from_quat([x, y, z, w])
94
+ return r.as_euler(order)
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Matrix constructors
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def translation_matrix(v) -> np.ndarray:
102
+ m = np.eye(4)
103
+ m[0, 3] = v[0]
104
+ m[1, 3] = v[1]
105
+ m[2, 3] = v[2]
106
+ return m
107
+
108
+
109
+ def scale_matrix(s) -> np.ndarray:
110
+ m = np.eye(4)
111
+ m[0, 0] = s[0]
112
+ m[1, 1] = s[1]
113
+ m[2, 2] = s[2]
114
+ return m
115
+
116
+
117
+ def trs_to_matrix4(t, r_quat, s) -> np.ndarray:
118
+ """Combine translation, rotation (w,x,y,z quat), scale into 4×4."""
119
+ T = translation_matrix(t)
120
+ R = quat_to_matrix4(r_quat)
121
+ S = scale_matrix(s)
122
+ return T @ R @ S
123
+
124
+
125
+ def matrix4_to_trs(m: np.ndarray):
126
+ """Decompose 4×4 into (translation[3], rotation_quat[4], scale[3])."""
127
+ t = m[:3, 3].copy()
128
+ sx = np.linalg.norm(m[:3, 0])
129
+ sy = np.linalg.norm(m[:3, 1])
130
+ sz = np.linalg.norm(m[:3, 2])
131
+ s = np.array([sx, sy, sz])
132
+ rot_m = m[:3, :3].copy()
133
+ if sx > 1e-12: rot_m[:, 0] /= sx
134
+ if sy > 1e-12: rot_m[:, 1] /= sy
135
+ if sz > 1e-12: rot_m[:, 2] /= sz
136
+ r = Rotation.from_matrix(rot_m)
137
+ x, y, z, w = r.as_quat()
138
+ q = np.array([w, x, y, z])
139
+ if q[0] < 0:
140
+ q = -q
141
+ return t, quat_normalize(q), s
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Vector helpers
146
+ # ---------------------------------------------------------------------------
147
+
148
+ def vec3(x=0.0, y=0.0, z=0.0) -> np.ndarray:
149
+ return np.array([x, y, z], dtype=float)
150
+
151
+
152
+ def get_point_on_vector(initial_pt: np.ndarray, terminal_pt: np.ndarray, distance: float) -> np.ndarray:
153
+ """
154
+ Point at 'distance' from initial_pt along (initial_pt → terminal_pt).
155
+ Matches Blender's get_point_on_vector helper in KeeMapBoneOperators.
156
+ """
157
+ n = initial_pt - terminal_pt
158
+ norm = np.linalg.norm(n)
159
+ if norm < 1e-12:
160
+ return initial_pt.copy()
161
+ n = n / norm
162
+ return initial_pt - distance * n
163
+
164
+
165
+ def apply_rotation_matrix4(m: np.ndarray, v: np.ndarray) -> np.ndarray:
166
+ """Apply only the rotation part of a 4×4 matrix to a 3-vector."""
167
+ return m[:3, :3] @ v
Retarget/retarget.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ retarget.py
3
+ Pure-Python port of KeeMapBoneOperators.py core math.
4
+
5
+ Replaces bpy / mathutils with numpy. No Blender dependency.
6
+ Public API mirrors the Blender operator flow:
7
+
8
+ get_bone_position_ws(bone, arm) → np.ndarray(3)
9
+ get_bone_ws_quat(bone, arm) → np.ndarray(4) w,x,y,z
10
+ set_bone_position_ws(bone, arm, pos)
11
+ set_bone_rotation(...)
12
+ set_bone_position(...)
13
+ set_bone_position_pole(...)
14
+ set_bone_scale(...)
15
+ calc_rotation_offset(bone_item, src_arm, dst_arm, settings)
16
+ calc_location_offset(bone_item, src_arm, dst_arm)
17
+ transfer_frame(src_arm, dst_arm, bone_items, settings, do_keyframe)
18
+ → Dict[bone_name → (pose_loc, pose_rot, pose_scale)]
19
+ transfer_animation(src_anim, dst_arm, bone_items, settings)
20
+ → List[Dict[bone_name → (pose_loc, pose_rot, pose_scale)]]
21
+ """
22
+ from __future__ import annotations
23
+ import math
24
+ import sys
25
+ from typing import Dict, List, Optional, Tuple
26
+ import numpy as np
27
+
28
+ from .skeleton import Armature, PoseBone
29
+ from .math3d import (
30
+ quat_identity, quat_normalize, quat_mul, quat_conjugate,
31
+ quat_rotation_difference, quat_dot,
32
+ quat_to_matrix4, matrix4_to_quat,
33
+ euler_to_quat, quat_to_euler,
34
+ translation_matrix, vec3, get_point_on_vector,
35
+ )
36
+ from .io.mapping import BoneMappingItem, KeeMapSettings
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Progress bar (console)
41
+ # ---------------------------------------------------------------------------
42
+
43
+ def _update_progress(job: str, progress: float) -> None:
44
+ length = 40
45
+ block = int(round(length * progress))
46
+ msg = f"\r{job}: [{'#'*block}{'-'*(length-block)}] {round(progress*100, 1)}%"
47
+ if progress >= 1:
48
+ msg += " DONE\r\n"
49
+ sys.stdout.write(msg)
50
+ sys.stdout.flush()
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # World-space position / quaternion getters
55
+ # ---------------------------------------------------------------------------
56
+
57
+ def get_bone_position_ws(bone: PoseBone, arm: Armature) -> np.ndarray:
58
+ """
59
+ Return world-space position of bone head.
60
+ Equivalent to Blender's GetBonePositionWS().
61
+ """
62
+ ws_matrix = arm.world_matrix @ bone.matrix_armature
63
+ return ws_matrix[:3, 3].copy()
64
+
65
+
66
+ def get_bone_ws_quat(bone: PoseBone, arm: Armature) -> np.ndarray:
67
+ """
68
+ Return world-space rotation as quaternion (w,x,y,z).
69
+ Equivalent to Blender's GetBoneWSQuat().
70
+ """
71
+ ws_matrix = arm.world_matrix @ bone.matrix_armature
72
+ return matrix4_to_quat(ws_matrix)
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # World-space position setter
77
+ # ---------------------------------------------------------------------------
78
+
79
+ def set_bone_position_ws(bone: PoseBone, arm: Armature, position: np.ndarray) -> None:
80
+ """
81
+ Move bone so its world-space head = position.
82
+ Equivalent to Blender's SetBonePositionWS().
83
+
84
+ Strategy:
85
+ 1. Build new armature-space matrix = old rotation + new translation
86
+ 2. Strip parent transform to get new local translation
87
+ 3. Update pose_location so FK matches
88
+ """
89
+ # Current armature-space matrix (rotation/scale part preserved)
90
+ arm_mat = bone.matrix_armature.copy()
91
+
92
+ # Target armature-space position
93
+ arm_world_inv = np.linalg.inv(arm.world_matrix)
94
+ target_arm_pos = (arm_world_inv @ np.append(position, 1.0))[:3]
95
+
96
+ # New armature-space matrix with replaced translation
97
+ new_arm_mat = arm_mat.copy()
98
+ new_arm_mat[:3, 3] = target_arm_pos
99
+
100
+ # Convert to local (parent-relative) space
101
+ if bone.parent is not None:
102
+ parent_arm_mat = bone.parent.matrix_armature
103
+ new_local = np.linalg.inv(parent_arm_mat) @ new_arm_mat
104
+ else:
105
+ new_local = new_arm_mat
106
+
107
+ # Extract translation from new_local = rest_local @ T(pose_loc) @ ...
108
+ # Approximate: strip rest_local rotation contribution to isolate pose_location
109
+ rest_inv = np.linalg.inv(bone.rest_matrix_local)
110
+ pose_delta = rest_inv @ new_local
111
+ bone.pose_location = pose_delta[:3, 3].copy()
112
+
113
+ # Recompute FK for this bone and its subtree
114
+ if bone.parent is not None:
115
+ bone._fk(bone.parent.matrix_armature)
116
+ else:
117
+ bone._fk(np.eye(4))
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Rotation setter (core retargeting math)
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def set_bone_rotation(
125
+ src_arm: Armature, src_name: str,
126
+ dst_arm: Armature, dst_name: str,
127
+ dst_twist_name: str,
128
+ correction_quat: np.ndarray,
129
+ has_twist: bool,
130
+ xfer_axis: str,
131
+ transpose: str,
132
+ mode: str,
133
+ ) -> None:
134
+ """
135
+ Port of Blender's SetBoneRotation().
136
+ Drives dst bone rotation to match src bone world-space rotation.
137
+
138
+ mode: "EULER" | "QUATERNION"
139
+ xfer_axis: "X" "Y" "Z" "XY" "XZ" "YZ" "XYZ"
140
+ transpose: "NONE" "ZYX" "ZXY" "XZY" "YZX" "YXZ"
141
+ """
142
+ src_bone = src_arm.get_bone(src_name)
143
+ dst_bone = dst_arm.get_bone(dst_name)
144
+
145
+ # ------------------------------------------------------------------
146
+ # Get source and destination world-space quaternions (current pose)
147
+ # ------------------------------------------------------------------
148
+ src_ws_quat = get_bone_ws_quat(src_bone, src_arm)
149
+ dst_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
150
+
151
+ # Rotation difference: r such that dst_ws @ r ≈ src_ws
152
+ diff = quat_rotation_difference(dst_ws_quat, src_ws_quat)
153
+
154
+ # FinalQuat = dst_local_pose_delta @ diff @ correction
155
+ final_quat = quat_normalize(
156
+ quat_mul(quat_mul(dst_bone.pose_rotation_quat, diff), correction_quat)
157
+ )
158
+
159
+ # ------------------------------------------------------------------
160
+ # Apply axis masking / transpose (EULER mode)
161
+ # ------------------------------------------------------------------
162
+ if mode == "EULER":
163
+ euler = quat_to_euler(final_quat, order="XYZ")
164
+
165
+ # Transpose axes
166
+ if transpose == "ZYX":
167
+ euler = np.array([euler[2], euler[1], euler[0]])
168
+ elif transpose == "ZXY":
169
+ euler = np.array([euler[2], euler[0], euler[1]])
170
+ elif transpose == "XZY":
171
+ euler = np.array([euler[0], euler[2], euler[1]])
172
+ elif transpose == "YZX":
173
+ euler = np.array([euler[1], euler[2], euler[0]])
174
+ elif transpose == "YXZ":
175
+ euler = np.array([euler[1], euler[0], euler[2]])
176
+ # else NONE — no change
177
+
178
+ # Mask axes
179
+ if xfer_axis == "X":
180
+ euler[1] = 0.0; euler[2] = 0.0
181
+ elif xfer_axis == "Y":
182
+ euler[0] = 0.0; euler[2] = 0.0
183
+ elif xfer_axis == "Z":
184
+ euler[0] = 0.0; euler[1] = 0.0
185
+ elif xfer_axis == "XY":
186
+ euler[2] = 0.0
187
+ elif xfer_axis == "XZ":
188
+ euler[1] = 0.0
189
+ elif xfer_axis == "YZ":
190
+ euler[0] = 0.0
191
+ # XYZ → no masking
192
+
193
+ final_quat = euler_to_quat(euler[0], euler[1], euler[2], order="XYZ")
194
+
195
+ # Twist bone: peel Y rotation off to twist bone
196
+ if has_twist and dst_twist_name:
197
+ twist_bone = dst_arm.get_bone(dst_twist_name)
198
+ y_euler = quat_to_euler(final_quat, order="XYZ")[1]
199
+ # Remove Y from main bone
200
+ euler_no_y = quat_to_euler(final_quat, order="XYZ")
201
+ euler_no_y[1] = 0.0
202
+ final_quat = euler_to_quat(*euler_no_y, order="XYZ")
203
+ # Apply Y to twist bone
204
+ twist_euler = quat_to_euler(twist_bone.pose_rotation_quat, order="XYZ")
205
+ twist_euler[1] = math.degrees(y_euler)
206
+ twist_bone.pose_rotation_quat = euler_to_quat(*twist_euler, order="XYZ")
207
+
208
+ else: # QUATERNION
209
+ if final_quat[0] < 0:
210
+ final_quat = -final_quat
211
+ final_quat = quat_normalize(final_quat)
212
+
213
+ dst_bone.pose_rotation_quat = final_quat
214
+
215
+ # Recompute FK
216
+ parent = dst_bone.parent
217
+ dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # Position setter
222
+ # ---------------------------------------------------------------------------
223
+
224
+ def set_bone_position(
225
+ src_arm: Armature, src_name: str,
226
+ dst_arm: Armature, dst_name: str,
227
+ dst_twist_name: str,
228
+ correction: np.ndarray,
229
+ gain: float,
230
+ ) -> None:
231
+ """
232
+ Port of Blender's SetBonePosition().
233
+ Moves dst bone to match src bone world-space position, with offset/gain.
234
+ """
235
+ src_bone = src_arm.get_bone(src_name)
236
+ dst_bone = dst_arm.get_bone(dst_name)
237
+
238
+ target_ws = get_bone_position_ws(src_bone, src_arm)
239
+ set_bone_position_ws(dst_bone, dst_arm, target_ws)
240
+
241
+ # Apply correction and gain to pose_location
242
+ dst_bone.pose_location[0] = (dst_bone.pose_location[0] + correction[0]) * gain
243
+ dst_bone.pose_location[1] = (dst_bone.pose_location[1] + correction[1]) * gain
244
+ dst_bone.pose_location[2] = (dst_bone.pose_location[2] + correction[2]) * gain
245
+
246
+ parent = dst_bone.parent
247
+ dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
248
+
249
+
250
+ # ---------------------------------------------------------------------------
251
+ # Pole bone position setter
252
+ # ---------------------------------------------------------------------------
253
+
254
+ def set_bone_position_pole(
255
+ src_arm: Armature, src_name: str,
256
+ dst_arm: Armature, dst_name: str,
257
+ dst_twist_name: str,
258
+ pole_distance: float,
259
+ ) -> None:
260
+ """
261
+ Port of Blender's SetBonePositionPole().
262
+ Positions an IK pole target relative to source limb geometry.
263
+ """
264
+ src_bone = src_arm.get_bone(src_name)
265
+ dst_bone = dst_arm.get_bone(dst_name)
266
+
267
+ parent_src = src_bone.parent_recursive[0] if src_bone.parent_recursive else src_bone
268
+
269
+ base_parent_ws = get_bone_position_ws(parent_src, src_arm)
270
+ base_child_ws = get_bone_position_ws(src_bone, src_arm)
271
+
272
+ # Tail = head + Y-axis direction of bone in world space
273
+ src_ws_mat = src_arm.world_matrix @ src_bone.matrix_armature
274
+ tail_ws = src_ws_mat[:3, 3] + src_ws_mat[:3, :3] @ np.array([0.0, 1.0, 0.0])
275
+
276
+ length_parent = np.linalg.norm(base_child_ws - base_parent_ws)
277
+ length_child = np.linalg.norm(tail_ws - base_child_ws)
278
+ total = length_parent + length_child
279
+
280
+ c_p_ratio = length_parent / total if total > 1e-12 else 0.5
281
+
282
+ length_pp_to_tail = np.linalg.norm(base_parent_ws - tail_ws)
283
+ average_location = get_point_on_vector(base_parent_ws, tail_ws, length_pp_to_tail * c_p_ratio)
284
+
285
+ distance = np.linalg.norm(base_child_ws - average_location)
286
+
287
+ if distance > 0.001:
288
+ pole_pos = get_point_on_vector(base_child_ws, average_location, pole_distance)
289
+ set_bone_position_ws(dst_bone, dst_arm, pole_pos)
290
+ parent = dst_bone.parent
291
+ dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
292
+
293
+
294
+ # ---------------------------------------------------------------------------
295
+ # Scale setter
296
+ # ---------------------------------------------------------------------------
297
+
298
+ def set_bone_scale(
299
+ src_arm: Armature, src_name: str,
300
+ dst_arm: Armature, dst_name: str,
301
+ src_scale_bone_name: str,
302
+ gain: float,
303
+ axis: str,
304
+ max_scale: float,
305
+ min_scale: float,
306
+ ) -> None:
307
+ """
308
+ Port of Blender's SetBoneScale().
309
+ Scales dst bone based on dot product between two source bone quaternions.
310
+ """
311
+ src_bone = src_arm.get_bone(src_name)
312
+ dst_bone = dst_arm.get_bone(dst_name)
313
+ secondary = src_arm.get_bone(src_scale_bone_name)
314
+
315
+ q1 = get_bone_ws_quat(src_bone, src_arm)
316
+ q2 = get_bone_ws_quat(secondary, src_arm)
317
+ amount = quat_dot(q1, q2) * gain
318
+
319
+ if amount < 0:
320
+ amount = -amount
321
+ amount = max(min_scale, min(max_scale, amount))
322
+
323
+ s = dst_bone.pose_scale
324
+ if axis == "X":
325
+ s[0] = amount
326
+ elif axis == "Y":
327
+ s[1] = amount
328
+ elif axis == "Z":
329
+ s[2] = amount
330
+ elif axis == "XY":
331
+ s[0] = s[1] = amount
332
+ elif axis == "XZ":
333
+ s[0] = s[2] = amount
334
+ elif axis == "YZ":
335
+ s[1] = s[2] = amount
336
+ else: # XYZ
337
+ s[:] = amount
338
+
339
+ parent = dst_bone.parent
340
+ dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
341
+
342
+
343
+ # ---------------------------------------------------------------------------
344
+ # Correction calculators
345
+ # ---------------------------------------------------------------------------
346
+
347
+ def calc_rotation_offset(
348
+ item: BoneMappingItem,
349
+ src_arm: Armature,
350
+ dst_arm: Armature,
351
+ settings: KeeMapSettings,
352
+ ) -> None:
353
+ """
354
+ Auto-compute the rotation correction factor for one bone mapping.
355
+ Port of Blender's CalcRotationOffset().
356
+ Modifies item.correction_factor and item.quat_correction_factor in-place.
357
+ """
358
+ if not item.source_bone_name or not item.destination_bone_name:
359
+ return
360
+ if not src_arm.has_bone(item.source_bone_name):
361
+ return
362
+ if not dst_arm.has_bone(item.destination_bone_name):
363
+ return
364
+
365
+ dst_bone = dst_arm.get_bone(item.destination_bone_name)
366
+
367
+ # Snapshot destination bone state
368
+ snap_r = dst_bone.pose_rotation_quat.copy()
369
+ snap_t = dst_bone.pose_location.copy()
370
+
371
+ starting_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
372
+
373
+ # Apply with identity correction
374
+ set_bone_rotation(
375
+ src_arm, item.source_bone_name,
376
+ dst_arm, item.destination_bone_name,
377
+ item.twist_bone_name,
378
+ quat_identity(),
379
+ False,
380
+ item.bone_rotation_application_axis,
381
+ item.bone_transpose_axis,
382
+ settings.bone_rotation_mode,
383
+ )
384
+ dst_arm.update_fk()
385
+
386
+ modified_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
387
+
388
+ # Correction = rotation that takes modified_ws back to starting_ws
389
+ q_diff = quat_rotation_difference(modified_ws_quat, starting_ws_quat)
390
+ euler = quat_to_euler(q_diff, order="XYZ")
391
+ item.correction_factor = euler.copy()
392
+ item.quat_correction_factor = q_diff.copy()
393
+
394
+ # Restore
395
+ dst_bone.pose_rotation_quat = snap_r
396
+ dst_bone.pose_location = snap_t
397
+ parent = dst_bone.parent
398
+ dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
399
+
400
+
401
+ def calc_location_offset(
402
+ item: BoneMappingItem,
403
+ src_arm: Armature,
404
+ dst_arm: Armature,
405
+ ) -> None:
406
+ """
407
+ Auto-compute position correction for one bone mapping.
408
+ Port of Blender's CalcLocationOffset().
409
+ """
410
+ if not item.source_bone_name or not item.destination_bone_name:
411
+ return
412
+ if not src_arm.has_bone(item.source_bone_name):
413
+ return
414
+ if not dst_arm.has_bone(item.destination_bone_name):
415
+ return
416
+
417
+ src_bone = src_arm.get_bone(item.source_bone_name)
418
+ dst_bone = dst_arm.get_bone(item.destination_bone_name)
419
+
420
+ source_ws_pos = get_bone_position_ws(src_bone, src_arm)
421
+ dest_ws_pos = get_bone_position_ws(dst_bone, dst_arm)
422
+
423
+ # Snapshot
424
+ snap_loc = dst_bone.pose_location.copy()
425
+
426
+ # Move dest to source position
427
+ set_bone_position_ws(dst_bone, dst_arm, source_ws_pos)
428
+ dst_arm.update_fk()
429
+ moved_pose_loc = dst_bone.pose_location.copy()
430
+
431
+ # Restore
432
+ set_bone_position_ws(dst_bone, dst_arm, dest_ws_pos)
433
+ dst_arm.update_fk()
434
+
435
+ delta = snap_loc - moved_pose_loc
436
+ item.position_correction_factor = delta.copy()
437
+
438
+
439
+ def calc_all_corrections(
440
+ bone_items: List[BoneMappingItem],
441
+ src_arm: Armature,
442
+ dst_arm: Armature,
443
+ settings: KeeMapSettings,
444
+ ) -> None:
445
+ """Auto-calculate rotation and position corrections for all mapped bones."""
446
+ for item in bone_items:
447
+ calc_rotation_offset(item, src_arm, dst_arm, settings)
448
+ if "pole" not in item.name.lower():
449
+ calc_location_offset(item, src_arm, dst_arm)
450
+
451
+
452
+ # ---------------------------------------------------------------------------
453
+ # Single-frame transfer
454
+ # ---------------------------------------------------------------------------
455
+
456
+ def transfer_frame(
457
+ src_arm: Armature,
458
+ dst_arm: Armature,
459
+ bone_items: List[BoneMappingItem],
460
+ settings: KeeMapSettings,
461
+ ) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
462
+ """
463
+ Apply retargeting for all bone mappings at the current source frame.
464
+ src_arm must already have FK updated for the current frame.
465
+
466
+ Returns a dict of bone_name → (pose_location, pose_rotation_quat, pose_scale)
467
+ suitable for writing into a keyframe list.
468
+ """
469
+ for item in bone_items:
470
+ if not item.source_bone_name or not item.destination_bone_name:
471
+ continue
472
+ if not src_arm.has_bone(item.source_bone_name):
473
+ continue
474
+ if not dst_arm.has_bone(item.destination_bone_name):
475
+ continue
476
+
477
+ # Build correction quaternion
478
+ if settings.bone_rotation_mode == "EULER":
479
+ cf = item.correction_factor
480
+ correction_quat = euler_to_quat(cf[0], cf[1], cf[2], order="XYZ")
481
+ else:
482
+ correction_quat = quat_normalize(item.quat_correction_factor)
483
+
484
+ # Rotation
485
+ if item.set_bone_rotation:
486
+ set_bone_rotation(
487
+ src_arm, item.source_bone_name,
488
+ dst_arm, item.destination_bone_name,
489
+ item.twist_bone_name,
490
+ correction_quat,
491
+ item.has_twist_bone,
492
+ item.bone_rotation_application_axis,
493
+ item.bone_transpose_axis,
494
+ settings.bone_rotation_mode,
495
+ )
496
+ dst_arm.update_fk()
497
+
498
+ # Position
499
+ if item.set_bone_position:
500
+ if item.postion_type == "SINGLE_BONE_OFFSET":
501
+ set_bone_position(
502
+ src_arm, item.source_bone_name,
503
+ dst_arm, item.destination_bone_name,
504
+ item.twist_bone_name,
505
+ item.position_correction_factor,
506
+ item.position_gain,
507
+ )
508
+ else:
509
+ set_bone_position_pole(
510
+ src_arm, item.source_bone_name,
511
+ dst_arm, item.destination_bone_name,
512
+ item.twist_bone_name,
513
+ -item.position_pole_distance,
514
+ )
515
+ dst_arm.update_fk()
516
+
517
+ # Scale
518
+ if item.set_bone_scale and item.scale_secondary_bone_name:
519
+ if src_arm.has_bone(item.scale_secondary_bone_name):
520
+ set_bone_scale(
521
+ src_arm, item.source_bone_name,
522
+ dst_arm, item.destination_bone_name,
523
+ item.scale_secondary_bone_name,
524
+ item.scale_gain,
525
+ item.bone_scale_application_axis,
526
+ item.scale_max,
527
+ item.scale_min,
528
+ )
529
+ dst_arm.update_fk()
530
+
531
+ # Snapshot destination bone state for this frame
532
+ result: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
533
+ for item in bone_items:
534
+ if not item.destination_bone_name:
535
+ continue
536
+ if not dst_arm.has_bone(item.destination_bone_name):
537
+ continue
538
+ dst_bone = dst_arm.get_bone(item.destination_bone_name)
539
+ result[item.destination_bone_name] = (
540
+ dst_bone.pose_location.copy(),
541
+ dst_bone.pose_rotation_quat.copy(),
542
+ dst_bone.pose_scale.copy(),
543
+ )
544
+ return result
545
+
546
+
547
+ # ---------------------------------------------------------------------------
548
+ # Full animation transfer
549
+ # ---------------------------------------------------------------------------
550
+
551
+ def transfer_animation(
552
+ src_anim, # BVHAnimation or any object with .armature + .apply_frame(i) + .num_frames
553
+ dst_arm: Armature,
554
+ bone_items: List[BoneMappingItem],
555
+ settings: KeeMapSettings,
556
+ ) -> List[Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]]:
557
+ """
558
+ Transfer all frames from src_anim to dst_arm.
559
+ Returns list of keyframe dicts (one per frame sampled).
560
+
561
+ Equivalent to Blender's PerformAnimationTransfer operator.
562
+ """
563
+ keyframes: List[Dict] = []
564
+ step = max(1, settings.keyframe_every_n_frames)
565
+ start = settings.start_frame_to_apply
566
+ total = settings.number_of_frames_to_apply
567
+ end = start + total
568
+
569
+ src_arm = src_anim.armature
570
+
571
+ i = start
572
+ n_steps = len(range(start, end, step))
573
+ step_i = 0
574
+ while i < end and i < src_anim.num_frames:
575
+ src_anim.apply_frame(i) # updates src_arm FK
576
+ dst_arm.update_fk()
577
+
578
+ frame_data = transfer_frame(src_arm, dst_arm, bone_items, settings)
579
+ keyframes.append(frame_data)
580
+
581
+ step_i += 1
582
+ _update_progress("Retargeting", step_i / n_steps)
583
+ i += step
584
+
585
+ _update_progress("Retargeting", 1.0)
586
+ return keyframes
Retarget/search.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ search.py
3
+ Stream TeoGchx/HumanML3D from HuggingFace and match motions by keyword.
4
+
5
+ Dataset: https://huggingface.co/datasets/TeoGchx/HumanML3D
6
+ Format: motion column is [T, 263] inline in parquet (standard HumanML3D)
7
+ Splits: train (23 384), val (1 460), test (4 384)
8
+
9
+ Usage
10
+ -----
11
+ from Retarget.search import search_motions
12
+
13
+ results = search_motions("a person walks forward", top_k=5)
14
+ for r in results:
15
+ print(r["caption"], r["frames"], "frames")
16
+ # r["motion"] → np.ndarray [T, 263]
17
+ """
18
+ from __future__ import annotations
19
+ import re
20
+ from typing import List, Optional
21
+ import numpy as np
22
+
23
+
24
+ # ─────────────────────────────────────────────────────────────────────────────
25
+ # Caption cleaning
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+
28
+ _SEP = re.compile(r'#|\|')
29
+ _POS_TAG = re.compile(r'^(?:[A-Z]{1,4}\s*)+$') # lines that look like POS tags
30
+
31
+
32
+ def _clean_caption(raw: str) -> str:
33
+ """
34
+ HumanML3D captions are stored as multiple sentences joined by '#',
35
+ sometimes followed by POS tag strings. Return the first human-readable
36
+ sentence.
37
+ """
38
+ parts = _SEP.split(raw)
39
+ for part in parts:
40
+ part = part.strip()
41
+ if not part:
42
+ continue
43
+ words = part.split()
44
+ # Skip if >50 % of tokens look like POS tags (all-caps, ≤4 chars)
45
+ pos_count = sum(1 for w in words if w.isupper() and len(w) <= 4)
46
+ if len(words) > 0 and pos_count / len(words) < 0.5:
47
+ return part
48
+ return parts[0].strip() if parts else raw.strip()
49
+
50
+
51
+ # ─────────────────────────────────────────────────────────────────────────────
52
+ # Search
53
+ # ─────────────────────────────────────────────────────────────────────────────
54
+
55
+ def search_motions(
56
+ query: str,
57
+ top_k: int = 8,
58
+ split: str = "test",
59
+ max_scan: int = 4384,
60
+ cached: bool = False,
61
+ ) -> List[dict]:
62
+ """
63
+ Stream TeoGchx/HumanML3D and return up to top_k motions matching query.
64
+
65
+ Parameters
66
+ ----------
67
+ query Natural-language description, e.g. "a person walks forward"
68
+ top_k Maximum number of results to return
69
+ split Dataset split — "test" (4 384 rows) is fastest to stream
70
+ max_scan Hard cap on rows examined before returning
71
+
72
+ Returns
73
+ -------
74
+ List of dicts, sorted by relevance score (descending):
75
+ caption str clean human-readable description
76
+ motion np.ndarray shape [T, 263], standard HumanML3D features
77
+ frames int number of frames (T)
78
+ duration float duration in seconds (at 20 fps)
79
+ name str original clip ID from dataset
80
+ score int keyword match score
81
+ """
82
+ try:
83
+ from datasets import load_dataset
84
+ except ImportError:
85
+ raise ImportError(
86
+ "pip install datasets (HuggingFace datasets library required)"
87
+ )
88
+
89
+ if cached:
90
+ # Downloads the split once (~400MB) and caches to ~/.cache/huggingface.
91
+ # Subsequent calls are instant. Use for local dev / testing.
92
+ ds = load_dataset("TeoGchx/HumanML3D", split=split)
93
+ else:
94
+ # Streaming: no disk cache, re-downloads each run. Good for server use.
95
+ ds = load_dataset("TeoGchx/HumanML3D", split=split, streaming=True)
96
+
97
+ # Tokenise query; remove punctuation
98
+ query_words = re.sub(r"[^\w\s]", "", query.lower()).split()
99
+ if not query_words:
100
+ return []
101
+
102
+ results: List[dict] = []
103
+ scanned = 0
104
+
105
+ for row in ds:
106
+ if scanned >= max_scan:
107
+ break
108
+ scanned += 1
109
+
110
+ caption_raw = row.get("caption", "") or ""
111
+ caption_clean = _clean_caption(caption_raw)
112
+ caption_lower = caption_clean.lower()
113
+
114
+ # Score: word-boundary matches count 2, substring matches count 1
115
+ score = 0
116
+ for kw in query_words:
117
+ if kw in caption_lower:
118
+ if re.search(r"\b" + re.escape(kw) + r"\b", caption_lower):
119
+ score += 2
120
+ else:
121
+ score += 1
122
+
123
+ if score == 0:
124
+ continue
125
+
126
+ motion_raw = row.get("motion")
127
+ if motion_raw is None:
128
+ continue
129
+
130
+ motion = np.array(motion_raw, dtype=np.float32) # [T, 263]
131
+ meta = row.get("meta_data") or {}
132
+
133
+ T = motion.shape[0]
134
+ frames = int(meta.get("num_frames", T))
135
+ duration = float(meta.get("duration", T / 20.0))
136
+
137
+ results.append({
138
+ "caption": caption_clean,
139
+ "motion": motion,
140
+ "frames": frames,
141
+ "duration": duration,
142
+ "name": str(meta.get("name", "")),
143
+ "score": score,
144
+ })
145
+
146
+ # Stop as soon as we have top_k results
147
+ if len(results) >= top_k:
148
+ break
149
+
150
+ results.sort(key=lambda x: -x["score"])
151
+ return results[:top_k]
152
+
153
+
154
+ def format_choice_label(result: dict) -> str:
155
+ """Short label for Gradio Radio component."""
156
+ caption = result["caption"]
157
+ if len(caption) > 72:
158
+ caption = caption[:72] + "…"
159
+ return f"{caption} ({result['frames']} frames, {result['duration']:.1f}s)"
Retarget/skeleton.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ skeleton.py
3
+ Pure-Python armature / pose-bone system.
4
+
5
+ Design matches Blender's pose-mode semantics:
6
+ - bone.rest_matrix_local = 4×4 rest pose in parent space (edit-mode)
7
+ - bone.pose_rotation_quat = local rotation DELTA from rest (≡ bone.rotation_quaternion)
8
+ - bone.pose_location = local translation DELTA from rest (≡ bone.location)
9
+ - bone.pose_scale = local scale (≡ bone.scale)
10
+ - bone.matrix_armature = FK-computed 4×4 in armature space (≡ bone.matrix in pose mode)
11
+
12
+ Armature.world_matrix corresponds to arm.matrix_world.
13
+ """
14
+ from __future__ import annotations
15
+ import numpy as np
16
+ from typing import Dict, List, Optional, Tuple
17
+ from .math3d import (
18
+ quat_identity, quat_normalize, quat_mul,
19
+ quat_to_matrix4, matrix4_to_quat,
20
+ translation_matrix, scale_matrix, trs_to_matrix4, matrix4_to_trs,
21
+ vec3,
22
+ )
23
+
24
+
25
+ class PoseBone:
26
+ def __init__(
27
+ self,
28
+ name: str,
29
+ rest_matrix_local: np.ndarray, # 4×4, in parent local space
30
+ parent: Optional["PoseBone"] = None,
31
+ ):
32
+ self.name = name
33
+ self.parent: Optional[PoseBone] = parent
34
+ self.children: List[PoseBone] = []
35
+ self.rest_matrix_local: np.ndarray = rest_matrix_local.copy()
36
+
37
+ # Pose state — start at rest (delta = identity)
38
+ self.pose_rotation_quat: np.ndarray = quat_identity()
39
+ self.pose_location: np.ndarray = vec3()
40
+ self.pose_scale: np.ndarray = np.ones(3)
41
+
42
+ # Cached FK result — call armature.update_fk() to refresh
43
+ self._matrix_armature: np.ndarray = np.eye(4)
44
+
45
+ # -----------------------------------------------------------------------
46
+ # Properties
47
+ # -----------------------------------------------------------------------
48
+
49
+ @property
50
+ def matrix_armature(self) -> np.ndarray:
51
+ """4×4 FK result in armature space. Refresh with armature.update_fk()."""
52
+ return self._matrix_armature
53
+
54
+ @property
55
+ def head(self) -> np.ndarray:
56
+ """Bone head position in armature space."""
57
+ return self._matrix_armature[:3, 3].copy()
58
+
59
+ @property
60
+ def tail(self) -> np.ndarray:
61
+ """
62
+ Approximate tail position (Y-axis in bone space, length 1).
63
+ Works for Y-along-bone convention (Blender / BVH default).
64
+ """
65
+ y_axis = self._matrix_armature[:3, :3] @ np.array([0.0, 1.0, 0.0])
66
+ return self._matrix_armature[:3, 3] + y_axis
67
+
68
+ # -----------------------------------------------------------------------
69
+ # FK
70
+ # -----------------------------------------------------------------------
71
+
72
+ def _compute_local_matrix(self) -> np.ndarray:
73
+ """rest_local @ T(pose_loc) @ R(pose_rot) @ S(pose_scale)."""
74
+ T = translation_matrix(self.pose_location)
75
+ R = quat_to_matrix4(self.pose_rotation_quat)
76
+ S = scale_matrix(self.pose_scale)
77
+ return self.rest_matrix_local @ T @ R @ S
78
+
79
+ def _fk(self, parent_matrix: np.ndarray) -> None:
80
+ self._matrix_armature = parent_matrix @ self._compute_local_matrix()
81
+ for child in self.children:
82
+ child._fk(self._matrix_armature)
83
+
84
+ # -----------------------------------------------------------------------
85
+ # Parent-chain helpers (Blender: bone.parent_recursive)
86
+ # -----------------------------------------------------------------------
87
+
88
+ @property
89
+ def parent_recursive(self) -> List["PoseBone"]:
90
+ chain: List[PoseBone] = []
91
+ cur = self.parent
92
+ while cur is not None:
93
+ chain.append(cur)
94
+ cur = cur.parent
95
+ return chain
96
+
97
+
98
+ class Armature:
99
+ """
100
+ Collection of PoseBones with a world transform.
101
+ Corresponds to a Blender armature object.
102
+ """
103
+
104
+ def __init__(self, name: str = "Armature"):
105
+ self.name = name
106
+ self.world_matrix: np.ndarray = np.eye(4) # arm.matrix_world
107
+ self._bones: Dict[str, PoseBone] = {}
108
+ self._roots: List[PoseBone] = []
109
+
110
+ # -----------------------------------------------------------------------
111
+ # Construction helpers
112
+ # -----------------------------------------------------------------------
113
+
114
+ def add_bone(self, bone: PoseBone, parent_name: Optional[str] = None) -> PoseBone:
115
+ self._bones[bone.name] = bone
116
+ if parent_name and parent_name in self._bones:
117
+ parent = self._bones[parent_name]
118
+ bone.parent = parent
119
+ parent.children.append(bone)
120
+ elif bone.parent is None:
121
+ self._roots.append(bone)
122
+ return bone
123
+
124
+ @property
125
+ def pose_bones(self) -> Dict[str, PoseBone]:
126
+ return self._bones
127
+
128
+ def get_bone(self, name: str) -> PoseBone:
129
+ if name not in self._bones:
130
+ raise KeyError(f"Bone '{name}' not found in armature '{self.name}'")
131
+ return self._bones[name]
132
+
133
+ def has_bone(self, name: str) -> bool:
134
+ return name in self._bones
135
+
136
+ # -----------------------------------------------------------------------
137
+ # FK update
138
+ # -----------------------------------------------------------------------
139
+
140
+ def update_fk(self) -> None:
141
+ """Recompute all bone armature-space matrices via FK."""
142
+ for root in self._roots:
143
+ root._fk(np.eye(4))
144
+
145
+ # -----------------------------------------------------------------------
146
+ # Snapshot / restore (for calc-correction passes)
147
+ # -----------------------------------------------------------------------
148
+
149
+ def snapshot(self) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
150
+ return {
151
+ name: (
152
+ bone.pose_rotation_quat.copy(),
153
+ bone.pose_location.copy(),
154
+ bone.pose_scale.copy(),
155
+ )
156
+ for name, bone in self._bones.items()
157
+ }
158
+
159
+ def restore(self, snap: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]) -> None:
160
+ for name, (r, t, s) in snap.items():
161
+ if name in self._bones:
162
+ self._bones[name].pose_rotation_quat = r.copy()
163
+ self._bones[name].pose_location = t.copy()
164
+ self._bones[name].pose_scale = s.copy()
165
+ self.update_fk()
Retarget/smpl.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ smpl.py
3
+ ───────────────────────────────────────────────────────────────────────────────
4
+ Parse HumanML3D [T, 263] feature vectors into structured SMPL motion data.
5
+
6
+ HumanML3D 263-dim layout per frame
7
+ [0] root angular-velocity (Y-axis, rad/frame)
8
+ [1] root height Y (metres)
9
+ [2:4] root XZ velocity (local-frame, metres/frame)
10
+ [4:67] joint local positions joints 1-21 relative to root, 21×3 (unused here)
11
+ [67:193] 6D joint rotations joints 1-21, 21×6
12
+ [193:259] joint velocities joints 0-21, 22×3 (unused here)
13
+ [259:263] foot contact flags (unused here)
14
+
15
+ Root rotation = cumulative integral of dim[0] → Y-axis quaternion.
16
+ Root position = dim[1] (height) + integrated XZ velocity.
17
+ Joint 1-21 rot = dims 67:193 as 6D continuous rotation representation
18
+ [Zhou et al. 2019] → Gram-Schmidt → 3×3 rotation matrix → quaternion.
19
+ These are LOCAL rotations relative to the SMPL parent joint's rest
20
+ frame, where the canonical T-pose is the zero (identity) rotation.
21
+ """
22
+ from __future__ import annotations
23
+ import numpy as np
24
+
25
+
26
+ # ──────────────────────────────────────────────────────────────────────────────
27
+ # 6D rotation helpers
28
+ # ──────────────────────────────────────────────────────────────────────────────
29
+
30
+ def rot6d_to_matrix(r6d: np.ndarray) -> np.ndarray:
31
+ """
32
+ [..., 6] → [..., 3, 3]
33
+ Reconstructs a rotation matrix from two columns using Gram-Schmidt.
34
+ The two columns are [a1 = r6d[..., 0:3], a2 = r6d[..., 3:6]].
35
+ """
36
+ a1 = r6d[..., 0:3].astype(np.float64)
37
+ a2 = r6d[..., 3:6].astype(np.float64)
38
+ b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-12)
39
+ b2 = a2 - (b1 * a2).sum(axis=-1, keepdims=True) * b1
40
+ b2 = b2 / (np.linalg.norm(b2, axis=-1, keepdims=True) + 1e-12)
41
+ b3 = np.cross(b1, b2)
42
+ return np.stack([b1, b2, b3], axis=-1) # columns → [..., 3, 3]
43
+
44
+
45
+ def matrix_to_quat(mat: np.ndarray) -> np.ndarray:
46
+ """
47
+ [..., 3, 3] → [..., 4] WXYZ quaternion, positive-W convention.
48
+ Uses scipy for numerical stability.
49
+ """
50
+ from scipy.spatial.transform import Rotation
51
+ shape = mat.shape[:-2]
52
+ flat = mat.reshape(-1, 3, 3).astype(np.float64)
53
+ xyzw = Rotation.from_matrix(flat).as_quat() # scipy → XYZW
54
+ wxyz = xyzw[:, [3, 0, 1, 2]].astype(np.float32)
55
+ wxyz[wxyz[:, 0] < 0] *= -1 # positive-W
56
+ return wxyz.reshape(*shape, 4)
57
+
58
+
59
+ def rot6d_to_quat(r6d: np.ndarray) -> np.ndarray:
60
+ """[..., 6] → [..., 4] WXYZ. Convenience: 6D → matrix → quaternion."""
61
+ return matrix_to_quat(rot6d_to_matrix(r6d))
62
+
63
+
64
+ # ──────────────────────────────────────────────────────────────────────────────
65
+ # Root motion recovery
66
+ # ──────────────────────────────────────────────────────────────────────────────
67
+
68
+ def _qrot_vec(q: np.ndarray, v: np.ndarray) -> np.ndarray:
69
+ """Rotate [N, 3] vectors by [N, 4] WXYZ quaternions (batch)."""
70
+ w, x, y, z = q[:, 0:1], q[:, 1:2], q[:, 2:3], q[:, 3:4]
71
+ vx, vy, vz = v[:, 0:1], v[:, 1:2], v[:, 2:3]
72
+ # Rodrigues-style: v + 2w*(q.xyz × v) + 2*(q.xyz × (q.xyz × v))
73
+ tx = 2 * (y * vz - z * vy)
74
+ ty = 2 * (z * vx - x * vz)
75
+ tz = 2 * (x * vy - y * vx)
76
+ return np.concatenate([
77
+ vx + w * tx + y * tz - z * ty,
78
+ vy + w * ty + z * tx - x * tz,
79
+ vz + w * tz + x * ty - y * tx,
80
+ ], axis=-1)
81
+
82
+
83
+ def recover_root_motion(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
84
+ """
85
+ Recover root world-space position and rotation from [T, 263] features.
86
+
87
+ Returns
88
+ -------
89
+ root_pos : [T, 3] world-space root position (Y = height above ground)
90
+ root_rot : [T, 4] WXYZ quaternion — Y-axis only (global facing direction)
91
+ """
92
+ T = data.shape[0]
93
+
94
+ # Facing direction: integrate Y-axis angular velocity
95
+ theta = np.cumsum(data[:, 0].astype(np.float32))
96
+ half = theta * 0.5
97
+ root_rot = np.zeros((T, 4), dtype=np.float32)
98
+ root_rot[:, 0] = np.cos(half)
99
+ root_rot[:, 2] = np.sin(half)
100
+
101
+ # XZ velocity encoded in root-local frame → world frame
102
+ vel_local = np.stack([
103
+ data[:, 2].astype(np.float32),
104
+ np.zeros(T, dtype=np.float32),
105
+ data[:, 3].astype(np.float32),
106
+ ], axis=-1)
107
+ vel_world = _qrot_vec(root_rot, vel_local)
108
+
109
+ root_pos = np.zeros((T, 3), dtype=np.float32)
110
+ root_pos[:, 0] = np.cumsum(vel_world[:, 0])
111
+ root_pos[:, 1] = data[:, 1]
112
+ root_pos[:, 2] = np.cumsum(vel_world[:, 2])
113
+
114
+ return root_pos, root_rot
115
+
116
+
117
+ # ──────────────────────────────────────────────────────────────────────────────
118
+ # SMPLMotion container
119
+ # ──────────────────────────────────────────────────────────────────────────────
120
+
121
+ class SMPLMotion:
122
+ """
123
+ Structured SMPL motion data parsed from a single HumanML3D clip.
124
+
125
+ Attributes
126
+ ----------
127
+ root_pos : [T, 3] world-space root position (metres)
128
+ root_rot : [T, 4] WXYZ root Y-axis rotation (global facing)
129
+ local_rot : [T, 21, 4] WXYZ local quaternions for joints 1-21
130
+ T-pose = identity; relative to SMPL parent frame
131
+ fps : float capture frame rate (20 for HumanML3D)
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ root_pos: np.ndarray,
137
+ root_rot: np.ndarray,
138
+ local_rot: np.ndarray,
139
+ fps: float = 20.0,
140
+ ):
141
+ self.root_pos = np.asarray(root_pos, dtype=np.float32)
142
+ self.root_rot = np.asarray(root_rot, dtype=np.float32)
143
+ self.local_rot = np.asarray(local_rot, dtype=np.float32)
144
+ self.fps = float(fps)
145
+
146
+ @property
147
+ def num_frames(self) -> int:
148
+ return self.root_pos.shape[0]
149
+
150
+ def slice(self, start: int = 0, end: int = -1) -> "SMPLMotion":
151
+ e = end if end > 0 else self.num_frames
152
+ return SMPLMotion(
153
+ self.root_pos[start:e],
154
+ self.root_rot[start:e],
155
+ self.local_rot[start:e],
156
+ self.fps,
157
+ )
158
+
159
+
160
+ def hml3d_to_smpl_motion(data: np.ndarray, fps: float = 20.0) -> SMPLMotion:
161
+ """
162
+ Convert HumanML3D [T, 263] feature array to a SMPLMotion.
163
+
164
+ Uses the actual 6D rotation data (dims 67:193) — NOT position-derived
165
+ rotations. This preserves twist and gives physically correct limb poses.
166
+
167
+ Parameters
168
+ ----------
169
+ data : [T, 263] raw HumanML3D features (e.g. from MoMask or dataset row)
170
+ fps : float frame rate (default 20 = HumanML3D native)
171
+ """
172
+ data = np.asarray(data, dtype=np.float32)
173
+ if data.ndim != 2 or data.shape[1] < 193:
174
+ raise ValueError(f"Expected [T, >=193] but got {data.shape}")
175
+
176
+ T = data.shape[0]
177
+
178
+ root_pos, root_rot = recover_root_motion(data)
179
+
180
+ # 6D rotations for joints 1-21: dims [67:193] → [T, 21, 6]
181
+ r6d = data[:, 67:193].reshape(T, 21, 6)
182
+ local_rot = rot6d_to_quat(r6d) # [T, 21, 4] WXYZ
183
+
184
+ return SMPLMotion(root_pos, root_rot, local_rot, fps)
app.py CHANGED
@@ -6,6 +6,7 @@ import shutil
6
  import traceback
7
  import json
8
  import random
 
9
  from pathlib import Path
10
 
11
  # ── ZeroGPU: install packages that can't be built at Docker build time ─────────
@@ -130,8 +131,13 @@ _triposg_pipe = None
130
  _rmbg_net = None
131
  _rmbg_version = None
132
  _last_glb_path = None
 
 
 
133
  _init_seed = random.randint(0, 2**31 - 1)
134
 
 
 
135
  ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
136
  [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
137
  * (256 / 112) + (256 - 112 * (256 / 112)) / 2)
@@ -167,6 +173,104 @@ def _ensure_ckpts():
167
 
168
  # ── Model loaders ─────────────────────────────────────────────────────────────
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def load_triposg():
171
  global _triposg_pipe, _rmbg_net, _rmbg_version
172
  if _triposg_pipe is not None:
@@ -309,6 +413,197 @@ def load_triposg():
309
  return _triposg_pipe, _rmbg_net
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  # ── Background removal helper ─────────────────────────────────────────────────
313
 
314
  def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
@@ -343,8 +638,8 @@ def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
343
 
344
  rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
345
  alpha = mask[:, :, np.newaxis]
346
- comp = (rgb * alpha + 0.5 * (1.0 - alpha) * 255).clip(0, 255).astype(np.uint8)
347
- return Image.fromarray(comp)
348
 
349
 
350
  def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
@@ -357,6 +652,188 @@ def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
357
  return input_image
358
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  # ── Stage 1: Shape generation ─────────────────────────────────────────────────
361
 
362
  @spaces.GPU(duration=180)
@@ -365,6 +842,28 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
365
  if input_image is None:
366
  return None, "Please upload an image."
367
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  progress(0.1, desc="Loading TripoSG...")
369
  pipe, rmbg_net = load_triposg()
370
 
@@ -373,16 +872,17 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
373
  img.save(img_path)
374
 
375
  progress(0.5, desc="Generating shape (SDF diffusion)...")
376
- from scripts.inference_triposg import run_triposg
377
- mesh = run_triposg(
378
- pipe=pipe,
379
- image_input=img_path,
380
- rmbg_net=rmbg_net if remove_background else None,
381
- seed=int(seed),
382
- num_inference_steps=int(num_steps),
383
- guidance_scale=float(guidance_scale),
384
- faces=int(face_count) if int(face_count) > 0 else -1,
385
- )
 
386
 
387
  out_path = "/tmp/triposg_shape.glb"
388
  mesh.export(out_path)
@@ -609,7 +1109,7 @@ def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames,
609
  animated = mdm_result.get("animated_glb")
610
 
611
  parts = ["Rigged: " + os.path.basename(rigged)]
612
- if fbx: parts.append("FBX: " + os.path.basename(fbx))
613
  if animated: parts.append("Animation: " + os.path.basename(animated))
614
 
615
  torch.cuda.empty_cache()
@@ -633,6 +1133,7 @@ def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength,
633
  from pipeline.enhance_surface import (
634
  run_stable_normal, run_depth_anything,
635
  bake_normal_into_glb, bake_depth_as_occlusion,
 
636
  )
637
  import pipeline.enhance_surface as _enh_mod
638
 
@@ -704,18 +1205,271 @@ def render_views(glb_file):
704
  return []
705
 
706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  # ── Full pipeline ─────────────────────────────────────────────────────────────
708
 
709
- def run_full_pipeline(input_image, num_steps, guidance, seed, face_count,
710
- variant, tex_seed, enhance_face,
711
  export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
712
  progress(0.0, desc="Stage 1/3: Generating shape...")
713
- glb, status = generate_shape(input_image, True, num_steps, guidance, seed, face_count)
714
  if not glb:
715
  return None, None, None, None, None, None, status
716
 
717
  progress(0.33, desc="Stage 2/3: Applying texture...")
718
- glb, mv_img, status = apply_texture(glb, input_image, True, variant, tex_seed, enhance_face)
 
719
  if not glb:
720
  return None, None, None, None, None, None, status
721
 
@@ -727,17 +1481,61 @@ def run_full_pipeline(input_image, num_steps, guidance, seed, face_count,
727
 
728
 
729
  # ── UI ────────────────────────────────────────────────────────────────────────
730
- with gr.Blocks(title="Image2Model") as demo:
731
  gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh")
732
- glb_state = gr.State(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
 
734
- with gr.Tabs():
 
 
 
735
 
736
  # ════════════════════════════════════════════════════════════════════
737
- with gr.Tab("Generate"):
738
  with gr.Row():
739
  with gr.Column(scale=1):
740
- input_image = gr.Image(label="Input Image", type="numpy")
 
 
 
 
 
 
741
 
742
  with gr.Accordion("Shape Settings", open=True):
743
  num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
@@ -756,9 +1554,11 @@ with gr.Blocks(title="Image2Model") as demo:
756
  shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
757
  texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
758
  render_btn = gr.Button("Render Views", variant="secondary", scale=1)
759
- run_all_btn = gr.Button("▶ Run Full Pipeline", variant="primary", interactive=False)
760
 
761
  with gr.Column(scale=1):
 
 
762
  status = gr.Textbox(label="Status", lines=3, interactive=False)
763
  model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
764
  download_file = gr.File(label="Download GLB")
@@ -766,6 +1566,8 @@ with gr.Blocks(title="Image2Model") as demo:
766
 
767
  render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
768
 
 
 
769
  _pipeline_btns = [shape_btn, run_all_btn]
770
 
771
  input_image.upload(
@@ -777,9 +1579,14 @@ with gr.Blocks(title="Image2Model") as demo:
777
  inputs=[], outputs=_pipeline_btns,
778
  )
779
 
 
 
 
 
 
780
  shape_btn.click(
781
- fn=lambda img, ns, gs, sd, fc: generate_shape(img, True, ns, gs, sd, fc),
782
- inputs=[input_image, num_steps, guidance, seed, face_count],
783
  outputs=[glb_state, status],
784
  ).then(
785
  fn=lambda p: (p, p) if p else (None, None),
@@ -787,8 +1594,9 @@ with gr.Blocks(title="Image2Model") as demo:
787
  )
788
 
789
  texture_btn.click(
790
- fn=lambda glb, img, v, ts, ef: apply_texture(glb, img, True, v, ts, ef),
791
- inputs=[glb_state, input_image, variant, tex_seed, enhance_face_check],
 
792
  outputs=[glb_state, multiview_img, status],
793
  ).then(
794
  fn=lambda p: (p, p) if p else (None, None),
@@ -797,6 +1605,29 @@ with gr.Blocks(title="Image2Model") as demo:
797
 
798
  render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
  # ════════════════════════════════════════════════════════════════════
801
  with gr.Tab("Rig & Export"):
802
  with gr.Row():
@@ -844,6 +1675,9 @@ with gr.Blocks(title="Image2Model") as demo:
844
  inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
845
  outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
846
  rig_model_3d, rigged_base_state, skel_glb_state],
 
 
 
847
  )
848
 
849
  show_skel_check.change(
@@ -852,6 +1686,103 @@ with gr.Blocks(title="Image2Model") as demo:
852
  outputs=[rig_model_3d],
853
  )
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  # ════════════════════════════════════════════════════════════════════
856
  with gr.Tab("Enhancement"):
857
  gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.")
@@ -868,6 +1799,7 @@ with gr.Blocks(title="Image2Model") as demo:
868
  displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
869
 
870
  enhance_btn = gr.Button("Run Enhancement", variant="primary")
 
871
 
872
  with gr.Column(scale=2):
873
  enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
@@ -886,12 +1818,111 @@ with gr.Blocks(title="Image2Model") as demo:
886
  enhanced_glb_dl, enhanced_model_3d, enhance_status],
887
  )
888
 
889
- # ── Run All wiring ────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  run_all_btn.click(
891
  fn=run_full_pipeline,
892
  inputs=[
893
- input_image, num_steps, guidance, seed, face_count,
894
- variant, tex_seed, enhance_face_check,
895
  export_fbx_check, mdm_prompt_box, mdm_frames_slider,
896
  ],
897
  outputs=[glb_state, download_file, multiview_img,
@@ -901,6 +1932,23 @@ with gr.Blocks(title="Image2Model") as demo:
901
  inputs=[glb_state], outputs=[model_3d, download_file],
902
  )
903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
  if __name__ == "__main__":
906
- demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())
 
 
6
  import traceback
7
  import json
8
  import random
9
+ import threading
10
  from pathlib import Path
11
 
12
  # ── ZeroGPU: install packages that can't be built at Docker build time ─────────
 
131
  _rmbg_net = None
132
  _rmbg_version = None
133
  _last_glb_path = None
134
+ _hyperswap_sess = None
135
+ _gfpgan_restorer = None
136
+ _firered_pipe = None
137
  _init_seed = random.randint(0, 2**31 - 1)
138
 
139
+ _model_load_lock = threading.Lock()
140
+
141
  ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
142
  [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
143
  * (256 / 112) + (256 - 112 * (256 / 112)) / 2)
 
173
 
174
  # ── Model loaders ─────────────────────────────────────────────────────────────
175
 
176
+ def _load_rmbg():
177
+ """Load RMBG-2.0 from 1038lab mirror."""
178
+ global _rmbg_net, _rmbg_version
179
+ if _rmbg_net is not None:
180
+ return
181
+ try:
182
+ from transformers import AutoModelForImageSegmentation
183
+ from torch.overrides import TorchFunctionMode
184
+
185
+ class _NoMetaMode(TorchFunctionMode):
186
+ """Intercept device='meta' tensor construction and redirect to CPU.
187
+
188
+ init_empty_weights() inside from_pretrained pushes a meta DeviceContext
189
+ ON TOP of any torch.device("cpu") wrapper, so meta wins. This mode is
190
+ pushed BELOW it; when meta DeviceContext adds device='meta' and chains
191
+ down the stack, we see it here and flip it back to 'cpu'.
192
+ """
193
+ def __torch_function__(self, func, types, args=(), kwargs=None):
194
+ if kwargs is None:
195
+ kwargs = {}
196
+ dev = kwargs.get("device")
197
+ if dev is not None:
198
+ dev_str = dev.type if isinstance(dev, torch.device) else str(dev)
199
+ if dev_str == "meta":
200
+ kwargs["device"] = "cpu"
201
+ return func(*args, **kwargs)
202
+
203
+ # transformers 5.x _finalize_model_loading calls mark_tied_weights_as_initialized
204
+ # which accesses all_tied_weights_keys. BiRefNetConfig inherits from the old
205
+ # PretrainedConfig alias which skips the new PreTrainedModel.__init__ section
206
+ # that sets this attribute. Patch the method to be safe.
207
+ from transformers import PreTrainedModel as _PTM
208
+ _orig_mark_tied = _PTM.mark_tied_weights_as_initialized
209
+ def _safe_mark_tied(self, loading_info):
210
+ if not hasattr(self, "all_tied_weights_keys"):
211
+ self.all_tied_weights_keys = {}
212
+ return _orig_mark_tied(self, loading_info)
213
+ _PTM.mark_tied_weights_as_initialized = _safe_mark_tied
214
+ try:
215
+ with _NoMetaMode():
216
+ _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
217
+ "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False,
218
+ )
219
+ finally:
220
+ _PTM.mark_tied_weights_as_initialized = _orig_mark_tied
221
+ _rmbg_net.to(DEVICE).eval()
222
+ _rmbg_version = "2.0"
223
+ print("RMBG-2.0 loaded.")
224
+ except Exception as e:
225
+ _rmbg_net = None
226
+ _rmbg_version = None
227
+ print(f"RMBG-2.0 failed: {e} — background removal disabled.")
228
+
229
+
230
+ def load_rmbg_only():
231
+ """Load RMBG standalone without loading TripoSG."""
232
+ _load_rmbg()
233
+ return _rmbg_net
234
+
235
+
236
+ def load_gfpgan():
237
+ global _gfpgan_restorer
238
+ if _gfpgan_restorer is not None:
239
+ return _gfpgan_restorer
240
+ try:
241
+ from gfpgan import GFPGANer
242
+ from basicsr.archs.rrdbnet_arch import RRDBNet
243
+ from realesrgan import RealESRGANer
244
+
245
+ model_path = str(CKPT_DIR / "GFPGANv1.4.pth")
246
+ if not os.path.exists(model_path):
247
+ print(f"[GFPGAN] Not found at {model_path}")
248
+ return None
249
+
250
+ realesrgan_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
251
+ bg_upsampler = None
252
+ if os.path.exists(realesrgan_path):
253
+ bg_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
254
+ num_block=23, num_grow_ch=32, scale=2)
255
+ bg_upsampler = RealESRGANer(
256
+ scale=2, model_path=realesrgan_path, model=bg_model,
257
+ tile=400, tile_pad=10, pre_pad=0, half=True,
258
+ )
259
+ print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded")
260
+ else:
261
+ print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler")
262
+
263
+ _gfpgan_restorer = GFPGANer(
264
+ model_path=model_path, upscale=2, arch="clean",
265
+ channel_multiplier=2, bg_upsampler=bg_upsampler,
266
+ )
267
+ print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)")
268
+ return _gfpgan_restorer
269
+ except Exception as e:
270
+ print(f"[GFPGAN] Load failed: {e}")
271
+ return None
272
+
273
+
274
  def load_triposg():
275
  global _triposg_pipe, _rmbg_net, _rmbg_version
276
  if _triposg_pipe is not None:
 
413
  return _triposg_pipe, _rmbg_net
414
 
415
 
416
+ def load_firered():
417
+ """Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer.
418
+
419
+ Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk).
420
+ Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model).
421
+ Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails.
422
+
423
+ text_encoder: 4-bit NF4 on GPU (~5.6 GB).
424
+ GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget.
425
+ Lightning scheduler: 4 steps, CFG 1.0 → ~1-2 min per inference.
426
+
427
+ GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE ≈ 24 GB.
428
+ """
429
+ global _firered_pipe
430
+ if _firered_pipe is not None:
431
+ return _firered_pipe
432
+
433
+ import math as _math
434
+ from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
435
+ from diffusers.models import QwenImageTransformer2DModel
436
+ from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
437
+ from accelerate import dispatch_model, infer_auto_device_map
438
+ from huggingface_hub import hf_hub_download
439
+
440
+ # Patch SDPA to cast K/V to match Q dtype.
441
+ import torch.nn.functional as _F
442
+ _orig_sdpa = _F.scaled_dot_product_attention
443
+ def _dtype_safe_sdpa(query, key, value, *a, **kw):
444
+ if key.dtype != query.dtype: key = key.to(query.dtype)
445
+ if value.dtype != query.dtype: value = value.to(query.dtype)
446
+ return _orig_sdpa(query, key, value, *a, **kw)
447
+ _F.scaled_dot_product_attention = _dtype_safe_sdpa
448
+
449
+ torch.cuda.empty_cache()
450
+
451
+ # Load RMBG NOW — before dispatch_model creates meta tensors that poison later loads
452
+ _load_rmbg()
453
+
454
+ gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16)
455
+
456
+ # ── Transformer: GGUF Q4_K_M — try fine-tuned Rapid-AIO first, fall back to base ──
457
+ transformer = None
458
+
459
+ # Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB)
460
+ try:
461
+ print("[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)...")
462
+ gguf_path = hf_hub_download(
463
+ repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF",
464
+ filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf",
465
+ )
466
+ print("[FireRed] Loading Rapid-AIO transformer from GGUF...")
467
+ transformer = QwenImageTransformer2DModel.from_single_file(
468
+ gguf_path,
469
+ quantization_config=gguf_config,
470
+ torch_dtype=torch.bfloat16,
471
+ config="Qwen/Qwen-Image-Edit-2511",
472
+ subfolder="transformer",
473
+ )
474
+ print("[FireRed] Rapid-AIO GGUF transformer loaded OK.")
475
+ except Exception as e:
476
+ print(f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF...")
477
+ transformer = None
478
+
479
+ # Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB)
480
+ if transformer is None:
481
+ print("[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)...")
482
+ gguf_path = hf_hub_download(
483
+ repo_id="unsloth/Qwen-Image-Edit-2511-GGUF",
484
+ filename="qwen-image-edit-2511-Q4_K_M.gguf",
485
+ )
486
+ print("[FireRed] Loading base transformer from GGUF...")
487
+ transformer = QwenImageTransformer2DModel.from_single_file(
488
+ gguf_path,
489
+ quantization_config=gguf_config,
490
+ torch_dtype=torch.bfloat16,
491
+ config="Qwen/Qwen-Image-Edit-2511",
492
+ subfolder="transformer",
493
+ )
494
+ print("[FireRed] Base GGUF transformer loaded OK.")
495
+
496
+ print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...")
497
+ device_map = infer_auto_device_map(
498
+ transformer,
499
+ max_memory={0: "18GiB", "cpu": "90GiB"},
500
+ dtype=torch.bfloat16,
501
+ )
502
+ n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0"))
503
+ n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu")
504
+ print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU")
505
+ transformer = dispatch_model(transformer, device_map=device_map)
506
+ used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
507
+ print(f"[FireRed] Transformer dispatched — VRAM: {used_mb} MB")
508
+
509
+ # ── text_encoder: 4-bit NF4 on GPU (~5.6 GB) ──────────────────────────────
510
+ bnb_enc = BitsAndBytesConfig(
511
+ load_in_4bit=True,
512
+ bnb_4bit_quant_type="nf4",
513
+ bnb_4bit_compute_dtype=torch.bfloat16,
514
+ bnb_4bit_use_double_quant=True,
515
+ )
516
+ print("[FireRed] Loading text_encoder (4-bit NF4)...")
517
+ text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
518
+ "Qwen/Qwen-Image-Edit-2511",
519
+ subfolder="text_encoder",
520
+ quantization_config=bnb_enc,
521
+ device_map="auto",
522
+ )
523
+ used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
524
+ print(f"[FireRed] Text encoder loaded — VRAM: {used_mb} MB")
525
+
526
+ # ── Pipeline: VAE + scheduler + processor + tokenizer ─────────────────────
527
+ print("[FireRed] Loading pipeline...")
528
+ _firered_pipe = QwenImageEditPlusPipeline.from_pretrained(
529
+ "Qwen/Qwen-Image-Edit-2511",
530
+ transformer=transformer,
531
+ text_encoder=text_encoder,
532
+ torch_dtype=torch.bfloat16,
533
+ )
534
+ _firered_pipe.vae.to(DEVICE)
535
+
536
+ # Lightning scheduler — 4 steps, use_dynamic_shifting, matches reference space config
537
+ _firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config({
538
+ "base_image_seq_len": 256,
539
+ "base_shift": _math.log(3),
540
+ "max_image_seq_len": 8192,
541
+ "max_shift": _math.log(3),
542
+ "num_train_timesteps": 1000,
543
+ "shift": 1.0,
544
+ "time_shift_type": "exponential",
545
+ "use_dynamic_shifting": True,
546
+ })
547
+
548
+ used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
549
+ print(f"[FireRed] Pipeline ready — total VRAM: {used_mb} MB")
550
+ return _firered_pipe
551
+
552
+
553
+ def _gallery_to_pil_list(gallery_value):
554
+ """Convert a Gradio Gallery value (list of various formats) to a list of PIL Images."""
555
+ pil_images = []
556
+ if not gallery_value:
557
+ return pil_images
558
+ for item in gallery_value:
559
+ try:
560
+ if isinstance(item, np.ndarray):
561
+ pil_images.append(Image.fromarray(item).convert("RGB"))
562
+ continue
563
+ if isinstance(item, Image.Image):
564
+ pil_images.append(item.convert("RGB"))
565
+ continue
566
+ # Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...}
567
+ if isinstance(item, dict):
568
+ img_data = item.get("image") or item
569
+ if isinstance(img_data, dict):
570
+ path = img_data.get("path") or img_data.get("url") or img_data.get("name")
571
+ else:
572
+ path = img_data
573
+ elif isinstance(item, (list, tuple)):
574
+ path = item[0]
575
+ else:
576
+ path = item
577
+ if path and os.path.exists(str(path)):
578
+ pil_images.append(Image.open(str(path)).convert("RGB"))
579
+ except Exception as e:
580
+ print(f"[FireRed] Could not load gallery image: {e}")
581
+ return pil_images
582
+
583
+
584
+ def _firered_resize(img):
585
+ """Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8."""
586
+ w, h = img.size
587
+ if max(w, h) > 1024:
588
+ if w > h:
589
+ nw, nh = 1024, int(1024 * h / w)
590
+ else:
591
+ nw, nh = int(1024 * w / h), 1024
592
+ else:
593
+ nw, nh = w, h
594
+ nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8)
595
+ if (nw, nh) != (w, h):
596
+ img = img.resize((nw, nh), Image.LANCZOS)
597
+ return img
598
+
599
+
600
+ _FIRERED_NEGATIVE = (
601
+ "worst quality, low quality, bad anatomy, bad hands, text, error, "
602
+ "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
603
+ "signature, watermark, username, blurry"
604
+ )
605
+
606
+
607
  # ── Background removal helper ─────────────────────────────────────────────────
608
 
609
  def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
 
638
 
639
  rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
640
  alpha = mask[:, :, np.newaxis]
641
+ comp = (rgb * alpha + 0.5 * (1.0 - alpha)) * 255
642
+ return Image.fromarray(comp.clip(0, 255).astype(np.uint8))
643
 
644
 
645
  def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
 
652
  return input_image
653
 
654
 
655
+ # ── RealESRGAN helpers ─────────────────────────────────────────────────────────
656
+
657
+ def _load_realesrgan(scale: int = 4):
658
+ """Load RealESRGAN upsampler. Returns RealESRGANer or None."""
659
+ try:
660
+ from basicsr.archs.rrdbnet_arch import RRDBNet
661
+ from realesrgan import RealESRGANer
662
+ if scale == 4:
663
+ model_path = str(CKPT_DIR / "RealESRGAN_x4plus.pth")
664
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
665
+ else:
666
+ model_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
667
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
668
+ if not os.path.exists(model_path):
669
+ print(f"[RealESRGAN] {model_path} not found")
670
+ return None
671
+ upsampler = RealESRGANer(
672
+ scale=scale, model_path=model_path, model=model,
673
+ tile=512, tile_pad=32, pre_pad=0, half=True,
674
+ )
675
+ print(f"[RealESRGAN] Loaded x{scale}plus")
676
+ return upsampler
677
+ except Exception as e:
678
+ print(f"[RealESRGAN] Load failed: {e}")
679
+ return None
680
+
681
+
682
+ def _enhance_glb_texture(glb_path: str) -> bool:
683
+ """
684
+ Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4,
685
+ downscale back to original resolution (sharper detail), then repack in-place.
686
+ Returns True if enhancement was applied.
687
+ """
688
+ import pygltflib
689
+
690
+ upsampler = _load_realesrgan(scale=4)
691
+ if upsampler is None:
692
+ upsampler = _load_realesrgan(scale=2)
693
+ if upsampler is None:
694
+ print("[enhance_glb] No RealESRGAN checkpoint available")
695
+ return False
696
+
697
+ glb = pygltflib.GLTF2().load(glb_path)
698
+ blob = bytearray(glb.binary_blob() or b"")
699
+
700
+ for mat in glb.materials:
701
+ bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None)
702
+ if bct is None:
703
+ continue
704
+ tex = glb.textures[bct.index]
705
+ if tex.source is None:
706
+ continue
707
+ img_obj = glb.images[tex.source]
708
+ if img_obj.bufferView is None:
709
+ continue
710
+ bv = glb.bufferViews[img_obj.bufferView]
711
+ offset, length = bv.byteOffset or 0, bv.byteLength
712
+
713
+ img_arr = np.frombuffer(blob[offset:offset + length], dtype=np.uint8)
714
+ atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
715
+ if atlas_bgr is None:
716
+ continue
717
+
718
+ orig_h, orig_w = atlas_bgr.shape[:2]
719
+ print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGAN…")
720
+
721
+ try:
722
+ upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4)
723
+ except Exception as e:
724
+ print(f"[enhance_glb] RealESRGAN enhance failed: {e}")
725
+ continue
726
+
727
+ restored = cv2.resize(upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
728
+
729
+ ok, new_bytes = cv2.imencode(".png", restored)
730
+ if not ok:
731
+ continue
732
+ new_bytes = new_bytes.tobytes()
733
+ new_len = len(new_bytes)
734
+
735
+ if new_len > length:
736
+ before = bytes(blob[:offset])
737
+ after = bytes(blob[offset + length:])
738
+ blob = bytearray(before + new_bytes + after)
739
+ delta = new_len - length
740
+ bv.byteLength = new_len
741
+ for other_bv in glb.bufferViews:
742
+ if (other_bv.byteOffset or 0) > offset:
743
+ other_bv.byteOffset += delta
744
+ glb.buffers[0].byteLength += delta
745
+ else:
746
+ blob[offset:offset + new_len] = new_bytes
747
+ bv.byteLength = new_len
748
+
749
+ glb.set_binary_blob(bytes(blob))
750
+ glb.save(glb_path)
751
+ print(f"[enhance_glb] GLB texture enhanced OK (was {length}B → {new_len}B)")
752
+ return True
753
+
754
+ print("[enhance_glb] No base-color texture found in GLB")
755
+ return False
756
+
757
+
758
+ # ── FireRed GPU functions ──────────────────────────────────────────────────────
759
+
760
+ @spaces.GPU(duration=600)
761
+ def firered_generate(gallery_images, prompt, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
762
+ """Run FireRed image-edit inference on one or more reference images (max 3 natively)."""
763
+ pil_images = _gallery_to_pil_list(gallery_images)
764
+ if not pil_images:
765
+ return None, int(seed), "Please upload at least one image."
766
+ if not prompt or not prompt.strip():
767
+ return None, int(seed), "Please enter an edit prompt."
768
+ try:
769
+ import gc
770
+ progress(0.05, desc="Loading FireRed pipeline...")
771
+ pipe = load_firered()
772
+
773
+ if randomize_seed:
774
+ seed = random.randint(0, 2**31 - 1)
775
+
776
+ # FireRed natively handles 1-3 images; cap silently and warn
777
+ if len(pil_images) > 3:
778
+ print(f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit).")
779
+ pil_images = pil_images[:3]
780
+
781
+ # Resize to max 1024px and align to multiple of 8 (prevents padding bars)
782
+ pil_images = [_firered_resize(img) for img in pil_images]
783
+ height, width = pil_images[0].height, pil_images[0].width
784
+ print(f"[FireRed] Input size after resize: {width}x{height}")
785
+
786
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
787
+
788
+ progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...")
789
+ with torch.inference_mode():
790
+ result = pipe(
791
+ image=pil_images,
792
+ prompt=prompt.strip(),
793
+ negative_prompt=_FIRERED_NEGATIVE,
794
+ num_inference_steps=int(steps),
795
+ generator=generator,
796
+ true_cfg_scale=float(guidance_scale),
797
+ num_images_per_prompt=1,
798
+ height=height,
799
+ width=width,
800
+ ).images[0]
801
+
802
+ gc.collect()
803
+ torch.cuda.empty_cache()
804
+ progress(1.0, desc="Done!")
805
+ n = len(pil_images)
806
+ note = " (truncated to 3)" if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3 else ""
807
+ return np.array(result), int(seed), f"Preview ready — {n} image(s) used{note}."
808
+ except Exception:
809
+ return None, int(seed), f"FireRed error:\n{traceback.format_exc()}"
810
+
811
+
812
+ @spaces.GPU(duration=60)
813
+ def firered_load_into_pipeline(firered_output, threshold, erode_px, progress=gr.Progress()):
814
+ """Load a FireRed output into the main pipeline with automatic background removal."""
815
+ if firered_output is None:
816
+ return None, None, "No FireRed output — generate an image first."
817
+ try:
818
+ progress(0.1, desc="Loading RMBG model...")
819
+ load_rmbg_only()
820
+
821
+ img = Image.fromarray(firered_output).convert("RGB")
822
+ if _rmbg_net is not None:
823
+ progress(0.5, desc="Removing background...")
824
+ composited = _remove_bg_rmbg(img, threshold=float(threshold), erode_px=int(erode_px))
825
+ result = np.array(composited)
826
+ msg = "Loaded into pipeline — background removed."
827
+ else:
828
+ result = firered_output
829
+ msg = "Loaded into pipeline (RMBG unavailable — background not removed)."
830
+
831
+ progress(1.0, desc="Done!")
832
+ return result, result, msg
833
+ except Exception:
834
+ return None, None, f"Error:\n{traceback.format_exc()}"
835
+
836
+
837
  # ── Stage 1: Shape generation ─────────────────────────────────────────────────
838
 
839
  @spaces.GPU(duration=180)
 
842
  if input_image is None:
843
  return None, "Please upload an image."
844
  try:
845
+ progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...")
846
+ global _firered_pipe
847
+ if _firered_pipe is not None:
848
+ # dispatch_model attaches accelerate hooks — remove them before .to("cpu")
849
+ try:
850
+ from accelerate.hooks import remove_hook_from_submodules
851
+ remove_hook_from_submodules(_firered_pipe.transformer)
852
+ _firered_pipe.transformer.to("cpu")
853
+ except Exception as _e:
854
+ print(f"[TripoSG] Transformer CPU offload: {_e}")
855
+ try:
856
+ _firered_pipe.text_encoder.to("cpu")
857
+ except Exception:
858
+ pass
859
+ try:
860
+ _firered_pipe.vae.to("cpu")
861
+ except Exception:
862
+ pass
863
+ _firered_pipe = None
864
+ torch.cuda.empty_cache()
865
+ print("[TripoSG] FireRed offloaded — VRAM freed for shape generation.")
866
+
867
  progress(0.1, desc="Loading TripoSG...")
868
  pipe, rmbg_net = load_triposg()
869
 
 
872
  img.save(img_path)
873
 
874
  progress(0.5, desc="Generating shape (SDF diffusion)...")
875
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
876
+ from scripts.inference_triposg import run_triposg
877
+ mesh = run_triposg(
878
+ pipe=pipe,
879
+ image_input=img_path,
880
+ rmbg_net=rmbg_net if remove_background else None,
881
+ seed=int(seed),
882
+ num_inference_steps=int(num_steps),
883
+ guidance_scale=float(guidance_scale),
884
+ faces=int(face_count) if int(face_count) > 0 else -1,
885
+ )
886
 
887
  out_path = "/tmp/triposg_shape.glb"
888
  mesh.export(out_path)
 
1109
  animated = mdm_result.get("animated_glb")
1110
 
1111
  parts = ["Rigged: " + os.path.basename(rigged)]
1112
+ if fbx: parts.append("FBX: " + os.path.basename(fbx))
1113
  if animated: parts.append("Animation: " + os.path.basename(animated))
1114
 
1115
  torch.cuda.empty_cache()
 
1133
  from pipeline.enhance_surface import (
1134
  run_stable_normal, run_depth_anything,
1135
  bake_normal_into_glb, bake_depth_as_occlusion,
1136
+ unload_models,
1137
  )
1138
  import pipeline.enhance_surface as _enh_mod
1139
 
 
1205
  return []
1206
 
1207
 
1208
+ # ── HyperSwap views ───────────────────────────────────────────────────────────
1209
+
1210
+ @spaces.GPU(duration=120)
1211
+ def hyperswap_views(embedding_json: str):
1212
+ """
1213
+ Stage 6 — run HyperSwap on the last rendered views.
1214
+ embedding_json: JSON string of the 512-d ArcFace embedding list.
1215
+ Returns a gallery of (swapped_image_path, view_name) tuples.
1216
+ """
1217
+ global _hyperswap_sess
1218
+ try:
1219
+ import onnxruntime as ort
1220
+ from insightface.app import FaceAnalysis
1221
+
1222
+ embedding = np.array(json.loads(embedding_json), dtype=np.float32)
1223
+ embedding /= np.linalg.norm(embedding)
1224
+
1225
+ # Load HyperSwap once
1226
+ if _hyperswap_sess is None:
1227
+ hs_path = str(CKPT_DIR / "hyperswap_1a_256.onnx")
1228
+ _hyperswap_sess = ort.InferenceSession(hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
1229
+ print(f"[hyperswap_views] Loaded {hs_path}")
1230
+
1231
+ app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
1232
+ app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1)
1233
+
1234
+ results = []
1235
+ for view_path, name in zip(VIEW_PATHS, VIEW_NAMES):
1236
+ if not os.path.exists(view_path):
1237
+ print(f"[hyperswap_views] Missing {view_path}, skipping")
1238
+ continue
1239
+
1240
+ bgr = cv2.imread(view_path)
1241
+ faces = app.get(bgr)
1242
+ if not faces:
1243
+ print(f"[hyperswap_views] {name}: no face detected")
1244
+ out_path = view_path # return original
1245
+ else:
1246
+ face = faces[0]
1247
+ M, _ = cv2.estimateAffinePartial2D(face.kps, ARCFACE_256,
1248
+ method=cv2.RANSAC, ransacReprojThreshold=100)
1249
+ H, W = bgr.shape[:2]
1250
+ aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR)
1251
+ t = ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1].copy().transpose(2, 0, 1)[None]
1252
+ out, mask = _hyperswap_sess.run(None, {
1253
+ "source": embedding.reshape(1, -1),
1254
+ "target": t,
1255
+ })
1256
+ out_bgr = (((out[0].transpose(1, 2, 0) + 1) / 2 * 255)
1257
+ .clip(0, 255).astype(np.uint8))[:, :, ::-1].copy()
1258
+ m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8)
1259
+ Mi = cv2.invertAffineTransform(M)
1260
+ of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR)
1261
+ mf = cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(np.float32)[:, :, None] / 255
1262
+ swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8)
1263
+
1264
+ # GFPGAN face restoration
1265
+ restorer = load_gfpgan()
1266
+ if restorer is not None:
1267
+ b = face.bbox.astype(int)
1268
+ h2, w2 = swapped.shape[:2]
1269
+ pad = 0.35
1270
+ bw2, bh2 = b[2]-b[0], b[3]-b[1]
1271
+ cx1 = max(0, b[0]-int(bw2*pad)); cy1 = max(0, b[1]-int(bh2*pad))
1272
+ cx2 = min(w2, b[2]+int(bw2*pad)); cy2 = min(h2, b[3]+int(bh2*pad))
1273
+ crop = swapped[cy1:cy2, cx1:cx2]
1274
+ try:
1275
+ _, _, rest = restorer.enhance(
1276
+ crop, has_aligned=False, only_center_face=True,
1277
+ paste_back=True, weight=0.5)
1278
+ if rest is not None:
1279
+ ch, cw = cy2 - cy1, cx2 - cx1
1280
+ if rest.shape[:2] != (ch, cw):
1281
+ rest = cv2.resize(rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4)
1282
+ swapped[cy1:cy2, cx1:cx2] = rest
1283
+ except Exception as _ge:
1284
+ print(f"[hyperswap_views] GFPGAN failed: {_ge}")
1285
+
1286
+ out_path = view_path.replace("render_", "swapped_")
1287
+ cv2.imwrite(out_path, swapped)
1288
+ print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}")
1289
+
1290
+ results.append((out_path, name))
1291
+
1292
+ return results
1293
+ except Exception:
1294
+ err = traceback.format_exc()
1295
+ print(f"hyperswap_views FAILED:\n{err}")
1296
+ return []
1297
+
1298
+
1299
+ # ── Animate tab functions ─────────────────────────────────────────────────────
1300
+
1301
+ def gradio_search_motions(query: str, progress=gr.Progress()):
1302
+ """Stream TeoGchx/HumanML3D and return matching motions as radio choices."""
1303
+ if not query.strip():
1304
+ return (
1305
+ gr.update(choices=[], visible=False),
1306
+ [],
1307
+ "Enter a motion description and click Search.",
1308
+ )
1309
+ try:
1310
+ progress(0.1, desc="Connecting to HumanML3D dataset…")
1311
+ sys.path.insert(0, str(HERE))
1312
+ from Retarget.search import search_motions, format_choice_label
1313
+ progress(0.3, desc="Streaming dataset…")
1314
+ results = search_motions(query, top_k=8)
1315
+ progress(1.0)
1316
+ if not results:
1317
+ return (
1318
+ gr.update(choices=["No matches — try different keywords"], visible=True),
1319
+ [],
1320
+ f"No motions matched '{query}'. Try broader terms.",
1321
+ )
1322
+ choices = [format_choice_label(r) for r in results]
1323
+ status = f"Found {len(results)} motions matching '{query}'"
1324
+ return (
1325
+ gr.update(choices=choices, value=choices[0], visible=True),
1326
+ results,
1327
+ status,
1328
+ )
1329
+ except Exception:
1330
+ return (
1331
+ gr.update(choices=[], visible=False),
1332
+ [],
1333
+ f"Search error:\n{traceback.format_exc()}",
1334
+ )
1335
+
1336
+
1337
+ @spaces.GPU(duration=180)
1338
+ def gradio_animate(
1339
+ rigged_glb_path,
1340
+ selected_label: str,
1341
+ motion_results: list,
1342
+ fps: int,
1343
+ max_frames: int,
1344
+ progress=gr.Progress(),
1345
+ ):
1346
+ """Bake selected HumanML3D motion onto the rigged GLB."""
1347
+ try:
1348
+ glb = rigged_glb_path or "/tmp/rig_out/rigged.glb"
1349
+ if not os.path.exists(glb):
1350
+ return None, "No rigged GLB — run the Rig step first.", None
1351
+
1352
+ if not motion_results or not selected_label:
1353
+ return None, "No motion selected — run Search first.", None
1354
+
1355
+ # Resolve which result was selected
1356
+ sys.path.insert(0, str(HERE))
1357
+ from Retarget.search import format_choice_label
1358
+ idx = 0
1359
+ for i, r in enumerate(motion_results):
1360
+ if format_choice_label(r) == selected_label:
1361
+ idx = i
1362
+ break
1363
+
1364
+ chosen = motion_results[idx]
1365
+ motion = chosen["motion"] # np.ndarray [T, 263]
1366
+ caption = chosen["caption"]
1367
+ T_total = motion.shape[0]
1368
+ n_frames = min(max_frames, T_total) if max_frames > 0 else T_total
1369
+
1370
+ progress(0.2, desc="Parsing skeleton…")
1371
+ from Retarget.animate import animate_glb_from_hml3d
1372
+
1373
+ out_path = "/tmp/animated_out/animated.glb"
1374
+ os.makedirs("/tmp/animated_out", exist_ok=True)
1375
+
1376
+ progress(0.4, desc="Mapping bones to SMPL joints…")
1377
+ animated = animate_glb_from_hml3d(
1378
+ motion=motion,
1379
+ rigged_glb=glb,
1380
+ output_glb=out_path,
1381
+ fps=int(fps),
1382
+ num_frames=int(n_frames),
1383
+ )
1384
+ progress(1.0, desc="Done!")
1385
+ status = (
1386
+ f"Animated: {n_frames} frames @ {fps} fps\n"
1387
+ f"Motion: {caption[:120]}"
1388
+ )
1389
+ return animated, status, animated
1390
+
1391
+ except Exception:
1392
+ return None, f"Error:\n{traceback.format_exc()}", None
1393
+
1394
+
1395
+ # ── PSHuman Face Transplant ────────────────────────────────────────────────────
1396
+
1397
+ def gradio_pshuman_face(
1398
+ input_image,
1399
+ rigged_glb_path,
1400
+ weight_threshold: float,
1401
+ retract_mm: float,
1402
+ pshuman_url: str,
1403
+ progress=gr.Progress(),
1404
+ ):
1405
+ """
1406
+ Full PSHuman face transplant pipeline:
1407
+ 1. Run PSHuman on input_image → colored OBJ face mesh
1408
+ 2. Run face_transplant.py → stitch face into rigged GLB
1409
+ 3. Return the combined GLB
1410
+
1411
+ PSHuman runs as a remote service (pshuman_url). On ZeroGPU the service_url
1412
+ must point to an externally-deployed PSHuman endpoint (PSHUMAN_URL env var
1413
+ or user-provided URL in the UI). Local localhost will not work on ZeroGPU.
1414
+ """
1415
+ try:
1416
+ if input_image is None:
1417
+ return None, "Upload a portrait image first.", None
1418
+ rigged = rigged_glb_path
1419
+ if not rigged or not os.path.exists(str(rigged)):
1420
+ return None, "No rigged GLB found — run the Rig step first.", None
1421
+
1422
+ work_dir = tempfile.mkdtemp(prefix="pshuman_transplant_")
1423
+ img_path = os.path.join(work_dir, "portrait.png")
1424
+ if isinstance(input_image, np.ndarray):
1425
+ Image.fromarray(input_image).save(img_path)
1426
+ else:
1427
+ input_image.save(img_path)
1428
+
1429
+ # pipeline/ is already in sys.path via PIPELINE_DIR insertion at startup
1430
+ # ── Step 1: PSHuman inference ──────────────────────────────────────────
1431
+ progress(0.05, desc="Step 1/2: Running PSHuman (generates multi-view face)...")
1432
+ from pipeline.pshuman_client import generate_pshuman_mesh
1433
+ face_obj = os.path.join(work_dir, "pshuman_face.obj")
1434
+ generate_pshuman_mesh(
1435
+ image_path = img_path,
1436
+ output_path = face_obj,
1437
+ service_url = pshuman_url.strip() or "http://localhost:7862",
1438
+ )
1439
+
1440
+ # ── Step 2: Face transplant ────────────────────────────────────────────
1441
+ progress(0.7, desc="Step 2/2: Stitching PSHuman face into rigged GLB...")
1442
+ out_glb = os.path.join(work_dir, "rigged_pshuman_face.glb")
1443
+
1444
+ from pipeline.face_transplant import transplant_face
1445
+ transplant_face(
1446
+ body_glb_path = str(rigged),
1447
+ pshuman_mesh_path = face_obj,
1448
+ output_path = out_glb,
1449
+ weight_threshold = float(weight_threshold),
1450
+ retract_amount = float(retract_mm) / 1000.0, # mm → metres
1451
+ )
1452
+
1453
+ progress(1.0, desc="Done!")
1454
+ return out_glb, "PSHuman face transplant complete.", out_glb
1455
+
1456
+ except Exception:
1457
+ return None, f"Error:\n{traceback.format_exc()}", None
1458
+
1459
+
1460
  # ── Full pipeline ─────────────────────────────────────────────────────────────
1461
 
1462
+ def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count,
1463
+ variant, tex_seed, enhance_face, rembg_threshold, rembg_erode,
1464
  export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
1465
  progress(0.0, desc="Stage 1/3: Generating shape...")
1466
+ glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count)
1467
  if not glb:
1468
  return None, None, None, None, None, None, status
1469
 
1470
  progress(0.33, desc="Stage 2/3: Applying texture...")
1471
+ glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed,
1472
+ enhance_face, rembg_threshold, rembg_erode)
1473
  if not glb:
1474
  return None, None, None, None, None, None, status
1475
 
 
1481
 
1482
 
1483
  # ── UI ────────────────────────────────────────────────────────────────────────
1484
+ with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
1485
  gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh")
1486
+ glb_state = gr.State(None)
1487
+ rigged_glb_state = gr.State(None) # persists rigged GLB for Animate + PSHuman tabs
1488
+
1489
+ with gr.Tabs() as tabs:
1490
+
1491
+ # ════════════════════════════════════════════════════════════════════
1492
+ with gr.Tab("Edit", id=0):
1493
+ gr.Markdown(
1494
+ "### Image Edit — FireRed\n"
1495
+ "Upload one or more reference images, write an edit prompt, preview the result, "
1496
+ "then click **Load to Generate** to send it to the 3D pipeline."
1497
+ )
1498
+ with gr.Row():
1499
+ with gr.Column(scale=1):
1500
+ firered_gallery = gr.Gallery(
1501
+ label="Reference Images (1–3 images, drag & drop)",
1502
+ interactive=True,
1503
+ columns=3,
1504
+ height=220,
1505
+ object_fit="contain",
1506
+ )
1507
+ firered_prompt = gr.Textbox(
1508
+ label="Edit Prompt",
1509
+ placeholder="make the person wear a red jacket",
1510
+ lines=2,
1511
+ )
1512
+ with gr.Row():
1513
+ firered_seed = gr.Number(value=_init_seed, label="Seed", precision=0)
1514
+ firered_rand = gr.Checkbox(label="Random Seed", value=True)
1515
+ with gr.Row():
1516
+ firered_guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5,
1517
+ label="Guidance Scale")
1518
+ firered_steps = gr.Slider(1, 40, value=4, step=1,
1519
+ label="Inference Steps")
1520
+ firered_btn = gr.Button("Generate Preview", variant="secondary")
1521
+ firered_status = gr.Textbox(label="Status", lines=2, interactive=False)
1522
 
1523
+ with gr.Column(scale=1):
1524
+ firered_output_img = gr.Image(label="FireRed Output", type="numpy",
1525
+ interactive=False)
1526
+ load_to_generate_btn = gr.Button("Load to Generate", variant="primary")
1527
 
1528
  # ════════════════════════════════════════════════════════════════════
1529
+ with gr.Tab("Generate", id=1):
1530
  with gr.Row():
1531
  with gr.Column(scale=1):
1532
+ input_image = gr.Image(label="Input Image", type="numpy")
1533
+ remove_bg_check = gr.Checkbox(label="Remove Background", value=True)
1534
+ with gr.Row():
1535
+ rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
1536
+ label="BG Threshold (higher = stricter)")
1537
+ rembg_erode = gr.Slider(0, 8, value=2, step=1,
1538
+ label="Edge Erode (px)")
1539
 
1540
  with gr.Accordion("Shape Settings", open=True):
1541
  num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
 
1554
  shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
1555
  texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
1556
  render_btn = gr.Button("Render Views", variant="secondary", scale=1)
1557
+ run_all_btn = gr.Button("▶ Run Full Pipeline (Shape + Texture + Rig)", variant="primary", interactive=False)
1558
 
1559
  with gr.Column(scale=1):
1560
+ rembg_preview = gr.Image(label="BG Removed Preview", type="numpy",
1561
+ interactive=False)
1562
  status = gr.Textbox(label="Status", lines=3, interactive=False)
1563
  model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
1564
  download_file = gr.File(label="Download GLB")
 
1566
 
1567
  render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
1568
 
1569
+ # ── wiring: Generate tab ──────────────────────────────────────────
1570
+ _rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode]
1571
  _pipeline_btns = [shape_btn, run_all_btn]
1572
 
1573
  input_image.upload(
 
1579
  inputs=[], outputs=_pipeline_btns,
1580
  )
1581
 
1582
+ input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
1583
+ remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
1584
+ rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
1585
+ rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
1586
+
1587
  shape_btn.click(
1588
+ fn=generate_shape,
1589
+ inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count],
1590
  outputs=[glb_state, status],
1591
  ).then(
1592
  fn=lambda p: (p, p) if p else (None, None),
 
1594
  )
1595
 
1596
  texture_btn.click(
1597
+ fn=apply_texture,
1598
+ inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed,
1599
+ enhance_face_check, rembg_threshold, rembg_erode],
1600
  outputs=[glb_state, multiview_img, status],
1601
  ).then(
1602
  fn=lambda p: (p, p) if p else (None, None),
 
1605
 
1606
  render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
1607
 
1608
+ # ── Edit tab wiring (after Generate so all components are defined) ──
1609
+ firered_btn.click(
1610
+ fn=firered_generate,
1611
+ inputs=[firered_gallery, firered_prompt, firered_seed, firered_rand,
1612
+ firered_guidance, firered_steps],
1613
+ outputs=[firered_output_img, firered_seed, firered_status],
1614
+ api_name="firered_generate",
1615
+ )
1616
+
1617
+ load_to_generate_btn.click(
1618
+ fn=firered_load_into_pipeline,
1619
+ inputs=[firered_output_img, rembg_threshold, rembg_erode],
1620
+ outputs=[input_image, rembg_preview, firered_status],
1621
+ ).then(
1622
+ fn=lambda img: (
1623
+ gr.update(interactive=img is not None),
1624
+ gr.update(interactive=img is not None),
1625
+ gr.update(selected=1),
1626
+ ),
1627
+ inputs=[input_image],
1628
+ outputs=[shape_btn, run_all_btn, tabs],
1629
+ )
1630
+
1631
  # ════════════════════════════════════════════════════════════════════
1632
  with gr.Tab("Rig & Export"):
1633
  with gr.Row():
 
1675
  inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
1676
  outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
1677
  rig_model_3d, rigged_base_state, skel_glb_state],
1678
+ ).then(
1679
+ fn=lambda p: p,
1680
+ inputs=[rigged_base_state], outputs=[rigged_glb_state],
1681
  )
1682
 
1683
  show_skel_check.change(
 
1686
  outputs=[rig_model_3d],
1687
  )
1688
 
1689
+ # ════════════════════════════════════════════════════════════════════
1690
+ with gr.Tab("Animate"):
1691
+ gr.Markdown(
1692
+ "### Motion Search & Animate\n"
1693
+ "Search the HumanML3D dataset for motions matching a description, "
1694
+ "then bake the selected motion onto your rigged GLB."
1695
+ )
1696
+ with gr.Row():
1697
+ with gr.Column(scale=1):
1698
+ motion_query = gr.Textbox(
1699
+ label="Motion Description",
1700
+ placeholder="a person walks forward slowly",
1701
+ lines=2,
1702
+ )
1703
+ search_btn = gr.Button("Search Motions", variant="secondary")
1704
+ motion_radio = gr.Radio(
1705
+ label="Select Motion", choices=[], visible=False,
1706
+ )
1707
+ motion_results_state = gr.State([])
1708
+
1709
+ gr.Markdown("### Animate Settings")
1710
+ animate_fps = gr.Slider(10, 60, value=30, step=5, label="FPS")
1711
+ animate_frames = gr.Slider(0, 600, value=0, step=30,
1712
+ label="Max Frames (0 = full motion)")
1713
+ animate_btn = gr.Button("Animate", variant="primary")
1714
+
1715
+ with gr.Column(scale=2):
1716
+ animate_status = gr.Textbox(label="Status", lines=4, interactive=False)
1717
+ animate_model_3d = gr.Model3D(label="Animated Preview",
1718
+ clear_color=[0.9, 0.9, 0.9, 1.0])
1719
+ animate_dl = gr.File(label="Download Animated GLB")
1720
+
1721
+ search_btn.click(
1722
+ fn=gradio_search_motions,
1723
+ inputs=[motion_query],
1724
+ outputs=[motion_radio, motion_results_state, animate_status],
1725
+ )
1726
+
1727
+ animate_btn.click(
1728
+ fn=gradio_animate,
1729
+ inputs=[rigged_glb_state, motion_radio, motion_results_state,
1730
+ animate_fps, animate_frames],
1731
+ outputs=[animate_dl, animate_status, animate_model_3d],
1732
+ )
1733
+
1734
+ # ════════════════════════════════════════════════════════════════════
1735
+ with gr.Tab("PSHuman Face"):
1736
+ gr.Markdown(
1737
+ "### PSHuman Face Transplant\n"
1738
+ "Generates a high-detail face mesh via PSHuman (multi-view diffusion), "
1739
+ "then transplants it into the rigged GLB.\n\n"
1740
+ "**Pipeline:** portrait → PSHuman (remote service) → colored OBJ → face_transplant → rigged GLB with HD face\n\n"
1741
+ "**Note:** On ZeroGPU, PSHuman must run as a remote service. "
1742
+ "Set `PSHUMAN_URL` environment variable or enter the URL below."
1743
+ )
1744
+ with gr.Row():
1745
+ with gr.Column(scale=1):
1746
+ pshuman_img_input = gr.Image(
1747
+ label="Portrait image (same as used for Generate)",
1748
+ type="pil",
1749
+ )
1750
+ with gr.Accordion("Advanced settings", open=False):
1751
+ pshuman_weight_thresh = gr.Slider(
1752
+ minimum=0.1, maximum=0.9, value=0.35, step=0.05,
1753
+ label="Head bone weight threshold",
1754
+ info="Vertices with head-bone weight above this get replaced",
1755
+ )
1756
+ pshuman_retract_mm = gr.Slider(
1757
+ minimum=0.0, maximum=20.0, value=4.0, step=0.5,
1758
+ label="Face retract (mm)",
1759
+ info="How far to push original face verts inward to avoid z-fighting",
1760
+ )
1761
+ pshuman_service_url = gr.Textbox(
1762
+ label="PSHuman service URL",
1763
+ value=os.environ.get("PSHUMAN_URL", "http://localhost:7862"),
1764
+ info="pshuman_app.py Gradio endpoint (deployed separately)",
1765
+ )
1766
+ pshuman_btn = gr.Button("Generate HD Face", variant="primary")
1767
+
1768
+ with gr.Column(scale=2):
1769
+ pshuman_status = gr.Textbox(label="Status", lines=4, interactive=False)
1770
+ pshuman_model_3d = gr.Model3D(
1771
+ label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
1772
+ pshuman_glb_dl = gr.File(label="Download GLB (with PSHuman face)")
1773
+
1774
+ pshuman_btn.click(
1775
+ fn=gradio_pshuman_face,
1776
+ inputs=[
1777
+ pshuman_img_input,
1778
+ rigged_glb_state,
1779
+ pshuman_weight_thresh,
1780
+ pshuman_retract_mm,
1781
+ pshuman_service_url,
1782
+ ],
1783
+ outputs=[pshuman_glb_dl, pshuman_status, pshuman_model_3d],
1784
+ )
1785
+
1786
  # ════════════════════════════════════════════════════════════════════
1787
  with gr.Tab("Enhancement"):
1788
  gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.")
 
1799
  displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
1800
 
1801
  enhance_btn = gr.Button("Run Enhancement", variant="primary")
1802
+ unload_btn = gr.Button("Unload Models (free VRAM)", variant="secondary")
1803
 
1804
  with gr.Column(scale=2):
1805
  enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
 
1818
  enhanced_glb_dl, enhanced_model_3d, enhance_status],
1819
  )
1820
 
1821
+ def _unload_enhancement_models():
1822
+ try:
1823
+ from pipeline.enhance_surface import unload_models
1824
+ unload_models()
1825
+ return "Enhancement models unloaded — VRAM freed."
1826
+ except Exception as e:
1827
+ return f"Unload failed: {e}"
1828
+
1829
+ unload_btn.click(
1830
+ fn=_unload_enhancement_models,
1831
+ inputs=[], outputs=[enhance_status],
1832
+ )
1833
+
1834
+ # ════════════════════════════════════════════════════════════════════
1835
+ with gr.Tab("Settings"):
1836
+
1837
+ def get_vram_status():
1838
+ lines = []
1839
+ if torch.cuda.is_available():
1840
+ alloc = torch.cuda.memory_allocated() / 1024**3
1841
+ reserv = torch.cuda.memory_reserved() / 1024**3
1842
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
1843
+ free = total - reserv
1844
+ lines.append(f"GPU: {torch.cuda.get_device_name(0)}")
1845
+ lines.append(f"VRAM total: {total:.1f} GB")
1846
+ lines.append(f"VRAM allocated: {alloc:.1f} GB")
1847
+ lines.append(f"VRAM reserved: {reserv:.1f} GB")
1848
+ lines.append(f"VRAM free: {free:.1f} GB")
1849
+ else:
1850
+ lines.append("No CUDA device available.")
1851
+ lines.append("")
1852
+ lines.append("Loaded models:")
1853
+ lines.append(f" TripoSG pipeline: {'loaded' if _triposg_pipe is not None else 'not loaded'}")
1854
+ lines.append(f" RMBG-{_rmbg_version or '?'}: {'loaded' if _rmbg_net is not None else 'not loaded'}")
1855
+ lines.append(f" FireRed: {'loaded' if _firered_pipe is not None else 'not loaded'}")
1856
+ try:
1857
+ import pipeline.enhance_surface as _enh_mod
1858
+ lines.append(f" StableNormal: {'loaded' if _enh_mod._normal_pipe is not None else 'not loaded'}")
1859
+ lines.append(f" Depth-Anything: {'loaded' if _enh_mod._depth_pipe is not None else 'not loaded'}")
1860
+ except Exception:
1861
+ lines.append(" StableNormal / Depth-Anything: (status unavailable)")
1862
+ return "\n".join(lines)
1863
+
1864
+ def _preload_triposg():
1865
+ try:
1866
+ load_triposg()
1867
+ return get_vram_status()
1868
+ except Exception:
1869
+ return f"Preload failed:\n{traceback.format_exc()}"
1870
+
1871
+ def _unload_triposg():
1872
+ global _triposg_pipe, _rmbg_net
1873
+ with _model_load_lock:
1874
+ if _triposg_pipe is not None:
1875
+ _triposg_pipe.to("cpu")
1876
+ del _triposg_pipe
1877
+ _triposg_pipe = None
1878
+ if _rmbg_net is not None:
1879
+ _rmbg_net.to("cpu")
1880
+ del _rmbg_net
1881
+ _rmbg_net = None
1882
+ torch.cuda.empty_cache()
1883
+ return get_vram_status()
1884
+
1885
+ def _unload_enhancement():
1886
+ try:
1887
+ from pipeline.enhance_surface import unload_models
1888
+ unload_models()
1889
+ except Exception:
1890
+ pass
1891
+ return get_vram_status()
1892
+
1893
+ def _unload_all():
1894
+ _unload_triposg()
1895
+ _unload_enhancement()
1896
+ return get_vram_status()
1897
+
1898
+ with gr.Row():
1899
+ with gr.Column(scale=1):
1900
+ gr.Markdown("### VRAM Management")
1901
+ preload_btn = gr.Button("Preload TripoSG + RMBG to VRAM", variant="primary")
1902
+ unload_triposg_btn = gr.Button("Unload TripoSG / RMBG")
1903
+ unload_enh_btn = gr.Button("Unload Enhancement Models (StableNormal / Depth)")
1904
+ unload_all_btn = gr.Button("Unload All Models", variant="stop")
1905
+ refresh_btn = gr.Button("Refresh Status")
1906
+
1907
+ with gr.Column(scale=1):
1908
+ gr.Markdown("### GPU Status")
1909
+ vram_status = gr.Textbox(
1910
+ label="", lines=12, interactive=False,
1911
+ value="Click Refresh to check VRAM status.",
1912
+ )
1913
+
1914
+ preload_btn.click(fn=_preload_triposg, inputs=[], outputs=[vram_status])
1915
+ unload_triposg_btn.click(fn=_unload_triposg, inputs=[], outputs=[vram_status])
1916
+ unload_enh_btn.click(fn=_unload_enhancement, inputs=[], outputs=[vram_status])
1917
+ unload_all_btn.click(fn=_unload_all, inputs=[], outputs=[vram_status])
1918
+ refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status])
1919
+
1920
+ # ── Run All wiring (after all tabs so components are defined) ────────
1921
  run_all_btn.click(
1922
  fn=run_full_pipeline,
1923
  inputs=[
1924
+ input_image, remove_bg_check, num_steps, guidance, seed, face_count,
1925
+ variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode,
1926
  export_fbx_check, mdm_prompt_box, mdm_frames_slider,
1927
  ],
1928
  outputs=[glb_state, download_file, multiview_img,
 
1932
  inputs=[glb_state], outputs=[model_3d, download_file],
1933
  )
1934
 
1935
+ # ── Hidden API endpoints ──────────────────────────────────────────────────
1936
+ _api_render_gallery = gr.Gallery(visible=False)
1937
+ _api_swap_gallery = gr.Gallery(visible=False)
1938
+
1939
+ def _render_last():
1940
+ path = _last_glb_path or "/tmp/triposg_textured.glb"
1941
+ return render_views(path)
1942
+
1943
+ _hs_emb_input = gr.Textbox(visible=False)
1944
+
1945
+ gr.Button(visible=False).click(
1946
+ fn=_render_last, inputs=[], outputs=[_api_render_gallery], api_name="render_last")
1947
+ gr.Button(visible=False).click(
1948
+ fn=hyperswap_views, inputs=[_hs_emb_input], outputs=[_api_swap_gallery],
1949
+ api_name="hyperswap_views")
1950
+
1951
 
1952
  if __name__ == "__main__":
1953
+ demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(),
1954
+ show_error=True, allowed_paths=["/tmp"])
pipeline/face_inswap_bake.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ face_inswap_bake.py — Proper face swap on rendered views, then UV-bake.
3
+
4
+ Pipeline:
5
+ 1. Render the mesh from multiple views (front + L/R 3-quarter)
6
+ 2. Run inswapper_128 to swap reference face onto each rendered view
7
+ 3. uv_render_attr() bakes each swapped render directly into UV texture
8
+ (render-space coords shared with UV lookup — no coordinate transforms)
9
+ 4. Composite multiple views (front takes priority, sides fill gaps)
10
+ 5. Save updated GLB
11
+
12
+ Usage:
13
+ python face_inswap_bake.py \
14
+ --body /tmp/triposg_textured.glb \
15
+ --face /tmp/triposg_face_ref.png \
16
+ --out /tmp/face_swapped.glb \
17
+ [--uv_size 4096] [--debug_dir /tmp]
18
+ """
19
+
20
+ import os, sys, argparse, warnings
21
+ warnings.filterwarnings('ignore')
22
+
23
+ import numpy as np
24
+ import cv2
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from PIL import Image
28
+ import trimesh
29
+ from trimesh.visual.texture import TextureVisuals
30
+ from trimesh.visual.material import PBRMaterial
31
+
32
+ sys.path.insert(0, '/root/MV-Adapter')
33
+ from mvadapter.utils.mesh_utils import (
34
+ NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
35
+ )
36
+ from mvadapter.utils.mesh_utils.uv import (
37
+ uv_precompute, uv_render_geometry, uv_render_attr,
38
+ )
39
+ from insightface.app import FaceAnalysis
40
+ import insightface
41
+ from gfpgan import GFPGANer
42
+
43
+
44
+ GFPGAN_PATH = '/root/MV-Adapter/checkpoints/GFPGANv1.4.pth'
45
+
46
+
47
+ # ── helpers ───────────────────────────────────────────────────────────────────
48
+
49
+ def _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac=0.76):
50
+ """UV-space mask covering only front-facing head triangles (no back-of-head)."""
51
+ verts = np.array(mesh_t.vertices, dtype=np.float64)
52
+ faces = np.array(mesh_t.faces, dtype=np.int32)
53
+ uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
54
+
55
+ y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
56
+ neck_y = float(y_min + (y_max - y_min) * neck_frac)
57
+ head_idx = np.where(verts[:, 1] > neck_y)[0]
58
+ hv = verts[head_idx]
59
+
60
+ z_thresh = float(np.percentile(hv[:, 2], 40))
61
+ front = hv[:, 2] >= z_thresh
62
+ if front.sum() < 30:
63
+ front = np.ones(len(hv), bool)
64
+
65
+ face_vert_idx = head_idx[front]
66
+ face_vert_mask = np.zeros(len(verts), bool)
67
+ face_vert_mask[face_vert_idx] = True
68
+ face_tri_mask = face_vert_mask[faces].all(axis=1)
69
+ face_tris = faces[face_tri_mask]
70
+ print(f' Geometry mask: {face_tri_mask.sum()} front-face triangles '
71
+ f'(neck_y={neck_y:.3f}, z_thresh={z_thresh:.3f})')
72
+
73
+ geom_mask = np.zeros((tex_H, tex_W), dtype=np.float32)
74
+ pts_list = []
75
+ for tri in face_tris:
76
+ uv = uvs[tri]
77
+ px = uv[:, 0] * tex_W
78
+ py = (1.0 - uv[:, 1]) * tex_H
79
+ pts_list.append(np.column_stack([px, py]).astype(np.int32))
80
+ if pts_list:
81
+ cv2.fillPoly(geom_mask, pts_list, 1.0)
82
+
83
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
84
+ geom_mask = cv2.dilate(geom_mask, kernel, iterations=2)
85
+ geom_mask = cv2.erode(geom_mask, kernel, iterations=1)
86
+ geom_mask = cv2.GaussianBlur(geom_mask, (31, 31), 8)
87
+ return geom_mask
88
+
89
+
90
+ def _detect_largest_face(img_bgr, app):
91
+ faces = app.get(img_bgr)
92
+ if not faces:
93
+ return None
94
+ return max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
95
+
96
+
97
+ def _render_view(ctx, mesh_mv, uv_pre, azimuth_deg, H, W, device):
98
+ """Render the mesh from a given azimuth; return (camera, uv_geom)."""
99
+ camera = get_orthogonal_camera(
100
+ elevation_deg=[0], distance=[1.8],
101
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
102
+ azimuth_deg=[azimuth_deg], device=device,
103
+ )
104
+ uv_geom = uv_render_geometry(
105
+ ctx, mesh_mv, camera,
106
+ view_height=H, view_width=W,
107
+ uv_precompute_output=uv_pre,
108
+ compute_depth_grad=False,
109
+ )
110
+ return camera, uv_geom
111
+
112
+
113
+ def face_inswap_bake(body_glb, face_img_path, out_glb,
114
+ uv_size=4096, debug_dir=None):
115
+
116
+ device = 'cuda'
117
+ INSWAPPER_PATH = '/root/MV-Adapter/checkpoints/inswapper_128.onnx'
118
+
119
+ # ── Load GFPGAN enhancer ──────────────────────────────────────────────────
120
+ print('[fib] Loading GFPGANv1.4 ...')
121
+ enhancer = GFPGANer(
122
+ model_path=GFPGAN_PATH,
123
+ upscale=1,
124
+ arch='clean',
125
+ channel_multiplier=2,
126
+ bg_upsampler=None,
127
+ )
128
+
129
+ # ── Load mesh ─────────────────────────────────────────────────────────────
130
+ print(f'[fib] Loading mesh: {body_glb}')
131
+ ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
132
+ mesh_mv = load_mesh(body_glb, rescale=True, device=device)
133
+
134
+ scene_t = trimesh.load(body_glb)
135
+ if isinstance(scene_t, trimesh.Scene):
136
+ geom_name = list(scene_t.geometry.keys())[0]
137
+ mesh_t = scene_t.geometry[geom_name]
138
+ else:
139
+ mesh_t = scene_t; geom_name = None
140
+
141
+ orig_tex_np = np.array(mesh_t.visual.material.baseColorTexture, dtype=np.float32) / 255.0
142
+ uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
143
+ tex_H, tex_W = orig_tex_np.shape[:2]
144
+ print(f' Texture: {tex_W}×{tex_H}')
145
+
146
+ # Build geometry mask (front-face head triangles only) at UV resolution
147
+ print('[fib] Building front-face geometry UV mask ...')
148
+ geom_uv_mask = _build_front_face_uv_mask(mesh_t, uv_size, uv_size)
149
+
150
+ # Render dimensions (match triposg_app.py)
151
+ H_r, W_r = 1024, 768
152
+
153
+ # ── Precompute UV geometry ─────────────────────────────────────────────────
154
+ print(f'[fib] Precomputing UV geometry ({uv_size}×{uv_size}) ...')
155
+ uv_pre = uv_precompute(ctx, mesh_mv, height=uv_size, width=uv_size)
156
+
157
+ # ── Load face swap model + face detector ──────────────────────────────────
158
+ print('[fib] Loading inswapper_128 ...')
159
+ swapper = insightface.model_zoo.get_model(
160
+ INSWAPPER_PATH, download=False,
161
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
162
+ )
163
+
164
+ app = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
165
+ app.prepare(ctx_id=0, det_size=(640, 640))
166
+
167
+ ref_bgr = cv2.imread(face_img_path)
168
+ ref_face = _detect_largest_face(ref_bgr, app)
169
+ if ref_face is None:
170
+ raise RuntimeError(f'No face detected in reference: {face_img_path}')
171
+ print(f' Reference face detected: bbox={ref_face.bbox.astype(int).tolist()}')
172
+
173
+ # ── Process each view ─────────────────────────────────────────────────────
174
+ # Views: front (azimuth=-90), slight left (-60), slight right (-120)
175
+ # Azimuth convention from MV-Adapter: -90 = front-facing
176
+ views = [
177
+ ('front', -90, 1.0), # (name, azimuth_deg, priority_weight)
178
+ ('threequarter_r', -60, 0.7),
179
+ ('threequarter_l', -120, 0.7),
180
+ ]
181
+
182
+ # Accumulators for weighted UV compositing
183
+ uv_colour_acc = np.zeros((uv_size, uv_size, 3), dtype=np.float32)
184
+ uv_weight_acc = np.zeros((uv_size, uv_size), dtype=np.float32)
185
+
186
+ for view_name, azimuth, weight in views:
187
+ print(f'\n[fib] View: {view_name} (azimuth={azimuth}°)')
188
+
189
+ # Create camera + UV geometry for this view
190
+ camera, uv_geom = _render_view(ctx, mesh_mv, uv_pre, azimuth, H_r, W_r, device)
191
+
192
+ # Render textured mesh from this view
193
+ render_out = render(ctx, mesh_mv, camera, height=H_r, width=W_r,
194
+ render_attr=True, render_depth=False, render_normal=False,
195
+ attr_background=0.0)
196
+ # render_out.attr: (1, H, W, 3) float in [0,1]
197
+ rendered_np = (render_out.attr[0].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
198
+ rendered_bgr = cv2.cvtColor(rendered_np, cv2.COLOR_RGB2BGR)
199
+
200
+ if debug_dir:
201
+ cv2.imwrite(os.path.join(debug_dir, f'fib_render_{view_name}.png'), rendered_bgr)
202
+
203
+ # Detect face in this rendered view
204
+ tgt_face = _detect_largest_face(rendered_bgr, app)
205
+ if tgt_face is None:
206
+ print(f' No face in {view_name} render — skipping')
207
+ continue
208
+ print(f' Target face: bbox={tgt_face.bbox.astype(int).tolist()}')
209
+
210
+ # Swap face
211
+ swapped_bgr = swapper.get(rendered_bgr.copy(), tgt_face, ref_face, paste_back=True)
212
+
213
+ # Enhance face detail with GFPGAN
214
+ _, _, enhanced_bgr = enhancer.enhance(
215
+ swapped_bgr, has_aligned=False, only_center_face=False, paste_back=True)
216
+ if enhanced_bgr is not None:
217
+ swapped_bgr = enhanced_bgr
218
+ print(f' GFPGAN enhanced')
219
+
220
+ if debug_dir:
221
+ cv2.imwrite(os.path.join(debug_dir, f'fib_swapped_{view_name}.png'), swapped_bgr)
222
+
223
+ swapped_rgb = cv2.cvtColor(swapped_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
224
+
225
+ # Build render-space face hull mask
226
+ kps = tgt_face.kps
227
+ hull_pts = cv2.convexHull(kps.astype(np.float32)).squeeze(1)
228
+ hull_cx, hull_cy = hull_pts.mean(axis=0)
229
+ hull_exp = (hull_pts - [hull_cx, hull_cy]) * 3.5 + [hull_cx, hull_cy]
230
+ face_mask = np.zeros((H_r, W_r), dtype=np.float32)
231
+ cv2.fillPoly(face_mask, [hull_exp.astype(np.int32)], 1.0)
232
+ face_mask = cv2.GaussianBlur(face_mask, (61, 61), 20)
233
+
234
+ # Bake swapped render into UV space
235
+ swapped_t = torch.tensor(swapped_rgb, device=device).unsqueeze(0) # (1,H,W,3)
236
+ mask_t = torch.tensor(face_mask[None], device=device)
237
+
238
+ uv_out = uv_render_attr(
239
+ images=swapped_t,
240
+ masks=mask_t,
241
+ uv_render_geometry_output=uv_geom,
242
+ )
243
+ uv_img = uv_out.uv_attr_proj[0].cpu().numpy() # (uv, uv, 3)
244
+ uv_mask = uv_out.uv_mask_proj[0].cpu().numpy() # (uv, uv)
245
+
246
+ # Kill back-of-head UV islands
247
+ uv_mask = uv_mask * geom_uv_mask
248
+
249
+ # Weighted accumulate
250
+ w = uv_mask * weight
251
+ uv_colour_acc += uv_img * w[..., None]
252
+ uv_weight_acc += w
253
+ print(f' Painted texels: {(uv_mask > 0.05).sum()}')
254
+
255
+ # ── Composite ──────────────────────────────────────────────────────────────
256
+ print('\n[fib] Compositing views ...')
257
+ valid = uv_weight_acc > 0.01
258
+ uv_final = np.where(valid[..., None],
259
+ uv_colour_acc / np.maximum(uv_weight_acc[..., None], 1e-6),
260
+ orig_tex_np[:uv_size, :uv_size] if uv_size <= tex_H else orig_tex_np)
261
+
262
+ # Resize to texture resolution if needed
263
+ if uv_size != tex_H or uv_size != tex_W:
264
+ uv_final_rs = cv2.resize(uv_final, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
265
+ weight_rs = cv2.resize(uv_weight_acc, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
266
+ else:
267
+ uv_final_rs = uv_final
268
+ weight_rs = uv_weight_acc
269
+
270
+ # Blend with original texture: use face-swap result where painted, orig elsewhere
271
+ alpha = np.clip(weight_rs, 0, 1)[..., None]
272
+ new_tex = uv_final_rs * alpha + orig_tex_np * (1.0 - alpha)
273
+ print(f' Total painted texels (tex res): {(weight_rs > 0.05).sum()}')
274
+
275
+ if debug_dir:
276
+ Image.fromarray((uv_final_rs * 255).clip(0,255).astype(np.uint8)).save(
277
+ os.path.join(debug_dir, 'fib_uv_composite.png'))
278
+
279
+ # ── Save GLB ──────────────────────────────────────────────────────────────
280
+ new_pil = Image.fromarray((new_tex * 255).clip(0, 255).astype(np.uint8))
281
+ mesh_t.visual = TextureVisuals(uv=uvs, material=PBRMaterial(baseColorTexture=new_pil))
282
+
283
+ if geom_name and isinstance(scene_t, trimesh.Scene):
284
+ scene_t.geometry[geom_name] = mesh_t
285
+ scene_t.export(out_glb)
286
+ else:
287
+ mesh_t.export(out_glb)
288
+
289
+ print(f'[fib] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
290
+ return out_glb
291
+
292
+
293
+ if __name__ == '__main__':
294
+ ap = argparse.ArgumentParser()
295
+ ap.add_argument('--body', required=True)
296
+ ap.add_argument('--face', required=True)
297
+ ap.add_argument('--out', required=True)
298
+ ap.add_argument('--uv_size', type=int, default=4096)
299
+ ap.add_argument('--debug_dir', default=None)
300
+ args = ap.parse_args()
301
+ face_inswap_bake(args.body, args.face, args.out,
302
+ uv_size=args.uv_size, debug_dir=args.debug_dir)
pipeline/face_project.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ face_project.py — Project reference face image onto TripoSG mesh UV texture.
3
+
4
+ Keeps geometry 100% intact. Paints the face-region UV triangles using
5
+ barycentric rasterization — never interpolates across UV island boundaries.
6
+
7
+ Usage:
8
+ python face_project.py --body /tmp/triposg_textured.glb \
9
+ --face /tmp/triposg_face_ref.png \
10
+ --out /tmp/face_projected.glb \
11
+ [--blend 0.9] [--neck_frac 0.84] [--debug_tex /tmp/tex.png]
12
+ """
13
+
14
+ import os, argparse, warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+ import numpy as np
18
+ import cv2
19
+ from PIL import Image
20
+ import trimesh
21
+ from trimesh.visual.texture import TextureVisuals
22
+ from trimesh.visual.material import PBRMaterial
23
+
24
+
25
+ # ── Face alignment ─────────────────────────────────────────────────────────────
26
+
27
+ def _aligned_face_bgr(face_img_bgr, target_size=512):
28
+ """Detect + align face via InsightFace 5-pt warp; falls back to square crop."""
29
+ try:
30
+ from insightface.app import FaceAnalysis
31
+ from insightface.utils import face_align
32
+ app = FaceAnalysis(providers=['CPUExecutionProvider'])
33
+ app.prepare(ctx_id=0, det_size=(640, 640))
34
+ faces = app.get(face_img_bgr)
35
+ if faces:
36
+ faces.sort(
37
+ key=lambda f: (f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1]),
38
+ reverse=True)
39
+ aligned = face_align.norm_crop(face_img_bgr, faces[0].kps,
40
+ image_size=target_size)
41
+ print(f' InsightFace aligned: {aligned.shape}')
42
+ return aligned
43
+ except Exception as e:
44
+ print(f' InsightFace unavailable ({e}), using centre-crop')
45
+ h, w = face_img_bgr.shape[:2]
46
+ side = min(h, w)
47
+ y0, x0 = (h - side) // 2, (w - side) // 2
48
+ return cv2.resize(face_img_bgr[y0:y0+side, x0:x0+side], (target_size, target_size))
49
+
50
+
51
+ # ── Triangle rasterizer ────────────────────────────────────────────────────────
52
+
53
+ def _rasterize_triangles(face_tri_uvs_px, face_tri_img_xy,
54
+ face_img_rgb, tex, blend,
55
+ max_uv_span=300):
56
+ """
57
+ Paint face_img_rgb colour into tex at UV locations, triangle by triangle.
58
+
59
+ face_tri_uvs_px : (M, 3, 2) UV pixel coords of M triangles
60
+ face_tri_img_xy : (M, 3, 2) projected image coords of M triangles
61
+ face_img_rgb : (H, W, 3) reference face image
62
+ tex : (texH, texW, 3) float32 texture (modified in-place)
63
+ blend : float 0–1
64
+ max_uv_span : skip triangles whose UV bounding box exceeds this (UV seams)
65
+ """
66
+ H_f, W_f = face_img_rgb.shape[:2]
67
+ tex_H, tex_W = tex.shape[:2]
68
+ painted = 0
69
+
70
+ for fi in range(len(face_tri_uvs_px)):
71
+ uv = face_tri_uvs_px[fi] # (3, 2) in texture pixel coords
72
+ img = face_tri_img_xy[fi] # (3, 2) in face-image pixel coords
73
+
74
+ # Skip UV-seam triangles (vertices far apart in UV space)
75
+ if (uv[:, 0].max() - uv[:, 0].min() > max_uv_span or
76
+ uv[:, 1].max() - uv[:, 1].min() > max_uv_span):
77
+ continue
78
+
79
+ # Bounding box in texture space
80
+ u_lo = max(0, int(uv[:, 0].min()))
81
+ u_hi = min(tex_W, int(uv[:, 0].max()) + 2)
82
+ v_lo = max(0, int(uv[:, 1].min()))
83
+ v_hi = min(tex_H, int(uv[:, 1].max()) + 2)
84
+ if u_hi <= u_lo or v_hi <= v_lo:
85
+ continue
86
+
87
+ # Grid of texel centres in this bounding box
88
+ gu, gv = np.meshgrid(np.arange(u_lo, u_hi), np.arange(v_lo, v_hi))
89
+ pts = np.column_stack([gu.ravel().astype(np.float32),
90
+ gv.ravel().astype(np.float32)]) # (K, 2)
91
+
92
+ # Barycentric coordinates (in UV pixel space)
93
+ A = uv[0].astype(np.float64)
94
+ AB = (uv[1] - uv[0]).astype(np.float64)
95
+ AC = (uv[2] - uv[0]).astype(np.float64)
96
+ denom = AB[0] * AC[1] - AB[1] * AC[0]
97
+ if abs(denom) < 0.5:
98
+ continue
99
+ P = pts.astype(np.float64) - A
100
+ b1 = (P[:, 0] * AC[1] - P[:, 1] * AC[0]) / denom
101
+ b2 = (P[:, 1] * AB[0] - P[:, 0] * AB[1]) / denom
102
+ b0 = 1.0 - b1 - b2
103
+
104
+ inside = (b0 >= 0) & (b1 >= 0) & (b2 >= 0)
105
+ if not inside.any():
106
+ continue
107
+
108
+ # Interpolate reference-face image coordinates
109
+ ix_f = (b0[inside] * img[0, 0] +
110
+ b1[inside] * img[1, 0] +
111
+ b2[inside] * img[2, 0])
112
+ iy_f = (b0[inside] * img[0, 1] +
113
+ b1[inside] * img[1, 1] +
114
+ b2[inside] * img[2, 1])
115
+
116
+ valid = ((ix_f >= 0) & (ix_f < W_f) & (iy_f >= 0) & (iy_f < H_f))
117
+ if not valid.any():
118
+ continue
119
+
120
+ ix = np.clip(ix_f[valid].astype(int), 0, W_f - 1)
121
+ iy = np.clip(iy_f[valid].astype(int), 0, H_f - 1)
122
+ colours = face_img_rgb[iy, ix].astype(np.float32) # (P, 3)
123
+
124
+ tu = pts[inside][valid, 0].astype(int)
125
+ tv = pts[inside][valid, 1].astype(int)
126
+ in_tex = (tu >= 0) & (tu < tex_W) & (tv >= 0) & (tv < tex_H)
127
+
128
+ tex[tv[in_tex], tu[in_tex]] = (
129
+ blend * colours[in_tex] +
130
+ (1.0 - blend) * tex[tv[in_tex], tu[in_tex]]
131
+ )
132
+ painted += int(in_tex.sum())
133
+
134
+ return painted
135
+
136
+
137
+ # ── Main ───────────────────────────────────────────────────────────────────────
138
+
139
+ def project_face(body_glb, face_img_path, out_glb,
140
+ blend=0.90, neck_frac=0.84, debug_tex=None):
141
+ """
142
+ Project reference face onto TripoSG UV texture via per-triangle rasterization.
143
+ """
144
+
145
+ # ── Load mesh ─────────────────────────────────────────────────────────────
146
+ print(f'[face_project] Loading {body_glb}')
147
+ scene = trimesh.load(body_glb)
148
+ if isinstance(scene, trimesh.Scene):
149
+ geom_name = list(scene.geometry.keys())[0]
150
+ mesh = scene.geometry[geom_name]
151
+ else:
152
+ mesh = scene
153
+ geom_name = None
154
+
155
+ verts = np.array(mesh.vertices, dtype=np.float64) # (N, 3)
156
+ faces = np.array(mesh.faces, dtype=np.int32) # (F, 3)
157
+ uvs = np.array(mesh.visual.uv, dtype=np.float64) # (N, 2)
158
+ mat = mesh.visual.material
159
+ orig_tex = np.array(mat.baseColorTexture, dtype=np.float32) # (H, W, 3) RGB
160
+ tex_H, tex_W = orig_tex.shape[:2]
161
+ print(f' {len(verts)} verts | {len(faces)} faces | texture {orig_tex.shape}')
162
+
163
+ # ── Identify face-region vertices ─────────────────────────────────────────
164
+ y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
165
+ neck_y = float(y_min + (y_max - y_min) * neck_frac)
166
+
167
+ head_mask = verts[:, 1] > neck_y
168
+ head_idx = np.where(head_mask)[0]
169
+ hv = verts[head_idx]
170
+
171
+ # Front half only (z >= median — face faces +Z)
172
+ z_med = float(np.median(hv[:, 2]))
173
+ front = hv[:, 2] >= z_med
174
+ if front.sum() < 30:
175
+ front = np.ones(len(hv), bool)
176
+
177
+ face_vert_idx = head_idx[front] # indices into the full vertex array
178
+
179
+ # Build boolean mask for fast triangle selection
180
+ face_vert_mask = np.zeros(len(verts), bool)
181
+ face_vert_mask[face_vert_idx] = True
182
+
183
+ # Select triangles where ALL 3 vertices are in the face region
184
+ face_tri_mask = face_vert_mask[faces].all(axis=1)
185
+ face_tris = faces[face_tri_mask] # (M, 3)
186
+ print(f' neck_y={neck_y:.4f} | head={len(head_idx)} '
187
+ f'| face-front={front.sum()} | face triangles={len(face_tris)}')
188
+
189
+ # ── Load and align reference face ─────────────────────────────────────────
190
+ print(f'[face_project] Reference face: {face_img_path}')
191
+ raw_bgr = cv2.imread(face_img_path)
192
+ aligned_bgr = _aligned_face_bgr(raw_bgr, target_size=512)
193
+ aligned_rgb = cv2.cvtColor(aligned_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
194
+ H_f, W_f = aligned_rgb.shape[:2]
195
+
196
+ # ── Compute face projection axes from actual face normal ─────────────────
197
+ fv = verts[face_vert_idx]
198
+
199
+ # Average normal of the front-facing face triangles defines projection dir
200
+ face_tri_normals = np.array(mesh.face_normals)[face_tri_mask]
201
+ face_fwd = face_tri_normals.mean(axis=0)
202
+ face_fwd /= np.linalg.norm(face_fwd)
203
+
204
+ # Build orthonormal right/up axes in the face plane
205
+ world_up = np.array([0., 1., 0.])
206
+ face_right = np.cross(face_fwd, world_up)
207
+ face_right /= np.linalg.norm(face_right)
208
+ face_up = np.cross(face_right, face_fwd)
209
+ face_up /= np.linalg.norm(face_up)
210
+ print(f' Face normal: {face_fwd.round(3)}')
211
+
212
+ # Project face vertices onto local (right, up) plane
213
+ fv_centroid = fv.mean(axis=0)
214
+ fv_c = fv - fv_centroid
215
+ lx = fv_c @ face_right
216
+ ly = fv_c @ face_up
217
+ x_span = float(lx.max() - lx.min())
218
+ y_span = float(ly.max() - ly.min())
219
+
220
+ # InsightFace norm_crop places eyes at ~37% from top of the 512px image.
221
+ # In 3D the eyes are ~78% up from neck → 28% above centroid.
222
+ # Shift the vertical origin up by 0.112*y_span so eye level → 37% in image.
223
+ cy_shift = 0.112 * y_span
224
+ pad = 0.10 # tighter crop so face features fill more of the image
225
+
226
+ def vert_to_img(v):
227
+ """Project 3D vertex to reference-face image using the face normal."""
228
+ c = v - fv_centroid # (N, 3)
229
+ lx = c @ face_right
230
+ ly = c @ face_up
231
+ pu = lx / (x_span * (1 + 2*pad)) + 0.5
232
+ pv = -(ly - cy_shift) / (y_span * (1 + 2*pad)) + 0.5
233
+ return np.column_stack([pu * W_f, pv * H_f]) # (N, 2)
234
+
235
+ def vert_to_uv_px(v_idx):
236
+ """Convert vertex UV coords to texture pixel coordinates."""
237
+ uv = uvs[v_idx]
238
+ # trimesh loads GLB UV with (0,0)=bottom-left; flip V for image row
239
+ col = uv[:, 0] * tex_W
240
+ row = (1.0 - uv[:, 1]) * tex_H
241
+ return np.column_stack([col, row]) # (N, 2)
242
+
243
+ # Pre-compute image + UV pixel coords for every vertex
244
+ all_img_px = vert_to_img(verts) # (N, 2)
245
+ all_uv_px = vert_to_uv_px(np.arange(len(verts))) # (N, 2)
246
+
247
+ # Gather per-triangle arrays
248
+ face_tri_uvs_px = all_uv_px[face_tris] # (M, 3, 2)
249
+ face_tri_img_xy = all_img_px[face_tris] # (M, 3, 2)
250
+
251
+ print(f' UV pixel range: u={face_tri_uvs_px[:,:,0].min():.0f}→'
252
+ f'{face_tri_uvs_px[:,:,0].max():.0f} '
253
+ f'v={face_tri_uvs_px[:,:,1].min():.0f}→'
254
+ f'{face_tri_uvs_px[:,:,1].max():.0f}')
255
+ print(f' Image coord range: x={face_tri_img_xy[:,:,0].min():.1f}→'
256
+ f'{face_tri_img_xy[:,:,0].max():.1f} '
257
+ f'y={face_tri_img_xy[:,:,1].min():.1f}→'
258
+ f'{face_tri_img_xy[:,:,1].max():.1f}')
259
+
260
+ # ── Rasterize face triangles into UV texture ──────────────────────────────
261
+ print(f'[face_project] Rasterizing {len(face_tris)} triangles into texture...')
262
+ new_tex = orig_tex.copy()
263
+ painted = _rasterize_triangles(
264
+ face_tri_uvs_px, face_tri_img_xy,
265
+ aligned_rgb, new_tex, blend,
266
+ max_uv_span=300
267
+ )
268
+ print(f' Painted {painted} texels across {len(face_tris)} triangles')
269
+
270
+ # ── Save debug texture if requested ──────────────────────────────────────
271
+ if debug_tex:
272
+ dbg = np.clip(new_tex, 0, 255).astype(np.uint8)
273
+ Image.fromarray(dbg).save(debug_tex)
274
+ print(f' Debug texture: {debug_tex}')
275
+
276
+ # ── Write modified texture back to mesh ───────────────────────────────────
277
+ new_pil = Image.fromarray(np.clip(new_tex, 0, 255).astype(np.uint8))
278
+ new_mat = PBRMaterial(baseColorTexture=new_pil)
279
+ mesh.visual = TextureVisuals(uv=uvs, material=new_mat)
280
+
281
+ os.makedirs(os.path.dirname(os.path.abspath(out_glb)), exist_ok=True)
282
+ if geom_name and isinstance(scene, trimesh.Scene):
283
+ scene.geometry[geom_name] = mesh
284
+ scene.export(out_glb)
285
+ else:
286
+ mesh.export(out_glb)
287
+
288
+ print(f'[face_project] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
289
+ return out_glb
290
+
291
+
292
+ # ── CLI ────────────────────────────────────────────────────────────────────────
293
+
294
+ if __name__ == '__main__':
295
+ ap = argparse.ArgumentParser()
296
+ ap.add_argument('--body', required=True)
297
+ ap.add_argument('--face', required=True)
298
+ ap.add_argument('--out', required=True)
299
+ ap.add_argument('--blend', type=float, default=0.90)
300
+ ap.add_argument('--neck_frac', type=float, default=0.84)
301
+ ap.add_argument('--debug_tex', default=None)
302
+ args = ap.parse_args()
303
+ project_face(args.body, args.face, args.out,
304
+ blend=args.blend, neck_frac=args.neck_frac,
305
+ debug_tex=args.debug_tex)
pipeline/face_swap_render.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ face_swap_render.py — Paint reference face onto TripoSG UV texture using
3
+ MV-Adapter's UV-baking pipeline.
4
+
5
+ Pipeline:
6
+ 1. Load mesh with same params as triposg_app.py render stage
7
+ 2. Create orthographic camera matching render_front.png (azimuth=-90)
8
+ 3. Detect face landmarks in render_front.png + reference photo via InsightFace
9
+ 4. norm_crop reference → canonical 512×512 frontal face
10
+ 5. Estimate 4-DOF similarity (canonical → render) and warpAffine
11
+ → produces face_on_render.png: reference face at correct render-space coords
12
+ 6. uv_render_attr(images=face_on_render) → projects render image into UV space
13
+ No inverse transform, no scale mismatch — the render-space coordinate system
14
+ is shared between the camera projection and the UV lookup.
15
+ 7. Blend projected face into original texture with geometry mask guard.
16
+ 8. Save updated GLB
17
+
18
+ Usage:
19
+ python face_swap_render.py \
20
+ --body /tmp/triposg_textured.glb \
21
+ --face /tmp/triposg_face_ref.png \
22
+ --render /tmp/render_front.png \
23
+ --out /tmp/face_swapped.glb \
24
+ [--blend 0.93] [--uv_size 4096] [--debug_dir /tmp]
25
+ """
26
+
27
+ import os, sys, argparse, warnings
28
+ warnings.filterwarnings('ignore')
29
+
30
+ import numpy as np
31
+ import cv2
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from PIL import Image
35
+ import trimesh
36
+ from trimesh.visual.texture import TextureVisuals
37
+ from trimesh.visual.material import PBRMaterial
38
+ from insightface.utils import face_align as insightface_align
39
+
40
+ sys.path.insert(0, '/root/MV-Adapter')
41
+ from mvadapter.utils.mesh_utils import (
42
+ NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera,
43
+ )
44
+ from mvadapter.utils.mesh_utils.uv import (
45
+ uv_precompute, uv_render_geometry, uv_render_attr,
46
+ )
47
+ from insightface.app import FaceAnalysis
48
+
49
+
50
+ def _detect_largest_face(img_bgr, app):
51
+ faces = app.get(img_bgr)
52
+ if not faces:
53
+ return None
54
+ faces.sort(key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]), reverse=True)
55
+ return faces[0]
56
+
57
+
58
+ def _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac=0.84):
59
+ """
60
+ Build a UV-space mask covering only the front-facing face triangles.
61
+ Excludes back-of-head, hair, and ears (lateral vertices).
62
+ """
63
+ verts = np.array(mesh_t.vertices, dtype=np.float64)
64
+ faces = np.array(mesh_t.faces, dtype=np.int32)
65
+ uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
66
+
67
+ # Head vertices above neck
68
+ y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
69
+ neck_y = float(y_min + (y_max - y_min) * neck_frac)
70
+ head_idx = np.where(verts[:, 1] > neck_y)[0]
71
+ hv = verts[head_idx]
72
+
73
+ # Front half: z >= 40th percentile — generous to include jaw/cheek toward ears
74
+ # No lateral exclusion — it splits UV islands through the eyes/mouth → duplicates
75
+ z_thresh = float(np.percentile(hv[:, 2], 40))
76
+ front = hv[:, 2] >= z_thresh
77
+ if front.sum() < 30:
78
+ front = np.ones(len(hv), bool)
79
+
80
+ face_vert_idx = head_idx[front]
81
+ face_vert_mask = np.zeros(len(verts), bool)
82
+ face_vert_mask[face_vert_idx] = True
83
+
84
+ face_tri_mask = face_vert_mask[faces].all(axis=1)
85
+ face_tris = faces[face_tri_mask]
86
+ print(f' Geometry mask: {face_tri_mask.sum()} front-face triangles selected '
87
+ f'(neck_y={neck_y:.3f}, z_thresh={z_thresh:.3f})')
88
+
89
+ # Rasterize into UV-space mask (trimesh UV: y=0 is bottom-left → flip V)
90
+ geom_mask = np.zeros((tex_H, tex_W), dtype=np.float32)
91
+ pts_list = []
92
+ for tri in face_tris:
93
+ uv = uvs[tri] # (3, 2)
94
+ px = uv[:, 0] * tex_W
95
+ py = (1.0 - uv[:, 1]) * tex_H
96
+ pts_list.append(np.column_stack([px, py]).astype(np.int32))
97
+ if pts_list:
98
+ cv2.fillPoly(geom_mask, pts_list, 1.0)
99
+
100
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
101
+ geom_mask = cv2.dilate(geom_mask, kernel, iterations=2) # close intra-tri gaps
102
+ geom_mask = cv2.erode(geom_mask, kernel, iterations=1) # retreat from island edges
103
+ geom_mask = cv2.GaussianBlur(geom_mask, (31, 31), 8) # soft transition
104
+ return geom_mask
105
+
106
+
107
+ def face_swap_render(body_glb, face_img_path, render_img_path, out_glb,
108
+ blend=0.93, uv_size=4096, neck_frac=0.76, debug_dir=None):
109
+
110
+ device = 'cuda'
111
+
112
+ # ── Step 1: Load mesh ─────────────────────────────────────────────────────
113
+ print(f'[fsr] Loading mesh: {body_glb}')
114
+ ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
115
+ mesh_mv = load_mesh(body_glb, rescale=True, device=device)
116
+
117
+ scene_t = trimesh.load(body_glb)
118
+ if isinstance(scene_t, trimesh.Scene):
119
+ geom_name = list(scene_t.geometry.keys())[0]
120
+ mesh_t = scene_t.geometry[geom_name]
121
+ else:
122
+ mesh_t = scene_t; geom_name = None
123
+
124
+ orig_tex = np.array(mesh_t.visual.material.baseColorTexture, dtype=np.float32) / 255.0
125
+ uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
126
+ tex_H, tex_W = orig_tex.shape[:2]
127
+ print(f' UV size: {tex_W}×{tex_H}')
128
+
129
+ # ── Step 1b: Geometry mask (front-face UV islands only) ───────────────────
130
+ print('[fsr] Building geometry front-face UV mask ...')
131
+ geom_uv_mask = _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac)
132
+
133
+ # ── Step 2: Orthographic camera matching render_front.png ─────────────────
134
+ render_img = cv2.imread(render_img_path)
135
+ H_r, W_r = render_img.shape[:2]
136
+ print(f' Render size: {W_r}×{H_r}')
137
+
138
+ camera = get_orthogonal_camera(
139
+ elevation_deg=[0], distance=[1.8],
140
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
141
+ azimuth_deg=[-90], device=device,
142
+ )
143
+
144
+ print(f'[fsr] Precomputing UV geometry ({uv_size}×{uv_size}) ...')
145
+ uv_pre = uv_precompute(ctx, mesh_mv, height=uv_size, width=uv_size)
146
+ uv_geom = uv_render_geometry(
147
+ ctx, mesh_mv, camera,
148
+ view_height=H_r, view_width=W_r,
149
+ uv_precompute_output=uv_pre,
150
+ compute_depth_grad=False,
151
+ )
152
+
153
+ # ── Step 3: Face landmark detection ───────────────────────────────────────
154
+ print('[fsr] Detecting face landmarks ...')
155
+ app = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
156
+ app.prepare(ctx_id=0, det_size=(640, 640))
157
+
158
+ ref_bgr = cv2.imread(face_img_path)
159
+ render_face = _detect_largest_face(render_img, app)
160
+ if render_face is None:
161
+ raise RuntimeError(f'No face detected in render: {render_img_path}')
162
+ ref_face = _detect_largest_face(ref_bgr, app)
163
+ if ref_face is None:
164
+ raise RuntimeError(f'No face detected in reference: {face_img_path}')
165
+
166
+ render_kps = render_face.kps # (5, 2)
167
+ ref_kps = ref_face.kps
168
+ print(f' render kps: x={render_kps[:,0].min():.0f}-{render_kps[:,0].max():.0f}'
169
+ f' y={render_kps[:,1].min():.0f}-{render_kps[:,1].max():.0f}')
170
+
171
+ # ── Step 4: norm_crop → canonical 512×512 frontal face ───────────────────
172
+ CANONICAL_SIZE = 512
173
+ aligned_bgr = insightface_align.norm_crop(ref_bgr, ref_kps, image_size=CANONICAL_SIZE)
174
+
175
+ # Fixed ARCFACE 5-point positions scaled to CANONICAL_SIZE
176
+ ARCFACE_112 = np.array([
177
+ [38.2946, 51.6963],
178
+ [73.5318, 51.5014],
179
+ [56.0252, 71.7366],
180
+ [41.5493, 92.3655],
181
+ [70.7299, 92.2041],
182
+ ], dtype=np.float32)
183
+ canonical_kps = ARCFACE_112 * (CANONICAL_SIZE / 112.0)
184
+
185
+ # ── Step 5: Forward warp: canonical → render space ────────────────────────
186
+ # 4-DOF similarity (scale + rotation + translation) with all 5 kps.
187
+ # FORWARD direction: canonical_kps → render_kps so that warpAffine places
188
+ # the face at exactly the render-space coordinates, downsampling cleanly.
189
+ fwd_M, inliers = cv2.estimateAffinePartial2D(
190
+ canonical_kps.astype(np.float32),
191
+ render_kps.astype(np.float32),
192
+ method=cv2.LMEDS,
193
+ )
194
+ print(f' Forward warp M:\n{fwd_M}')
195
+
196
+ face_on_render_bgr = cv2.warpAffine(
197
+ aligned_bgr, fwd_M, (W_r, H_r),
198
+ flags=cv2.INTER_LANCZOS4,
199
+ borderMode=cv2.BORDER_CONSTANT, borderValue=0,
200
+ )
201
+ face_on_render_rgb = cv2.cvtColor(face_on_render_bgr,
202
+ cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
203
+
204
+ # ── Step 6: Render-space face hull mask ───────────────────────────────────
205
+ # Only paint UV texels that correspond to pixels inside the face region.
206
+ hull_pts = cv2.convexHull(render_kps.astype(np.float32)).squeeze(1)
207
+ hull_cx, hull_cy = hull_pts.mean(axis=0)
208
+ hull_expanded = (hull_pts - [hull_cx, hull_cy]) * 4.0 + [hull_cx, hull_cy]
209
+ face_mask_render = np.zeros((H_r, W_r), dtype=np.float32)
210
+ cv2.fillPoly(face_mask_render, [hull_expanded.astype(np.int32)], 1.0)
211
+ # Restrict to where the warped face actually has content
212
+ face_content = (face_on_render_bgr.mean(axis=2) > 3.0 / 255.0).astype(np.float32)
213
+ face_mask_render = face_mask_render * face_content
214
+ face_mask_render = cv2.GaussianBlur(face_mask_render, (51, 51), 15)
215
+
216
+ # ── Step 7: Project face-on-render into UV space ──────────────────────────
217
+ # uv_render_attr uses uv_pos_ndc as a lookup: for each UV texel, sample the
218
+ # render-space image at that texel's render NDC position.
219
+ # Since face_on_render is already in render-space coords, this is exact.
220
+ print('[fsr] Projecting face into UV space via uv_render_attr ...')
221
+ face_t = torch.tensor(face_on_render_rgb, device=device).unsqueeze(0) # (1,H,W,3)
222
+ mask_t = torch.tensor(face_mask_render[None], device=device)
223
+
224
+ uv_attr_out = uv_render_attr(
225
+ images=face_t,
226
+ masks=mask_t,
227
+ uv_render_geometry_output=uv_geom,
228
+ )
229
+ uv_face_img = uv_attr_out.uv_attr_proj[0].cpu().numpy() # (uv, uv, 3)
230
+ uv_face_mask = uv_attr_out.uv_mask_proj[0].cpu().numpy() # (uv, uv)
231
+
232
+ # Rescale to tex resolution if needed
233
+ if uv_size != tex_H or uv_size != tex_W:
234
+ uv_face_img_rs = cv2.resize(uv_face_img, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
235
+ uv_face_mask_rs = cv2.resize(uv_face_mask, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
236
+ else:
237
+ uv_face_img_rs = uv_face_img
238
+ uv_face_mask_rs = uv_face_mask
239
+
240
+ # ── Step 7b: Apply geometry mask — kill back-of-head / ear UV islands ────
241
+ uv_face_mask_rs = uv_face_mask_rs * geom_uv_mask
242
+
243
+ # Final blend alpha — use full blend=1.0 inside the face region so no
244
+ # original texture leaks through and creates duplicate features
245
+ alpha = np.clip(uv_face_mask_rs, 0, 1)[..., None]
246
+ painted_px = int((alpha[..., 0] > 0.01).sum())
247
+ print(f' Painted texels: {painted_px}')
248
+
249
+ if debug_dir:
250
+ cv2.imwrite(os.path.join(debug_dir, 'fsr_aligned_ref.png'), aligned_bgr)
251
+ cv2.imwrite(os.path.join(debug_dir, 'fsr_face_on_render.png'), face_on_render_bgr)
252
+ cv2.imwrite(os.path.join(debug_dir, 'fsr_face_mask_render.png'),
253
+ (face_mask_render * 255).astype(np.uint8))
254
+ cv2.imwrite(os.path.join(debug_dir, 'fsr_geom_mask.png'),
255
+ (geom_uv_mask * 255).astype(np.uint8))
256
+ cv2.imwrite(os.path.join(debug_dir, 'fsr_uv_mask.png'),
257
+ (uv_face_mask_rs * 255).astype(np.uint8))
258
+ Image.fromarray((uv_face_img_rs * 255).clip(0, 255).astype(np.uint8)).save(
259
+ os.path.join(debug_dir, 'fsr_uv_face.png'))
260
+ print(f' Debug files saved to {debug_dir}')
261
+
262
+ # ── Step 8: Blend into original texture ───────────────────────────────────
263
+ print(f'[fsr] Blending (blend={blend}) ...')
264
+ new_tex = uv_face_img_rs * alpha + orig_tex * (1.0 - alpha)
265
+
266
+ # ── Step 9: Save GLB ──────────────────────────────────────────────────────
267
+ new_pil = Image.fromarray((new_tex * 255).clip(0, 255).astype(np.uint8))
268
+ mesh_t.visual = TextureVisuals(uv=uvs, material=PBRMaterial(baseColorTexture=new_pil))
269
+
270
+ if geom_name and isinstance(scene_t, trimesh.Scene):
271
+ scene_t.geometry[geom_name] = mesh_t
272
+ scene_t.export(out_glb)
273
+ else:
274
+ mesh_t.export(out_glb)
275
+
276
+ print(f'[fsr] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
277
+ return out_glb
278
+
279
+
280
+ if __name__ == '__main__':
281
+ ap = argparse.ArgumentParser()
282
+ ap.add_argument('--body', required=True)
283
+ ap.add_argument('--face', required=True)
284
+ ap.add_argument('--render', required=True, help='Front render (e.g. render_front.png)')
285
+ ap.add_argument('--out', required=True)
286
+ ap.add_argument('--blend', type=float, default=0.93)
287
+ ap.add_argument('--uv_size', type=int, default=4096)
288
+ ap.add_argument('--neck_frac', type=float, default=0.76)
289
+ ap.add_argument('--debug_dir', default=None)
290
+ args = ap.parse_args()
291
+ face_swap_render(args.body, args.face, args.render, args.out,
292
+ blend=args.blend, uv_size=args.uv_size,
293
+ neck_frac=args.neck_frac, debug_dir=args.debug_dir)
pipeline/face_transplant.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ face_transplant.py
3
+ ==================
4
+ Replace the face/head region of a rigged UniRig GLB with a higher-detail
5
+ PSHuman mesh, while preserving the skeleton, rig, and skinning weights.
6
+
7
+ Algorithm
8
+ ---------
9
+ 1. Parse rigged GLB → vertices, faces, UVs, JOINTS_0, WEIGHTS_0, bone list
10
+ 2. Identify head vertices → any vert whose dominant bone is in HEAD_BONES
11
+ 3. Load PSHuman mesh (OBJ or GLB, no rig)
12
+ 4. Align PSHuman head to UniRig head bounding box (scale + translate)
13
+ 5. Transfer skinning weights to PSHuman verts via K-nearest-neighbour from
14
+ UniRig head verts (scipy KDTree, weighted average)
15
+ 6. Retract UniRig face verts slightly inward so PSHuman sits on top cleanly
16
+ 7. Rebuild the GLB with two mesh primitives:
17
+ - Primitive 0 : UniRig body (face verts retracted)
18
+ - Primitive 1 : PSHuman face (new, with transferred weights)
19
+ 8. Write output GLB
20
+
21
+ Usage
22
+ -----
23
+ python -m pipeline.face_transplant \\
24
+ --body rigged_body.glb \\
25
+ --face pshuman_output.obj \\
26
+ --output rigged_body_with_pshuman_face.glb
27
+
28
+ Optionally supply --head-bones as comma-separated bone-name substrings
29
+ (default: head,Head,skull). Any bone whose name contains one of these
30
+ substrings is treated as a head bone.
31
+
32
+ Requires: pygltflib numpy scipy trimesh (pip install each)
33
+ """
34
+ from __future__ import annotations
35
+
36
+ import argparse
37
+ import base64
38
+ import struct
39
+ import json
40
+ from pathlib import Path
41
+ from typing import Dict, List, Optional, Tuple
42
+ import numpy as np
43
+ from scipy.spatial import KDTree
44
+ import trimesh
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # GLB low-level helpers (subset of Retarget/io/gltf_io.py re-used here)
49
+ # ---------------------------------------------------------------------------
50
+
51
+ try:
52
+ import pygltflib
53
+ except ImportError:
54
+ raise ImportError("pip install pygltflib")
55
+
56
+
57
+ def _read_accessor_raw(gltf: pygltflib.GLTF2, accessor_idx: int) -> np.ndarray:
58
+ acc = gltf.accessors[accessor_idx]
59
+ bv = gltf.bufferViews[acc.bufferView]
60
+ buf = gltf.buffers[bv.buffer]
61
+
62
+ if buf.uri and buf.uri.startswith("data:"):
63
+ _, b64 = buf.uri.split(",", 1)
64
+ raw = base64.b64decode(b64)
65
+ elif buf.uri:
66
+ base_dir = Path(gltf._path).parent if getattr(gltf, "_path", None) else Path(".")
67
+ raw = (base_dir / buf.uri).read_bytes()
68
+ else:
69
+ raw = bytes(gltf.binary_blob())
70
+
71
+ type_nc = {"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4, "MAT4": 16}
72
+ fmt_map = {5120: "b", 5121: "B", 5122: "h", 5123: "H", 5125: "I", 5126: "f"}
73
+
74
+ n_comp = type_nc[acc.type]
75
+ fmt = fmt_map[acc.componentType]
76
+ item_sz = struct.calcsize(fmt) * n_comp
77
+ stride = bv.byteStride or item_sz
78
+ start = bv.byteOffset + (acc.byteOffset or 0)
79
+
80
+ items = []
81
+ for i in range(acc.count):
82
+ offset = start + i * stride
83
+ vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
84
+ items.append(vals)
85
+
86
+ arr = np.array(items)
87
+ if arr.ndim == 2 and arr.shape[1] == 1:
88
+ arr = arr[:, 0]
89
+ return arr
90
+
91
+
92
+ def _accessor_dtype(gltf: pygltflib.GLTF2, accessor_idx: int):
93
+ fmt_map = {5120: np.int8, 5121: np.uint8, 5122: np.int16, 5123: np.uint16,
94
+ 5125: np.uint32, 5126: np.float32}
95
+ return fmt_map[gltf.accessors[accessor_idx].componentType]
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Mesh extraction
100
+ # ---------------------------------------------------------------------------
101
+
102
+ class GLBMesh:
103
+ """
104
+ All data from the first skin's first mesh primitive in a GLB.
105
+ """
106
+ def __init__(self, path: str):
107
+ self.path = path
108
+ gltf = pygltflib.GLTF2().load(path)
109
+ gltf._path = path
110
+ self.gltf = gltf
111
+
112
+ if not gltf.skins:
113
+ raise ValueError("No skin found in GLB — is this a rigged file?")
114
+ self.skin = gltf.skins[0]
115
+ self.joint_names: List[str] = [gltf.nodes[j].name or f"joint_{k}"
116
+ for k, j in enumerate(self.skin.joints)]
117
+
118
+ # Find a mesh node that uses this skin
119
+ self.mesh_prim, self.mesh_node_idx = self._find_skinned_prim()
120
+
121
+ attrs = self.mesh_prim.attributes
122
+ self.verts = _read_accessor_raw(gltf, attrs.POSITION).astype(np.float32)
123
+ self.normals = (_read_accessor_raw(gltf, attrs.NORMAL).astype(np.float32)
124
+ if attrs.NORMAL is not None else None)
125
+ self.uvs = (_read_accessor_raw(gltf, attrs.TEXCOORD_0).astype(np.float32)
126
+ if attrs.TEXCOORD_0 is not None else None)
127
+ self.faces = _read_accessor_raw(gltf, self.mesh_prim.indices).astype(np.int32).reshape(-1, 3)
128
+
129
+ # Skinning — may be JOINTS_0 / WEIGHTS_0 (uint8/uint16 + float)
130
+ self.joints4 = None
131
+ self.weights4 = None
132
+ if attrs.JOINTS_0 is not None:
133
+ self.joints4 = _read_accessor_raw(gltf, attrs.JOINTS_0).astype(np.int32)
134
+ self.weights4 = _read_accessor_raw(gltf, attrs.WEIGHTS_0).astype(np.float32)
135
+
136
+ # Material index (carry over to output)
137
+ self.material_idx = self.mesh_prim.material
138
+
139
+ def _find_skinned_prim(self):
140
+ skin_node_indices = set(self.skin.joints)
141
+ # find mesh node that references this skin
142
+ for ni, node in enumerate(self.gltf.nodes):
143
+ if node.skin == 0 and node.mesh is not None:
144
+ mesh = self.gltf.meshes[node.mesh]
145
+ return mesh.primitives[0], ni
146
+ # fallback: first mesh node
147
+ for ni, node in enumerate(self.gltf.nodes):
148
+ if node.mesh is not None:
149
+ mesh = self.gltf.meshes[node.mesh]
150
+ return mesh.primitives[0], ni
151
+ raise ValueError("No mesh primitive found")
152
+
153
+ def head_bone_indices(self, substrings=("head", "Head", "skull", "Skull", "neck", "Neck")) -> List[int]:
154
+ """Return joint indices (into self.joint_names) matching any substring.
155
+ Falls back to positional heuristic (highest-Y dominant bone) when no
156
+ bone names match (e.g. generic bone_0/bone_1 naming from UniRig)."""
157
+ result = []
158
+ for i, name in enumerate(self.joint_names):
159
+ if any(s in name for s in substrings):
160
+ result.append(i)
161
+ if not result and self.joints4 is not None and self.weights4 is not None:
162
+ # Positional fallback: pick bone whose dominant vertices have highest avg Y.
163
+ n_bones = len(self.joint_names)
164
+ bone_y_sum = np.zeros(n_bones)
165
+ bone_y_cnt = np.zeros(n_bones, dtype=np.int32)
166
+ for vi in range(len(self.verts)):
167
+ dom = int(self.joints4[vi, np.argmax(self.weights4[vi])])
168
+ bone_y_sum[dom] += self.verts[vi, 1]
169
+ bone_y_cnt[dom] += 1
170
+ with np.errstate(invalid='ignore'):
171
+ bone_y_avg = np.where(bone_y_cnt > 0, bone_y_sum / bone_y_cnt, -np.inf)
172
+ top = int(np.argmax(bone_y_avg))
173
+ print(f"[face_transplant] No named head bones; positional fallback: "
174
+ f"bone {top} ({self.joint_names[top]}, avg_y={bone_y_avg[top]:.3f})")
175
+ result = [top]
176
+ return result
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # Face-region identification
181
+ # ---------------------------------------------------------------------------
182
+
183
+ def find_face_verts(glb_mesh: GLBMesh, head_joint_indices: List[int],
184
+ weight_threshold: float = 0.35) -> np.ndarray:
185
+ """
186
+ Return boolean mask of face/head vertices:
187
+ any vert whose total weight on head joints exceeds weight_threshold.
188
+ """
189
+ if glb_mesh.joints4 is None:
190
+ raise ValueError("Mesh has no skinning weights — cannot identify face region")
191
+
192
+ n = len(glb_mesh.verts)
193
+ mask = np.zeros(n, dtype=bool)
194
+ head_set = set(head_joint_indices)
195
+
196
+ for vi in range(n):
197
+ total_head_w = 0.0
198
+ for c in range(4):
199
+ j = glb_mesh.joints4[vi, c]
200
+ w = glb_mesh.weights4[vi, c]
201
+ if j in head_set:
202
+ total_head_w += w
203
+ if total_head_w >= weight_threshold:
204
+ mask[vi] = True
205
+
206
+ return mask
207
+
208
+
209
+ # ---------------------------------------------------------------------------
210
+ # PSHuman mesh loading + alignment
211
+ # ---------------------------------------------------------------------------
212
+
213
+ def _crop_to_head(mesh: trimesh.Trimesh, head_fraction: float = 0.22) -> trimesh.Trimesh:
214
+ """
215
+ Keep only the top head_fraction of the PSHuman body mesh by Y coordinate.
216
+ PSHuman produces a full-body mesh; we only want the head/face portion.
217
+ """
218
+ y = mesh.vertices[:, 1]
219
+ threshold = y.max() - (y.max() - y.min()) * head_fraction
220
+ vert_keep = y >= threshold
221
+ face_keep = vert_keep[mesh.faces].all(axis=1)
222
+ kept_faces = mesh.faces[face_keep]
223
+ used = np.unique(kept_faces)
224
+ remap = np.full(len(mesh.vertices), -1, dtype=np.int32)
225
+ remap[used] = np.arange(len(used))
226
+ new_verts = mesh.vertices[used].astype(np.float32)
227
+ new_faces = remap[kept_faces]
228
+ result = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
229
+ if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
230
+ result.visual = trimesh.visual.TextureVisuals(uv=np.array(mesh.visual.uv)[used])
231
+ print(f"[face_transplant] PSHuman head crop ({head_fraction*100:.0f}%): "
232
+ f"{len(mesh.vertices)} → {len(new_verts)} verts (Y ≥ {threshold:.3f})")
233
+ return result
234
+
235
+
236
+ def load_and_align_pshuman(pshuman_path: str, target_verts: np.ndarray) -> trimesh.Trimesh:
237
+ """
238
+ Load PSHuman mesh (OBJ/GLB/PLY), crop to head region, then scale+translate
239
+ to fit the bounding box of target_verts (UniRig head verts).
240
+ """
241
+ mesh: trimesh.Trimesh = trimesh.load(pshuman_path, force="mesh", process=False)
242
+ print(f"[face_transplant] PSHuman mesh: {len(mesh.vertices)} verts, {len(mesh.faces)} faces")
243
+
244
+ # PSHuman is full-body — crop to just the head before aligning
245
+ mesh = _crop_to_head(mesh)
246
+
247
+ # Target bbox from UniRig head region
248
+ tgt_min = target_verts.min(axis=0)
249
+ tgt_max = target_verts.max(axis=0)
250
+ tgt_ctr = (tgt_min + tgt_max) * 0.5
251
+ tgt_ext = (tgt_max - tgt_min)
252
+
253
+ src_min = mesh.vertices.min(axis=0).astype(np.float32)
254
+ src_max = mesh.vertices.max(axis=0).astype(np.float32)
255
+ src_ctr = (src_min + src_max) * 0.5
256
+ src_ext = (src_max - src_min)
257
+
258
+ # Uniform scale: match the largest axis of the target
259
+ dominant = np.argmax(tgt_ext)
260
+ scale = float(tgt_ext[dominant]) / float(src_ext[dominant] + 1e-9)
261
+
262
+ verts = mesh.vertices.astype(np.float32).copy()
263
+ verts = (verts - src_ctr) * scale + tgt_ctr
264
+
265
+ mesh.vertices = verts
266
+ print(f"[face_transplant] PSHuman aligned: scale={scale:.4f}, center={tgt_ctr}")
267
+ return mesh
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Weight transfer via KDTree
272
+ # ---------------------------------------------------------------------------
273
+
274
+ def transfer_weights(
275
+ donor_verts: np.ndarray, # (M, 3) UniRig face verts
276
+ donor_joints: np.ndarray, # (M, 4) uint16
277
+ donor_weights: np.ndarray, # (M, 4) float32
278
+ recipient_verts: np.ndarray, # (N, 3) PSHuman face verts
279
+ k: int = 5,
280
+ ) -> Tuple[np.ndarray, np.ndarray]:
281
+ """
282
+ K-nearest-neighbour weight transfer.
283
+ Returns (joints4, weights4) for recipient_verts.
284
+ """
285
+ tree = KDTree(donor_verts)
286
+ dists, idxs = tree.query(recipient_verts, k=k) # (N, k)
287
+
288
+ N = len(recipient_verts)
289
+ n_joints_total = int(donor_joints.max()) + 1
290
+
291
+ # Build dense per-recipient weight vector
292
+ dense = np.zeros((N, n_joints_total), dtype=np.float64)
293
+ for ki in range(k):
294
+ w_dist = 1.0 / (dists[:, ki] + 1e-8) # inverse-distance
295
+ for vi in range(N):
296
+ di = idxs[vi, ki]
297
+ for c in range(4):
298
+ j = donor_joints[di, c]
299
+ w = donor_weights[di, c]
300
+ dense[vi, j] += w * w_dist[vi]
301
+
302
+ # Re-normalise rows
303
+ row_sum = dense.sum(axis=1, keepdims=True) + 1e-12
304
+ dense /= row_sum
305
+
306
+ # Pack back into 4-bone format (top-4 by weight)
307
+ out_joints = np.zeros((N, 4), dtype=np.uint16)
308
+ out_weights = np.zeros((N, 4), dtype=np.float32)
309
+ for vi in range(N):
310
+ top4 = np.argsort(dense[vi])[-4:][::-1]
311
+ total = dense[vi, top4].sum() + 1e-12
312
+ for c, j in enumerate(top4):
313
+ out_joints[vi, c] = j
314
+ out_weights[vi, c] = dense[vi, j] / total
315
+
316
+ return out_joints, out_weights
317
+
318
+
319
+ # ---------------------------------------------------------------------------
320
+ # GLB rebuild
321
+ # ---------------------------------------------------------------------------
322
+
323
+ def _pack_buffer_view(data_bytes: bytes, target: list, byte_offset: int,
324
+ byte_stride: Optional[int] = None) -> Tuple[int, int]:
325
+ """
326
+ Append data_bytes to target buffer, return (buffer_view_index, new_offset).
327
+ """
328
+ bv = pygltflib.BufferView(
329
+ buffer=0,
330
+ byteOffset=byte_offset,
331
+ byteLength=len(data_bytes),
332
+ )
333
+ if byte_stride:
334
+ bv.byteStride = byte_stride
335
+ return bv, byte_offset + len(data_bytes)
336
+
337
+
338
+ def _make_accessor(component_type: int, type_str: str, count: int,
339
+ bv_idx: int, min_vals=None, max_vals=None) -> pygltflib.Accessor:
340
+ acc = pygltflib.Accessor(
341
+ bufferView=bv_idx,
342
+ byteOffset=0,
343
+ componentType=component_type,
344
+ count=count,
345
+ type=type_str,
346
+ )
347
+ if min_vals is not None:
348
+ acc.min = [float(v) for v in min_vals]
349
+ if max_vals is not None:
350
+ acc.max = [float(v) for v in max_vals]
351
+ return acc
352
+
353
+
354
+ FLOAT32 = pygltflib.FLOAT # 5126
355
+ UINT16 = pygltflib.UNSIGNED_SHORT # 5123
356
+ UINT32 = pygltflib.UNSIGNED_INT # 5125
357
+ UBYTE = pygltflib.UNSIGNED_BYTE # 5121
358
+
359
+
360
+ def transplant_face(
361
+ body_glb_path: str,
362
+ pshuman_mesh_path: str,
363
+ output_path: str,
364
+ head_bone_substrings: Tuple[str, ...] = ("head", "Head", "skull", "Skull"),
365
+ weight_threshold: float = 0.35,
366
+ retract_amount: float = 0.004, # metres — how far to push face verts inward
367
+ knn: int = 5,
368
+ ):
369
+ """
370
+ Main entry point.
371
+
372
+ Parameters
373
+ ----------
374
+ body_glb_path : rigged UniRig GLB
375
+ pshuman_mesh_path : PSHuman output mesh (OBJ / GLB / PLY)
376
+ output_path : result GLB path
377
+ head_bone_substrings : bone name fragments that identify head joints
378
+ weight_threshold : head-weight sum above which a vertex is "face"
379
+ retract_amount : metres to push face verts inward to avoid z-fight
380
+ knn : neighbours for weight transfer
381
+ """
382
+ print(f"[face_transplant] Loading rigged GLB: {body_glb_path}")
383
+ glb = GLBMesh(body_glb_path)
384
+ print(f" Verts: {len(glb.verts)} Faces: {len(glb.faces)}")
385
+ print(f" Bones ({len(glb.joint_names)}): {', '.join(glb.joint_names[:8])} ...")
386
+
387
+ # 1. Identify head joints
388
+ head_ji = glb.head_bone_indices(substrings=head_bone_substrings)
389
+ if not head_ji:
390
+ raise RuntimeError(
391
+ f"No head bones found with substrings {head_bone_substrings}.\n"
392
+ f"Available bones: {glb.joint_names}"
393
+ )
394
+ print(f" Head joints ({len(head_ji)}): {[glb.joint_names[i] for i in head_ji]}")
395
+
396
+ # 2. Find face/head vertices
397
+ face_mask = find_face_verts(glb, head_ji, weight_threshold=weight_threshold)
398
+ print(f" Face verts: {face_mask.sum()} / {len(glb.verts)}")
399
+
400
+ min_face_verts = max(3, min(10, len(glb.verts) // 4))
401
+ if face_mask.sum() < min_face_verts:
402
+ raise RuntimeError(
403
+ f"Only {face_mask.sum()} face vertices found (need >= {min_face_verts}) — "
404
+ f"try lowering --weight-threshold (current: {weight_threshold})"
405
+ )
406
+
407
+ # 3. Load + align PSHuman mesh
408
+ face_verts_ur = glb.verts[face_mask]
409
+ ps_mesh = load_and_align_pshuman(pshuman_mesh_path, face_verts_ur)
410
+ ps_verts = np.array(ps_mesh.vertices, dtype=np.float32)
411
+ ps_faces = np.array(ps_mesh.faces, dtype=np.int32)
412
+ ps_uvs = None
413
+ if hasattr(ps_mesh.visual, "uv") and ps_mesh.visual.uv is not None:
414
+ ps_uvs = np.array(ps_mesh.visual.uv, dtype=np.float32)
415
+
416
+ # 4. Transfer weights: donor = UniRig face verts, recipient = PSHuman verts
417
+ print("[face_transplant] Transferring skinning weights via KNN ...")
418
+ ps_joints, ps_weights = transfer_weights(
419
+ donor_verts = glb.verts[face_mask].astype(np.float64),
420
+ donor_joints = glb.joints4[face_mask],
421
+ donor_weights = glb.weights4[face_mask],
422
+ recipient_verts = ps_verts.astype(np.float64),
423
+ k = knn,
424
+ )
425
+ print(f" Done. Head joint coverage in PSHuman: "
426
+ f"{(np.isin(ps_joints[:, 0], head_ji)).mean() * 100:.1f}% primary bone is head")
427
+
428
+ # 5. Retract UniRig face verts inward (push along −normal)
429
+ body_verts = glb.verts.copy()
430
+ if glb.normals is not None:
431
+ body_verts[face_mask] -= glb.normals[face_mask] * retract_amount
432
+ else:
433
+ # push toward centroid
434
+ centroid = body_verts[face_mask].mean(axis=0)
435
+ dirs = centroid - body_verts[face_mask]
436
+ norms = np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-9
437
+ body_verts[face_mask] += (dirs / norms) * retract_amount
438
+
439
+ # 6. Rebuild GLB
440
+ print("[face_transplant] Rebuilding GLB ...")
441
+ _write_transplanted_glb(
442
+ source_gltf = glb,
443
+ body_verts = body_verts,
444
+ ps_verts = ps_verts,
445
+ ps_faces = ps_faces,
446
+ ps_uvs = ps_uvs,
447
+ ps_joints = ps_joints,
448
+ ps_weights = ps_weights,
449
+ output_path = output_path,
450
+ )
451
+ print(f"[face_transplant] Saved -> {output_path}")
452
+
453
+
454
+ # ---------------------------------------------------------------------------
455
+ # GLB writer
456
+ # ---------------------------------------------------------------------------
457
+
458
+ def _write_transplanted_glb(
459
+ source_gltf: GLBMesh,
460
+ body_verts: np.ndarray,
461
+ ps_verts: np.ndarray,
462
+ ps_faces: np.ndarray,
463
+ ps_uvs: Optional[np.ndarray],
464
+ ps_joints: np.ndarray,
465
+ ps_weights: np.ndarray,
466
+ output_path: str,
467
+ ):
468
+ """
469
+ Copy the source GLB structure, replace mesh primitive 0 vertex data,
470
+ and append a new primitive for the PSHuman face.
471
+ """
472
+ import copy
473
+ gltf = pygltflib.GLTF2().load(source_gltf.path)
474
+ gltf._path = source_gltf.path
475
+
476
+ # ------------------------------------------------------------------
477
+ # Preserve embedded images as data URIs BEFORE we wipe buffer views.
478
+ # The binary blob rebuild below only contains geometry; any image data
479
+ # referenced via bufferView would otherwise be lost.
480
+ # ------------------------------------------------------------------
481
+ try:
482
+ blob = bytes(gltf.binary_blob())
483
+ except Exception:
484
+ blob = b""
485
+ for img in gltf.images:
486
+ if img.bufferView is not None and img.uri is None and blob:
487
+ bv = gltf.bufferViews[img.bufferView]
488
+ img_bytes = blob[bv.byteOffset: bv.byteOffset + bv.byteLength]
489
+ mime = img.mimeType or "image/png"
490
+ img.uri = "data:{};base64,{}".format(mime, base64.b64encode(img_bytes).decode())
491
+ img.bufferView = None
492
+
493
+ # ------------------------------------------------------------------
494
+ # We will rebuild the entire binary buffer from scratch.
495
+ # Collect all data chunks; track buffer views + accessors.
496
+ # ------------------------------------------------------------------
497
+ chunks: List[bytes] = []
498
+ bviews: List[pygltflib.BufferView] = []
499
+ accors: List[pygltflib.Accessor] = []
500
+ byte_offset = 0
501
+
502
+ def add_chunk(data: bytes, component_type: int, type_str: str, count: int,
503
+ min_v=None, max_v=None, stride: int = None) -> int:
504
+ """Append data, create buffer view + accessor, return accessor index."""
505
+ nonlocal byte_offset
506
+ bv = pygltflib.BufferView(buffer=0, byteOffset=byte_offset, byteLength=len(data))
507
+ if stride:
508
+ bv.byteStride = stride
509
+ bviews.append(bv)
510
+ bv_idx = len(bviews) - 1
511
+
512
+ acc = pygltflib.Accessor(
513
+ bufferView=bv_idx,
514
+ byteOffset=0,
515
+ componentType=component_type,
516
+ count=count,
517
+ type=type_str,
518
+ )
519
+ if min_v is not None:
520
+ acc.min = [float(x) for x in np.atleast_1d(min_v)]
521
+ if max_v is not None:
522
+ acc.max = [float(x) for x in np.atleast_1d(max_v)]
523
+ accors.append(acc)
524
+ acc_idx = len(accors) - 1
525
+
526
+ chunks.append(data)
527
+ byte_offset += len(data)
528
+ return acc_idx
529
+
530
+ # ------------------------------------------------------------------
531
+ # Primitive 0 — UniRig body (retracted face verts)
532
+ # ------------------------------------------------------------------
533
+ body_v = body_verts.astype(np.float32)
534
+ body_i = source_gltf.faces.astype(np.uint32).flatten()
535
+ body_n = (source_gltf.normals.astype(np.float32)
536
+ if source_gltf.normals is not None else None)
537
+ body_uv = (source_gltf.uvs.astype(np.float32)
538
+ if source_gltf.uvs is not None else None)
539
+ body_j = source_gltf.joints4.astype(np.uint16)
540
+ body_w = source_gltf.weights4.astype(np.float32)
541
+
542
+ # indices
543
+ bi_idx = add_chunk(body_i.tobytes(), UINT32, "SCALAR", len(body_i),
544
+ min_v=[int(body_i.min())], max_v=[int(body_i.max())])
545
+ # positions
546
+ bv_idx = add_chunk(body_v.tobytes(), FLOAT32, "VEC3", len(body_v),
547
+ min_v=body_v.min(axis=0), max_v=body_v.max(axis=0))
548
+ body_attrs = pygltflib.Attributes(POSITION=bv_idx)
549
+ if body_n is not None:
550
+ body_attrs.NORMAL = add_chunk(body_n.tobytes(), FLOAT32, "VEC3", len(body_n))
551
+ if body_uv is not None:
552
+ body_attrs.TEXCOORD_0 = add_chunk(body_uv.tobytes(), FLOAT32, "VEC2", len(body_uv))
553
+ if body_j is not None:
554
+ body_attrs.JOINTS_0 = add_chunk(body_j.tobytes(), UINT16, "VEC4", len(body_j))
555
+ body_attrs.WEIGHTS_0 = add_chunk(body_w.tobytes(), FLOAT32, "VEC4", len(body_w))
556
+
557
+ prim0 = pygltflib.Primitive(
558
+ attributes=body_attrs,
559
+ indices=bi_idx,
560
+ material=source_gltf.material_idx,
561
+ mode=4, # TRIANGLES
562
+ )
563
+
564
+ # ------------------------------------------------------------------
565
+ # Primitive 1 — PSHuman face
566
+ # ------------------------------------------------------------------
567
+ ps_v = ps_verts.astype(np.float32)
568
+ ps_i = ps_faces.astype(np.uint32).flatten()
569
+ ps_j4 = ps_joints.astype(np.uint16)
570
+ ps_w4 = ps_weights.astype(np.float32)
571
+
572
+ # PSHuman material — reuse body material for now (same texture look)
573
+ # If PSHuman has its own texture, you'd add a new material here.
574
+ face_mat_idx = source_gltf.material_idx
575
+
576
+ fi_idx = add_chunk(ps_i.tobytes(), UINT32, "SCALAR", len(ps_i),
577
+ min_v=[int(ps_i.min())], max_v=[int(ps_i.max())])
578
+ fv_idx = add_chunk(ps_v.tobytes(), FLOAT32, "VEC3", len(ps_v),
579
+ min_v=ps_v.min(axis=0), max_v=ps_v.max(axis=0))
580
+ face_attrs = pygltflib.Attributes(POSITION=fv_idx)
581
+ if ps_uvs is not None:
582
+ face_attrs.TEXCOORD_0 = add_chunk(ps_uvs.tobytes(), FLOAT32, "VEC2", len(ps_uvs))
583
+ face_attrs.JOINTS_0 = add_chunk(ps_j4.tobytes(), UINT16, "VEC4", len(ps_j4))
584
+ face_attrs.WEIGHTS_0 = add_chunk(ps_w4.tobytes(), FLOAT32, "VEC4", len(ps_w4))
585
+
586
+ prim1 = pygltflib.Primitive(
587
+ attributes=face_attrs,
588
+ indices=fi_idx,
589
+ material=face_mat_idx,
590
+ mode=4,
591
+ )
592
+
593
+ # ------------------------------------------------------------------
594
+ # Patch gltf structure
595
+ # ------------------------------------------------------------------
596
+ # Find or create the mesh that uses our skin
597
+ mesh_node = gltf.nodes[source_gltf.mesh_node_idx]
598
+ old_mesh_idx = mesh_node.mesh
599
+
600
+ new_mesh = pygltflib.Mesh(
601
+ name="body_with_pshuman_face",
602
+ primitives=[prim0, prim1],
603
+ )
604
+
605
+ # Replace or append
606
+ if old_mesh_idx is not None and old_mesh_idx < len(gltf.meshes):
607
+ gltf.meshes[old_mesh_idx] = new_mesh
608
+ target_mesh_idx = old_mesh_idx
609
+ else:
610
+ gltf.meshes.append(new_mesh)
611
+ target_mesh_idx = len(gltf.meshes) - 1
612
+
613
+ mesh_node.mesh = target_mesh_idx
614
+
615
+ # Replace buffer views and accessors
616
+ gltf.bufferViews = bviews
617
+ gltf.accessors = accors
618
+
619
+ # Rewrite buffer
620
+ combined = b"".join(chunks)
621
+ # Pad to 4-byte alignment
622
+ if len(combined) % 4:
623
+ combined += b"\x00" * (4 - len(combined) % 4)
624
+
625
+ gltf.buffers = [pygltflib.Buffer(byteLength=len(combined))]
626
+ gltf.set_binary_blob(combined)
627
+
628
+ # Drop stale animation (it referenced old accessor indices)
629
+ # The user can re-add animation later if needed.
630
+ gltf.animations = []
631
+
632
+ gltf.save(output_path)
633
+
634
+
635
+ # ---------------------------------------------------------------------------
636
+ # CLI
637
+ # ---------------------------------------------------------------------------
638
+
639
+ def main():
640
+ parser = argparse.ArgumentParser(description="Transplant PSHuman face into UniRig GLB")
641
+ parser.add_argument("--body", required=True, help="Rigged UniRig GLB")
642
+ parser.add_argument("--face", required=True, help="PSHuman mesh (OBJ/GLB/PLY)")
643
+ parser.add_argument("--output", required=True, help="Output GLB path")
644
+ parser.add_argument("--head-bones", default="head,Head,skull,Skull",
645
+ help="Comma-separated bone name substrings for head detection")
646
+ parser.add_argument("--weight-threshold", type=float, default=0.35,
647
+ help="Minimum head-bone weight sum to classify a vert as face")
648
+ parser.add_argument("--retract", type=float, default=0.004,
649
+ help="Metres to retract UniRig face verts inward (default 0.004)")
650
+ parser.add_argument("--knn", type=int, default=5,
651
+ help="K nearest neighbours for weight transfer")
652
+ args = parser.parse_args()
653
+
654
+ subs = tuple(s.strip() for s in args.head_bones.split(","))
655
+ transplant_face(
656
+ body_glb_path = args.body,
657
+ pshuman_mesh_path = args.face,
658
+ output_path = args.output,
659
+ head_bone_substrings = subs,
660
+ weight_threshold = args.weight_threshold,
661
+ retract_amount = args.retract,
662
+ knn = args.knn,
663
+ )
664
+
665
+
666
+ if __name__ == "__main__":
667
+ main()
pipeline/head_replace.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ head_replace.py — Replace TripoSG head with DECA-reconstructed head at mesh level.
3
+
4
+ Requires: trimesh, numpy, scipy, cv2, torch (+ face-alignment via DECA deps)
5
+ Optional: pymeshlab (for mesh clean-up)
6
+
7
+ Usage (standalone):
8
+ python head_replace.py --body /tmp/triposg_textured.glb \
9
+ --face /path/to/face.jpg \
10
+ --out /tmp/head_replaced.glb
11
+
12
+ Returns combined GLB with DECA head geometry + TripoSG body.
13
+ """
14
+
15
+ import os, sys, argparse, warnings
16
+ warnings.filterwarnings('ignore')
17
+
18
+ import numpy as np
19
+ import cv2
20
+ from PIL import Image
21
+
22
+ # ──────────────────────────────────────────────────────────────────
23
+ # Patch DECA before importing it to avoid pytorch3d dependency
24
+ # ──────────────────────────────────────────────────────────────────
25
+ DECA_ROOT = '/root/DECA'
26
+ sys.path.insert(0, DECA_ROOT)
27
+
28
+ # Stub out the rasterizer so DECA doesn't try to import pytorch3d
29
+ import importlib, types
30
+ _fake_renderer = types.ModuleType('decalib.utils.renderer')
31
+ _fake_renderer.set_rasterizer = lambda t='pytorch3d': None
32
+
33
+ class _FakeRender:
34
+ """No-op renderer — we only need the mesh, not rendered images."""
35
+ def __init__(self, *a, **kw): pass
36
+ def to(self, *a, **kw): return self
37
+ def __call__(self, *a, **kw): return {'images': None, 'alpha_images': None,
38
+ 'normal_images': None, 'grid': None,
39
+ 'transformed_normals': None, 'normals': None}
40
+ def render_shape(self, *a, **kw): return None, None, None, None
41
+ def world2uv(self, *a, **kw): return None
42
+ def add_SHlight(self, *a, **kw): return None
43
+
44
+ _fake_renderer.SRenderY = _FakeRender
45
+ sys.modules['decalib.utils.renderer'] = _fake_renderer
46
+
47
+ # Patch deca.py: make _setup_renderer a no-op when renderer not available
48
+ from decalib import deca as _deca_mod
49
+ _orig_setup = _deca_mod.DECA._setup_renderer
50
+
51
+ def _patched_setup(self, model_cfg):
52
+ try:
53
+ _orig_setup(self, model_cfg)
54
+ except Exception as e:
55
+ print(f'[head_replace] Renderer disabled ({e})')
56
+ self.render = _FakeRender()
57
+ # Still load mask / displacement data we need for UV baking
58
+ from skimage.io import imread
59
+ import torch, torch.nn.functional as F
60
+ try:
61
+ mask = imread(model_cfg.face_eye_mask_path).astype(np.float32) / 255.
62
+ mask = torch.from_numpy(mask[:, :, 0])[None, None, :, :].contiguous()
63
+ self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size])
64
+ mask2 = imread(model_cfg.face_mask_path).astype(np.float32) / 255.
65
+ mask2 = torch.from_numpy(mask2[:, :, 0])[None, None, :, :].contiguous()
66
+ self.uv_face_mask = F.interpolate(mask2, [model_cfg.uv_size, model_cfg.uv_size])
67
+ except Exception:
68
+ pass
69
+ try:
70
+ fixed_dis = np.load(model_cfg.fixed_displacement_path)
71
+ self.fixed_uv_dis = torch.tensor(fixed_dis).float()
72
+ except Exception:
73
+ pass
74
+ try:
75
+ mean_tex_np = imread(model_cfg.mean_tex_path).astype(np.float32) / 255.
76
+ mean_tex = torch.from_numpy(mean_tex_np.transpose(2, 0, 1))[None]
77
+ self.mean_texture = F.interpolate(mean_tex, [model_cfg.uv_size, model_cfg.uv_size])
78
+ except Exception:
79
+ pass
80
+ try:
81
+ self.dense_template = np.load(model_cfg.dense_template_path,
82
+ allow_pickle=True, encoding='latin1').item()
83
+ except Exception:
84
+ pass
85
+
86
+ _deca_mod.DECA._setup_renderer = _patched_setup
87
+
88
+
89
+ # ──────────────────────────────────────────────────────────────────
90
+ # FLAME mesh: parse head_template.obj for UV map
91
+ # ──────────────────────────────────────────────────────────────────
92
+ def _load_flame_template(obj_path=os.path.join(DECA_ROOT, 'data', 'head_template.obj')):
93
+ """Return (verts, faces, uv_verts, uv_faces) from head_template.obj."""
94
+ verts, uv_verts = [], []
95
+ faces_v, faces_uv = [], []
96
+ for line in open(obj_path):
97
+ t = line.split()
98
+ if not t:
99
+ continue
100
+ if t[0] == 'v':
101
+ verts.append([float(t[1]), float(t[2]), float(t[3])])
102
+ elif t[0] == 'vt':
103
+ uv_verts.append([float(t[1]), float(t[2])])
104
+ elif t[0] == 'f':
105
+ vi, uvi = [], []
106
+ for tok in t[1:]:
107
+ parts = tok.split('/')
108
+ vi.append(int(parts[0]) - 1)
109
+ uvi.append(int(parts[1]) - 1 if len(parts) > 1 and parts[1] else 0)
110
+ if len(vi) == 3:
111
+ faces_v.append(vi)
112
+ faces_uv.append(uvi)
113
+ return (np.array(verts, dtype=np.float32),
114
+ np.array(faces_v, dtype=np.int32),
115
+ np.array(uv_verts, dtype=np.float32),
116
+ np.array(faces_uv, dtype=np.int32))
117
+
118
+
119
+ # ──────────────────────────────────────────────────────────────────
120
+ # UV texture baking (software rasteriser, no pytorch3d needed)
121
+ # ──────────────────────────────────────────────────────────────────
122
+ def _bake_uv_texture(verts3d, faces_v, uv_verts, faces_uv, cam, face_img_bgr, tex_size=256):
123
+ """
124
+ Project face_img_bgr onto the FLAME UV map using orthographic camera.
125
+ verts3d : (N,3) FLAME vertices in world space
126
+ cam : (3,) = [scale, tx, ty] orthographic camera
127
+ Returns : (tex_size, tex_size, 3) uint8 texture (BGR)
128
+ """
129
+ H, W = face_img_bgr.shape[:2]
130
+ scale, tx, ty = float(cam[0]), float(cam[1]), float(cam[2])
131
+
132
+ # Orthographic project: DECA formula = (vert_2D + [tx,ty]) * scale, then flip y
133
+ proj = np.zeros((len(verts3d), 2), dtype=np.float32)
134
+ proj[:, 0] = (verts3d[:, 0] + tx) * scale
135
+ proj[:, 1] = -((verts3d[:, 1] + ty) * scale) # y-flip matches DECA convention
136
+
137
+ # Map to pixel coords: image spans proj ∈ [-1,1] → pixel [0, WH]
138
+ img_pts = (proj + 1.0) * 0.5 * np.array([W, H], dtype=np.float32) # (N, 2)
139
+
140
+ # UV pixel coords
141
+ uv_px = uv_verts * tex_size # (K, 2)
142
+
143
+ # Output buffers
144
+ tex_acc = np.zeros((tex_size, tex_size, 3), dtype=np.float64)
145
+ tex_cnt = np.zeros((tex_size, tex_size), dtype=np.float64)
146
+ z_buf = np.full((tex_size, tex_size), -1e9, dtype=np.float64)
147
+
148
+ # Vectorised rasteriser in UV space:
149
+ # For each face, scatter samples from img_pts into uv_px coords.
150
+ # Use scipy.interpolate.griddata as a fast splat.
151
+ from scipy.interpolate import griddata
152
+
153
+ # Front-facing mask (z > threshold) — only bake visible faces
154
+ z_face = verts3d[faces_v, 2].mean(axis=1) # (M,) mean z per face
155
+ front_mask = z_face >= -0.02 # keep front and side faces
156
+
157
+ # For each face corner, record (uv_px, img_pts) sample
158
+ corners_uv = uv_px[faces_uv[front_mask]] # (K, 3, 2)
159
+ corners_img = img_pts[faces_v[front_mask]] # (K, 3, 2)
160
+
161
+ # Flatten to (K*3, 2)
162
+ src_uv = corners_uv.reshape(-1, 2) # UV pixel destination
163
+ src_img = corners_img.reshape(-1, 2) # image pixel source
164
+
165
+ # Remove out-of-bounds image samples
166
+ valid = ((src_img[:, 0] >= 0) & (src_img[:, 0] < W) &
167
+ (src_img[:, 1] >= 0) & (src_img[:, 1] < H))
168
+ src_uv = src_uv[valid]
169
+ src_img = src_img[valid]
170
+
171
+ # Sample face image at src_img positions
172
+ ix = np.clip(src_img[:, 0].astype(int), 0, W - 1)
173
+ iy = np.clip(src_img[:, 1].astype(int), 0, H - 1)
174
+ colours = face_img_bgr[iy, ix].astype(np.float32) # (P, 3)
175
+
176
+ # Clip UV destinations to texture bounds
177
+ uv_dest = np.clip(src_uv, 0, tex_size - 1 - 1e-6).astype(np.float32)
178
+
179
+ # Build query grid for griddata interpolation
180
+ grid_u, grid_v = np.meshgrid(np.arange(tex_size), np.arange(tex_size))
181
+ grid_pts = np.column_stack([grid_u.ravel(), grid_v.ravel()])
182
+
183
+ # Interpolate each colour channel
184
+ tex_baked = np.zeros((tex_size * tex_size, 3), dtype=np.float32)
185
+ for ch in range(3):
186
+ ch_vals = griddata(uv_dest, colours[:, ch], grid_pts,
187
+ method='linear', fill_value=np.nan)
188
+ tex_baked[:, ch] = ch_vals
189
+ tex_baked = tex_baked.reshape(tex_size, tex_size, 3)
190
+ face_baked_mask = ~np.isnan(tex_baked[:, :, 0])
191
+
192
+ # Base texture: mean_texture (skin tone fallback for unsampled regions)
193
+ mean_tex_path = os.path.join(DECA_ROOT, 'data', 'mean_texture.jpg')
194
+ if os.path.exists(mean_tex_path):
195
+ mt = cv2.resize(cv2.imread(mean_tex_path), (tex_size, tex_size)).astype(np.float32)
196
+ else:
197
+ mt = np.full((tex_size, tex_size, 3), 180.0, dtype=np.float32)
198
+
199
+ # Blend: baked face over mean texture
200
+ result = mt.copy()
201
+ result[face_baked_mask] = np.clip(tex_baked[face_baked_mask], 0, 255)
202
+ return result.astype(np.uint8)
203
+
204
+
205
+ # ──────────────────────────────────────────────────────────────────
206
+ # DECA inference
207
+ # ──────────────────────────────────────────────────────────────────
208
+ def run_deca(face_img_path, device='cuda'):
209
+ """
210
+ Run DECA on face_img_path.
211
+ Returns (verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_img_bgr)
212
+ """
213
+ import torch
214
+ from decalib.deca import DECA
215
+ from decalib.utils import config as cfg_module
216
+ from decalib.datasets import datasets
217
+
218
+ cfg = cfg_module.get_cfg_defaults()
219
+ cfg.model.use_tex = False
220
+
221
+ print('[DECA] Loading model...')
222
+ deca = DECA(config=cfg, device=device)
223
+ deca.eval()
224
+
225
+ print('[DECA] Preprocessing image...')
226
+ testdata = datasets.TestData(face_img_path)
227
+ img_tensor = testdata[0]['image'].to(device)[None, ...]
228
+
229
+ print('[DECA] Encoding...')
230
+ with torch.no_grad():
231
+ codedict = deca.encode(img_tensor, use_detail=False)
232
+ verts, _, _ = deca.flame(
233
+ shape_params=codedict['shape'],
234
+ expression_params=codedict['exp'],
235
+ pose_params=codedict['pose']
236
+ )
237
+
238
+ verts_np = verts[0].cpu().numpy() # (5023, 3)
239
+ cam_np = codedict['cam'][0].cpu().numpy() # (3,)
240
+ print(f'[DECA] Mesh: {verts_np.shape}, cam={cam_np}')
241
+
242
+ # Load FLAME UV map
243
+ _, faces_v, uv_verts, faces_uv = _load_flame_template()
244
+
245
+ # Get face image for texture baking (use the cropped/aligned 224x224)
246
+ img_np = (img_tensor[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
247
+ img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
248
+
249
+ print('[DECA] Baking UV texture...')
250
+ tex_bgr = _bake_uv_texture(verts_np, faces_v, uv_verts, faces_uv, cam_np, img_bgr, tex_size=256)
251
+
252
+ return verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr
253
+
254
+
255
+ # ──────────────────────────────────────────────────────────────────
256
+ # Mesh helpers
257
+ # ──────────────────────────────────────────────────────────────────
258
+ def _find_neck_height(mesh):
259
+ """
260
+ Find the best neck cut height in a body mesh.
261
+ Strategy: in the top 40% of the mesh, find the local minimum of
262
+ cross-sectional area (the neck is narrower than the head).
263
+ Returns the y-value of the cut plane.
264
+ """
265
+ verts = mesh.vertices
266
+ y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
267
+ y_range = y_max - y_min
268
+
269
+ # Scan [80%, 87%] to find the neck-base narrowing below the face.
270
+ # The range [83%, 91%] was picking the crown taper instead of the neck.
271
+ y_start = y_min + y_range * 0.80
272
+ y_end = y_min + y_range * 0.87
273
+ steps = 20
274
+ ys = np.linspace(y_start, y_end, steps)
275
+ band = y_range * 0.015
276
+
277
+ r10_vals = []
278
+ for y in ys:
279
+ pts = verts[(verts[:, 1] >= y - band) & (verts[:, 1] <= y + band)]
280
+ if len(pts) < 6:
281
+ r10_vals.append(1.0); continue
282
+ xz = pts[:, [0, 2]]
283
+ cx, cz = xz.mean(0)
284
+ radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
285
+ r10_vals.append(float(np.percentile(radii, 10)))
286
+
287
+ from scipy.ndimage import uniform_filter1d
288
+ r10 = uniform_filter1d(np.array(r10_vals), size=3)
289
+ neck_idx = int(np.argmin(r10[2:-2])) + 2
290
+ neck_y = float(ys[neck_idx])
291
+ frac = (neck_y - y_min) / y_range
292
+ print(f'[neck] Cut height: {neck_y:.4f} (y_range {y_min:.3f}–{y_max:.3f}, frac={frac:.2f})')
293
+ return neck_y
294
+
295
+
296
+ def _weld_mesh(mesh):
297
+ """
298
+ Merge duplicate vertices (UV-split mesh → geometric mesh).
299
+ Returns a new trimesh with welded vertices.
300
+ """
301
+ import trimesh
302
+ from scipy.spatial import cKDTree
303
+ verts = mesh.vertices
304
+ tree = cKDTree(verts)
305
+ # Build mapping: each vertex → canonical representative
306
+ N = len(verts)
307
+ mapping = np.arange(N, dtype=np.int64)
308
+ pairs = tree.query_pairs(r=1e-5)
309
+ for a, b in pairs:
310
+ root_a = int(mapping[a])
311
+ root_b = int(mapping[b])
312
+ while mapping[root_a] != root_a:
313
+ root_a = int(mapping[root_a])
314
+ while mapping[root_b] != root_b:
315
+ root_b = int(mapping[root_b])
316
+ if root_a != root_b:
317
+ mapping[root_b] = root_a
318
+ # Flatten chains
319
+ for i in range(N):
320
+ root = int(mapping[i])
321
+ while mapping[root] != root:
322
+ root = int(mapping[root])
323
+ mapping[i] = root
324
+ # Compact the mapping
325
+ unique_ids = np.unique(mapping)
326
+ compact = np.full(N, -1, dtype=np.int64)
327
+ compact[unique_ids] = np.arange(len(unique_ids))
328
+ new_faces = compact[mapping[mesh.faces]]
329
+ new_verts = verts[unique_ids]
330
+ return trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
331
+
332
+
333
+ def _cut_mesh_below(mesh, y_cut):
334
+ """Keep only faces where all vertices are at or below y_cut. Preserves UV/texture."""
335
+ import trimesh
336
+ from trimesh.visual.texture import TextureVisuals
337
+ v_mask = mesh.vertices[:, 1] <= y_cut
338
+ f_keep = np.all(v_mask[mesh.faces], axis=1)
339
+ faces_kept = mesh.faces[f_keep]
340
+ used_verts = np.unique(faces_kept)
341
+ old_to_new = np.full(len(mesh.vertices), -1, dtype=np.int64)
342
+ old_to_new[used_verts] = np.arange(len(used_verts))
343
+ new_faces = old_to_new[faces_kept]
344
+ new_verts = mesh.vertices[used_verts]
345
+ new_mesh = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
346
+ # Preserve UV + texture if present
347
+ if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
348
+ new_mesh.visual = TextureVisuals(
349
+ uv=mesh.visual.uv[used_verts],
350
+ material=mesh.visual.material)
351
+ return new_mesh
352
+
353
+
354
+ def _extract_neck_ring_geometric(mesh, neck_y, n_pts=64, band_frac=0.02):
355
+ """
356
+ Extract a neck ring using topological boundary edges near neck_y.
357
+ Falls back to angle-sorted vertices if topology is non-manifold.
358
+ Works on welded (geometric) meshes.
359
+ """
360
+ verts = mesh.vertices
361
+ y_range = verts[:, 1].max() - verts[:, 1].min()
362
+ band = y_range * band_frac
363
+
364
+ # --- Try topological boundary near neck_y first ---
365
+ edges = np.sort(mesh.edges, axis=1)
366
+ u, c2 = np.unique(edges, axis=0, return_counts=True)
367
+ be = u[c2 == 1] # boundary edges
368
+
369
+ # Keep boundary edges where BOTH endpoints are near neck_y
370
+ v_near = np.abs(verts[:, 1] - neck_y) <= band * 2
371
+ neck_be = be[v_near[be[:, 0]] & v_near[be[:, 1]]]
372
+
373
+ if len(neck_be) >= 8:
374
+ # Build adjacency and walk loop
375
+ adj = {}
376
+ for e in neck_be:
377
+ adj.setdefault(int(e[0]), []).append(int(e[1]))
378
+ adj.setdefault(int(e[1]), []).append(int(e[0]))
379
+ # Find the largest connected loop
380
+ visited = set()
381
+ loops = []
382
+ for start in adj:
383
+ if start in visited: continue
384
+ loop = [start]; visited.add(start); prev = -1; cur = start
385
+ for _ in range(len(neck_be) + 1):
386
+ nbrs = [v for v in adj.get(cur, []) if v != prev]
387
+ if not nbrs: break
388
+ nxt = nbrs[0]
389
+ if nxt == start: break
390
+ if nxt in visited: break
391
+ visited.add(nxt); prev = cur; cur = nxt; loop.append(cur)
392
+ loops.append(loop)
393
+ if loops:
394
+ best = max(loops, key=len)
395
+ if len(best) >= 8:
396
+ ring_pts = verts[best]
397
+ # Snap all ring points to neck_y (smooth the cut plane)
398
+ ring_pts = ring_pts.copy()
399
+ ring_pts[:, 1] = neck_y
400
+ return _resample_loop(ring_pts, n_pts)
401
+
402
+ # --- Fallback: use inner-cluster (neck column) vertices in the band ---
403
+ mask = (verts[:, 1] >= neck_y - band) & (verts[:, 1] <= neck_y + band)
404
+ pts = verts[mask]
405
+ if len(pts) < 8:
406
+ raise ValueError(f'Too few vertices near neck_y={neck_y:.4f}: {len(pts)}')
407
+
408
+ # Keep only inner-ring vertices (below 35th percentile radius from centroid)
409
+ # This excludes the outer face/head surface and keeps only the neck column
410
+ xz = pts[:, [0, 2]]
411
+ cx, cz = xz.mean(0)
412
+ radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
413
+ thresh = np.percentile(radii, 35)
414
+ inner_mask = radii <= thresh
415
+ if inner_mask.sum() >= 8:
416
+ pts = pts[inner_mask]
417
+ # Recompute centroid on inner pts
418
+ cx, cz = pts[:, [0, 2]].mean(0)
419
+
420
+ # Sort by angle in XZ plane
421
+ angles = np.arctan2(pts[:, 2] - cz, pts[:, 0] - cx)
422
+ pts_sorted = pts[np.argsort(angles)]
423
+ pts_sorted = pts_sorted.copy()
424
+ pts_sorted[:, 1] = neck_y # snap to cut plane
425
+ return _resample_loop(pts_sorted, n_pts)
426
+
427
+
428
+ def _extract_boundary_loop(mesh):
429
+ """
430
+ Extract the boundary edge loop (ordered) from a welded mesh.
431
+ Returns (N, 3) ordered vertex positions.
432
+ """
433
+ # Find boundary edges (edges used by exactly one face)
434
+ edges = np.sort(mesh.edges, axis=1)
435
+ unique, counts = np.unique(edges, axis=0, return_counts=True)
436
+ boundary_edges = unique[counts == 1]
437
+
438
+ if len(boundary_edges) == 0:
439
+ raise ValueError('No boundary edges found — mesh may be closed')
440
+
441
+ # Build adjacency for boundary edges
442
+ adj = {}
443
+ for e in boundary_edges:
444
+ adj.setdefault(int(e[0]), []).append(int(e[1]))
445
+ adj.setdefault(int(e[1]), []).append(int(e[0]))
446
+
447
+ # Walk the longest connected loop
448
+ # Find all loops
449
+ visited = set()
450
+ loops = []
451
+ for start_v in adj:
452
+ if start_v in visited:
453
+ continue
454
+ loop = [start_v]
455
+ visited.add(start_v)
456
+ prev = -1
457
+ cur = start_v
458
+ for _ in range(len(boundary_edges) + 1):
459
+ nbrs = [v for v in adj.get(cur, []) if v != prev]
460
+ if not nbrs:
461
+ break
462
+ nxt = nbrs[0]
463
+ if nxt == start_v:
464
+ break
465
+ if nxt in visited:
466
+ break
467
+ visited.add(nxt)
468
+ prev = cur
469
+ cur = nxt
470
+ loop.append(cur)
471
+ loops.append(loop)
472
+
473
+ # Use the longest loop
474
+ best = max(loops, key=len)
475
+ return mesh.vertices[best]
476
+
477
+
478
+ def _resample_loop(loop_pts, N):
479
+ """Resample an ordered set of 3D points to exactly N evenly-spaced points."""
480
+ from scipy.interpolate import interp1d
481
+ # Arc-length parameterisation
482
+ diffs = np.diff(loop_pts, axis=0, prepend=loop_pts[-1:])
483
+ seg_lens = np.linalg.norm(diffs, axis=1)
484
+ t = np.cumsum(seg_lens)
485
+ t = np.insert(t, 0, 0)
486
+ t /= t[-1]
487
+ # Close the loop
488
+ t[-1] = 1.0
489
+ loop_closed = np.vstack([loop_pts, loop_pts[0]])
490
+ interp = interp1d(t, loop_closed, axis=0)
491
+ t_new = np.linspace(0, 1, N, endpoint=False)
492
+ return interp(t_new)
493
+
494
+
495
+ def _bridge_loops(loop_a, loop_b):
496
+ """
497
+ Create a triangle strip bridging two ordered loops of equal length N.
498
+ loop_a, loop_b: (N, 3) vertex positions
499
+ Returns (verts, faces) — just the bridge strip as a trimesh-ready array.
500
+ """
501
+ N = len(loop_a)
502
+ verts = np.vstack([loop_a, loop_b]) # (2N, 3) — a:0..N-1, b:N..2N-1
503
+ faces = []
504
+ for i in range(N):
505
+ j = (i + 1) % N
506
+ ai, aj = i, j
507
+ bi, bj = i + N, j + N
508
+ faces.append([ai, aj, bi])
509
+ faces.append([aj, bj, bi])
510
+ return verts, np.array(faces, dtype=np.int32)
511
+
512
+
513
+ # ──────────────────────────────────────────────────────────────────
514
+ # DECA head → trimesh
515
+ # ──────────────────────────────────────────────────────────────────
516
+ def deca_to_trimesh(verts_np, faces_v, uv_verts, faces_uv, tex_bgr):
517
+ """
518
+ Assemble a trimesh.Trimesh from DECA outputs with UV texture.
519
+ Uses per-vertex UV (averaged over face corners sharing each vertex).
520
+ """
521
+ import trimesh
522
+ from trimesh.visual.texture import TextureVisuals
523
+ from trimesh.visual.material import PBRMaterial
524
+
525
+ # Average face-corner UVs per vertex
526
+ N = len(verts_np)
527
+ uv_sum = np.zeros((N, 2), dtype=np.float64)
528
+ uv_cnt = np.zeros(N, dtype=np.int32)
529
+ for fi in range(len(faces_v)):
530
+ for ci in range(3):
531
+ vi = faces_v[fi, ci]
532
+ uvi = faces_uv[fi, ci]
533
+ uv_sum[vi] += uv_verts[uvi]
534
+ uv_cnt[vi] += 1
535
+ uv_cnt = np.maximum(uv_cnt, 1)
536
+ uv_per_vert = (uv_sum / uv_cnt[:, None]).astype(np.float32)
537
+
538
+ mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_v, process=False)
539
+
540
+ tex_rgb = cv2.cvtColor(tex_bgr, cv2.COLOR_BGR2RGB)
541
+ tex_pil = Image.fromarray(tex_rgb)
542
+
543
+ try:
544
+ mat = PBRMaterial(baseColorTexture=tex_pil)
545
+ mesh.visual = TextureVisuals(uv=uv_per_vert, material=mat)
546
+ print(f'[deca_to_trimesh] UV attached: {uv_per_vert.shape}, tex={tex_rgb.shape}')
547
+ except Exception as e:
548
+ print(f'[deca_to_trimesh] UV attach failed ({e}) — using vertex colours')
549
+ mesh.visual.vertex_colors = np.tile([200, 175, 155, 255], (len(verts_np), 1))
550
+
551
+ return mesh
552
+
553
+
554
+ # ──────────────────────────────────────────────────────────────────
555
+ # Main head-replacement function
556
+ # ──────────────────────────────────────────────────────────────────
557
+ def replace_head(body_glb: str, face_img_path: str, out_glb: str,
558
+ device: str = 'cuda', bridge_n: int = 64):
559
+ """
560
+ Main entry point.
561
+ body_glb : path to TripoSG textured GLB
562
+ face_img_path : path to reference face image
563
+ out_glb : output path for combined GLB
564
+ bridge_n : number of vertices in the stitching ring
565
+ """
566
+ import trimesh
567
+ import torch
568
+
569
+ # ── 1. Load body GLB ──────────────────────────────────────────
570
+ print('[replace_head] Loading body GLB...')
571
+ scene = trimesh.load(body_glb)
572
+ if isinstance(scene, trimesh.Scene):
573
+ body_mesh = trimesh.util.concatenate(
574
+ [g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)]
575
+ )
576
+ else:
577
+ body_mesh = scene
578
+
579
+ print(f' Body: {len(body_mesh.vertices)} verts, {len(body_mesh.faces)} faces')
580
+
581
+ # ── 1b. Weld body mesh (UV-split → geometric) ─────────────────
582
+ print('[replace_head] Welding mesh vertices...')
583
+ body_welded = _weld_mesh(body_mesh)
584
+ print(f' Welded: {len(body_welded.vertices)} verts (was {len(body_mesh.vertices)})')
585
+
586
+ # ── 2. Find neck cut height ───────────────────────────────────
587
+ neck_y = _find_neck_height(body_welded)
588
+
589
+ # ── 3. Cut body at neck ───────────────────────────────────────
590
+ print('[replace_head] Cutting body at neck...')
591
+ # Work on welded mesh for topology; keep original mesh for geometry export
592
+ body_lower_welded = _cut_mesh_below(body_welded, neck_y)
593
+ body_lower = _cut_mesh_below(body_mesh, neck_y) # keeps original UV/texture
594
+ print(f' Body lower: {len(body_lower.vertices)} verts')
595
+
596
+ # Extract neck ring geometrically (robust for non-manifold UV-split meshes)
597
+ body_neck_loop = _extract_neck_ring_geometric(body_welded, neck_y, n_pts=bridge_n)
598
+ print(f' Body neck ring: {len(body_neck_loop)} pts (geometric)')
599
+
600
+ # ── 4. Run DECA ───────────────────────────────────────────────
601
+ print('[replace_head] Running DECA...')
602
+ verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr = run_deca(face_img_path, device=device)
603
+
604
+ # ── 5. Align DECA head to body coordinate system ─────────────
605
+ # TripoSG body is roughly in [-1,1] world space (y-up)
606
+ # DECA/FLAME space: head centered around origin, scale ≈ 1.5-2.5 units for full head
607
+ # We need to:
608
+ # a) Scale the FLAME head to match body scale
609
+ # b) Position the FLAME head so its neck base aligns with body neck ring
610
+
611
+ # Get the bottom of the FLAME head (neck area)
612
+ # FLAME template: bottom vertices are the neck boundary ring
613
+ flame_mesh_tmp = __import__('trimesh').Trimesh(vertices=verts_np, faces=faces_v, process=False)
614
+ try:
615
+ flame_neck_loop = _extract_boundary_loop(flame_mesh_tmp)
616
+ print(f' FLAME neck ring (topology): {len(flame_neck_loop)} verts')
617
+ except Exception as e:
618
+ print(f' FLAME boundary loop failed ({e}), using geometric extraction')
619
+ # Geometric fallback: bottom 5% of head vertices
620
+ flame_neck_y = verts_np[:, 1].min() + (verts_np[:, 1].max() - verts_np[:, 1].min()) * 0.08
621
+ flame_neck_loop = _extract_neck_ring_geometric(flame_mesh_tmp, flame_neck_y, n_pts=bridge_n)
622
+ print(f' FLAME neck ring (geometric): {len(flame_neck_loop)} pts')
623
+
624
+ # ── 5b. Compute head position using NECK RING centroid ───────────────
625
+ # Directly align FLAME neck ring center → body neck ring center in all 3 axes.
626
+ # This is robust regardless of body pose or tilt.
627
+ body_neck_center = body_neck_loop.mean(axis=0)
628
+
629
+ # Estimate head height from WELDED mesh crown (more reliable than UV-split mesh)
630
+ welded_y_max = float(body_welded.vertices[:, 1].max())
631
+ body_head_height = welded_y_max - neck_y
632
+
633
+ flame_neck_center_unscaled = flame_neck_loop.mean(axis=0)
634
+ flame_y_min = verts_np[:, 1].min()
635
+ flame_y_max = verts_np[:, 1].max()
636
+ flame_head_height = flame_y_max - flame_y_min
637
+
638
+ print(f' Body neck center: {body_neck_center.round(4)}')
639
+ print(f' Body head space: {body_head_height:.4f} (neck_y={neck_y:.4f}, crown_y={welded_y_max:.4f})')
640
+ print(f' FLAME head height (unscaled): {flame_head_height:.4f}')
641
+ print(f' FLAME neck center (unscaled): {flame_neck_center_unscaled.round(4)}')
642
+
643
+ # Scale FLAME head to match body head height
644
+ if flame_head_height > 1e-5:
645
+ head_scale = body_head_height / flame_head_height
646
+ else:
647
+ head_scale = 1.0
648
+ print(f' Head scale: {head_scale:.4f}')
649
+
650
+ # Translate: FLAME neck ring center → body neck ring center in XZ,
651
+ # FLAME mesh bottom (flame_y_min) → neck_y in Y.
652
+ # This ensures the head fills the full space from neck_y to body crown.
653
+ translate = np.array([
654
+ body_neck_center[0] - flame_neck_center_unscaled[0] * head_scale,
655
+ neck_y - flame_y_min * head_scale,
656
+ body_neck_center[2] - flame_neck_center_unscaled[2] * head_scale,
657
+ ])
658
+ print(f' Translate: {translate.round(4)}')
659
+ verts_aligned = verts_np * head_scale + translate
660
+ print(f' FLAME aligned y={verts_aligned[:,1].min():.4f}→{verts_aligned[:,1].max():.4f}'
661
+ f' x={verts_aligned[:,0].min():.4f}→{verts_aligned[:,0].max():.4f}'
662
+ f' z={verts_aligned[:,2].min():.4f}→{verts_aligned[:,2].max():.4f}')
663
+
664
+ # Extract FLAME neck loop after alignment (at the cut plane y=neck_y)
665
+ flame_verts_aligned = verts_aligned
666
+ flame_mesh_aligned = __import__('trimesh').Trimesh(
667
+ vertices=flame_verts_aligned, faces=faces_v, process=False)
668
+ try:
669
+ flame_neck_loop_aligned = _extract_boundary_loop(flame_mesh_aligned)
670
+ print(f' FLAME neck ring (topology): {len(flame_neck_loop_aligned)} verts')
671
+ except Exception:
672
+ flame_neck_y_aligned = flame_verts_aligned[:, 1].min() + (
673
+ flame_verts_aligned[:, 1].max() - flame_verts_aligned[:, 1].min()) * 0.05
674
+ flame_neck_loop_aligned = _extract_neck_ring_geometric(
675
+ flame_mesh_aligned, flame_neck_y_aligned, n_pts=bridge_n)
676
+ print(f' FLAME neck ring (geometric): {len(flame_neck_loop_aligned)} pts')
677
+
678
+ flame_neck_r = np.linalg.norm(flame_neck_loop_aligned - flame_neck_loop_aligned.mean(0), axis=1).mean()
679
+ body_neck_r = np.linalg.norm(body_neck_loop - body_neck_loop.mean(0), axis=1).mean()
680
+ print(f' Body neck radius: {body_neck_r:.4f} FLAME neck radius (scaled): {flame_neck_r:.4f}')
681
+
682
+ # ── 6. Resample both neck loops to bridge_n points ────────────
683
+ body_loop_r = _resample_loop(body_neck_loop, bridge_n)
684
+ flame_loop_r = _resample_loop(flame_neck_loop_aligned, bridge_n)
685
+
686
+ # Ensure loops are oriented consistently (both CW or both CCW)
687
+ # Compute signed area to check orientation
688
+ def _loop_orientation(loop):
689
+ c = loop.mean(0)
690
+ t = loop - c
691
+ cross = np.cross(t[:-1], t[1:])
692
+ return float(np.sum(cross[:, 1])) # y-component
693
+
694
+ o_body = _loop_orientation(body_loop_r)
695
+ o_flame = _loop_orientation(flame_loop_r)
696
+ if (o_body > 0) != (o_flame > 0):
697
+ flame_loop_r = flame_loop_r[::-1]
698
+
699
+ # ── 7. Align loop starting points (minimise bridge twist) ─────
700
+ # Match starting vertex: find flame loop point closest to body loop start
701
+ dists = np.linalg.norm(flame_loop_r - body_loop_r[0], axis=1)
702
+ best_offset = int(np.argmin(dists))
703
+ flame_loop_r = np.roll(flame_loop_r, -best_offset, axis=0)
704
+
705
+ # ── 8. Build bridge strip ─────────────────────────────────────
706
+ bridge_verts, bridge_faces = _bridge_loops(body_loop_r, flame_loop_r)
707
+ bridge_mesh = __import__('trimesh').Trimesh(vertices=bridge_verts, faces=bridge_faces, process=False)
708
+
709
+ # ── 9. Combine: body_lower + bridge + FLAME head ──────────────
710
+ # Build FLAME head mesh with texture
711
+ head_mesh = deca_to_trimesh(flame_verts_aligned, faces_v, uv_verts, faces_uv, tex_bgr)
712
+
713
+ # Combine all parts
714
+ combined = __import__('trimesh').util.concatenate([body_lower, bridge_mesh, head_mesh])
715
+ combined = __import__('trimesh').Trimesh(
716
+ vertices=combined.vertices,
717
+ faces=combined.faces,
718
+ process=False
719
+ )
720
+
721
+ # Try to copy body texture to combined if available
722
+ try:
723
+ if hasattr(body_lower.visual, 'material'):
724
+ pass # Keep per-mesh materials — export as scene
725
+ except Exception:
726
+ pass
727
+
728
+ # ── 10. Export ────────────────────────────────────────────────
729
+ print(f'[replace_head] Exporting combined mesh: {len(combined.vertices)} verts...')
730
+ os.makedirs(os.path.dirname(out_glb) or '.', exist_ok=True)
731
+
732
+ # Export as GLB scene with separate submeshes (preserves textures)
733
+ try:
734
+ import trimesh
735
+ scene_out = trimesh.Scene()
736
+ scene_out.add_geometry(body_lower, geom_name='body')
737
+ scene_out.add_geometry(bridge_mesh, geom_name='bridge')
738
+ scene_out.add_geometry(head_mesh, geom_name='head')
739
+ scene_out.export(out_glb)
740
+ print(f'[replace_head] Saved scene GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
741
+ except Exception as e:
742
+ print(f'[replace_head] Scene export failed ({e}), trying single mesh...')
743
+ combined.export(out_glb)
744
+ print(f'[replace_head] Saved GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
745
+
746
+ return out_glb
747
+
748
+
749
+ # ──────────────────────────────────────────────────────────────────
750
+ # CLI
751
+ # ──────────────────────────────────────────────────────────────────
752
+ if __name__ == '__main__':
753
+ ap = argparse.ArgumentParser()
754
+ ap.add_argument('--body', required=True, help='TripoSG body GLB path')
755
+ ap.add_argument('--face', required=True, help='Reference face image path')
756
+ ap.add_argument('--out', required=True, help='Output GLB path')
757
+ ap.add_argument('--bridge', type=int, default=64, help='Bridge ring vertex count')
758
+ ap.add_argument('--cpu', action='store_true', help='Use CPU instead of CUDA')
759
+ args = ap.parse_args()
760
+
761
+ device = 'cpu' if args.cpu else ('cuda' if __import__('torch').cuda.is_available() else 'cpu')
762
+ replace_head(args.body, args.face, args.out, device=device, bridge_n=args.bridge)
pipeline/pshuman_client.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pshuman_client.py
3
+ =================
4
+ Call PSHuman to generate a high-detail 3D face mesh from a portrait image.
5
+
6
+ Two modes:
7
+ - Direct (default when service_url is localhost): runs PSHuman inference.py
8
+ as a subprocess without going through Gradio HTTP. Avoids the gradio_client
9
+ API-info bug that affects the pshuman Gradio env.
10
+ - Remote: uses gradio_client to call a running pshuman_app.py service.
11
+
12
+ Usage (standalone)
13
+ ------------------
14
+ python -m pipeline.pshuman_client \\
15
+ --image /path/to/portrait.png \\
16
+ --output /tmp/pshuman_face.obj \\
17
+ [--url http://remote-host:7862] # omit for direct/local mode
18
+
19
+ Requires: gradio-client (remote mode only)
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import glob
25
+ import os
26
+ import shutil
27
+ import subprocess
28
+ import time
29
+ from pathlib import Path
30
+
31
+ # Default: assume running on the same instance (local)
32
+ _DEFAULT_URL = os.environ.get("PSHUMAN_URL", "http://localhost:7862")
33
+
34
+ # ── Paths (on the Vast instance) ──────────────────────────────────────────────
35
+ PSHUMAN_DIR = "/root/PSHuman"
36
+ CONDA_PYTHON = "/root/miniconda/envs/pshuman/bin/python"
37
+ CONFIG = f"{PSHUMAN_DIR}/configs/inference-768-6view.yaml"
38
+ HF_MODEL_DIR = f"{PSHUMAN_DIR}/checkpoints/PSHuman_Unclip_768_6views"
39
+ HF_MODEL_HUB = "pengHTYX/PSHuman_Unclip_768_6views"
40
+
41
+
42
+ def _run_pshuman_direct(image_path: str, work_dir: str) -> str:
43
+ """
44
+ Run PSHuman inference.py directly as a subprocess.
45
+ Returns path to the colored OBJ mesh.
46
+ """
47
+ img_dir = os.path.join(work_dir, "input")
48
+ out_dir = os.path.join(work_dir, "out")
49
+ os.makedirs(img_dir, exist_ok=True)
50
+ os.makedirs(out_dir, exist_ok=True)
51
+
52
+ scene = "face"
53
+ dst = os.path.join(img_dir, f"{scene}.png")
54
+ shutil.copy(image_path, dst)
55
+
56
+ hf_model = HF_MODEL_DIR if Path(HF_MODEL_DIR).exists() else HF_MODEL_HUB
57
+
58
+ cmd = [
59
+ CONDA_PYTHON, f"{PSHUMAN_DIR}/inference.py",
60
+ "--config", CONFIG,
61
+ f"pretrained_model_name_or_path={hf_model}",
62
+ f"validation_dataset.root_dir={img_dir}",
63
+ f"save_dir={out_dir}",
64
+ "validation_dataset.crop_size=740",
65
+ "with_smpl=false",
66
+ "num_views=7",
67
+ "save_mode=rgb",
68
+ "seed=42",
69
+ ]
70
+
71
+ print(f"[pshuman] Running direct inference: {' '.join(cmd[:4])} ...")
72
+ t0 = time.time()
73
+
74
+ # Set CUDA_HOME + extra include dirs so nvdiffrast/torch JIT can compile.
75
+ # On Vast.ai, triposg conda env ships nvcc at bin/nvcc and CUDA headers
76
+ # scattered across site-packages/nvidia/{pkg}/include/ directories.
77
+ env = os.environ.copy()
78
+ if "CUDA_HOME" not in env:
79
+ _triposg = "/root/miniconda/envs/triposg"
80
+ _targets = os.path.join(_triposg, "targets", "x86_64-linux")
81
+ _nvcc_bin = os.path.join(_triposg, "bin")
82
+ _cuda_home = _targets # has include/cuda_runtime_api.h
83
+
84
+ _nvvm_bin = os.path.join(_triposg, "nvvm", "bin") # contains cicc
85
+ _nvcc_real = os.path.join(_targets, "bin") # contains nvcc (real one)
86
+
87
+ if (os.path.exists(os.path.join(_cuda_home, "include", "cuda_runtime_api.h"))
88
+ and (os.path.exists(os.path.join(_nvcc_bin, "nvcc"))
89
+ or os.path.exists(os.path.join(_nvcc_real, "nvcc")))):
90
+ env["CUDA_HOME"] = _cuda_home
91
+ # Build PATH: nvvm/bin (cicc) + targets/.../bin (nvcc real) + conda bin (nvcc wrapper)
92
+ path_parts = []
93
+ if os.path.isdir(_nvvm_bin):
94
+ path_parts.append(_nvvm_bin)
95
+ if os.path.isdir(_nvcc_real):
96
+ path_parts.append(_nvcc_real)
97
+ path_parts.append(_nvcc_bin)
98
+ env["PATH"] = ":".join(path_parts) + ":" + env.get("PATH", "")
99
+
100
+ # Collect all nvidia sub-package include dirs (cusparse, cublas, etc.)
101
+ _nvidia_site = os.path.join(_triposg, "lib", "python3.10",
102
+ "site-packages", "nvidia")
103
+ _extra_incs = []
104
+ if os.path.isdir(_nvidia_site):
105
+ import glob as _glob
106
+ for _inc in _glob.glob(os.path.join(_nvidia_site, "*/include")):
107
+ if os.path.isdir(_inc):
108
+ _extra_incs.append(_inc)
109
+ if _extra_incs:
110
+ _sep = ":"
111
+ _existing = env.get("CPATH", "")
112
+ env["CPATH"] = _sep.join(_extra_incs) + (_sep + _existing if _existing else "")
113
+ print(f"[pshuman] CUDA_HOME={_cuda_home}, {len(_extra_incs)} nvidia include dirs added")
114
+
115
+ proc = subprocess.run(
116
+ cmd, cwd=PSHUMAN_DIR,
117
+ capture_output=False,
118
+ text=True,
119
+ timeout=600,
120
+ env=env,
121
+ )
122
+ elapsed = time.time() - t0
123
+ print(f"[pshuman] Inference done in {elapsed:.1f}s (exit={proc.returncode})")
124
+
125
+ if proc.returncode != 0:
126
+ raise RuntimeError(f"PSHuman inference failed (exit {proc.returncode})")
127
+
128
+ # Locate output OBJ — PSHuman may save relative to its CWD (/root/PSHuman/out/)
129
+ # rather than to the specified save_dir, so check both locations.
130
+ cwd_out_dir = os.path.join(PSHUMAN_DIR, "out", scene)
131
+ patterns = [
132
+ f"{out_dir}/{scene}/result_clr_scale4_{scene}.obj",
133
+ f"{out_dir}/{scene}/result_clr_scale*_{scene}.obj",
134
+ f"{out_dir}/**/*.obj",
135
+ f"{cwd_out_dir}/result_clr_scale*_{scene}.obj",
136
+ f"{cwd_out_dir}/*.obj",
137
+ f"{PSHUMAN_DIR}/out/**/*.obj",
138
+ ]
139
+ obj_path = None
140
+ for pat in patterns:
141
+ hits = sorted(glob.glob(pat, recursive=True))
142
+ if hits:
143
+ colored = [h for h in hits if "clr" in h]
144
+ obj_path = (colored or hits)[-1]
145
+ break
146
+
147
+ if not obj_path:
148
+ all_files = list(Path(out_dir).rglob("*"))
149
+ objs = [str(f) for f in all_files if f.suffix in (".obj", ".ply", ".glb")]
150
+ if objs:
151
+ obj_path = objs[-1]
152
+ if not obj_path and Path(cwd_out_dir).exists():
153
+ for f in Path(cwd_out_dir).rglob("*.obj"):
154
+ obj_path = str(f)
155
+ break
156
+ if not obj_path:
157
+ raise FileNotFoundError(
158
+ f"No mesh output found in {out_dir}. "
159
+ f"Files: {[str(f) for f in all_files[:20]]}"
160
+ )
161
+
162
+ print(f"[pshuman] Output mesh: {obj_path}")
163
+ return obj_path
164
+
165
+
166
+ def generate_pshuman_mesh(
167
+ image_path: str,
168
+ output_path: str,
169
+ service_url: str = _DEFAULT_URL,
170
+ timeout: float = 600.0,
171
+ ) -> str:
172
+ """
173
+ Generate a PSHuman face mesh and save it to *output_path*.
174
+
175
+ When service_url points to localhost, PSHuman inference.py is run directly
176
+ (no Gradio HTTP, avoids gradio_client API-info bug).
177
+ For remote URLs, gradio_client is used.
178
+
179
+ Parameters
180
+ ----------
181
+ image_path : local PNG/JPG path of the portrait
182
+ output_path : where to save the downloaded OBJ
183
+ service_url : base URL of pshuman_app.py, or "direct" to skip HTTP
184
+ timeout : seconds to wait for inference (used in remote mode)
185
+
186
+ Returns
187
+ -------
188
+ output_path (convenience)
189
+ """
190
+ import tempfile
191
+
192
+ output_path = str(output_path)
193
+ os.makedirs(Path(output_path).parent, exist_ok=True)
194
+
195
+ is_local = (
196
+ "localhost" in service_url
197
+ or "127.0.0.1" in service_url
198
+ or service_url.strip().lower() == "direct"
199
+ or not service_url.strip()
200
+ )
201
+
202
+ if is_local:
203
+ # ── Direct mode: run subprocess ───────────────────────────────────────
204
+ print(f"[pshuman] Direct mode (no HTTP) — running inference on {image_path}")
205
+ work_dir = tempfile.mkdtemp(prefix="pshuman_direct_")
206
+ obj_tmp = _run_pshuman_direct(image_path, work_dir)
207
+ else:
208
+ # ── Remote mode: call Gradio service ──────────────────────────────────
209
+ try:
210
+ from gradio_client import Client
211
+ except ImportError:
212
+ raise ImportError("pip install gradio-client")
213
+
214
+ print(f"[pshuman] Connecting to {service_url}")
215
+ client = Client(service_url)
216
+
217
+ print(f"[pshuman] Submitting: {image_path}")
218
+ result = client.predict(
219
+ image=image_path,
220
+ api_name="/gradio_generate_face",
221
+ )
222
+
223
+ if isinstance(result, (list, tuple)):
224
+ obj_tmp = result[0]
225
+ status = result[1] if len(result) > 1 else "ok"
226
+ elif isinstance(result, dict):
227
+ obj_tmp = result.get("obj_path") or result.get("value")
228
+ status = result.get("status", "ok")
229
+ else:
230
+ obj_tmp = result
231
+ status = "ok"
232
+
233
+ if not obj_tmp or "Error" in str(status):
234
+ raise RuntimeError(f"PSHuman service error: {status}")
235
+
236
+ if isinstance(obj_tmp, dict):
237
+ obj_tmp = obj_tmp.get("path") or obj_tmp.get("name") or str(obj_tmp)
238
+
239
+ work_dir = str(Path(str(obj_tmp)).parent)
240
+
241
+ # ── Copy OBJ + companions to output location ───────────────────────────
242
+ shutil.copy(str(obj_tmp), output_path)
243
+ print(f"[pshuman] Saved OBJ -> {output_path}")
244
+
245
+ src_dir = Path(str(obj_tmp)).parent
246
+ out_dir = Path(output_path).parent
247
+ for ext in ("*.mtl", "*.png", "*.jpg"):
248
+ for f in src_dir.glob(ext):
249
+ dest = out_dir / f.name
250
+ if not dest.exists():
251
+ shutil.copy(str(f), str(dest))
252
+
253
+ return output_path
254
+
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # CLI
258
+ # ---------------------------------------------------------------------------
259
+
260
+ def main():
261
+ parser = argparse.ArgumentParser(
262
+ description="Generate PSHuman face mesh from portrait image"
263
+ )
264
+ parser.add_argument("--image", required=True, help="Portrait image path")
265
+ parser.add_argument("--output", required=True, help="Output OBJ path")
266
+ parser.add_argument(
267
+ "--url", default=_DEFAULT_URL,
268
+ help="PSHuman service URL, or 'direct' to run inference locally "
269
+ "(default: http://localhost:7862 → auto-selects direct mode)",
270
+ )
271
+ parser.add_argument("--timeout", type=float, default=600.0)
272
+ args = parser.parse_args()
273
+
274
+ generate_pshuman_mesh(
275
+ image_path = args.image,
276
+ output_path = args.output,
277
+ service_url = args.url,
278
+ timeout = args.timeout,
279
+ )
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()
pipeline/render_glb.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, cv2
2
+ sys.path.insert(0, '/root/MV-Adapter')
3
+ import numpy as np, torch
4
+ from mvadapter.utils.mesh_utils import NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render
5
+
6
+ glb = sys.argv[1]
7
+ out = sys.argv[2]
8
+ device = 'cuda'
9
+ ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
10
+ mesh = load_mesh(glb, rescale=True, device=device)
11
+
12
+ views = [('front',-90),('right',-180),('back',-270),('left',0)]
13
+ imgs = []
14
+ for name, az in views:
15
+ cam = get_orthogonal_camera(elevation_deg=[0], distance=[1.8],
16
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
17
+ azimuth_deg=[az], device=device)
18
+ r = render(ctx, mesh, cam, height=512, width=384,
19
+ render_attr=True, render_depth=False, render_normal=False, attr_background=0.15)
20
+ img = (r.attr[0].cpu().numpy()*255).clip(0,255).astype('uint8')
21
+ imgs.append(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
22
+
23
+ grid = np.concatenate(imgs, axis=1)
24
+ cv2.imwrite(out, grid)
25
+ print(f'Saved {grid.shape[1]}x{grid.shape[0]} grid to {out}')
pipeline/tpose.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tpose.py — T-pose a humanoid GLB using YOLO pose estimation.
3
+
4
+ Pipeline:
5
+ 1. Render the mesh from front view (azimuth=-90)
6
+ 2. Run YOLOv8-pose to get 17 COCO keypoints in render-space
7
+ 3. Unproject keypoints through the orthographic camera to 3D
8
+ 4. Build Blender armature with bones at detected 3D joint positions (current pose)
9
+ 5. Auto-weight skin the mesh to this armature
10
+ 6. Rotate arm/leg bones to T-pose, apply deformation, export
11
+
12
+ Usage:
13
+ blender --background --python tpose.py -- <input.glb> <output.glb>
14
+ """
15
+
16
+ import bpy, sys, math, mathutils, os, tempfile
17
+ import numpy as np
18
+
19
+ # ── Args ─────────────────────────────────────────────────────────────────────
20
+ argv = sys.argv
21
+ argv = argv[argv.index("--") + 1:] if "--" in argv else []
22
+ if len(argv) < 2:
23
+ print("Usage: blender --background --python tpose.py -- input.glb output.glb")
24
+ sys.exit(1)
25
+ input_glb = argv[0]
26
+ output_glb = argv[1]
27
+
28
+ # ── Step 1: Render front view using nvdiffrast (outside Blender) ───────────────
29
+ # We do this via a subprocess call before Blender scene setup,
30
+ # using the triposg Python env which has MV-Adapter + nvdiffrast.
31
+ import subprocess, json
32
+
33
+ TRIPOSG_PYTHON = '/root/miniconda/envs/triposg/bin/python'
34
+ RENDER_SCRIPT = '/tmp/_tpose_render.py'
35
+ RENDER_OUT = '/tmp/_tpose_front.png'
36
+ KP_OUT = '/tmp/_tpose_kp.json'
37
+
38
+ render_code = r"""
39
+ import sys, json
40
+ sys.path.insert(0, '/root/MV-Adapter')
41
+ import numpy as np, cv2, torch
42
+ from mvadapter.utils.mesh_utils import (
43
+ NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
44
+ )
45
+
46
+ body_glb = sys.argv[1]
47
+ out_png = sys.argv[2]
48
+
49
+ device = 'cuda'
50
+ ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
51
+ mesh_mv = load_mesh(body_glb, rescale=True, device=device)
52
+ camera = get_orthogonal_camera(
53
+ elevation_deg=[0], distance=[1.8],
54
+ left=-0.55, right=0.55, bottom=-0.55, top=0.55,
55
+ azimuth_deg=[-90], device=device,
56
+ )
57
+ out = render(ctx, mesh_mv, camera, height=1024, width=768,
58
+ render_attr=True, render_depth=False, render_normal=False,
59
+ attr_background=0.5)
60
+ img_np = (out.attr[0].cpu().numpy() * 255).clip(0,255).astype('uint8')
61
+ cv2.imwrite(out_png, cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
62
+ print(f"Rendered to {out_png}")
63
+ """
64
+
65
+ with open(RENDER_SCRIPT, 'w') as f:
66
+ f.write(render_code)
67
+
68
+ print("[tpose] Rendering front view ...")
69
+ r = subprocess.run([TRIPOSG_PYTHON, RENDER_SCRIPT, input_glb, RENDER_OUT],
70
+ capture_output=True, text=True)
71
+ print(r.stdout.strip()); print(r.stderr[-500:] if r.stderr else '')
72
+
73
+ # ── Step 2: YOLO pose estimation ──────────────────────────────────────────────
74
+ YOLO_SCRIPT = '/tmp/_tpose_yolo.py'
75
+ yolo_code = r"""
76
+ import sys, json
77
+ import cv2
78
+ from ultralytics import YOLO
79
+ import numpy as np
80
+
81
+ img_path = sys.argv[1]
82
+ kp_path = sys.argv[2]
83
+
84
+ model = YOLO('yolov8n-pose.pt')
85
+ img = cv2.imread(img_path)
86
+ H, W = img.shape[:2]
87
+
88
+ results = model(img, verbose=False)
89
+ if not results or results[0].keypoints is None:
90
+ print("ERROR: no person detected"); sys.exit(1)
91
+
92
+ # Pick detection with highest confidence
93
+ kps_all = results[0].keypoints.data.cpu().numpy() # (N, 17, 3)
94
+ confs = kps_all[:, :, 2].mean(axis=1)
95
+ best = kps_all[confs.argmax()] # (17, 3): x, y, conf
96
+
97
+ # COCO 17 keypoints:
98
+ # 0=nose 1=left_eye 2=right_eye 3=left_ear 4=right_ear
99
+ # 5=left_shoulder 6=right_shoulder 7=left_elbow 8=right_elbow
100
+ # 9=left_wrist 10=right_wrist 11=left_hip 12=right_hip
101
+ # 13=left_knee 14=right_knee 15=left_ankle 16=right_ankle
102
+
103
+ names = ['nose','left_eye','right_eye','left_ear','right_ear',
104
+ 'left_shoulder','right_shoulder','left_elbow','right_elbow',
105
+ 'left_wrist','right_wrist','left_hip','right_hip',
106
+ 'left_knee','right_knee','left_ankle','right_ankle']
107
+
108
+ kp_dict = {}
109
+ for i, name in enumerate(names):
110
+ x, y, c = best[i]
111
+ kp_dict[name] = {'x': float(x)/W, 'y': float(y)/H, 'conf': float(c)}
112
+ print(f" {name}: ({x:.1f},{y:.1f}) conf={c:.2f}")
113
+
114
+ kp_dict['img_hw'] = [int(H), int(W)]
115
+ with open(kp_path, 'w') as f:
116
+ json.dump(kp_dict, f)
117
+ print(f"Keypoints saved to {kp_path}")
118
+ """
119
+
120
+ with open(YOLO_SCRIPT, 'w') as f:
121
+ f.write(yolo_code)
122
+
123
+ print("[tpose] Running YOLO pose estimation ...")
124
+ r2 = subprocess.run([TRIPOSG_PYTHON, YOLO_SCRIPT, RENDER_OUT, KP_OUT],
125
+ capture_output=True, text=True)
126
+ print(r2.stdout.strip()); print(r2.stderr[-300:] if r2.stderr else '')
127
+
128
+ if not os.path.exists(KP_OUT):
129
+ print("ERROR: YOLO failed — falling back to heuristic")
130
+ kp_data = None
131
+ else:
132
+ with open(KP_OUT) as f:
133
+ kp_data = json.load(f)
134
+
135
+ # ── Step 3: Unproject render-space keypoints to 3D ────────────────────────────
136
+ # Orthographic camera: left=-0.55, right=0.55, bottom=-0.55, top=0.55
137
+ # Render: 768×1024. NDC x = 2*(px/W)-1, ndc y = 1-2*(py/H)
138
+ # World X = ndc_x * 0.55, World Y (mesh up) = ndc_y * 0.55
139
+ # We need 3D positions in the ORIGINAL mesh coordinate space.
140
+ # After Blender GLB import, original mesh Y → Blender Z, original Z → Blender -Y
141
+
142
+ def kp_to_3d(name, z_default=0.0):
143
+ """Convert YOLO keypoint (image fraction) → Blender 3D coords."""
144
+ if kp_data is None or name not in kp_data:
145
+ return None
146
+ k = kp_data[name]
147
+ if k['conf'] < 0.3:
148
+ return None
149
+ # Image coords (fractions) → NDC
150
+ ndc_x = 2 * k['x'] - 1.0 # left→right = mesh X
151
+ ndc_y = -(2 * k['y'] - 1.0) # top→bottom = mesh Y (up)
152
+ # Orthographic: frustum ±0.55
153
+ mesh_x = ndc_x * 0.55
154
+ mesh_y = ndc_y * 0.55 # this is mesh-space Y (vertical)
155
+ # After GLB import: mesh Y → Blender Z, mesh Z → Blender -Y
156
+ bl_x = mesh_x
157
+ bl_z = mesh_y # height
158
+ bl_y = z_default # depth (not observable from front view)
159
+ return (bl_x, bl_y, bl_z)
160
+
161
+ # Key joint positions in Blender space
162
+ J = {}
163
+ for name in ['nose','left_shoulder','right_shoulder','left_elbow','right_elbow',
164
+ 'left_wrist','right_wrist','left_hip','right_hip',
165
+ 'left_knee','right_knee','left_ankle','right_ankle']:
166
+ p = kp_to_3d(name)
167
+ if p: J[name] = p
168
+
169
+ print(f"[tpose] Detected joints: {list(J.keys())}")
170
+
171
+ # ── Step 4: Set up Blender scene ──────────────────────────────────────────────
172
+ bpy.ops.wm.read_factory_settings(use_empty=True)
173
+ bpy.ops.import_scene.gltf(filepath=input_glb)
174
+ bpy.context.view_layer.update()
175
+
176
+ mesh_obj = next((o for o in bpy.data.objects if o.type == 'MESH'), None)
177
+ if not mesh_obj:
178
+ print("ERROR: no mesh"); sys.exit(1)
179
+
180
+ verts_w = np.array([mesh_obj.matrix_world @ v.co for v in mesh_obj.data.vertices])
181
+ z_min, z_max = verts_w[:,2].min(), verts_w[:,2].max()
182
+ x_c = (verts_w[:,0].min() + verts_w[:,0].max()) / 2
183
+ y_c = (verts_w[:,1].min() + verts_w[:,1].max()) / 2
184
+ H_mesh = z_max - z_min
185
+
186
+ def zh(frac): return z_min + frac * H_mesh
187
+ def jv(name, fallback_frac=None, fallback_x=0.0):
188
+ """Get joint position from YOLO or use fallback."""
189
+ if name in J:
190
+ x, y, z = J[name]
191
+ return (x, y_c, z) # use mesh y_c for depth
192
+ if fallback_frac is not None:
193
+ return (x_c + fallback_x, y_c, zh(fallback_frac))
194
+ return None
195
+
196
+ # ── Step 5: Build armature in CURRENT pose ────────────────────────────────────
197
+ bpy.ops.object.armature_add(location=(x_c, y_c, zh(0.5)))
198
+ arm_obj = bpy.context.object
199
+ arm_obj.name = 'PoseRig'
200
+ arm = arm_obj.data
201
+
202
+ bpy.ops.object.mode_set(mode='EDIT')
203
+ eb = arm.edit_bones
204
+
205
+ def V(xyz): return mathutils.Vector(xyz)
206
+
207
+ def add_bone(name, head, tail, parent=None, connect=False):
208
+ b = eb.new(name)
209
+ b.head = V(head)
210
+ b.tail = V(tail)
211
+ if parent and parent in eb:
212
+ b.parent = eb[parent]
213
+ b.use_connect = connect
214
+ return b
215
+
216
+ # Helper: midpoint
217
+ def mid(a, b): return tuple((a[i]+b[i])/2 for i in range(3))
218
+ def offset(p, dx=0, dy=0, dz=0): return (p[0]+dx, p[1]+dy, p[2]+dz)
219
+
220
+ # ── Spine / hips ─────────────────────────────────────────────────────────────
221
+ hip_L = jv('left_hip', 0.48, -0.07)
222
+ hip_R = jv('right_hip', 0.48, 0.07)
223
+ sh_L = jv('left_shoulder', 0.77, -0.20)
224
+ sh_R = jv('right_shoulder', 0.77, 0.20)
225
+ nose = jv('nose', 0.92)
226
+
227
+ hips_c = mid(hip_L, hip_R) if (hip_L and hip_R) else (x_c, y_c, zh(0.48))
228
+ sh_c = mid(sh_L, sh_R) if (sh_L and sh_R) else (x_c, y_c, zh(0.77))
229
+
230
+ add_bone('Hips', hips_c, offset(hips_c, dz=H_mesh*0.08))
231
+ add_bone('Spine', hips_c, offset(hips_c, dz=(sh_c[2]-hips_c[2])*0.5), 'Hips')
232
+ add_bone('Chest', offset(hips_c, dz=(sh_c[2]-hips_c[2])*0.5), sh_c, 'Spine', True)
233
+ if nose:
234
+ neck_z = sh_c[2] + (nose[2]-sh_c[2])*0.35
235
+ head_z = sh_c[2] + (nose[2]-sh_c[2])*0.65
236
+ add_bone('Neck', (x_c, y_c, neck_z), (x_c, y_c, head_z), 'Chest')
237
+ add_bone('Head', (x_c, y_c, head_z), (x_c, y_c, nose[2]+H_mesh*0.05), 'Neck', True)
238
+ else:
239
+ add_bone('Neck', sh_c, offset(sh_c, dz=H_mesh*0.06), 'Chest')
240
+ add_bone('Head', offset(sh_c, dz=H_mesh*0.06), offset(sh_c, dz=H_mesh*0.14), 'Neck', True)
241
+
242
+ # ── Arms (placed at DETECTED current pose positions) ─────────────────────────
243
+ el_L = jv('left_elbow', 0.60, -0.30)
244
+ el_R = jv('right_elbow', 0.60, 0.30)
245
+ wr_L = jv('left_wrist', 0.45, -0.25)
246
+ wr_R = jv('right_wrist', 0.45, 0.25)
247
+
248
+ for side, sh, el, wr in (('L', sh_L, el_L, wr_L), ('R', sh_R, el_R, wr_R)):
249
+ if not sh: continue
250
+ el_pos = el if el else offset(sh, dz=-H_mesh*0.15)
251
+ wr_pos = wr if wr else offset(el_pos, dz=-H_mesh*0.15)
252
+ hand = offset(wr_pos, dz=(wr_pos[2]-el_pos[2])*0.4)
253
+ add_bone(f'UpperArm.{side}', sh, el_pos, 'Chest')
254
+ add_bone(f'ForeArm.{side}', el_pos, wr_pos, f'UpperArm.{side}', True)
255
+ add_bone(f'Hand.{side}', wr_pos, hand, f'ForeArm.{side}', True)
256
+
257
+ # ── Legs ─────────────────────────────────────────────────────────────────────
258
+ kn_L = jv('left_knee', 0.25, -0.07)
259
+ kn_R = jv('right_knee', 0.25, 0.07)
260
+ an_L = jv('left_ankle', 0.04, -0.06)
261
+ an_R = jv('right_ankle', 0.04, 0.06)
262
+
263
+ for side, hp, kn, an in (('L', hip_L, kn_L, an_L), ('R', hip_R, kn_R, an_R)):
264
+ if not hp: continue
265
+ kn_pos = kn if kn else offset(hp, dz=-H_mesh*0.23)
266
+ an_pos = an if an else offset(kn_pos, dz=-H_mesh*0.22)
267
+ toe = offset(an_pos, dy=-H_mesh*0.06, dz=-H_mesh*0.02)
268
+ add_bone(f'UpperLeg.{side}', hp, kn_pos, 'Hips')
269
+ add_bone(f'LowerLeg.{side}', kn_pos, an_pos, f'UpperLeg.{side}', True)
270
+ add_bone(f'Foot.{side}', an_pos, toe, f'LowerLeg.{side}', True)
271
+
272
+ bpy.ops.object.mode_set(mode='OBJECT')
273
+
274
+ # ── Step 6: Skin mesh to armature ────────────────────────────────────────────
275
+ bpy.context.view_layer.objects.active = arm_obj
276
+ mesh_obj.select_set(True)
277
+ arm_obj.select_set(True)
278
+ bpy.ops.object.parent_set(type='ARMATURE_AUTO')
279
+ print("[tpose] Auto-weights applied")
280
+
281
+ # ── Step 7: Pose arms to T-pose ───────────────────────────────────────────────
282
+ # Compute per-arm rotation: from (current elbow - shoulder) direction → horizontal ±X
283
+ bpy.context.view_layer.objects.active = arm_obj
284
+ bpy.ops.object.mode_set(mode='POSE')
285
+
286
+ pb = arm_obj.pose.bones
287
+
288
+ def set_tpose_arm(side, sh_pos, el_pos):
289
+ if not sh_pos or not el_pos:
290
+ return
291
+ if f'UpperArm.{side}' not in pb:
292
+ return
293
+ # Current upper-arm direction in armature local space
294
+ sx = -1 if side == 'L' else 1
295
+ # T-pose direction: ±X horizontal
296
+ tpose_dir = mathutils.Vector((sx, 0, 0))
297
+ # Current bone direction (head→tail) in world space
298
+ bone = arm_obj.data.bones[f'UpperArm.{side}']
299
+ cur_dir = (bone.tail_local - bone.head_local).normalized()
300
+ # Rotation needed in bone's local space
301
+ rot_quat = cur_dir.rotation_difference(tpose_dir)
302
+ pb[f'UpperArm.{side}'].rotation_mode = 'QUATERNION'
303
+ pb[f'UpperArm.{side}'].rotation_quaternion = rot_quat
304
+
305
+ # Straighten forearm along the same axis
306
+ if f'ForeArm.{side}' in pb:
307
+ pb[f'ForeArm.{side}'].rotation_mode = 'QUATERNION'
308
+ pb[f'ForeArm.{side}'].rotation_quaternion = mathutils.Quaternion((1,0,0,0))
309
+
310
+ set_tpose_arm('L', sh_L, el_L)
311
+ set_tpose_arm('R', sh_R, el_R)
312
+
313
+ bpy.context.view_layer.update()
314
+ bpy.ops.object.mode_set(mode='OBJECT')
315
+
316
+ # ── Step 8: Apply armature modifier ──────────────────────────────────────────
317
+ bpy.context.view_layer.objects.active = mesh_obj
318
+ mesh_obj.select_set(True)
319
+ for mod in mesh_obj.modifiers:
320
+ if mod.type == 'ARMATURE':
321
+ bpy.ops.object.modifier_apply(modifier=mod.name)
322
+ print(f"[tpose] Applied modifier: {mod.name}")
323
+ break
324
+
325
+ bpy.data.objects.remove(arm_obj, do_unlink=True)
326
+
327
+ # ── Step 9: Export ────────────────────────────────────────────────────────────
328
+ bpy.ops.export_scene.gltf(
329
+ filepath=output_glb, export_format='GLB',
330
+ export_texcoords=True, export_normals=True,
331
+ export_materials='EXPORT', use_selection=False)
332
+ print(f"[tpose] Done → {output_glb}")
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- # HuggingFace ZeroGPU Space — Gradio SDK [cache-bust: 2]
2
  spaces
3
  numpy>=2
4
 
@@ -79,3 +79,15 @@ typeguard
79
  sentencepiece
80
  spandrel
81
  imageio
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace ZeroGPU Space — Gradio SDK [cache-bust: 3]
2
  spaces
3
  numpy>=2
4
 
 
79
  sentencepiece
80
  spandrel
81
  imageio
82
+ gradio_client
83
+
84
+ # FireRed / GGUF quantization
85
+ bitsandbytes
86
+
87
+ # Motion search + retargeting
88
+ filterpy
89
+ pytorch-lightning
90
+ lightning-utilities
91
+ webdataset
92
+ hydra-core
93
+ matplotlib
utils/pytorch3d_minimal.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pytorch3d_minimal.py
3
+ ====================
4
+ Drop-in replacement for the pytorch3d subset used by PSHuman's project_mesh.py
5
+ and mesh_utils.py. Uses nvdiffrast for GPU rasterization.
6
+
7
+ Implements:
8
+ - Meshes / TexturesVertex
9
+ - look_at_view_transform
10
+ - FoVOrthographicCameras / OrthographicCameras (orthographic projection only)
11
+ - RasterizationSettings / MeshRasterizer (via nvdiffrast)
12
+ - render_pix2faces_py3d (compatibility shim)
13
+ """
14
+ from __future__ import annotations
15
+ import math
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Texture / Mesh containers
23
+ # ---------------------------------------------------------------------------
24
+
25
+ class TexturesVertex:
26
+ def __init__(self, verts_features):
27
+ # verts_features: list of [N, C] tensors (one per mesh in batch)
28
+ self._feats = verts_features
29
+
30
+ def verts_features_packed(self):
31
+ return self._feats[0]
32
+
33
+ def clone(self):
34
+ return TexturesVertex([f.clone() for f in self._feats])
35
+
36
+ def detach(self):
37
+ return TexturesVertex([f.detach() for f in self._feats])
38
+
39
+ def to(self, device):
40
+ self._feats = [f.to(device) for f in self._feats]
41
+ return self
42
+
43
+
44
+ class Meshes:
45
+ def __init__(self, verts, faces, textures=None):
46
+ self._verts = verts # list of [N,3] float tensors
47
+ self._faces = faces # list of [F,3] long tensors
48
+ self.textures = textures
49
+
50
+ # ---- accessors --------------------------------------------------------
51
+ def verts_padded(self): return torch.stack(self._verts)
52
+ def faces_padded(self): return torch.stack(self._faces)
53
+ def verts_packed(self): return self._verts[0]
54
+ def faces_packed(self): return self._faces[0]
55
+ def verts_list(self): return self._verts
56
+ def faces_list(self): return self._faces
57
+
58
+ def verts_normals_packed(self):
59
+ v, f = self._verts[0], self._faces[0]
60
+ v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
61
+ fn = torch.cross(v1 - v0, v2 - v0, dim=1)
62
+ fn = F.normalize(fn, dim=1)
63
+ vn = torch.zeros_like(v)
64
+ for k in range(3):
65
+ vn.scatter_add_(0, f[:, k:k+1].expand(-1, 3), fn)
66
+ return F.normalize(vn, dim=1)
67
+
68
+ # ---- device / copy ----------------------------------------------------
69
+ def to(self, device):
70
+ self._verts = [v.to(device) for v in self._verts]
71
+ self._faces = [f.to(device) for f in self._faces]
72
+ if self.textures is not None:
73
+ self.textures.to(device)
74
+ return self
75
+
76
+ def clone(self):
77
+ m = Meshes([v.clone() for v in self._verts],
78
+ [f.clone() for f in self._faces])
79
+ if self.textures is not None:
80
+ m.textures = self.textures.clone()
81
+ return m
82
+
83
+ def detach(self):
84
+ m = Meshes([v.detach() for v in self._verts],
85
+ [f.detach() for f in self._faces])
86
+ if self.textures is not None:
87
+ m.textures = self.textures.detach()
88
+ return m
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Camera math (mirrors pytorch3d look_at_view_transform + Orthographic)
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def _look_at_rotation(camera_pos: torch.Tensor,
96
+ at: torch.Tensor,
97
+ up: torch.Tensor) -> torch.Tensor:
98
+ """Return (3,3) rotation matrix: world → camera."""
99
+ z = F.normalize(camera_pos - at, dim=-1) # cam looks along -Z
100
+ x = F.normalize(torch.cross(up, z, dim=-1), dim=-1)
101
+ y = torch.cross(z, x, dim=-1)
102
+ R = torch.stack([x, y, z], dim=-1) # columns = cam axes
103
+ return R # shape (3,3)
104
+
105
+
106
+ def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0,
107
+ degrees=True, device="cpu"):
108
+ """Matches pytorch3d convention exactly."""
109
+ if degrees:
110
+ elev = math.radians(float(elev))
111
+ azim = math.radians(float(azim))
112
+
113
+ # camera position in world
114
+ cx = dist * math.cos(elev) * math.sin(azim)
115
+ cy = dist * math.sin(elev)
116
+ cz = dist * math.cos(elev) * math.cos(azim)
117
+ eye = torch.tensor([[cx, cy, cz]], dtype=torch.float32, device=device)
118
+ at = torch.zeros(1, 3, device=device)
119
+ up = torch.tensor([[0, 1, 0]], dtype=torch.float32, device=device)
120
+
121
+ # pytorch3d stores R transposed (row = cam axis in world space)
122
+ R = _look_at_rotation(eye[0], at[0], up[0]).T.unsqueeze(0) # (1,3,3)
123
+
124
+ # T = camera position expressed in camera space
125
+ T = torch.bmm(-R, eye.unsqueeze(-1)).squeeze(-1) # (1,3)
126
+ return R, T
127
+
128
+
129
+ class _OrthoCamera:
130
+ """Minimal orthographic camera, matches FoVOrthographicCameras API."""
131
+ def __init__(self, R, T, focal_length=1.0, device="cpu"):
132
+ self.R = R.to(device) # (B,3,3)
133
+ self.T = T.to(device) # (B,3)
134
+ self.focal = float(focal_length)
135
+ self.device = device
136
+
137
+ def to(self, device):
138
+ self.R = self.R.to(device)
139
+ self.T = self.T.to(device)
140
+ self.device = device
141
+ return self
142
+
143
+ def get_znear(self):
144
+ return torch.tensor(0.01, device=self.device)
145
+
146
+ def is_perspective(self):
147
+ return False
148
+
149
+ def transform_points_ndc(self, points):
150
+ """
151
+ points: (B, N, 3) world coords
152
+ returns: (B, N, 3) NDC coords (X,Y in [-1,1], Z = depth)
153
+ """
154
+ # world → camera
155
+ pts_cam = torch.bmm(points, self.R) + self.T.unsqueeze(1) # (B,N,3)
156
+ # orthographic NDC: scale by focal, flip Y to match image convention
157
+ ndc_x = pts_cam[..., 0] * self.focal
158
+ ndc_y = -pts_cam[..., 1] * self.focal # pytorch3d flips Y
159
+ ndc_z = pts_cam[..., 2]
160
+ return torch.stack([ndc_x, ndc_y, ndc_z], dim=-1)
161
+
162
+ def _world_to_clip(self, verts: torch.Tensor) -> torch.Tensor:
163
+ """verts: (N,3) → clip (N,4) for nvdiffrast."""
164
+ pts_cam = (verts @ self.R[0].T) + self.T[0] # (N,3)
165
+ cx = pts_cam[:, 0] * self.focal
166
+ cy = -pts_cam[:, 1] * self.focal # flip Y
167
+ cz = pts_cam[:, 2]
168
+ w = torch.ones_like(cz)
169
+ return torch.stack([cx, cy, cz, w], dim=1) # (N,4)
170
+
171
+
172
+ # Aliases used in project_mesh.py
173
+ def FoVOrthographicCameras(device="cpu", R=None, T=None,
174
+ min_x=-1, max_x=1, min_y=-1, max_y=1,
175
+ focal_length=None, **kwargs):
176
+ fl = focal_length if focal_length is not None else 1.0 / (max_x + 1e-9)
177
+ return _OrthoCamera(R, T, focal_length=fl, device=device)
178
+
179
+
180
+ def FoVPerspectiveCameras(device="cpu", R=None, T=None, fov=60, degrees=True, **kwargs):
181
+ # Fallback: treat as orthographic at fov-derived scale (good enough for PSHuman)
182
+ fl = 1.0 / math.tan(math.radians(fov / 2)) if degrees else 1.0 / math.tan(fov / 2)
183
+ return _OrthoCamera(R, T, focal_length=fl, device=device)
184
+
185
+
186
+ OrthographicCameras = FoVOrthographicCameras
187
+
188
+
189
+ # ---------------------------------------------------------------------------
190
+ # Rasterizer (nvdiffrast-based)
191
+ # ---------------------------------------------------------------------------
192
+
193
+ class RasterizationSettings:
194
+ def __init__(self, image_size=512, blur_radius=0.0, faces_per_pixel=1):
195
+ if isinstance(image_size, (list, tuple)):
196
+ self.H, self.W = image_size[0], image_size[1]
197
+ else:
198
+ self.H = self.W = int(image_size)
199
+
200
+
201
+ class _Fragments:
202
+ def __init__(self, pix_to_face):
203
+ self.pix_to_face = pix_to_face.unsqueeze(-1) # (1,H,W,1)
204
+
205
+
206
+ class MeshRasterizer:
207
+ def __init__(self, cameras=None, raster_settings=None):
208
+ self.cameras = cameras
209
+ self.settings = raster_settings
210
+ self._glctx = None
211
+
212
+ def _get_ctx(self, device):
213
+ if self._glctx is None:
214
+ import nvdiffrast.torch as dr
215
+ self._glctx = dr.RasterizeCudaContext(device=device)
216
+ return self._glctx
217
+
218
+ def __call__(self, meshes: Meshes, cameras=None):
219
+ cam = cameras or self.cameras
220
+ H, W = self.settings.H, self.settings.W
221
+ device = meshes.verts_packed().device
222
+ import nvdiffrast.torch as dr
223
+ glctx = self._get_ctx(str(device))
224
+
225
+ verts = meshes.verts_packed().to(device)
226
+ faces = meshes.faces_packed().to(torch.int32).to(device)
227
+ clip = cam._world_to_clip(verts).unsqueeze(0) # (1,N,4)
228
+ rast, _ = dr.rasterize(glctx, clip, faces, resolution=(H, W))
229
+ pix_to_face = rast[0, :, :, -1].to(torch.int32) - 1 # -1 = background
230
+ return _Fragments(pix_to_face.unsqueeze(0))
231
+
232
+
233
+ # ---------------------------------------------------------------------------
234
+ # render_pix2faces_py3d shim (used in get_visible_faces)
235
+ # ---------------------------------------------------------------------------
236
+
237
+ def render_pix2faces_py3d(meshes, cameras, H=512, W=512, **kwargs):
238
+ """Returns {'pix_to_face': (1,H,W)} integer tensor of face indices (-1=bg)."""
239
+ settings = RasterizationSettings(image_size=(H, W))
240
+ rasterizer = MeshRasterizer(cameras=cameras, raster_settings=settings)
241
+ frags = rasterizer(meshes)
242
+ return {"pix_to_face": frags.pix_to_face[..., 0]} # (1,H,W)