Spaces:
Sleeping
Port MeshForge features to ZeroGPU Space: FireRed, PSHuman, Motion Search
Browse filesNew tabs and features from the MeshForge server version, adapted for ZeroGPU:
- Edit tab: FireRed GGUF-quantized image editing (QwenImageEditPlusPipeline)
- Animate tab: HumanML3D motion search + GLB animation via Retarget/
- PSHuman Face tab: HD face transplant via PSHuman service + face_transplant.py
- Settings tab: VRAM management (preload/unload/refresh)
- Generate tab: remove-BG preview controls + FireRed→Generate flow
- Enhancement tab: unload button for VRAM management
New pipeline modules: face_transplant, pshuman_client, render_glb, tpose,
face_inswap_bake, face_project, face_swap_render, head_replace
New Retarget/ directory: motion search, animate, skeleton, SMPL retargeting
New utils/pytorch3d_minimal.py
Updated requirements.txt: bitsandbytes, gradio_client, filterpy, pytorch-lightning,
lightning-utilities, webdataset, hydra-core, matplotlib
- Retarget/README.md +127 -0
- Retarget/__init__.py +49 -0
- Retarget/animate.py +611 -0
- Retarget/cli.py +129 -0
- Retarget/generate.py +131 -0
- Retarget/humanml3d_to_bvh.py +813 -0
- Retarget/io/__init__.py +1 -0
- Retarget/io/bvh.py +216 -0
- Retarget/io/gltf_io.py +316 -0
- Retarget/io/mapping.py +189 -0
- Retarget/math3d.py +167 -0
- Retarget/retarget.py +586 -0
- Retarget/search.py +159 -0
- Retarget/skeleton.py +165 -0
- Retarget/smpl.py +184 -0
- app.py +1079 -31
- pipeline/face_inswap_bake.py +302 -0
- pipeline/face_project.py +305 -0
- pipeline/face_swap_render.py +293 -0
- pipeline/face_transplant.py +667 -0
- pipeline/head_replace.py +762 -0
- pipeline/pshuman_client.py +283 -0
- pipeline/render_glb.py +25 -0
- pipeline/tpose.py +332 -0
- requirements.txt +13 -1
- utils/pytorch3d_minimal.py +242 -0
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rig_retarget
|
| 2 |
+
|
| 3 |
+
Pure-Python rig retargeting library. No Blender required.
|
| 4 |
+
|
| 5 |
+
Based on **[KeeMap Blender Rig Retargeting Addon](https://github.com/nkeeline/Keemap-Blender-Rig-ReTargeting-Addon)** by [Nick Keeline](https://github.com/nkeeline) (GPL v2).
|
| 6 |
+
All core retargeting math is a direct port of his work. Mapping JSON files are fully compatible with KeeMap.
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## File layout
|
| 11 |
+
|
| 12 |
+
```
|
| 13 |
+
rig_retarget/
|
| 14 |
+
├── math3d.py # Quaternion / matrix math (numpy + scipy), replaces mathutils
|
| 15 |
+
├── skeleton.py # Armature + PoseBone with FK, replaces bpy armature objects
|
| 16 |
+
├── retarget.py # Core retargeting logic — faithful port of KeeMapBoneOperators.py
|
| 17 |
+
├── cli.py # CLI entry point
|
| 18 |
+
└── io/
|
| 19 |
+
├── bvh.py # BVH mocap reader (source animation)
|
| 20 |
+
├── gltf_io.py # glTF/GLB reader + animation writer (UniRig destination)
|
| 21 |
+
└── mapping.py # JSON bone mapping — same format as KeeMap
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## Install
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install numpy scipy pygltflib
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## CLI
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Retarget BVH onto UniRig GLB
|
| 38 |
+
python -m rig_retarget.cli \
|
| 39 |
+
--source motion.bvh \
|
| 40 |
+
--dest unirig_character.glb \
|
| 41 |
+
--mapping radical2unirig.json \
|
| 42 |
+
--output animated_character.glb \
|
| 43 |
+
--fps 30 --start 0 --frames 200 --step 1
|
| 44 |
+
|
| 45 |
+
# Auto-calculate bone correction factors and save back to the mapping file
|
| 46 |
+
python -m rig_retarget.cli --calc-corrections \
|
| 47 |
+
--source motion.bvh \
|
| 48 |
+
--dest unirig_character.glb \
|
| 49 |
+
--mapping mymap.json
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## Python API
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
from rig_retarget.io.bvh import load_bvh
|
| 58 |
+
from rig_retarget.io.gltf_io import load_gltf, write_gltf_animation
|
| 59 |
+
from rig_retarget.io.mapping import load_mapping
|
| 60 |
+
from rig_retarget.retarget import transfer_animation, calc_all_corrections
|
| 61 |
+
|
| 62 |
+
# Load
|
| 63 |
+
settings, bone_items = load_mapping("my_map.json")
|
| 64 |
+
src_anim = load_bvh("motion.bvh")
|
| 65 |
+
dst_arm = load_gltf("unirig_char.glb")
|
| 66 |
+
|
| 67 |
+
# Optional: auto-calc corrections at first frame
|
| 68 |
+
src_anim.apply_frame(0)
|
| 69 |
+
calc_all_corrections(bone_items, src_anim.armature, dst_arm, settings)
|
| 70 |
+
|
| 71 |
+
# Transfer
|
| 72 |
+
settings.number_of_frames_to_apply = src_anim.num_frames
|
| 73 |
+
keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
|
| 74 |
+
|
| 75 |
+
# Write output GLB
|
| 76 |
+
write_gltf_animation("unirig_char.glb", dst_arm, keyframes, "output.glb")
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Mapping JSON format
|
| 82 |
+
|
| 83 |
+
100% compatible with KeeMap's `.json` files. Use KeeMap in Blender to create
|
| 84 |
+
and tune mappings, then use this library offline for batch processing.
|
| 85 |
+
|
| 86 |
+
Key fields per bone:
|
| 87 |
+
|
| 88 |
+
| Field | Description |
|
| 89 |
+
|---|---|
|
| 90 |
+
| `SourceBoneName` | Bone name in the source rig (BVH joint name) |
|
| 91 |
+
| `DestinationBoneName` | Bone name in the UniRig skeleton (glTF node name) |
|
| 92 |
+
| `set_bone_rotation` | Drive rotation from source |
|
| 93 |
+
| `set_bone_position` | Drive position from source |
|
| 94 |
+
| `bone_rotation_application_axis` | Mask axes: `X` `Y` `Z` `XY` `XZ` `YZ` `XYZ` |
|
| 95 |
+
| `bone_transpose_axis` | Swap axes: `NONE` `ZYX` `ZXY` `XZY` `YZX` `YXZ` |
|
| 96 |
+
| `CorrectionFactorX/Y/Z` | Euler correction (radians) |
|
| 97 |
+
| `postion_type` | `SINGLE_BONE_OFFSET` or `POLE` |
|
| 98 |
+
| `position_pole_distance` | IK pole distance |
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Blender → pure-Python mapping
|
| 103 |
+
|
| 104 |
+
| Blender | rig_retarget |
|
| 105 |
+
|---|---|
|
| 106 |
+
| `bpy.data.objects[name]` | `Armature` + `load_gltf()` / `load_bvh()` |
|
| 107 |
+
| `arm.pose.bones[name]` | `arm.get_bone(name)` → `PoseBone` |
|
| 108 |
+
| `bone.matrix` (pose space) | `bone.matrix_armature` |
|
| 109 |
+
| `arm.matrix_world` | `arm.world_matrix` |
|
| 110 |
+
| `arm.convert_space(...)` | `arm.world_matrix @ bone.matrix_armature` |
|
| 111 |
+
| `bone.rotation_quaternion` | `bone.pose_rotation_quat` |
|
| 112 |
+
| `bone.location` | `bone.pose_location` |
|
| 113 |
+
| `bone.keyframe_insert(...)` | returned in `keyframes` list from `transfer_frame()` |
|
| 114 |
+
| `bpy.context.scene.frame_set(i)` | `src_anim.apply_frame(i)` |
|
| 115 |
+
| `mathutils.Quaternion` | `np.ndarray [w,x,y,z]` + `math3d.*` |
|
| 116 |
+
| `mathutils.Matrix` | `np.ndarray (4,4)` |
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## Limitations / TODO
|
| 121 |
+
|
| 122 |
+
- **glTF source animation** reading not yet implemented (BVH only for now).
|
| 123 |
+
Add `io/gltf_anim_reader.py` reading `gltf.animations[0]` sampler data.
|
| 124 |
+
- FBX source support: use `pyassimp` or `bpy` offline with `--background`.
|
| 125 |
+
- IK solving: pole bone positioning is FK-only; a full IK solver (FABRIK/CCD)
|
| 126 |
+
would improve accuracy for limb targets.
|
| 127 |
+
- Quaternion mode twist bones: parity with Blender not guaranteed for complex twist rigs.
|
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
rig_retarget
|
| 3 |
+
============
|
| 4 |
+
Pure-Python rig retargeting library.
|
| 5 |
+
No Blender dependency. Targets TripoSG meshes auto-rigged by UniRig (SIGGRAPH 2025).
|
| 6 |
+
|
| 7 |
+
Quick start
|
| 8 |
+
-----------
|
| 9 |
+
from rig_retarget.io.bvh import load_bvh
|
| 10 |
+
from rig_retarget.io.gltf_io import load_gltf, write_gltf_animation
|
| 11 |
+
from rig_retarget.io.mapping import load_mapping
|
| 12 |
+
from rig_retarget.retarget import transfer_animation
|
| 13 |
+
|
| 14 |
+
settings, bone_items = load_mapping("my_map.json")
|
| 15 |
+
src_anim = load_bvh("motion.bvh")
|
| 16 |
+
dst_arm = load_gltf("unirig_char.glb")
|
| 17 |
+
|
| 18 |
+
keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
|
| 19 |
+
write_gltf_animation("unirig_char.glb", dst_arm, keyframes, "output.glb")
|
| 20 |
+
|
| 21 |
+
CLI
|
| 22 |
+
---
|
| 23 |
+
python -m rig_retarget.cli --source motion.bvh --dest char.glb \\
|
| 24 |
+
--mapping map.json --output char_animated.glb
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from .skeleton import Armature, PoseBone
|
| 28 |
+
from .retarget import (
|
| 29 |
+
get_bone_position_ws,
|
| 30 |
+
get_bone_ws_quat,
|
| 31 |
+
set_bone_position_ws,
|
| 32 |
+
set_bone_rotation,
|
| 33 |
+
set_bone_position,
|
| 34 |
+
set_bone_position_pole,
|
| 35 |
+
set_bone_scale,
|
| 36 |
+
calc_rotation_offset,
|
| 37 |
+
calc_location_offset,
|
| 38 |
+
calc_all_corrections,
|
| 39 |
+
transfer_frame,
|
| 40 |
+
transfer_animation,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"Armature", "PoseBone",
|
| 45 |
+
"get_bone_position_ws", "get_bone_ws_quat", "set_bone_position_ws",
|
| 46 |
+
"set_bone_rotation", "set_bone_position", "set_bone_position_pole",
|
| 47 |
+
"set_bone_scale", "calc_rotation_offset", "calc_location_offset",
|
| 48 |
+
"calc_all_corrections", "transfer_frame", "transfer_animation",
|
| 49 |
+
]
|
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
animate.py
|
| 3 |
+
──────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
Bake SMPL motion (from HumanML3D [T, 263] features) onto a UniRig-rigged GLB.
|
| 5 |
+
|
| 6 |
+
Retargeting method: world-direction matching
|
| 7 |
+
────────────────────────────────────────────
|
| 8 |
+
Commercial retargeters (Mixamo, Rokoko, MotionBuilder) avoid rest-pose
|
| 9 |
+
convention mismatches by matching WORLD BONE DIRECTIONS, not local rotations.
|
| 10 |
+
|
| 11 |
+
Algorithm (per frame, per bone):
|
| 12 |
+
1. Run t2m FK with HumanML3D 6D rotations → world bone direction d_t2m
|
| 13 |
+
2. Flip X axis: t2m +X = character's LEFT; SMPL/UniRig +X = character's RIGHT
|
| 14 |
+
So d_desired = (-d_t2m_x, d_t2m_y, d_t2m_z) in SMPL/UniRig world frame
|
| 15 |
+
3. d_rest = normalize(ur_pos[bone] - ur_pos[parent]) from GLB inverse bind matrices
|
| 16 |
+
4. R_world = R_between(d_rest, d_desired) -- minimal rotation in world space
|
| 17 |
+
5. local_rot = inv(R_world[parent]) @ R_world[bone]
|
| 18 |
+
6. pose_rot_delta = inv(rest_r) @ local_rot -- composing with glTF rest rotation
|
| 19 |
+
|
| 20 |
+
This avoids all rest-pose convention issues:
|
| 21 |
+
- t2m canonical arms point DOWN: handled automatically
|
| 22 |
+
- t2m canonical hips/shoulders have inverted X: handled by the X-flip
|
| 23 |
+
- UniRig non-identity rest rotations: handled by inv(rest_r) composition
|
| 24 |
+
|
| 25 |
+
Key bugs fixed vs previous version:
|
| 26 |
+
- IBM column-major: glTF IBMs are column-major; was using inv(ibm)[:3,3] (zeros).
|
| 27 |
+
Fixed to inv(ibm.T)[:3,3] which gives correct world-space bone positions.
|
| 28 |
+
- Normalisation: was mixing ur/smpl Y ranges, causing wrong height alignment.
|
| 29 |
+
Fixed with independent per-skeleton Y normalisation.
|
| 30 |
+
- Rotation convention: was applying t2m rotations directly without X-flip.
|
| 31 |
+
Fixed by world-direction matching with coordinate-frame conversion.
|
| 32 |
+
"""
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
import os
|
| 35 |
+
import re
|
| 36 |
+
import numpy as np
|
| 37 |
+
from typing import Union
|
| 38 |
+
|
| 39 |
+
from .smpl import SMPLMotion, hml3d_to_smpl_motion
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 43 |
+
# T2M (HumanML3D) skeleton constants
|
| 44 |
+
# Source: HumanML3D/common/paramUtil.py
|
| 45 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 46 |
+
|
| 47 |
+
T2M_RAW_OFFSETS = np.array([
|
| 48 |
+
[ 0, 0, 0], # 0 Hips (root)
|
| 49 |
+
[ 1, 0, 0], # 1 LeftUpLeg +X = character LEFT in t2m convention
|
| 50 |
+
[-1, 0, 0], # 2 RightUpLeg
|
| 51 |
+
[ 0, 1, 0], # 3 Spine
|
| 52 |
+
[ 0,-1, 0], # 4 LeftLeg
|
| 53 |
+
[ 0,-1, 0], # 5 RightLeg
|
| 54 |
+
[ 0, 1, 0], # 6 Spine1
|
| 55 |
+
[ 0,-1, 0], # 7 LeftFoot
|
| 56 |
+
[ 0,-1, 0], # 8 RightFoot
|
| 57 |
+
[ 0, 1, 0], # 9 Spine2
|
| 58 |
+
[ 0, 0, 1], # 10 LeftToeBase
|
| 59 |
+
[ 0, 0, 1], # 11 RightToeBase
|
| 60 |
+
[ 0, 1, 0], # 12 Neck
|
| 61 |
+
[ 1, 0, 0], # 13 LeftShoulder +X = character LEFT
|
| 62 |
+
[-1, 0, 0], # 14 RightShoulder
|
| 63 |
+
[ 0, 0, 1], # 15 Head
|
| 64 |
+
[ 0,-1, 0], # 16 LeftArm arms hang DOWN in t2m canonical
|
| 65 |
+
[ 0,-1, 0], # 17 RightArm
|
| 66 |
+
[ 0,-1, 0], # 18 LeftForeArm
|
| 67 |
+
[ 0,-1, 0], # 19 RightForeArm
|
| 68 |
+
[ 0,-1, 0], # 20 LeftHand
|
| 69 |
+
[ 0,-1, 0], # 21 RightHand
|
| 70 |
+
], dtype=np.float64)
|
| 71 |
+
|
| 72 |
+
T2M_KINEMATIC_CHAIN = [
|
| 73 |
+
[0, 2, 5, 8, 11], # Hips -> RightUpLeg -> RightLeg -> RightFoot -> RightToe
|
| 74 |
+
[0, 1, 4, 7, 10], # Hips -> LeftUpLeg -> LeftLeg -> LeftFoot -> LeftToe
|
| 75 |
+
[0, 3, 6, 9, 12, 15], # Hips -> Spine -> Spine1 -> Spine2 -> Neck -> Head
|
| 76 |
+
[9, 14, 17, 19, 21], # Spine2 -> RightShoulder -> RightArm -> RightForeArm -> RightHand
|
| 77 |
+
[9, 13, 16, 18, 20], # Spine2 -> LeftShoulder -> LeftArm -> LeftForeArm -> LeftHand
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
# Parent joint index for each of the 22 t2m joints
|
| 81 |
+
T2M_PARENTS = [-1] * 22
|
| 82 |
+
for _chain in T2M_KINEMATIC_CHAIN:
|
| 83 |
+
for _k in range(1, len(_chain)):
|
| 84 |
+
T2M_PARENTS[_chain[_k]] = _chain[_k - 1]
|
| 85 |
+
|
| 86 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 87 |
+
# SMPL joint names / T-pose (for bone mapping reference)
|
| 88 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 89 |
+
|
| 90 |
+
SMPL_NAMES = [
|
| 91 |
+
"Hips", "LeftUpLeg", "RightUpLeg", "Spine",
|
| 92 |
+
"LeftLeg", "RightLeg", "Spine1", "LeftFoot",
|
| 93 |
+
"RightFoot", "Spine2", "LeftToeBase", "RightToeBase",
|
| 94 |
+
"Neck", "LeftShoulder", "RightShoulder","Head",
|
| 95 |
+
"LeftArm", "RightArm", "LeftForeArm", "RightForeArm",
|
| 96 |
+
"LeftHand", "RightHand",
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
# Approximate T-pose joint world positions in metres (Y-up, facing +Z)
|
| 100 |
+
# +X = character's RIGHT (standard SMPL/UniRig convention)
|
| 101 |
+
SMPL_TPOSE = np.array([
|
| 102 |
+
[ 0.000, 0.920, 0.000], # 0 Hips
|
| 103 |
+
[-0.095, 0.920, 0.000], # 1 LeftUpLeg (character's left = -X)
|
| 104 |
+
[ 0.095, 0.920, 0.000], # 2 RightUpLeg
|
| 105 |
+
[ 0.000, 0.980, 0.000], # 3 Spine
|
| 106 |
+
[-0.095, 0.495, 0.000], # 4 LeftLeg
|
| 107 |
+
[ 0.095, 0.495, 0.000], # 5 RightLeg
|
| 108 |
+
[ 0.000, 1.050, 0.000], # 6 Spine1
|
| 109 |
+
[-0.095, 0.075, 0.000], # 7 LeftFoot
|
| 110 |
+
[ 0.095, 0.075, 0.000], # 8 RightFoot
|
| 111 |
+
[ 0.000, 1.120, 0.000], # 9 Spine2
|
| 112 |
+
[-0.095, 0.000, -0.020], # 10 LeftToeBase
|
| 113 |
+
[ 0.095, 0.000, -0.020], # 11 RightToeBase
|
| 114 |
+
[ 0.000, 1.370, 0.000], # 12 Neck
|
| 115 |
+
[-0.130, 1.290, 0.000], # 13 LeftShoulder
|
| 116 |
+
[ 0.130, 1.290, 0.000], # 14 RightShoulder
|
| 117 |
+
[ 0.000, 1.500, 0.000], # 15 Head
|
| 118 |
+
[-0.330, 1.290, 0.000], # 16 LeftArm
|
| 119 |
+
[ 0.330, 1.290, 0.000], # 17 RightArm
|
| 120 |
+
[-0.630, 1.290, 0.000], # 18 LeftForeArm
|
| 121 |
+
[ 0.630, 1.290, 0.000], # 19 RightForeArm
|
| 122 |
+
[-0.910, 1.290, 0.000], # 20 LeftHand
|
| 123 |
+
[ 0.910, 1.290, 0.000], # 21 RightHand
|
| 124 |
+
], dtype=np.float32)
|
| 125 |
+
|
| 126 |
+
# Name hint table: lowercase substrings -> SMPL joint index
|
| 127 |
+
_NAME_HINTS: list[tuple[list[str], int]] = [
|
| 128 |
+
(["hips","pelvis","root"], 0),
|
| 129 |
+
(["leftupleg","l_upleg","leftthigh","lefthip","thigh_l"], 1),
|
| 130 |
+
(["rightupleg","r_upleg","rightthigh","righthip","thigh_r"], 2),
|
| 131 |
+
(["spine","spine0","spine_01"], 3),
|
| 132 |
+
(["leftleg","leftknee","lowerleg_l","knee_l"], 4),
|
| 133 |
+
(["rightleg","rightknee","lowerleg_r","knee_r"], 5),
|
| 134 |
+
(["spine1","spine_02"], 6),
|
| 135 |
+
(["leftfoot","l_foot","foot_l"], 7),
|
| 136 |
+
(["rightfoot","r_foot","foot_r"], 8),
|
| 137 |
+
(["spine2","spine_03","chest"], 9),
|
| 138 |
+
(["lefttoebase","lefttoe","l_toe","toe_l"], 10),
|
| 139 |
+
(["righttoebase","righttoe","r_toe","toe_r"], 11),
|
| 140 |
+
(["neck"], 12),
|
| 141 |
+
(["leftshoulder","leftcollar","clavicle_l"], 13),
|
| 142 |
+
(["rightshoulder","rightcollar","clavicle_r"], 14),
|
| 143 |
+
(["head"], 15),
|
| 144 |
+
(["leftarm","upperarm_l","l_arm"], 16),
|
| 145 |
+
(["rightarm","upperarm_r","r_arm"], 17),
|
| 146 |
+
(["leftforearm","lowerarm_l","l_forearm"], 18),
|
| 147 |
+
(["rightforearm","lowerarm_r","r_forearm"], 19),
|
| 148 |
+
(["lefthand","hand_l","l_hand"], 20),
|
| 149 |
+
(["righthand","hand_r","r_hand"], 21),
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 154 |
+
# Quaternion helpers (scalar-first WXYZ convention throughout)
|
| 155 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 156 |
+
|
| 157 |
+
_ID_QUAT = np.array([1., 0., 0., 0.], dtype=np.float32)
|
| 158 |
+
_ID_MAT3 = np.eye(3, dtype=np.float64)
|
| 159 |
+
|
| 160 |
+
def _qmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 161 |
+
aw, ax, ay, az = a
|
| 162 |
+
bw, bx, by, bz = b
|
| 163 |
+
return np.array([
|
| 164 |
+
aw*bw - ax*bx - ay*by - az*bz,
|
| 165 |
+
aw*bx + ax*bw + ay*bz - az*by,
|
| 166 |
+
aw*by - ax*bz + ay*bw + az*bx,
|
| 167 |
+
aw*bz + ax*by - ay*bx + az*bw,
|
| 168 |
+
], dtype=np.float32)
|
| 169 |
+
|
| 170 |
+
def _qnorm(q: np.ndarray) -> np.ndarray:
|
| 171 |
+
n = np.linalg.norm(q)
|
| 172 |
+
return (q / n) if n > 1e-12 else _ID_QUAT.copy()
|
| 173 |
+
|
| 174 |
+
def _qinv(q: np.ndarray) -> np.ndarray:
|
| 175 |
+
"""Conjugate = inverse for unit quaternion."""
|
| 176 |
+
return q * np.array([1., -1., -1., -1.], dtype=np.float32)
|
| 177 |
+
|
| 178 |
+
def _quat_to_mat(q: np.ndarray) -> np.ndarray:
|
| 179 |
+
"""WXYZ quaternion -> 3x3 rotation matrix (float64)."""
|
| 180 |
+
w, x, y, z = q.astype(np.float64)
|
| 181 |
+
return np.array([
|
| 182 |
+
[1-2*(y*y+z*z), 2*(x*y-w*z), 2*(x*z+w*y)],
|
| 183 |
+
[ 2*(x*y+w*z), 1-2*(x*x+z*z), 2*(y*z-w*x)],
|
| 184 |
+
[ 2*(x*z-w*y), 2*(y*z+w*x), 1-2*(x*x+y*y)],
|
| 185 |
+
], dtype=np.float64)
|
| 186 |
+
|
| 187 |
+
def _mat_to_quat(m: np.ndarray) -> np.ndarray:
|
| 188 |
+
"""3x3 rotation matrix -> WXYZ quaternion (float32, positive-W)."""
|
| 189 |
+
from scipy.spatial.transform import Rotation
|
| 190 |
+
xyzw = Rotation.from_matrix(m.astype(np.float64)).as_quat()
|
| 191 |
+
wxyz = np.array([xyzw[3], xyzw[0], xyzw[1], xyzw[2]], dtype=np.float32)
|
| 192 |
+
if wxyz[0] < 0:
|
| 193 |
+
wxyz = -wxyz
|
| 194 |
+
return wxyz
|
| 195 |
+
|
| 196 |
+
def _r_between(u: np.ndarray, v: np.ndarray) -> np.ndarray:
|
| 197 |
+
"""
|
| 198 |
+
Minimal rotation matrix (3x3) that maps unit vector u to unit vector v.
|
| 199 |
+
Uses the Rodrigues formula; handles parallel/antiparallel cases.
|
| 200 |
+
"""
|
| 201 |
+
u = u / (np.linalg.norm(u) + 1e-12)
|
| 202 |
+
v = v / (np.linalg.norm(v) + 1e-12)
|
| 203 |
+
c = float(np.dot(u, v))
|
| 204 |
+
if c >= 1.0 - 1e-7:
|
| 205 |
+
return _ID_MAT3.copy()
|
| 206 |
+
if c <= -1.0 + 1e-7:
|
| 207 |
+
# 180 degree rotation: pick any perpendicular axis
|
| 208 |
+
perp = np.array([1., 0., 0.]) if abs(u[0]) < 0.9 else np.array([0., 1., 0.])
|
| 209 |
+
ax = np.cross(u, perp)
|
| 210 |
+
ax /= np.linalg.norm(ax)
|
| 211 |
+
return 2.0 * np.outer(ax, ax) - _ID_MAT3
|
| 212 |
+
ax = np.cross(u, v) # sin(theta) * rotation axis
|
| 213 |
+
s = np.linalg.norm(ax)
|
| 214 |
+
K = np.array([[ 0, -ax[2], ax[1]],
|
| 215 |
+
[ ax[2], 0, -ax[0]],
|
| 216 |
+
[-ax[1], ax[0], 0]], dtype=np.float64)
|
| 217 |
+
return _ID_MAT3 + K + K @ K * ((1.0 - c) / (s * s + 1e-12))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 221 |
+
# GLB skin reader
|
| 222 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 223 |
+
|
| 224 |
+
def _read_glb_skin(rigged_glb: str):
|
| 225 |
+
"""
|
| 226 |
+
Return (gltf, skin, ibm[n,4,4], node_trs{name->(t,r_wxyz,s)},
|
| 227 |
+
bone_names[], bone_parent_map{name->parent_name_or_None}).
|
| 228 |
+
|
| 229 |
+
ibm is stored as-read from the binary blob (column-major from glTF spec).
|
| 230 |
+
Callers must use inv(ibm[i].T)[:3,3] to get correct world positions.
|
| 231 |
+
"""
|
| 232 |
+
import base64
|
| 233 |
+
import pygltflib
|
| 234 |
+
|
| 235 |
+
gltf = pygltflib.GLTF2().load(rigged_glb)
|
| 236 |
+
if not gltf.skins:
|
| 237 |
+
raise ValueError(f"No skin found in {rigged_glb}")
|
| 238 |
+
skin = gltf.skins[0]
|
| 239 |
+
|
| 240 |
+
def _raw_bytes(buf):
|
| 241 |
+
if buf.uri is None:
|
| 242 |
+
return bytes(gltf.binary_blob())
|
| 243 |
+
if buf.uri.startswith("data:"):
|
| 244 |
+
return base64.b64decode(buf.uri.split(",", 1)[1])
|
| 245 |
+
from pathlib import Path
|
| 246 |
+
return (Path(rigged_glb).parent / buf.uri).read_bytes()
|
| 247 |
+
|
| 248 |
+
acc = gltf.accessors[skin.inverseBindMatrices]
|
| 249 |
+
bv = gltf.bufferViews[acc.bufferView]
|
| 250 |
+
raw = _raw_bytes(gltf.buffers[bv.buffer])
|
| 251 |
+
start = (bv.byteOffset or 0) + (acc.byteOffset or 0)
|
| 252 |
+
n = acc.count
|
| 253 |
+
ibm = np.frombuffer(raw[start: start + n * 64], dtype=np.float32).reshape(n, 4, 4)
|
| 254 |
+
|
| 255 |
+
# Build node parent map (node_index -> parent_node_index)
|
| 256 |
+
node_parent: dict[int, int] = {}
|
| 257 |
+
for ni, node in enumerate(gltf.nodes):
|
| 258 |
+
for child_idx in (node.children or []):
|
| 259 |
+
node_parent[child_idx] = ni
|
| 260 |
+
|
| 261 |
+
joint_set = set(skin.joints)
|
| 262 |
+
bone_names = []
|
| 263 |
+
node_trs: dict[str, tuple] = {}
|
| 264 |
+
bone_parent_map: dict[str, str | None] = {}
|
| 265 |
+
|
| 266 |
+
for i, j_idx in enumerate(skin.joints):
|
| 267 |
+
node = gltf.nodes[j_idx]
|
| 268 |
+
name = node.name or f"bone_{i}"
|
| 269 |
+
bone_names.append(name)
|
| 270 |
+
|
| 271 |
+
t = np.array(node.translation or [0., 0., 0.], dtype=np.float32)
|
| 272 |
+
r_xyzw = np.array(node.rotation or [0., 0., 0., 1.], dtype=np.float32)
|
| 273 |
+
s = np.array(node.scale or [1., 1., 1.], dtype=np.float32)
|
| 274 |
+
r_wxyz = np.array([r_xyzw[3], r_xyzw[0], r_xyzw[1], r_xyzw[2]], dtype=np.float32)
|
| 275 |
+
node_trs[name] = (t, r_wxyz, s)
|
| 276 |
+
|
| 277 |
+
# Find parent bone (walk up node hierarchy to nearest joint)
|
| 278 |
+
parent_node = node_parent.get(j_idx)
|
| 279 |
+
parent_name: str | None = None
|
| 280 |
+
while parent_node is not None:
|
| 281 |
+
if parent_node in joint_set:
|
| 282 |
+
pnode = gltf.nodes[parent_node]
|
| 283 |
+
parent_name = pnode.name or f"bone_{skin.joints.index(parent_node)}"
|
| 284 |
+
break
|
| 285 |
+
parent_node = node_parent.get(parent_node)
|
| 286 |
+
bone_parent_map[name] = parent_name
|
| 287 |
+
|
| 288 |
+
print(f"[GLB] {len(bone_names)} bones from skin '{skin.name or 'Armature'}'")
|
| 289 |
+
return gltf, skin, ibm, node_trs, bone_names, bone_parent_map
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 293 |
+
# Bone mapping
|
| 294 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 295 |
+
|
| 296 |
+
def _strip_name(name: str) -> str:
|
| 297 |
+
name = re.sub(r'^(mixamorig:|j_bip_[lcr]_|cc_base_|bip01_|rig:|chr:)',
|
| 298 |
+
"", name, flags=re.IGNORECASE)
|
| 299 |
+
return re.sub(r'[_\-\s.]', "", name).lower()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def build_bone_map(
|
| 303 |
+
rigged_glb: str,
|
| 304 |
+
verbose: bool = True,
|
| 305 |
+
) -> tuple[dict, dict, float, dict, dict]:
|
| 306 |
+
"""
|
| 307 |
+
Map UniRig bone names -> SMPL joint index by spatial proximity + name hints.
|
| 308 |
+
|
| 309 |
+
Returns
|
| 310 |
+
-------
|
| 311 |
+
bone_to_smpl : {bone_name: smpl_joint_index}
|
| 312 |
+
node_trs : {bone_name: (t[3], r_wxyz[4], s[3])}
|
| 313 |
+
height_scale : float (UniRig height / SMPL reference height)
|
| 314 |
+
bone_parent_map : {bone_name: parent_bone_name_or_None}
|
| 315 |
+
ur_pos_by_name : {bone_name: world_pos[3]}
|
| 316 |
+
"""
|
| 317 |
+
_gltf, _skin, ibm, node_trs, bone_names, bone_parent_map = _read_glb_skin(rigged_glb)
|
| 318 |
+
|
| 319 |
+
# FIX: glTF IBMs are stored column-major.
|
| 320 |
+
# numpy reads as row-major, so the stored data is the TRANSPOSE of the actual matrix.
|
| 321 |
+
# Correct world position = inv(actual_IBM)[:3,3] = inv(ibm[i].T)[:3,3]
|
| 322 |
+
ur_pos = np.array([
|
| 323 |
+
np.linalg.inv(ibm[i].T)[:3, 3] for i in range(len(bone_names))
|
| 324 |
+
], dtype=np.float32)
|
| 325 |
+
|
| 326 |
+
ur_pos_by_name = {name: ur_pos[i] for i, name in enumerate(bone_names)}
|
| 327 |
+
|
| 328 |
+
# Scale SMPL T-pose to match character height
|
| 329 |
+
ur_h = ur_pos[:, 1].max() - ur_pos[:, 1].min()
|
| 330 |
+
sm_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
|
| 331 |
+
h_sc = (ur_h / sm_h) if sm_h > 1e-6 else 1.0
|
| 332 |
+
sm_pos = SMPL_TPOSE * h_sc
|
| 333 |
+
|
| 334 |
+
# FIX: Normalise ur and smpl Y ranges independently (floor=0, top=1 for each).
|
| 335 |
+
# The old code used a shared reference which caused floor offsets to misalign.
|
| 336 |
+
def _norm_independent(pos, own_range_min, own_range_max, x_range, z_range):
|
| 337 |
+
p = pos.copy().astype(np.float64)
|
| 338 |
+
y_range = (own_range_max - own_range_min) or 1.0
|
| 339 |
+
p[:, 0] /= (x_range or 1.0)
|
| 340 |
+
p[:, 1] = (p[:, 1] - own_range_min) / y_range
|
| 341 |
+
p[:, 2] /= (z_range or 1.0)
|
| 342 |
+
return p
|
| 343 |
+
|
| 344 |
+
# Common X/Z scale (use both skeletons' width for reference)
|
| 345 |
+
x_range = max(
|
| 346 |
+
abs(ur_pos[:, 0].max() - ur_pos[:, 0].min()),
|
| 347 |
+
abs(sm_pos[:, 0].max() - sm_pos[:, 0].min()),
|
| 348 |
+
) or 1.0
|
| 349 |
+
z_range = max(
|
| 350 |
+
abs(ur_pos[:, 2].max() - ur_pos[:, 2].min()),
|
| 351 |
+
abs(sm_pos[:, 2].max() - sm_pos[:, 2].min()),
|
| 352 |
+
) or 1.0
|
| 353 |
+
|
| 354 |
+
ur_n = _norm_independent(ur_pos, ur_pos[:, 1].min(), ur_pos[:, 1].max(), x_range, z_range)
|
| 355 |
+
sm_n = _norm_independent(sm_pos, sm_pos[:, 1].min(), sm_pos[:, 1].max(), x_range, z_range)
|
| 356 |
+
|
| 357 |
+
dist = np.linalg.norm(ur_n[:, None] - sm_n[None], axis=-1) # [M, 22]
|
| 358 |
+
d_sc = 1.0 - np.clip(dist / (dist.max() + 1e-9), 0, 1)
|
| 359 |
+
|
| 360 |
+
# Name hint score
|
| 361 |
+
n_sc = np.zeros((len(bone_names), 22), dtype=np.float32)
|
| 362 |
+
for mi, bname in enumerate(bone_names):
|
| 363 |
+
stripped = _strip_name(bname)
|
| 364 |
+
for kws, ji in _NAME_HINTS:
|
| 365 |
+
if any(kw in stripped for kw in kws):
|
| 366 |
+
n_sc[mi, ji] = 1.0
|
| 367 |
+
|
| 368 |
+
combined = 0.6 * d_sc + 0.4 * n_sc # [M, 22]
|
| 369 |
+
|
| 370 |
+
# Greedy assignment
|
| 371 |
+
THRESHOLD = 0.35
|
| 372 |
+
pairs = sorted(
|
| 373 |
+
((mi, ji, combined[mi, ji])
|
| 374 |
+
for mi in range(len(bone_names))
|
| 375 |
+
for ji in range(22)),
|
| 376 |
+
key=lambda x: -x[2],
|
| 377 |
+
)
|
| 378 |
+
bone_to_smpl: dict[str, int] = {}
|
| 379 |
+
taken: set[int] = set()
|
| 380 |
+
for mi, ji, score in pairs:
|
| 381 |
+
if score < THRESHOLD:
|
| 382 |
+
break
|
| 383 |
+
bname = bone_names[mi]
|
| 384 |
+
if bname in bone_to_smpl or ji in taken:
|
| 385 |
+
continue
|
| 386 |
+
bone_to_smpl[bname] = ji
|
| 387 |
+
taken.add(ji)
|
| 388 |
+
|
| 389 |
+
if verbose:
|
| 390 |
+
n_mapped = len(bone_to_smpl)
|
| 391 |
+
print(f"\n[MAP] {n_mapped}/{len(bone_names)} bones mapped to SMPL joints:")
|
| 392 |
+
for bname, ji in sorted(bone_to_smpl.items(), key=lambda x: x[1]):
|
| 393 |
+
print(f" {bname:<40} -> {SMPL_NAMES[ji]}")
|
| 394 |
+
unmapped = [n for n in bone_names if n not in bone_to_smpl]
|
| 395 |
+
if unmapped:
|
| 396 |
+
preview = ", ".join(unmapped[:8])
|
| 397 |
+
print(f"[MAP] {len(unmapped)} unmapped (identity): {preview}"
|
| 398 |
+
+ (" ..." if len(unmapped) > 8 else ""))
|
| 399 |
+
print()
|
| 400 |
+
|
| 401 |
+
return bone_to_smpl, node_trs, h_sc, bone_parent_map, ur_pos_by_name
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 405 |
+
# T2M forward kinematics (world rotation matrices)
|
| 406 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 407 |
+
|
| 408 |
+
def _compute_t2m_world_rots(
|
| 409 |
+
root_rot_wxyz: np.ndarray, # [4] WXYZ
|
| 410 |
+
local_rots_wxyz: np.ndarray, # [21, 4] WXYZ (joints 1-21)
|
| 411 |
+
) -> np.ndarray:
|
| 412 |
+
"""
|
| 413 |
+
Compute accumulated world rotation matrices for all 22 t2m joints at one frame.
|
| 414 |
+
Matches skeleton.py's forward_kinematics_cont6d_np: each chain RESETS to R_root.
|
| 415 |
+
|
| 416 |
+
Returns [22, 3, 3] world rotation matrices.
|
| 417 |
+
"""
|
| 418 |
+
R_root = _quat_to_mat(root_rot_wxyz)
|
| 419 |
+
world_rots = np.zeros((22, 3, 3), dtype=np.float64)
|
| 420 |
+
world_rots[0] = R_root
|
| 421 |
+
|
| 422 |
+
for chain in T2M_KINEMATIC_CHAIN:
|
| 423 |
+
R = R_root.copy() # always start from R_root (matches skeleton.py)
|
| 424 |
+
for i in range(1, len(chain)):
|
| 425 |
+
j = chain[i]
|
| 426 |
+
R_local = _quat_to_mat(local_rots_wxyz[j - 1]) # j-1: joints 1-21
|
| 427 |
+
R = R @ R_local
|
| 428 |
+
world_rots[j] = R
|
| 429 |
+
|
| 430 |
+
return world_rots
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 434 |
+
# Keyframe builder — world-direction matching
|
| 435 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 436 |
+
|
| 437 |
+
def build_keyframes(
|
| 438 |
+
motion: SMPLMotion,
|
| 439 |
+
bone_to_smpl: dict[str, int],
|
| 440 |
+
node_trs: dict[str, tuple],
|
| 441 |
+
height_scale: float,
|
| 442 |
+
bone_parent_map: dict[str, str | None],
|
| 443 |
+
ur_pos_by_name: dict[str, np.ndarray],
|
| 444 |
+
) -> list[dict]:
|
| 445 |
+
"""
|
| 446 |
+
Convert SMPLMotion -> List[Dict[bone_name -> (loc, rot_delta, scale)]]
|
| 447 |
+
using world-direction matching retargeting.
|
| 448 |
+
"""
|
| 449 |
+
T = motion.num_frames
|
| 450 |
+
zeros3 = np.zeros(3, dtype=np.float32)
|
| 451 |
+
ones3 = np.ones(3, dtype=np.float32)
|
| 452 |
+
|
| 453 |
+
# Topological order: root joints (si==0) first, then by SMPL joint index
|
| 454 |
+
# (parents always have lower SMPL indices in the kinematic chain)
|
| 455 |
+
sorted_bones = sorted(bone_to_smpl.keys(), key=lambda b: bone_to_smpl[b])
|
| 456 |
+
|
| 457 |
+
keyframes: list[dict] = []
|
| 458 |
+
|
| 459 |
+
for ti in range(T):
|
| 460 |
+
frame: dict = {}
|
| 461 |
+
|
| 462 |
+
# T2M world rotation matrices for this frame
|
| 463 |
+
world_rots_t2m = _compute_t2m_world_rots(
|
| 464 |
+
motion.root_rot[ti].astype(np.float64),
|
| 465 |
+
motion.local_rot[ti].astype(np.float64),
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Track UniRig world rotations per bone (needed for child local rotations)
|
| 469 |
+
world_rot_ur: dict[str, np.ndarray] = {}
|
| 470 |
+
|
| 471 |
+
for bname in sorted_bones:
|
| 472 |
+
si = bone_to_smpl[bname]
|
| 473 |
+
rest_t, rest_r, _rest_s = node_trs[bname]
|
| 474 |
+
rest_t = rest_t.astype(np.float32)
|
| 475 |
+
rest_r_mat = _quat_to_mat(rest_r)
|
| 476 |
+
|
| 477 |
+
# ── Root bone (si == 0): drive world translation + facing rotation ──
|
| 478 |
+
if si == 0:
|
| 479 |
+
world_pos = motion.root_pos[ti].astype(np.float64) * height_scale
|
| 480 |
+
pose_loc = (world_pos - rest_t.astype(np.float64)).astype(np.float32)
|
| 481 |
+
|
| 482 |
+
# Root world rotation = t2m root rotation (Y-axis only)
|
| 483 |
+
R_world_root = _quat_to_mat(motion.root_rot[ti])
|
| 484 |
+
world_rot_ur[bname] = R_world_root
|
| 485 |
+
|
| 486 |
+
# pose_rot_delta = inv(rest_r) @ target_world_rot
|
| 487 |
+
pose_rot_mat = rest_r_mat.T @ R_world_root
|
| 488 |
+
pose_rot = _mat_to_quat(pose_rot_mat)
|
| 489 |
+
frame[bname] = (pose_loc, pose_rot, ones3)
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
# ── Non-root bone: world-direction matching ──────────────────────
|
| 493 |
+
|
| 494 |
+
# T2M world bone direction (in t2m coordinate frame)
|
| 495 |
+
raw_dir_t2m = world_rots_t2m[si] @ T2M_RAW_OFFSETS[si] # [3]
|
| 496 |
+
|
| 497 |
+
# COORDINATE FRAME CONVERSION: t2m +X = character LEFT; SMPL +X = character RIGHT
|
| 498 |
+
# Flip X to convert t2m world directions -> SMPL/UniRig world directions
|
| 499 |
+
d_desired = np.array([-raw_dir_t2m[0], raw_dir_t2m[1], raw_dir_t2m[2]])
|
| 500 |
+
d_desired_norm = d_desired / (np.linalg.norm(d_desired) + 1e-12)
|
| 501 |
+
|
| 502 |
+
# UniRig rest bone direction (from inverse bind matrices, world space)
|
| 503 |
+
parent_b = bone_parent_map.get(bname)
|
| 504 |
+
if parent_b and parent_b in ur_pos_by_name:
|
| 505 |
+
d_rest = (ur_pos_by_name[bname] - ur_pos_by_name[parent_b]).astype(np.float64)
|
| 506 |
+
else:
|
| 507 |
+
d_rest = ur_pos_by_name[bname].astype(np.float64)
|
| 508 |
+
d_rest_norm = d_rest / (np.linalg.norm(d_rest) + 1e-12)
|
| 509 |
+
|
| 510 |
+
# Minimal world-space rotation: rest direction -> desired direction
|
| 511 |
+
R_world_desired = _r_between(d_rest_norm, d_desired_norm) # [3, 3]
|
| 512 |
+
world_rot_ur[bname] = R_world_desired
|
| 513 |
+
|
| 514 |
+
# Local rotation = inv(parent_world) @ R_world_desired
|
| 515 |
+
if parent_b and parent_b in world_rot_ur:
|
| 516 |
+
R_parent = world_rot_ur[parent_b]
|
| 517 |
+
else:
|
| 518 |
+
R_parent = _ID_MAT3
|
| 519 |
+
|
| 520 |
+
local_rot_mat = R_parent.T @ R_world_desired # R_parent^-1 @ R_world
|
| 521 |
+
|
| 522 |
+
# pose_rot_delta = inv(rest_r) @ local_rot
|
| 523 |
+
# (glTF applies: final = rest_r @ pose_rot_delta = local_rot)
|
| 524 |
+
pose_rot_mat = rest_r_mat.T @ local_rot_mat
|
| 525 |
+
pose_rot = _mat_to_quat(pose_rot_mat)
|
| 526 |
+
|
| 527 |
+
frame[bname] = (zeros3, pose_rot, ones3)
|
| 528 |
+
|
| 529 |
+
keyframes.append(frame)
|
| 530 |
+
|
| 531 |
+
return keyframes
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 535 |
+
# Public API
|
| 536 |
+
# ��─────────────────────────────────────────────────────────────────────────────
|
| 537 |
+
|
| 538 |
+
def animate_glb(
|
| 539 |
+
motion: Union[np.ndarray, list, SMPLMotion],
|
| 540 |
+
rigged_glb: str,
|
| 541 |
+
output_glb: str,
|
| 542 |
+
fps: float = 20.0,
|
| 543 |
+
start_frame: int = 0,
|
| 544 |
+
num_frames: int = -1,
|
| 545 |
+
) -> str:
|
| 546 |
+
"""
|
| 547 |
+
Bake a HumanML3D motion clip onto a UniRig-rigged GLB.
|
| 548 |
+
|
| 549 |
+
Parameters
|
| 550 |
+
----------
|
| 551 |
+
motion : [T, 263] ndarray, list, or pre-parsed SMPLMotion
|
| 552 |
+
rigged_glb : path to UniRig merge output (.glb with a skin)
|
| 553 |
+
output_glb : destination path for animated GLB
|
| 554 |
+
fps : frame rate embedded in the animation track
|
| 555 |
+
start_frame / num_frames : optional clip range (-1 = all frames)
|
| 556 |
+
|
| 557 |
+
Returns str absolute path to output_glb.
|
| 558 |
+
"""
|
| 559 |
+
from .io.gltf_io import write_gltf_animation
|
| 560 |
+
|
| 561 |
+
# 1. Parse motion
|
| 562 |
+
if isinstance(motion, SMPLMotion):
|
| 563 |
+
smpl = motion
|
| 564 |
+
else:
|
| 565 |
+
data = np.asarray(motion, dtype=np.float32)
|
| 566 |
+
if data.ndim != 2 or data.shape[1] < 193:
|
| 567 |
+
raise ValueError(f"Expected [T, 263] HumanML3D features, got {data.shape}")
|
| 568 |
+
smpl = hml3d_to_smpl_motion(data, fps=fps)
|
| 569 |
+
|
| 570 |
+
# 2. Slice
|
| 571 |
+
end = (start_frame + num_frames) if num_frames > 0 else smpl.num_frames
|
| 572 |
+
smpl = smpl.slice(start_frame, end)
|
| 573 |
+
print(f"[animate] {smpl.num_frames} frames @ {fps:.0f} fps -> {output_glb}")
|
| 574 |
+
|
| 575 |
+
# 3. Build bone map (now returns parent map and world positions too)
|
| 576 |
+
bone_to_smpl, node_trs, h_sc, bone_parent_map, ur_pos_by_name = \
|
| 577 |
+
build_bone_map(rigged_glb, verbose=True)
|
| 578 |
+
if not bone_to_smpl:
|
| 579 |
+
raise RuntimeError(
|
| 580 |
+
"build_bone_map returned 0 matches. "
|
| 581 |
+
"Ensure the GLB has a valid skin with readable inverse bind matrices."
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# 4. Build keyframes using world-direction matching
|
| 585 |
+
keyframes = build_keyframes(smpl, bone_to_smpl, node_trs, h_sc,
|
| 586 |
+
bone_parent_map, ur_pos_by_name)
|
| 587 |
+
|
| 588 |
+
# 5. Write GLB
|
| 589 |
+
out_dir = os.path.dirname(os.path.abspath(output_glb))
|
| 590 |
+
if out_dir:
|
| 591 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 592 |
+
|
| 593 |
+
write_gltf_animation(
|
| 594 |
+
source_filepath=rigged_glb,
|
| 595 |
+
dest_armature=None,
|
| 596 |
+
keyframes=keyframes,
|
| 597 |
+
output_filepath=output_glb,
|
| 598 |
+
fps=float(fps),
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
return output_glb
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
# Backwards-compatibility alias
|
| 605 |
+
def animate_glb_from_hml3d(
|
| 606 |
+
motion, rigged_glb, output_glb, fps=20, start_frame=0, num_frames=-1
|
| 607 |
+
):
|
| 608 |
+
return animate_glb(
|
| 609 |
+
motion, rigged_glb, output_glb,
|
| 610 |
+
fps=fps, start_frame=start_frame, num_frames=num_frames,
|
| 611 |
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
cli.py
|
| 3 |
+
Command-line interface for rig_retarget.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python -m rig_retarget.cli \\
|
| 7 |
+
--source walk.bvh \\
|
| 8 |
+
--dest unirig_character.glb \\
|
| 9 |
+
--mapping radical2unirig.json \\
|
| 10 |
+
--output animated_character.glb \\
|
| 11 |
+
[--fps 30] [--start 0] [--frames 100] [--step 1]
|
| 12 |
+
|
| 13 |
+
# Calculate corrections only (no transfer):
|
| 14 |
+
python -m rig_retarget.cli --calc-corrections \\
|
| 15 |
+
--source walk.bvh --dest unirig_character.glb \\
|
| 16 |
+
--mapping mymap.json
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
import argparse
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _parse_args(argv=None):
|
| 25 |
+
p = argparse.ArgumentParser(
|
| 26 |
+
prog="rig_retarget",
|
| 27 |
+
description="Retarget animation from BVH/glTF source onto UniRig/glTF destination.",
|
| 28 |
+
)
|
| 29 |
+
p.add_argument("--source", required=True, help="Source animation file (.bvh or .glb/.gltf)")
|
| 30 |
+
p.add_argument("--dest", required=True, help="Destination skeleton file (.glb/.gltf, UniRig output)")
|
| 31 |
+
p.add_argument("--mapping", required=True, help="KeeMap-compatible JSON bone mapping file")
|
| 32 |
+
p.add_argument("--output", default=None, help="Output animated .glb (default: dest_retargeted.glb)")
|
| 33 |
+
p.add_argument("--fps", type=float, default=30.0)
|
| 34 |
+
p.add_argument("--start", type=int, default=0, help="Start frame index (0-based)")
|
| 35 |
+
p.add_argument("--frames", type=int, default=None, help="Number of frames to transfer (default: all)")
|
| 36 |
+
p.add_argument("--step", type=int, default=1, help="Keyframe every N source frames")
|
| 37 |
+
p.add_argument("--skin", type=int, default=0, help="Skin index in destination glTF")
|
| 38 |
+
p.add_argument("--calc-corrections", action="store_true",
|
| 39 |
+
help="Auto-calculate bone corrections and update the mapping JSON, then exit.")
|
| 40 |
+
p.add_argument("--verbose", action="store_true")
|
| 41 |
+
return p.parse_args(argv)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main(argv=None) -> None:
|
| 45 |
+
args = _parse_args(argv)
|
| 46 |
+
|
| 47 |
+
from .io.mapping import load_mapping, save_mapping, KeeMapSettings
|
| 48 |
+
from .io.gltf_io import load_gltf, write_gltf_animation
|
| 49 |
+
from .retarget import (
|
| 50 |
+
calc_all_corrections, transfer_animation,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# -----------------------------------------------------------------------
|
| 54 |
+
# Load mapping
|
| 55 |
+
# -----------------------------------------------------------------------
|
| 56 |
+
print(f"[*] Loading mapping : {args.mapping}")
|
| 57 |
+
settings, bone_items = load_mapping(args.mapping)
|
| 58 |
+
|
| 59 |
+
# Override settings from CLI args
|
| 60 |
+
settings.start_frame_to_apply = args.start
|
| 61 |
+
settings.keyframe_every_n_frames = args.step
|
| 62 |
+
|
| 63 |
+
# -----------------------------------------------------------------------
|
| 64 |
+
# Load source animation
|
| 65 |
+
# -----------------------------------------------------------------------
|
| 66 |
+
src_path = Path(args.source)
|
| 67 |
+
print(f"[*] Loading source : {src_path}")
|
| 68 |
+
|
| 69 |
+
if src_path.suffix.lower() == ".bvh":
|
| 70 |
+
from .io.bvh import load_bvh
|
| 71 |
+
src_anim = load_bvh(str(src_path))
|
| 72 |
+
if args.verbose:
|
| 73 |
+
print(f" BVH: {src_anim.num_frames} frames, "
|
| 74 |
+
f"{src_anim.frame_time*1000:.1f} ms/frame, "
|
| 75 |
+
f"{len(src_anim.armature.pose_bones)} joints")
|
| 76 |
+
elif src_path.suffix.lower() in (".glb", ".gltf"):
|
| 77 |
+
# glTF source — load skeleton only; animation reading is TODO
|
| 78 |
+
raise NotImplementedError(
|
| 79 |
+
"glTF source animation reading is not yet implemented. "
|
| 80 |
+
"Use a BVH file for the source animation."
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
print(f"[!] Unsupported source format: {src_path.suffix}", file=sys.stderr)
|
| 84 |
+
sys.exit(1)
|
| 85 |
+
|
| 86 |
+
if args.frames is not None:
|
| 87 |
+
settings.number_of_frames_to_apply = args.frames
|
| 88 |
+
else:
|
| 89 |
+
settings.number_of_frames_to_apply = src_anim.num_frames - args.start
|
| 90 |
+
|
| 91 |
+
# -----------------------------------------------------------------------
|
| 92 |
+
# Load destination skeleton
|
| 93 |
+
# -----------------------------------------------------------------------
|
| 94 |
+
dst_path = Path(args.dest)
|
| 95 |
+
print(f"[*] Loading dest : {dst_path}")
|
| 96 |
+
dst_arm = load_gltf(str(dst_path), skin_index=args.skin)
|
| 97 |
+
if args.verbose:
|
| 98 |
+
print(f" Skeleton: {len(dst_arm.pose_bones)} bones")
|
| 99 |
+
|
| 100 |
+
# -----------------------------------------------------------------------
|
| 101 |
+
# Auto-correct pass (optional)
|
| 102 |
+
# -----------------------------------------------------------------------
|
| 103 |
+
if args.calc_corrections:
|
| 104 |
+
print("[*] Calculating bone corrections ...")
|
| 105 |
+
src_anim.apply_frame(args.start)
|
| 106 |
+
calc_all_corrections(bone_items, src_anim.armature, dst_arm, settings)
|
| 107 |
+
save_mapping(args.mapping, settings, bone_items)
|
| 108 |
+
print(f"[*] Updated mapping saved → {args.mapping}")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
# -----------------------------------------------------------------------
|
| 112 |
+
# Transfer
|
| 113 |
+
# -----------------------------------------------------------------------
|
| 114 |
+
print(f"[*] Transferring {settings.number_of_frames_to_apply} frames "
|
| 115 |
+
f"(start={settings.start_frame_to_apply}, step={settings.keyframe_every_n_frames}) ...")
|
| 116 |
+
keyframes = transfer_animation(src_anim, dst_arm, bone_items, settings)
|
| 117 |
+
print(f"[*] Generated {len(keyframes)} keyframes")
|
| 118 |
+
|
| 119 |
+
# -----------------------------------------------------------------------
|
| 120 |
+
# Write output
|
| 121 |
+
# -----------------------------------------------------------------------
|
| 122 |
+
out_path = args.output or str(dst_path.with_name(dst_path.stem + "_retargeted.glb"))
|
| 123 |
+
print(f"[*] Writing output : {out_path}")
|
| 124 |
+
write_gltf_animation(str(dst_path), dst_arm, keyframes, out_path, fps=args.fps, skin_index=args.skin)
|
| 125 |
+
print("[✓] Done")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
generate.py
|
| 3 |
+
───────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
Text-to-motion generation.
|
| 5 |
+
|
| 6 |
+
Primary backend: MoMask inference server running on the Vast.ai instance.
|
| 7 |
+
Returns [T, 263] HumanML3D features directly — no SMPL
|
| 8 |
+
body mesh required.
|
| 9 |
+
|
| 10 |
+
Fallback backend: HumanML3D dataset keyword search (offline / no GPU needed).
|
| 11 |
+
|
| 12 |
+
Usage
|
| 13 |
+
─────
|
| 14 |
+
from Retarget.generate import generate_motion
|
| 15 |
+
|
| 16 |
+
# Use MoMask on instance
|
| 17 |
+
motion = generate_motion("a person walks forward",
|
| 18 |
+
backend_url="http://ssh4.vast.ai:8765")
|
| 19 |
+
|
| 20 |
+
# Local fallback (streams HuggingFace dataset)
|
| 21 |
+
motion = generate_motion("a person walks forward")
|
| 22 |
+
|
| 23 |
+
# Returned motion: np.ndarray [T, 263]
|
| 24 |
+
# Feed directly to animate_glb()
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
import json
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 32 |
+
# Public API
|
| 33 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
def generate_motion(
|
| 36 |
+
prompt: str,
|
| 37 |
+
backend_url: str | None = None,
|
| 38 |
+
num_frames: int = 196,
|
| 39 |
+
fps: float = 20.0,
|
| 40 |
+
seed: int = -1,
|
| 41 |
+
) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Generate a HumanML3D [T, 263] motion array from a text prompt.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
prompt
|
| 48 |
+
Natural language description of the desired motion.
|
| 49 |
+
Examples: "a person walks forward", "someone does a jumping jack",
|
| 50 |
+
"a man waves hello with his right hand"
|
| 51 |
+
backend_url
|
| 52 |
+
URL of the MoMask inference server. E.g. "http://ssh4.vast.ai:8765".
|
| 53 |
+
If None or if the server is unreachable, falls back to dataset search.
|
| 54 |
+
num_frames
|
| 55 |
+
Desired clip length in frames (at 20 fps; max ~196 ≈ 9.8 s).
|
| 56 |
+
fps
|
| 57 |
+
Target fps (MoMask natively produces 20 fps).
|
| 58 |
+
seed
|
| 59 |
+
Random seed for reproducibility (-1 = random).
|
| 60 |
+
|
| 61 |
+
Returns
|
| 62 |
+
-------
|
| 63 |
+
np.ndarray shape [T, 263] HumanML3D feature vector.
|
| 64 |
+
"""
|
| 65 |
+
if backend_url:
|
| 66 |
+
try:
|
| 67 |
+
return _call_momask(prompt, backend_url, num_frames, seed)
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
print(f"[generate] MoMask unreachable ({exc}) — falling back to dataset search")
|
| 70 |
+
|
| 71 |
+
return _dataset_search_fallback(prompt)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 75 |
+
# MoMask backend
|
| 76 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 77 |
+
|
| 78 |
+
def _call_momask(
|
| 79 |
+
prompt: str,
|
| 80 |
+
url: str,
|
| 81 |
+
num_frames: int,
|
| 82 |
+
seed: int,
|
| 83 |
+
) -> np.ndarray:
|
| 84 |
+
"""POST to the MoMask inference server; return [T, 263] array."""
|
| 85 |
+
import urllib.request
|
| 86 |
+
|
| 87 |
+
payload = json.dumps({
|
| 88 |
+
"prompt": prompt,
|
| 89 |
+
"num_frames": num_frames,
|
| 90 |
+
"seed": seed,
|
| 91 |
+
}).encode("utf-8")
|
| 92 |
+
|
| 93 |
+
req = urllib.request.Request(
|
| 94 |
+
f"{url.rstrip('/')}/generate",
|
| 95 |
+
data=payload,
|
| 96 |
+
headers={"Content-Type": "application/json"},
|
| 97 |
+
method="POST",
|
| 98 |
+
)
|
| 99 |
+
with urllib.request.urlopen(req, timeout=180) as resp:
|
| 100 |
+
result = json.loads(resp.read())
|
| 101 |
+
|
| 102 |
+
motion = np.array(result["motion"], dtype=np.float32)
|
| 103 |
+
if motion.ndim != 2 or motion.shape[1] < 193:
|
| 104 |
+
raise ValueError(f"Server returned unexpected shape {motion.shape}")
|
| 105 |
+
|
| 106 |
+
print(f"[generate] MoMask: {motion.shape[0]} frames for '{prompt}'")
|
| 107 |
+
return motion
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 111 |
+
# Dataset search fallback
|
| 112 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 113 |
+
|
| 114 |
+
def _dataset_search_fallback(prompt: str) -> np.ndarray:
|
| 115 |
+
"""
|
| 116 |
+
Keyword search in TeoGchx/HumanML3D dataset (streaming, HuggingFace).
|
| 117 |
+
Used when no MoMask server is available.
|
| 118 |
+
"""
|
| 119 |
+
from .search import search_motions, format_choice_label
|
| 120 |
+
|
| 121 |
+
print(f"[generate] Searching HumanML3D dataset for: '{prompt}'")
|
| 122 |
+
results = search_motions(prompt, top_k=5, split="test", max_scan=500)
|
| 123 |
+
if not results:
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
f"No motion found in dataset for prompt: {prompt!r}\n"
|
| 126 |
+
"Check your internet connection or deploy MoMask on the instance."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
best = results[0]
|
| 130 |
+
print(f"[generate] Best match: {format_choice_label(best)}")
|
| 131 |
+
return np.array(best["motion"], dtype=np.float32)
|
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
humanml3d_to_bvh.py
|
| 4 |
+
Convert HumanML3D .npy motion files → BVH animation.
|
| 5 |
+
When a UniRig-rigged GLB (or ASCII FBX) is supplied via --rig, the BVH is
|
| 6 |
+
built using the UniRig skeleton's own bone names and hierarchy, with
|
| 7 |
+
automatic bone-to-SMPL-joint mapping — no Blender required.
|
| 8 |
+
|
| 9 |
+
Dependencies
|
| 10 |
+
numpy (always required)
|
| 11 |
+
pygltflib pip install pygltflib (required for --rig GLB files)
|
| 12 |
+
|
| 13 |
+
Usage
|
| 14 |
+
# SMPL-named BVH (no rig needed)
|
| 15 |
+
python humanml3d_to_bvh.py 000001.npy
|
| 16 |
+
|
| 17 |
+
# Retargeted to UniRig skeleton
|
| 18 |
+
python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb
|
| 19 |
+
|
| 20 |
+
# Explicit output + fps
|
| 21 |
+
python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb -o anim.bvh --fps 20
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
import argparse, re, sys
|
| 26 |
+
from dataclasses import dataclass, field
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 33 |
+
# SMPL 22-joint skeleton definition
|
| 34 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 35 |
+
|
| 36 |
+
SMPL_NAMES = [
|
| 37 |
+
"Hips", # 0 pelvis / root
|
| 38 |
+
"LeftUpLeg", # 1 left_hip
|
| 39 |
+
"RightUpLeg", # 2 right_hip
|
| 40 |
+
"Spine", # 3 spine1
|
| 41 |
+
"LeftLeg", # 4 left_knee
|
| 42 |
+
"RightLeg", # 5 right_knee
|
| 43 |
+
"Spine1", # 6 spine2
|
| 44 |
+
"LeftFoot", # 7 left_ankle
|
| 45 |
+
"RightFoot", # 8 right_ankle
|
| 46 |
+
"Spine2", # 9 spine3
|
| 47 |
+
"LeftToeBase", # 10 left_foot
|
| 48 |
+
"RightToeBase", # 11 right_foot
|
| 49 |
+
"Neck", # 12 neck
|
| 50 |
+
"LeftShoulder", # 13 left_collar
|
| 51 |
+
"RightShoulder", # 14 right_collar
|
| 52 |
+
"Head", # 15 head
|
| 53 |
+
"LeftArm", # 16 left_shoulder
|
| 54 |
+
"RightArm", # 17 right_shoulder
|
| 55 |
+
"LeftForeArm", # 18 left_elbow
|
| 56 |
+
"RightForeArm", # 19 right_elbow
|
| 57 |
+
"LeftHand", # 20 left_wrist
|
| 58 |
+
"RightHand", # 21 right_wrist
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
SMPL_PARENT = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
|
| 62 |
+
NUM_SMPL = 22
|
| 63 |
+
|
| 64 |
+
SMPL_TPOSE = np.array([
|
| 65 |
+
[ 0.000, 0.920, 0.000], # 0 Hips
|
| 66 |
+
[-0.095, 0.920, 0.000], # 1 LeftUpLeg
|
| 67 |
+
[ 0.095, 0.920, 0.000], # 2 RightUpLeg
|
| 68 |
+
[ 0.000, 0.980, 0.000], # 3 Spine
|
| 69 |
+
[-0.095, 0.495, 0.000], # 4 LeftLeg
|
| 70 |
+
[ 0.095, 0.495, 0.000], # 5 RightLeg
|
| 71 |
+
[ 0.000, 1.050, 0.000], # 6 Spine1
|
| 72 |
+
[-0.095, 0.075, 0.000], # 7 LeftFoot
|
| 73 |
+
[ 0.095, 0.075, 0.000], # 8 RightFoot
|
| 74 |
+
[ 0.000, 1.120, 0.000], # 9 Spine2
|
| 75 |
+
[-0.095, 0.000, -0.020], # 10 LeftToeBase
|
| 76 |
+
[ 0.095, 0.000, -0.020], # 11 RightToeBase
|
| 77 |
+
[ 0.000, 1.370, 0.000], # 12 Neck
|
| 78 |
+
[-0.130, 1.290, 0.000], # 13 LeftShoulder
|
| 79 |
+
[ 0.130, 1.290, 0.000], # 14 RightShoulder
|
| 80 |
+
[ 0.000, 1.500, 0.000], # 15 Head
|
| 81 |
+
[-0.330, 1.290, 0.000], # 16 LeftArm
|
| 82 |
+
[ 0.330, 1.290, 0.000], # 17 RightArm
|
| 83 |
+
[-0.630, 1.290, 0.000], # 18 LeftForeArm
|
| 84 |
+
[ 0.630, 1.290, 0.000], # 19 RightForeArm
|
| 85 |
+
[-0.910, 1.290, 0.000], # 20 LeftHand
|
| 86 |
+
[ 0.910, 1.290, 0.000], # 21 RightHand
|
| 87 |
+
], dtype=np.float32)
|
| 88 |
+
|
| 89 |
+
_SMPL_CHILDREN: list[list[int]] = [[] for _ in range(NUM_SMPL)]
|
| 90 |
+
for _j, _p in enumerate(SMPL_PARENT):
|
| 91 |
+
if _p >= 0:
|
| 92 |
+
_SMPL_CHILDREN[_p].append(_j)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _smpl_dfs() -> list[int]:
|
| 96 |
+
order, stack = [], [0]
|
| 97 |
+
while stack:
|
| 98 |
+
j = stack.pop()
|
| 99 |
+
order.append(j)
|
| 100 |
+
for c in reversed(_SMPL_CHILDREN[j]):
|
| 101 |
+
stack.append(c)
|
| 102 |
+
return order
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
SMPL_DFS = _smpl_dfs()
|
| 106 |
+
|
| 107 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 108 |
+
# Quaternion helpers (numpy, WXYZ)
|
| 109 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 110 |
+
|
| 111 |
+
def qnorm(q: np.ndarray) -> np.ndarray:
|
| 112 |
+
return q / (np.linalg.norm(q, axis=-1, keepdims=True) + 1e-9)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def qmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 116 |
+
aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
| 117 |
+
bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
| 118 |
+
return np.stack([
|
| 119 |
+
aw*bw - ax*bx - ay*by - az*bz,
|
| 120 |
+
aw*bx + ax*bw + ay*bz - az*by,
|
| 121 |
+
aw*by - ax*bz + ay*bw + az*bx,
|
| 122 |
+
aw*bz + ax*by - ay*bx + az*bw,
|
| 123 |
+
], axis=-1)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def qinv(q: np.ndarray) -> np.ndarray:
|
| 127 |
+
return q * np.array([1, -1, -1, -1], dtype=np.float32)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def qrot(q: np.ndarray, v: np.ndarray) -> np.ndarray:
|
| 131 |
+
vq = np.concatenate([np.zeros((*v.shape[:-1], 1), dtype=v.dtype), v], axis=-1)
|
| 132 |
+
return qmul(qmul(q, vq), qinv(q))[..., 1:]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def qbetween(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 136 |
+
"""Swing quaternion rotating unit-vectors a to b. [..., 3] to [..., 4]."""
|
| 137 |
+
a = a / (np.linalg.norm(a, axis=-1, keepdims=True) + 1e-9)
|
| 138 |
+
b = b / (np.linalg.norm(b, axis=-1, keepdims=True) + 1e-9)
|
| 139 |
+
dot = np.clip((a * b).sum(axis=-1, keepdims=True), -1.0, 1.0)
|
| 140 |
+
cross = np.cross(a, b)
|
| 141 |
+
w = np.sqrt(np.maximum((1.0 + dot) * 0.5, 0.0))
|
| 142 |
+
xyz = cross / (2.0 * w + 1e-9)
|
| 143 |
+
anti = (dot[..., 0] < -0.9999)
|
| 144 |
+
if anti.any():
|
| 145 |
+
perp = np.where(
|
| 146 |
+
np.abs(a[anti, 0:1]) < 0.9,
|
| 147 |
+
np.tile([1, 0, 0], (anti.sum(), 1)),
|
| 148 |
+
np.tile([0, 1, 0], (anti.sum(), 1)),
|
| 149 |
+
).astype(np.float32)
|
| 150 |
+
ax_f = np.cross(a[anti], perp)
|
| 151 |
+
ax_f = ax_f / (np.linalg.norm(ax_f, axis=-1, keepdims=True) + 1e-9)
|
| 152 |
+
w[anti] = 0.0
|
| 153 |
+
xyz[anti] = ax_f
|
| 154 |
+
return qnorm(np.concatenate([w, xyz], axis=-1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def quat_to_euler_ZXY(q: np.ndarray) -> np.ndarray:
|
| 158 |
+
"""WXYZ quaternions to ZXY Euler degrees (rz, rx, ry) for BVH."""
|
| 159 |
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
| 160 |
+
sin_x = np.clip(2.0*(w*x - y*z), -1.0, 1.0)
|
| 161 |
+
return np.stack([
|
| 162 |
+
np.degrees(np.arctan2(2.0*(w*z + x*y), 1.0 - 2.0*(x*x + z*z))),
|
| 163 |
+
np.degrees(np.arcsin(sin_x)),
|
| 164 |
+
np.degrees(np.arctan2(2.0*(w*y + x*z), 1.0 - 2.0*(x*x + y*y))),
|
| 165 |
+
], axis=-1)
|
| 166 |
+
|
| 167 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 168 |
+
# HumanML3D 263-dim recovery
|
| 169 |
+
#
|
| 170 |
+
# Layout per frame:
|
| 171 |
+
# [0] root Y-axis angular velocity (rad/frame)
|
| 172 |
+
# [1] root height Y (m)
|
| 173 |
+
# [2:4] root XZ velocity in local frame
|
| 174 |
+
# [4:67] local positions of joints 1-21 (21 x 3 = 63)
|
| 175 |
+
# [67:193] 6-D rotations for joints 1-21 (21 x 6 = 126, unused here)
|
| 176 |
+
# [193:259] joint velocities (22 x 3 = 66, unused here)
|
| 177 |
+
# [259:263] foot contact (4, unused here)
|
| 178 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 179 |
+
|
| 180 |
+
def _recover_root(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 181 |
+
T = data.shape[0]
|
| 182 |
+
theta = np.cumsum(data[:, 0])
|
| 183 |
+
half = theta * 0.5
|
| 184 |
+
r_rot = np.zeros((T, 4), dtype=np.float32)
|
| 185 |
+
r_rot[:, 0] = np.cos(half) # W
|
| 186 |
+
r_rot[:, 2] = np.sin(half) # Y
|
| 187 |
+
vel_local = np.stack([data[:, 2], np.zeros(T, dtype=np.float32), data[:, 3]], -1)
|
| 188 |
+
vel_world = qrot(r_rot, vel_local)
|
| 189 |
+
r_pos = np.zeros((T, 3), dtype=np.float32)
|
| 190 |
+
r_pos[:, 0] = np.cumsum(vel_world[:, 0])
|
| 191 |
+
r_pos[:, 1] = data[:, 1]
|
| 192 |
+
r_pos[:, 2] = np.cumsum(vel_world[:, 2])
|
| 193 |
+
return r_rot, r_pos
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def recover_from_ric(data: np.ndarray, joints_num: int = 22) -> np.ndarray:
|
| 197 |
+
"""263-dim features to world-space positions [T, joints_num, 3]."""
|
| 198 |
+
data = data.astype(np.float32)
|
| 199 |
+
r_rot, r_pos = _recover_root(data)
|
| 200 |
+
loc = data[:, 4:4 + (joints_num-1)*3].reshape(-1, joints_num-1, 3)
|
| 201 |
+
rinv = np.broadcast_to(qinv(r_rot)[:, None], (*loc.shape[:2], 4)).copy()
|
| 202 |
+
wloc = qrot(rinv, loc) + r_pos[:, None]
|
| 203 |
+
return np.concatenate([r_pos[:, None], wloc], axis=1)
|
| 204 |
+
|
| 205 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 206 |
+
# SMPL geometry helpers
|
| 207 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 208 |
+
|
| 209 |
+
def _scale_smpl_tpose(positions: np.ndarray) -> np.ndarray:
|
| 210 |
+
data_h = positions[:, :, 1].max() - positions[:, :, 1].min()
|
| 211 |
+
ref_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
|
| 212 |
+
scale = (data_h / ref_h) if (ref_h > 1e-6 and data_h > 1e-6) else 1.0
|
| 213 |
+
return SMPL_TPOSE * scale
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _rest_dirs(tpose: np.ndarray, children: list[list[int]],
|
| 217 |
+
parent: list[int]) -> np.ndarray:
|
| 218 |
+
N = tpose.shape[0]
|
| 219 |
+
dirs = np.zeros((N, 3), dtype=np.float32)
|
| 220 |
+
for j in range(N):
|
| 221 |
+
ch = children[j]
|
| 222 |
+
if ch:
|
| 223 |
+
avg = np.stack([tpose[c] - tpose[j] for c in ch]).mean(0)
|
| 224 |
+
dirs[j] = avg / (np.linalg.norm(avg) + 1e-9)
|
| 225 |
+
else:
|
| 226 |
+
v = tpose[j] - tpose[parent[j]]
|
| 227 |
+
dirs[j] = v / (np.linalg.norm(v) + 1e-9)
|
| 228 |
+
return dirs
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def positions_to_local_quats(positions: np.ndarray,
|
| 232 |
+
tpose: np.ndarray) -> np.ndarray:
|
| 233 |
+
"""World-space joint positions [T, 22, 3] to local quaternions [T, 22, 4]."""
|
| 234 |
+
T = positions.shape[0]
|
| 235 |
+
rd = _rest_dirs(tpose, _SMPL_CHILDREN, SMPL_PARENT)
|
| 236 |
+
|
| 237 |
+
world_q = np.zeros((T, NUM_SMPL, 4), dtype=np.float32)
|
| 238 |
+
world_q[:, :, 0] = 1.0
|
| 239 |
+
|
| 240 |
+
for j in range(NUM_SMPL):
|
| 241 |
+
ch = _SMPL_CHILDREN[j]
|
| 242 |
+
if ch:
|
| 243 |
+
vecs = np.stack([positions[:, c] - positions[:, j] for c in ch], 1).mean(1)
|
| 244 |
+
else:
|
| 245 |
+
vecs = positions[:, j] - positions[:, SMPL_PARENT[j]]
|
| 246 |
+
cur = vecs / (np.linalg.norm(vecs, axis=-1, keepdims=True) + 1e-9)
|
| 247 |
+
rd_b = np.broadcast_to(rd[j], cur.shape).copy()
|
| 248 |
+
world_q[:, j] = qbetween(rd_b, cur)
|
| 249 |
+
|
| 250 |
+
local_q = np.zeros_like(world_q)
|
| 251 |
+
local_q[:, :, 0] = 1.0
|
| 252 |
+
for j in SMPL_DFS:
|
| 253 |
+
p = SMPL_PARENT[j]
|
| 254 |
+
if p < 0:
|
| 255 |
+
local_q[:, j] = world_q[:, j]
|
| 256 |
+
else:
|
| 257 |
+
local_q[:, j] = qmul(qinv(world_q[:, p]), world_q[:, j])
|
| 258 |
+
|
| 259 |
+
return qnorm(local_q)
|
| 260 |
+
|
| 261 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 262 |
+
# UniRig skeleton data structure
|
| 263 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 264 |
+
|
| 265 |
+
@dataclass
|
| 266 |
+
class Bone:
|
| 267 |
+
name: str
|
| 268 |
+
parent: Optional[str]
|
| 269 |
+
world_rest_pos: np.ndarray
|
| 270 |
+
children: list[str] = field(default_factory=list)
|
| 271 |
+
smpl_idx: Optional[int] = None
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class UnirigSkeleton:
|
| 275 |
+
def __init__(self, bones: dict[str, Bone]):
|
| 276 |
+
self.bones = bones
|
| 277 |
+
self.root = next(b for b in bones.values() if b.parent is None)
|
| 278 |
+
|
| 279 |
+
def dfs_order(self) -> list[str]:
|
| 280 |
+
order, stack = [], [self.root.name]
|
| 281 |
+
while stack:
|
| 282 |
+
n = stack.pop()
|
| 283 |
+
order.append(n)
|
| 284 |
+
for c in reversed(self.bones[n].children):
|
| 285 |
+
stack.append(c)
|
| 286 |
+
return order
|
| 287 |
+
|
| 288 |
+
def local_offsets(self) -> dict[str, np.ndarray]:
|
| 289 |
+
offsets = {}
|
| 290 |
+
for name, bone in self.bones.items():
|
| 291 |
+
if bone.parent is None:
|
| 292 |
+
offsets[name] = bone.world_rest_pos.copy()
|
| 293 |
+
else:
|
| 294 |
+
offsets[name] = bone.world_rest_pos - self.bones[bone.parent].world_rest_pos
|
| 295 |
+
return offsets
|
| 296 |
+
|
| 297 |
+
def rest_direction(self, name: str) -> np.ndarray:
|
| 298 |
+
bone = self.bones[name]
|
| 299 |
+
if bone.children:
|
| 300 |
+
vecs = np.stack([self.bones[c].world_rest_pos - bone.world_rest_pos
|
| 301 |
+
for c in bone.children])
|
| 302 |
+
avg = vecs.mean(0)
|
| 303 |
+
return avg / (np.linalg.norm(avg) + 1e-9)
|
| 304 |
+
if bone.parent is None:
|
| 305 |
+
return np.array([0, 1, 0], dtype=np.float32)
|
| 306 |
+
v = bone.world_rest_pos - self.bones[bone.parent].world_rest_pos
|
| 307 |
+
return v / (np.linalg.norm(v) + 1e-9)
|
| 308 |
+
|
| 309 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 310 |
+
# GLB skeleton parser
|
| 311 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 312 |
+
|
| 313 |
+
def parse_glb_skeleton(path: str) -> UnirigSkeleton:
|
| 314 |
+
"""Extract skeleton from a UniRig-rigged GLB (uses pygltflib)."""
|
| 315 |
+
try:
|
| 316 |
+
import pygltflib
|
| 317 |
+
except ImportError:
|
| 318 |
+
sys.exit("[ERROR] pygltflib not installed. pip install pygltflib")
|
| 319 |
+
|
| 320 |
+
import base64
|
| 321 |
+
|
| 322 |
+
gltf = pygltflib.GLTF2().load(path)
|
| 323 |
+
if not gltf.skins:
|
| 324 |
+
sys.exit(f"[ERROR] No skin found in {path}")
|
| 325 |
+
|
| 326 |
+
skin = gltf.skins[0]
|
| 327 |
+
joint_indices = skin.joints
|
| 328 |
+
|
| 329 |
+
def get_buffer_bytes(buf_idx: int) -> bytes:
|
| 330 |
+
buf = gltf.buffers[buf_idx]
|
| 331 |
+
if buf.uri is None:
|
| 332 |
+
return bytes(gltf.binary_blob())
|
| 333 |
+
if buf.uri.startswith("data:"):
|
| 334 |
+
return base64.b64decode(buf.uri.split(",", 1)[1])
|
| 335 |
+
return (Path(path).parent / buf.uri).read_bytes()
|
| 336 |
+
|
| 337 |
+
def read_accessor(acc_idx: int) -> np.ndarray:
|
| 338 |
+
acc = gltf.accessors[acc_idx]
|
| 339 |
+
bv = gltf.bufferViews[acc.bufferView]
|
| 340 |
+
raw = get_buffer_bytes(bv.buffer)
|
| 341 |
+
COMP = {5120: ('b',1), 5121: ('B',1), 5122: ('h',2),
|
| 342 |
+
5123: ('H',2), 5125: ('I',4), 5126: ('f',4)}
|
| 343 |
+
DIMS = {"SCALAR":1,"VEC2":2,"VEC3":3,"VEC4":4,"MAT2":4,"MAT3":9,"MAT4":16}
|
| 344 |
+
fmt, sz = COMP[acc.componentType]
|
| 345 |
+
dim = DIMS[acc.type]
|
| 346 |
+
start = (bv.byteOffset or 0) + (acc.byteOffset or 0)
|
| 347 |
+
stride = bv.byteStride
|
| 348 |
+
if stride is None or stride == 0 or stride == sz * dim:
|
| 349 |
+
chunk = raw[start: start + acc.count * sz * dim]
|
| 350 |
+
return np.frombuffer(chunk, dtype=fmt).reshape(acc.count, dim).astype(np.float32)
|
| 351 |
+
rows = []
|
| 352 |
+
for i in range(acc.count):
|
| 353 |
+
off = start + i * stride
|
| 354 |
+
rows.append(np.frombuffer(raw[off: off + sz * dim], dtype=fmt))
|
| 355 |
+
return np.stack(rows).astype(np.float32)
|
| 356 |
+
|
| 357 |
+
ibm = read_accessor(skin.inverseBindMatrices).reshape(-1, 4, 4)
|
| 358 |
+
joint_set = set(joint_indices)
|
| 359 |
+
ni_name = {ni: (gltf.nodes[ni].name or f"bone_{ni}") for ni in joint_indices}
|
| 360 |
+
|
| 361 |
+
bones: dict[str, Bone] = {}
|
| 362 |
+
for i, ni in enumerate(joint_indices):
|
| 363 |
+
name = ni_name[ni]
|
| 364 |
+
world_mat = np.linalg.inv(ibm[i])
|
| 365 |
+
bones[name] = Bone(name=name, parent=None,
|
| 366 |
+
world_rest_pos=world_mat[:3, 3].astype(np.float32))
|
| 367 |
+
|
| 368 |
+
for ni in joint_indices:
|
| 369 |
+
for ci in (gltf.nodes[ni].children or []):
|
| 370 |
+
if ci in joint_set:
|
| 371 |
+
p, c = ni_name[ni], ni_name[ci]
|
| 372 |
+
bones[c].parent = p
|
| 373 |
+
bones[p].children.append(c)
|
| 374 |
+
|
| 375 |
+
print(f"[GLB] {len(bones)} bones from skin '{gltf.skins[0].name or 'Armature'}'")
|
| 376 |
+
return UnirigSkeleton(bones)
|
| 377 |
+
|
| 378 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 379 |
+
# ASCII FBX skeleton parser
|
| 380 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 381 |
+
|
| 382 |
+
def parse_fbx_ascii_skeleton(path: str) -> UnirigSkeleton:
|
| 383 |
+
"""Parse ASCII-format FBX for LimbNode / Root bones."""
|
| 384 |
+
raw = Path(path).read_bytes()
|
| 385 |
+
if raw[:4] == b"Kayd":
|
| 386 |
+
sys.exit(
|
| 387 |
+
f"[ERROR] {path} is binary FBX.\n"
|
| 388 |
+
"Convert to GLB first, e.g.:\n"
|
| 389 |
+
" gltf-pipeline -i rigged.fbx -o rigged.glb"
|
| 390 |
+
)
|
| 391 |
+
text = raw.decode("utf-8", errors="replace")
|
| 392 |
+
|
| 393 |
+
model_pat = re.compile(
|
| 394 |
+
r'Model:\s*(\d+),\s*"Model::([^"]+)",\s*"(LimbNode|Root|Null)"'
|
| 395 |
+
r'.*?Properties70:\s*\{(.*?)\}',
|
| 396 |
+
re.DOTALL
|
| 397 |
+
)
|
| 398 |
+
trans_pat = re.compile(
|
| 399 |
+
r'P:\s*"Lcl Translation".*?(-?[\d.e+\-]+),\s*(-?[\d.e+\-]+),\s*(-?[\d.e+\-]+)'
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
uid_name: dict[str, str] = {}
|
| 403 |
+
uid_local: dict[str, np.ndarray] = {}
|
| 404 |
+
|
| 405 |
+
for m in model_pat.finditer(text):
|
| 406 |
+
uid, name = m.group(1), m.group(2)
|
| 407 |
+
uid_name[uid] = name
|
| 408 |
+
tm = trans_pat.search(m.group(4))
|
| 409 |
+
uid_local[uid] = (np.array([float(tm.group(i)) for i in (1,2,3)], dtype=np.float32)
|
| 410 |
+
if tm else np.zeros(3, dtype=np.float32))
|
| 411 |
+
|
| 412 |
+
if not uid_name:
|
| 413 |
+
sys.exit("[ERROR] No LimbNode/Root bones found in FBX")
|
| 414 |
+
|
| 415 |
+
conn_pat = re.compile(r'C:\s*"OO",\s*(\d+),\s*(\d+)')
|
| 416 |
+
uid_parent: dict[str, str] = {}
|
| 417 |
+
for m in conn_pat.finditer(text):
|
| 418 |
+
child, par = m.group(1), m.group(2)
|
| 419 |
+
if child in uid_name and par in uid_name:
|
| 420 |
+
uid_parent[child] = par
|
| 421 |
+
|
| 422 |
+
# Detect cm vs m
|
| 423 |
+
all_y = np.array([t[1] for t in uid_local.values()])
|
| 424 |
+
scale = 0.01 if all_y.max() > 10.0 else 1.0
|
| 425 |
+
if scale != 1.0:
|
| 426 |
+
print(f"[FBX] Centimetre units detected — scaling by {scale}")
|
| 427 |
+
for uid in uid_local:
|
| 428 |
+
uid_local[uid] *= scale
|
| 429 |
+
|
| 430 |
+
# Accumulate world translations (topological order)
|
| 431 |
+
def topo(uid_to_par):
|
| 432 |
+
visited, order = set(), []
|
| 433 |
+
def visit(u):
|
| 434 |
+
if u in visited: return
|
| 435 |
+
visited.add(u)
|
| 436 |
+
if u in uid_to_par: visit(uid_to_par[u])
|
| 437 |
+
order.append(u)
|
| 438 |
+
for u in uid_to_par: visit(u)
|
| 439 |
+
for u in uid_name:
|
| 440 |
+
if u not in visited: order.append(u)
|
| 441 |
+
return order
|
| 442 |
+
|
| 443 |
+
world: dict[str, np.ndarray] = {}
|
| 444 |
+
for uid in topo(uid_parent):
|
| 445 |
+
loc = uid_local.get(uid, np.zeros(3, dtype=np.float32))
|
| 446 |
+
world[uid] = (world.get(uid_parent[uid], np.zeros(3, dtype=np.float32)) + loc
|
| 447 |
+
if uid in uid_parent else loc.copy())
|
| 448 |
+
|
| 449 |
+
bones: dict[str, Bone] = {}
|
| 450 |
+
for uid, name in uid_name.items():
|
| 451 |
+
bones[name] = Bone(name=name, parent=None, world_rest_pos=world[uid])
|
| 452 |
+
|
| 453 |
+
for uid, p_uid in uid_parent.items():
|
| 454 |
+
c, p = uid_name[uid], uid_name[p_uid]
|
| 455 |
+
bones[c].parent = p
|
| 456 |
+
if c not in bones[p].children:
|
| 457 |
+
bones[p].children.append(c)
|
| 458 |
+
|
| 459 |
+
print(f"[FBX] {len(bones)} bones parsed from ASCII FBX")
|
| 460 |
+
return UnirigSkeleton(bones)
|
| 461 |
+
|
| 462 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 463 |
+
# Auto bone mapping: UniRig bones to SMPL joints
|
| 464 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 465 |
+
|
| 466 |
+
# Keyword table: normalised name fragments -> SMPL joint index
|
| 467 |
+
_NAME_HINTS: list[tuple[list[str], int]] = [
|
| 468 |
+
(["hips","pelvis","root","hip"], 0),
|
| 469 |
+
(["leftupleg","l_upleg","lupleg","leftthigh","lefthip",
|
| 470 |
+
"left_upper_leg","l_thigh","thigh_l","upperleg_l","j_bip_l_upperleg"], 1),
|
| 471 |
+
(["rightupleg","r_upleg","rupleg","rightthigh","righthip",
|
| 472 |
+
"right_upper_leg","r_thigh","thigh_r","upperleg_r","j_bip_r_upperleg"], 2),
|
| 473 |
+
(["spine","spine0","spine_01","j_bip_c_spine"], 3),
|
| 474 |
+
(["leftleg","leftknee","l_leg","lleg","leftlowerleg",
|
| 475 |
+
"left_lower_leg","lowerleg_l","knee_l","j_bip_l_lowerleg"], 4),
|
| 476 |
+
(["rightleg","rightknee","r_leg","rleg","rightlowerleg",
|
| 477 |
+
"right_lower_leg","lowerleg_r","knee_r","j_bip_r_lowerleg"], 5),
|
| 478 |
+
(["spine1","spine_02","j_bip_c_spine1"], 6),
|
| 479 |
+
(["leftfoot","left_foot","l_foot","lfoot","foot_l","j_bip_l_foot"], 7),
|
| 480 |
+
(["rightfoot","right_foot","r_foot","rfoot","foot_r","j_bip_r_foot"], 8),
|
| 481 |
+
(["spine2","spine_03","j_bip_c_spine2","chest"], 9),
|
| 482 |
+
(["lefttoebase","lefttoe","l_toe","ltoe","toe_l"], 10),
|
| 483 |
+
(["righttoebase","righttoe","r_toe","rtoe","toe_r"], 11),
|
| 484 |
+
(["neck","j_bip_c_neck"], 12),
|
| 485 |
+
(["leftshoulder","leftcollar","l_shoulder","leftclavicle",
|
| 486 |
+
"clavicle_l","j_bip_l_shoulder"], 13),
|
| 487 |
+
(["rightshoulder","rightcollar","r_shoulder","rightclavicle",
|
| 488 |
+
"clavicle_r","j_bip_r_shoulder"], 14),
|
| 489 |
+
(["head","j_bip_c_head"], 15),
|
| 490 |
+
(["leftarm","leftupper","l_arm","larm","leftupperarm",
|
| 491 |
+
"upperarm_l","j_bip_l_upperarm"], 16),
|
| 492 |
+
(["rightarm","rightupper","r_arm","rarm","rightupperarm",
|
| 493 |
+
"upperarm_r","j_bip_r_upperarm"], 17),
|
| 494 |
+
(["leftforearm","leftlower","l_forearm","lforearm",
|
| 495 |
+
"lowerarm_l","j_bip_l_lowerarm"], 18),
|
| 496 |
+
(["rightforearm","rightlower","r_forearm","rforearm",
|
| 497 |
+
"lowerarm_r","j_bip_r_lowerarm"], 19),
|
| 498 |
+
(["lefthand","l_hand","lhand","hand_l","j_bip_l_hand"], 20),
|
| 499 |
+
(["righthand","r_hand","rhand","hand_r","j_bip_r_hand"], 21),
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _strip_name(name: str) -> str:
|
| 504 |
+
"""Remove common rig namespace prefixes, then lower-case, remove separators."""
|
| 505 |
+
name = re.sub(r'^(mixamorig:|j_bip_[lcr]_|cc_base_|bip01_|rig:|chr:)',
|
| 506 |
+
"", name, flags=re.IGNORECASE)
|
| 507 |
+
return re.sub(r'[_\-\s.]', "", name).lower()
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _normalise_positions(pos: np.ndarray) -> np.ndarray:
|
| 511 |
+
"""Normalise [N, 3] to [0,1] in Y, [-1,1] in X and Z."""
|
| 512 |
+
y_min, y_max = pos[:, 1].min(), pos[:, 1].max()
|
| 513 |
+
h = (y_max - y_min) or 1.0
|
| 514 |
+
xr = (pos[:, 0].max() - pos[:, 0].min()) or 1.0
|
| 515 |
+
zr = (pos[:, 2].max() - pos[:, 2].min()) or 1.0
|
| 516 |
+
out = pos.copy()
|
| 517 |
+
out[:, 0] /= xr
|
| 518 |
+
out[:, 1] = (out[:, 1] - y_min) / h
|
| 519 |
+
out[:, 2] /= zr
|
| 520 |
+
return out
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def auto_map(skel: UnirigSkeleton, verbose: bool = True) -> None:
|
| 524 |
+
"""
|
| 525 |
+
Assign skel.bones[name].smpl_idx for each UniRig bone that best matches
|
| 526 |
+
an SMPL joint. Score = 0.6 * position_proximity + 0.4 * name_hint.
|
| 527 |
+
Greedy: each SMPL joint taken by at most one UniRig bone.
|
| 528 |
+
Bones with combined score < 0.35 are left unmapped (identity in BVH).
|
| 529 |
+
"""
|
| 530 |
+
names = list(skel.bones.keys())
|
| 531 |
+
ur_pos = np.stack([skel.bones[n].world_rest_pos for n in names]) # [M, 3]
|
| 532 |
+
|
| 533 |
+
# Scale SMPL T-pose to match UniRig height
|
| 534 |
+
ur_h = ur_pos[:, 1].max() - ur_pos[:, 1].min()
|
| 535 |
+
sm_h = SMPL_TPOSE[:, 1].max() - SMPL_TPOSE[:, 1].min()
|
| 536 |
+
sm_pos = SMPL_TPOSE * ((ur_h / sm_h) if sm_h > 1e-6 else 1.0)
|
| 537 |
+
|
| 538 |
+
all_norm = _normalise_positions(np.concatenate([ur_pos, sm_pos]))
|
| 539 |
+
ur_norm = all_norm[:len(names)]
|
| 540 |
+
sm_norm = all_norm[len(names):]
|
| 541 |
+
|
| 542 |
+
# Distance score [M, 22]
|
| 543 |
+
dist = np.linalg.norm(ur_norm[:, None] - sm_norm[None], axis=-1)
|
| 544 |
+
dist_sc = 1.0 - np.clip(dist / (dist.max() + 1e-9), 0, 1)
|
| 545 |
+
|
| 546 |
+
# Name score [M, 22]
|
| 547 |
+
norm_names = [_strip_name(n) for n in names]
|
| 548 |
+
name_sc = np.array(
|
| 549 |
+
[[1.0 if norm in kws else 0.0
|
| 550 |
+
for kws, _ in _NAME_HINTS]
|
| 551 |
+
for norm in norm_names],
|
| 552 |
+
dtype=np.float32,
|
| 553 |
+
) # [M, 22]
|
| 554 |
+
|
| 555 |
+
combined = 0.6 * dist_sc + 0.4 * name_sc # [M, 22]
|
| 556 |
+
|
| 557 |
+
# Greedy assignment
|
| 558 |
+
THRESHOLD = 0.35
|
| 559 |
+
taken_smpl: set[int] = set()
|
| 560 |
+
pairs = sorted(
|
| 561 |
+
((i, j, combined[i, j])
|
| 562 |
+
for i in range(len(names)) for j in range(NUM_SMPL)),
|
| 563 |
+
key=lambda x: -x[2],
|
| 564 |
+
)
|
| 565 |
+
for bi, si, score in pairs:
|
| 566 |
+
if score < THRESHOLD:
|
| 567 |
+
break
|
| 568 |
+
name = names[bi]
|
| 569 |
+
if skel.bones[name].smpl_idx is not None or si in taken_smpl:
|
| 570 |
+
continue
|
| 571 |
+
skel.bones[name].smpl_idx = si
|
| 572 |
+
taken_smpl.add(si)
|
| 573 |
+
|
| 574 |
+
if verbose:
|
| 575 |
+
mapped = [(n, b.smpl_idx) for n, b in skel.bones.items() if b.smpl_idx is not None]
|
| 576 |
+
unmapped = [n for n, b in skel.bones.items() if b.smpl_idx is None]
|
| 577 |
+
print(f"\n[MAP] {len(mapped)}/{len(skel.bones)} bones mapped to SMPL joints:")
|
| 578 |
+
for ur_name, si in sorted(mapped, key=lambda x: x[1]):
|
| 579 |
+
sc = combined[names.index(ur_name), si]
|
| 580 |
+
print(f" {ur_name:40s} -> {SMPL_NAMES[si]:16s} score={sc:.2f}")
|
| 581 |
+
if unmapped:
|
| 582 |
+
print(f"[MAP] {len(unmapped)} unmapped (identity rotation): "
|
| 583 |
+
+ ", ".join(unmapped[:8])
|
| 584 |
+
+ (" ..." if len(unmapped) > 8 else ""))
|
| 585 |
+
print()
|
| 586 |
+
|
| 587 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 588 |
+
# BVH writers
|
| 589 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 590 |
+
|
| 591 |
+
def _smpl_offsets(tpose: np.ndarray) -> np.ndarray:
|
| 592 |
+
offsets = np.zeros_like(tpose)
|
| 593 |
+
for j, p in enumerate(SMPL_PARENT):
|
| 594 |
+
offsets[j] = tpose[j] if p < 0 else tpose[j] - tpose[p]
|
| 595 |
+
return offsets
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def write_bvh_smpl(output_path: str, positions: np.ndarray, fps: int = 20) -> None:
|
| 599 |
+
"""BVH with standard SMPL bone names (no rig file needed)."""
|
| 600 |
+
T = positions.shape[0]
|
| 601 |
+
tpose = _scale_smpl_tpose(positions)
|
| 602 |
+
offsets = _smpl_offsets(tpose)
|
| 603 |
+
tp_w = tpose + (positions[0, 0] - tpose[0])
|
| 604 |
+
local_q = positions_to_local_quats(positions, tp_w)
|
| 605 |
+
euler = quat_to_euler_ZXY(local_q)
|
| 606 |
+
|
| 607 |
+
with open(output_path, "w") as f:
|
| 608 |
+
f.write("HIERARCHY\n")
|
| 609 |
+
|
| 610 |
+
def wj(j, ind):
|
| 611 |
+
off = offsets[j]
|
| 612 |
+
f.write(f"{'ROOT' if SMPL_PARENT[j]<0 else ind+'JOINT'} {SMPL_NAMES[j]}\n")
|
| 613 |
+
f.write(f"{ind}{{\n")
|
| 614 |
+
f.write(f"{ind}\tOFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}\n")
|
| 615 |
+
if SMPL_PARENT[j] < 0:
|
| 616 |
+
f.write(f"{ind}\tCHANNELS 6 Xposition Yposition Zposition "
|
| 617 |
+
"Zrotation Xrotation Yrotation\n")
|
| 618 |
+
else:
|
| 619 |
+
f.write(f"{ind}\tCHANNELS 3 Zrotation Xrotation Yrotation\n")
|
| 620 |
+
for c in _SMPL_CHILDREN[j]:
|
| 621 |
+
wj(c, ind + "\t")
|
| 622 |
+
if not _SMPL_CHILDREN[j]:
|
| 623 |
+
f.write(f"{ind}\tEnd Site\n{ind}\t{{\n"
|
| 624 |
+
f"{ind}\t\tOFFSET 0.000000 0.050000 0.000000\n{ind}\t}}\n")
|
| 625 |
+
f.write(f"{ind}}}\n")
|
| 626 |
+
|
| 627 |
+
wj(0, "")
|
| 628 |
+
f.write(f"MOTION\nFrames: {T}\nFrame Time: {1.0/fps:.8f}\n")
|
| 629 |
+
for t in range(T):
|
| 630 |
+
rp = positions[t, 0]
|
| 631 |
+
row = [f"{rp[0]:.6f}", f"{rp[1]:.6f}", f"{rp[2]:.6f}"]
|
| 632 |
+
for j in SMPL_DFS:
|
| 633 |
+
rz, rx, ry = euler[t, j]
|
| 634 |
+
row += [f"{rz:.6f}", f"{rx:.6f}", f"{ry:.6f}"]
|
| 635 |
+
f.write(" ".join(row) + "\n")
|
| 636 |
+
|
| 637 |
+
print(f"[OK] {T} frames @ {fps} fps -> {output_path} (SMPL skeleton)")
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def write_bvh_unirig(output_path: str,
|
| 641 |
+
positions: np.ndarray,
|
| 642 |
+
skel: UnirigSkeleton,
|
| 643 |
+
fps: int = 20) -> None:
|
| 644 |
+
"""
|
| 645 |
+
BVH using UniRig bone names and hierarchy.
|
| 646 |
+
Mapped bones receive SMPL-derived local rotations with rest-pose correction.
|
| 647 |
+
Unmapped bones (fingers, face bones, etc.) are set to identity.
|
| 648 |
+
"""
|
| 649 |
+
T = positions.shape[0]
|
| 650 |
+
|
| 651 |
+
# Compute SMPL local quaternions
|
| 652 |
+
tpose = _scale_smpl_tpose(positions)
|
| 653 |
+
tp_w = tpose + (positions[0, 0] - tpose[0])
|
| 654 |
+
smpl_q = positions_to_local_quats(positions, tp_w) # [T, 22, 4]
|
| 655 |
+
smpl_rd = _rest_dirs(tp_w, _SMPL_CHILDREN, SMPL_PARENT) # [22, 3]
|
| 656 |
+
|
| 657 |
+
# Rest-pose correction quaternions per bone:
|
| 658 |
+
# q_corr = qbetween(unirig_rest_dir, smpl_rest_dir)
|
| 659 |
+
# unirig_local_q = smpl_local_q @ q_corr
|
| 660 |
+
# This ensures: when applied to unirig_rest_dir, the result matches
|
| 661 |
+
# the SMPL animated direction — accounting for any difference in
|
| 662 |
+
# rest-pose bone orientations between the two skeletons.
|
| 663 |
+
corrections: dict[str, np.ndarray] = {}
|
| 664 |
+
for name, bone in skel.bones.items():
|
| 665 |
+
if bone.smpl_idx is None:
|
| 666 |
+
continue
|
| 667 |
+
ur_rd = skel.rest_direction(name).astype(np.float32)
|
| 668 |
+
sm_rd = smpl_rd[bone.smpl_idx].astype(np.float32)
|
| 669 |
+
corrections[name] = qbetween(ur_rd[None], sm_rd[None])[0] # [4]
|
| 670 |
+
|
| 671 |
+
# Scale root translation from SMPL proportions to UniRig proportions
|
| 672 |
+
ur_h = (max(b.world_rest_pos[1] for b in skel.bones.values())
|
| 673 |
+
- min(b.world_rest_pos[1] for b in skel.bones.values()))
|
| 674 |
+
sm_h = tp_w[:, 1].max() - tp_w[:, 1].min()
|
| 675 |
+
pos_sc = (ur_h / sm_h) if sm_h > 1e-6 else 1.0
|
| 676 |
+
|
| 677 |
+
dfs = skel.dfs_order()
|
| 678 |
+
offsets = skel.local_offsets()
|
| 679 |
+
|
| 680 |
+
# Pre-compute euler per bone [T, 3]
|
| 681 |
+
ID_EUL = np.zeros((T, 3), dtype=np.float32)
|
| 682 |
+
bone_euler: dict[str, np.ndarray] = {}
|
| 683 |
+
for name, bone in skel.bones.items():
|
| 684 |
+
if bone.smpl_idx is not None:
|
| 685 |
+
q = smpl_q[:, bone.smpl_idx].copy() # [T, 4]
|
| 686 |
+
c = corrections.get(name)
|
| 687 |
+
if c is not None:
|
| 688 |
+
q = qnorm(qmul(q, np.broadcast_to(c[None], q.shape).copy()))
|
| 689 |
+
bone_euler[name] = quat_to_euler_ZXY(q) # [T, 3]
|
| 690 |
+
else:
|
| 691 |
+
bone_euler[name] = ID_EUL
|
| 692 |
+
|
| 693 |
+
with open(output_path, "w") as f:
|
| 694 |
+
f.write("HIERARCHY\n")
|
| 695 |
+
|
| 696 |
+
def wj(name, ind):
|
| 697 |
+
off = offsets[name]
|
| 698 |
+
bone = skel.bones[name]
|
| 699 |
+
f.write(f"{'ROOT' if bone.parent is None else ind+'JOINT'} {name}\n")
|
| 700 |
+
f.write(f"{ind}{{\n")
|
| 701 |
+
f.write(f"{ind}\tOFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}\n")
|
| 702 |
+
if bone.parent is None:
|
| 703 |
+
f.write(f"{ind}\tCHANNELS 6 Xposition Yposition Zposition "
|
| 704 |
+
"Zrotation Xrotation Yrotation\n")
|
| 705 |
+
else:
|
| 706 |
+
f.write(f"{ind}\tCHANNELS 3 Zrotation Xrotation Yrotation\n")
|
| 707 |
+
for c in bone.children:
|
| 708 |
+
wj(c, ind + "\t")
|
| 709 |
+
if not bone.children:
|
| 710 |
+
f.write(f"{ind}\tEnd Site\n{ind}\t{{\n"
|
| 711 |
+
f"{ind}\t\tOFFSET 0.000000 0.050000 0.000000\n{ind}\t}}\n")
|
| 712 |
+
f.write(f"{ind}}}\n")
|
| 713 |
+
|
| 714 |
+
wj(skel.root.name, "")
|
| 715 |
+
f.write(f"MOTION\nFrames: {T}\nFrame Time: {1.0/fps:.8f}\n")
|
| 716 |
+
|
| 717 |
+
for t in range(T):
|
| 718 |
+
rp = positions[t, 0] * pos_sc
|
| 719 |
+
row = [f"{rp[0]:.6f}", f"{rp[1]:.6f}", f"{rp[2]:.6f}"]
|
| 720 |
+
for name in dfs:
|
| 721 |
+
rz, rx, ry = bone_euler[name][t]
|
| 722 |
+
row += [f"{rz:.6f}", f"{rx:.6f}", f"{ry:.6f}"]
|
| 723 |
+
f.write(" ".join(row) + "\n")
|
| 724 |
+
|
| 725 |
+
n_mapped = sum(1 for b in skel.bones.values() if b.smpl_idx is not None)
|
| 726 |
+
print(f"[OK] {T} frames @ {fps} fps -> {output_path} "
|
| 727 |
+
f"(UniRig: {n_mapped} driven, {len(skel.bones)-n_mapped} identity)")
|
| 728 |
+
|
| 729 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 730 |
+
# Motion loader
|
| 731 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 732 |
+
|
| 733 |
+
def load_motion(npy_path: str) -> tuple[np.ndarray, int]:
|
| 734 |
+
"""Return (positions [T, 22, 3], fps). Auto-detects HumanML3D format."""
|
| 735 |
+
data = np.load(npy_path).astype(np.float32)
|
| 736 |
+
print(f"[INFO] {npy_path} shape={data.shape}")
|
| 737 |
+
|
| 738 |
+
if data.ndim == 3 and data.shape[1] == 22 and data.shape[2] == 3:
|
| 739 |
+
print("[INFO] Format: new_joints [T, 22, 3]")
|
| 740 |
+
return data, 20
|
| 741 |
+
|
| 742 |
+
if data.ndim == 2 and data.shape[1] == 263:
|
| 743 |
+
print("[INFO] Format: new_joint_vecs [T, 263]")
|
| 744 |
+
pos = recover_from_ric(data, 22)
|
| 745 |
+
print(f"[INFO] Recovered positions {pos.shape}")
|
| 746 |
+
return pos, 20
|
| 747 |
+
|
| 748 |
+
if data.ndim == 2 and data.shape[1] == 272:
|
| 749 |
+
print("[INFO] Format: 272-dim (30 fps)")
|
| 750 |
+
return recover_from_ric(data[:, :263], 22), 30
|
| 751 |
+
|
| 752 |
+
if (data.ndim == 2 and data.shape[1] == 251) or \
|
| 753 |
+
(data.ndim == 3 and data.shape[1] == 21):
|
| 754 |
+
sys.exit("[ERROR] KIT-ML (21-joint) format not yet supported.")
|
| 755 |
+
|
| 756 |
+
sys.exit(f"[ERROR] Unrecognised shape {data.shape}. "
|
| 757 |
+
"Expected [T,22,3] or [T,263].")
|
| 758 |
+
|
| 759 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 760 |
+
# CLI
|
| 761 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 762 |
+
|
| 763 |
+
def main() -> None:
|
| 764 |
+
ap = argparse.ArgumentParser(
|
| 765 |
+
description="HumanML3D .npy -> BVH, optionally retargeted to UniRig skeleton",
|
| 766 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 767 |
+
epilog="""
|
| 768 |
+
Examples
|
| 769 |
+
python humanml3d_to_bvh.py 000001.npy
|
| 770 |
+
Standard SMPL-named BVH (no rig file needed)
|
| 771 |
+
|
| 772 |
+
python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb
|
| 773 |
+
BVH retargeted to UniRig bone names, auto-mapped by position + name
|
| 774 |
+
|
| 775 |
+
python humanml3d_to_bvh.py 000001.npy --rig rigged_mesh.glb -o anim.bvh --fps 20
|
| 776 |
+
|
| 777 |
+
Supported --rig formats
|
| 778 |
+
.glb / .gltf UniRig merge.sh output (requires: pip install pygltflib)
|
| 779 |
+
.fbx ASCII FBX only (binary FBX: convert to GLB first)
|
| 780 |
+
""")
|
| 781 |
+
ap.add_argument("input", help="HumanML3D .npy motion file")
|
| 782 |
+
ap.add_argument("--rig", default=None,
|
| 783 |
+
help="UniRig-rigged mesh .glb or ASCII .fbx for auto-mapping")
|
| 784 |
+
ap.add_argument("-o", "--output", default=None, help="Output .bvh path")
|
| 785 |
+
ap.add_argument("--fps", type=int, default=0,
|
| 786 |
+
help="Override FPS (default: auto from format)")
|
| 787 |
+
ap.add_argument("--quiet", action="store_true",
|
| 788 |
+
help="Suppress mapping table")
|
| 789 |
+
args = ap.parse_args()
|
| 790 |
+
|
| 791 |
+
inp = Path(args.input)
|
| 792 |
+
out = Path(args.output) if args.output else inp.with_suffix(".bvh")
|
| 793 |
+
|
| 794 |
+
positions, auto_fps = load_motion(str(inp))
|
| 795 |
+
fps = args.fps if args.fps > 0 else auto_fps
|
| 796 |
+
|
| 797 |
+
if args.rig:
|
| 798 |
+
ext = Path(args.rig).suffix.lower()
|
| 799 |
+
if ext in (".glb", ".gltf"):
|
| 800 |
+
skel = parse_glb_skeleton(args.rig)
|
| 801 |
+
elif ext == ".fbx":
|
| 802 |
+
skel = parse_fbx_ascii_skeleton(args.rig)
|
| 803 |
+
else:
|
| 804 |
+
sys.exit(f"[ERROR] Unsupported rig format: {ext} (use .glb or .fbx)")
|
| 805 |
+
|
| 806 |
+
auto_map(skel, verbose=not args.quiet)
|
| 807 |
+
write_bvh_unirig(str(out), positions, skel, fps=fps)
|
| 808 |
+
else:
|
| 809 |
+
write_bvh_smpl(str(out), positions, fps=fps)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
if __name__ == "__main__":
|
| 813 |
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""rig_retarget.io — file format readers / writers."""
|
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
io/bvh.py
|
| 3 |
+
BVH (Biovision Hierarchy) reader.
|
| 4 |
+
|
| 5 |
+
Returns an Armature in rest pose plus an iterator / list of frame states.
|
| 6 |
+
Each frame state sets the bone pose_rotation_quat / pose_location on the
|
| 7 |
+
source armature so that retarget.get_bone_ws_quat / get_bone_position_ws
|
| 8 |
+
return the correct world-space values.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
import math
|
| 12 |
+
import re
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from ..skeleton import Armature, PoseBone
|
| 17 |
+
from ..math3d import translation_matrix, euler_to_quat, quat_identity, vec3
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Internal BVH data structures
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
class _BVHJoint:
|
| 25 |
+
def __init__(self, name: str):
|
| 26 |
+
self.name = name
|
| 27 |
+
self.offset: np.ndarray = vec3()
|
| 28 |
+
self.channels: List[str] = []
|
| 29 |
+
self.children: List["_BVHJoint"] = []
|
| 30 |
+
self.is_end_site: bool = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Parser
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
def _tokenize(text: str) -> List[str]:
|
| 38 |
+
return re.split(r"[\s]+", text.strip())
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _parse_hierarchy(tokens: List[str], idx: int) -> Tuple[_BVHJoint, int]:
|
| 42 |
+
"""Parse one joint block. idx should point to joint name."""
|
| 43 |
+
name = tokens[idx]; idx += 1
|
| 44 |
+
joint = _BVHJoint(name)
|
| 45 |
+
assert tokens[idx] == "{", f"Expected '{{' got '{tokens[idx]}'"
|
| 46 |
+
idx += 1
|
| 47 |
+
while tokens[idx] != "}":
|
| 48 |
+
kw = tokens[idx].upper()
|
| 49 |
+
if kw == "OFFSET":
|
| 50 |
+
joint.offset = np.array([float(tokens[idx+1]), float(tokens[idx+2]), float(tokens[idx+3])])
|
| 51 |
+
idx += 4
|
| 52 |
+
elif kw == "CHANNELS":
|
| 53 |
+
n = int(tokens[idx+1]); idx += 2
|
| 54 |
+
joint.channels = [tokens[idx+i].upper() for i in range(n)]
|
| 55 |
+
idx += n
|
| 56 |
+
elif kw == "JOINT":
|
| 57 |
+
idx += 1
|
| 58 |
+
child, idx = _parse_hierarchy(tokens, idx)
|
| 59 |
+
joint.children.append(child)
|
| 60 |
+
elif kw == "END" and tokens[idx+1].upper() == "SITE":
|
| 61 |
+
# End Site block — just parse and discard
|
| 62 |
+
idx += 2
|
| 63 |
+
assert tokens[idx] == "{"; idx += 1
|
| 64 |
+
while tokens[idx] != "}":
|
| 65 |
+
idx += 1
|
| 66 |
+
idx += 1 # skip '}'
|
| 67 |
+
else:
|
| 68 |
+
idx += 1 # unknown token, skip
|
| 69 |
+
idx += 1 # skip '}'
|
| 70 |
+
return joint, idx
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _collect_joints(joint: _BVHJoint) -> List[_BVHJoint]:
|
| 74 |
+
result = [joint]
|
| 75 |
+
for c in joint.children:
|
| 76 |
+
result.extend(_collect_joints(c))
|
| 77 |
+
return result
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
# Build Armature from BVH hierarchy
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def _build_armature(root_joint: _BVHJoint) -> Armature:
|
| 85 |
+
arm = Armature("BVH_Source")
|
| 86 |
+
|
| 87 |
+
def add_recursive(j: _BVHJoint, parent_name: Optional[str], parent_world: np.ndarray):
|
| 88 |
+
# rest_matrix_local = T(offset) relative to parent
|
| 89 |
+
rest_local = translation_matrix(j.offset)
|
| 90 |
+
bone = PoseBone(j.name, rest_local)
|
| 91 |
+
arm.add_bone(bone, parent_name)
|
| 92 |
+
world = parent_world @ rest_local
|
| 93 |
+
for child in j.children:
|
| 94 |
+
add_recursive(child, j.name, world)
|
| 95 |
+
|
| 96 |
+
add_recursive(root_joint, None, np.eye(4))
|
| 97 |
+
arm.update_fk()
|
| 98 |
+
return arm
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
# Frame application
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
_CHANNEL_MAP = {
|
| 106 |
+
"XROTATION": ("rx",), "YROTATION": ("ry",), "ZROTATION": ("rz",),
|
| 107 |
+
"XPOSITION": ("tx",), "YPOSITION": ("ty",), "ZPOSITION": ("tz",),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _apply_frame(arm: Armature, all_joints: List[_BVHJoint], values: List[float]) -> None:
|
| 112 |
+
"""Set bone poses for one BVH frame."""
|
| 113 |
+
vi = 0
|
| 114 |
+
for j in all_joints:
|
| 115 |
+
tx = ty = tz = 0.0
|
| 116 |
+
rx = ry = rz = 0.0
|
| 117 |
+
for ch in j.channels:
|
| 118 |
+
key = _CHANNEL_MAP.get(ch, None)
|
| 119 |
+
if key:
|
| 120 |
+
val = values[vi]
|
| 121 |
+
k = key[0]
|
| 122 |
+
if k == "tx": tx = val
|
| 123 |
+
elif k == "ty": ty = val
|
| 124 |
+
elif k == "tz": tz = val
|
| 125 |
+
elif k == "rx": rx = math.radians(val)
|
| 126 |
+
elif k == "ry": ry = math.radians(val)
|
| 127 |
+
elif k == "rz": rz = math.radians(val)
|
| 128 |
+
vi += 1
|
| 129 |
+
|
| 130 |
+
if j.name not in arm.pose_bones:
|
| 131 |
+
continue
|
| 132 |
+
bone = arm.pose_bones[j.name]
|
| 133 |
+
|
| 134 |
+
# BVH rotation order is specified per channel list; rebuild from order
|
| 135 |
+
rot_channels = [c for c in j.channels if "ROTATION" in c]
|
| 136 |
+
order = "".join(c[0] for c in rot_channels) # e.g. "ZXY"
|
| 137 |
+
angles = {"X": rx, "Y": ry, "Z": rz}
|
| 138 |
+
angle_seq = [angles[a] for a in order]
|
| 139 |
+
bone.pose_rotation_quat = euler_to_quat(*angle_seq, order=order)
|
| 140 |
+
|
| 141 |
+
# Translation — only root joints typically have it
|
| 142 |
+
if tx or ty or tz:
|
| 143 |
+
bone.pose_location = np.array([tx, ty, tz])
|
| 144 |
+
|
| 145 |
+
arm.update_fk()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
# Public API
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
class BVHAnimation:
|
| 153 |
+
"""Loaded BVH file. Iterate frames by calling advance(frame_index)."""
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
armature: Armature,
|
| 158 |
+
all_joints: List[_BVHJoint],
|
| 159 |
+
frame_data: List[List[float]],
|
| 160 |
+
frame_time: float,
|
| 161 |
+
):
|
| 162 |
+
self.armature = armature
|
| 163 |
+
self._all_joints = all_joints
|
| 164 |
+
self._frame_data = frame_data
|
| 165 |
+
self.frame_time = frame_time
|
| 166 |
+
self.num_frames = len(frame_data)
|
| 167 |
+
|
| 168 |
+
def apply_frame(self, frame_index: int) -> None:
|
| 169 |
+
"""Advance armature to frame_index and update FK."""
|
| 170 |
+
if frame_index < 0 or frame_index >= self.num_frames:
|
| 171 |
+
raise IndexError(f"Frame {frame_index} out of range [0, {self.num_frames})")
|
| 172 |
+
_apply_frame(self.armature, self._all_joints, self._frame_data[frame_index])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def load_bvh(filepath: str) -> BVHAnimation:
|
| 176 |
+
"""
|
| 177 |
+
Parse a BVH file.
|
| 178 |
+
Returns BVHAnimation with an Armature ready for retargeting.
|
| 179 |
+
"""
|
| 180 |
+
with open(filepath, "r") as f:
|
| 181 |
+
text = f.read()
|
| 182 |
+
|
| 183 |
+
tokens = _tokenize(text)
|
| 184 |
+
idx = 0
|
| 185 |
+
|
| 186 |
+
# Expect HIERARCHY keyword
|
| 187 |
+
while tokens[idx].upper() != "HIERARCHY":
|
| 188 |
+
idx += 1
|
| 189 |
+
idx += 1
|
| 190 |
+
|
| 191 |
+
root_kw = tokens[idx].upper()
|
| 192 |
+
assert root_kw in ("ROOT", "JOINT"), f"Expected ROOT/JOINT, got '{tokens[idx]}'"
|
| 193 |
+
idx += 1
|
| 194 |
+
root_joint, idx = _parse_hierarchy(tokens, idx)
|
| 195 |
+
|
| 196 |
+
# MOTION section
|
| 197 |
+
while tokens[idx].upper() != "MOTION":
|
| 198 |
+
idx += 1
|
| 199 |
+
idx += 1
|
| 200 |
+
|
| 201 |
+
assert tokens[idx].upper() == "FRAMES:"; idx += 1
|
| 202 |
+
num_frames = int(tokens[idx]); idx += 1
|
| 203 |
+
assert tokens[idx].upper() == "FRAME"; assert tokens[idx+1].upper() == "TIME:"; idx += 2
|
| 204 |
+
frame_time = float(tokens[idx]); idx += 1
|
| 205 |
+
|
| 206 |
+
all_joints = _collect_joints(root_joint)
|
| 207 |
+
total_channels = sum(len(j.channels) for j in all_joints)
|
| 208 |
+
|
| 209 |
+
frame_data: List[List[float]] = []
|
| 210 |
+
for _ in range(num_frames):
|
| 211 |
+
row = [float(tokens[idx + k]) for k in range(total_channels)]
|
| 212 |
+
idx += total_channels
|
| 213 |
+
frame_data.append(row)
|
| 214 |
+
|
| 215 |
+
arm = _build_armature(root_joint)
|
| 216 |
+
return BVHAnimation(arm, all_joints, frame_data, frame_time)
|
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
io/gltf_io.py
|
| 3 |
+
Load a glTF/GLB skeleton (e.g. UniRig output) into an Armature.
|
| 4 |
+
Write retargeted animation back into a glTF/GLB file.
|
| 5 |
+
|
| 6 |
+
Requires: pip install pygltflib
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
import base64
|
| 10 |
+
import json
|
| 11 |
+
import struct
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import pygltflib
|
| 18 |
+
except ImportError:
|
| 19 |
+
raise ImportError("pip install pygltflib")
|
| 20 |
+
|
| 21 |
+
from ..skeleton import Armature, PoseBone
|
| 22 |
+
from ..math3d import (
|
| 23 |
+
quat_identity, quat_normalize, matrix4_to_quat, matrix4_to_trs,
|
| 24 |
+
trs_to_matrix4, vec3,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Helpers
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
def _node_local_trs(node: "pygltflib.Node"):
|
| 33 |
+
"""Extract TRS from a glTF node. Returns (t[3], r_wxyz[4], s[3])."""
|
| 34 |
+
t = np.array(node.translation or [0.0, 0.0, 0.0])
|
| 35 |
+
r_xyzw = np.array(node.rotation or [0.0, 0.0, 0.0, 1.0])
|
| 36 |
+
s = np.array(node.scale or [1.0, 1.0, 1.0])
|
| 37 |
+
# Convert glTF (x,y,z,w) → our (w,x,y,z)
|
| 38 |
+
r_wxyz = np.array([r_xyzw[3], r_xyzw[0], r_xyzw[1], r_xyzw[2]])
|
| 39 |
+
return t, r_wxyz, s
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _node_local_matrix(node: "pygltflib.Node") -> np.ndarray:
|
| 43 |
+
if node.matrix:
|
| 44 |
+
# glTF stores column-major; convert to row-major
|
| 45 |
+
m = np.array(node.matrix, dtype=float).reshape(4, 4).T
|
| 46 |
+
return m
|
| 47 |
+
t, r, s = _node_local_trs(node)
|
| 48 |
+
return trs_to_matrix4(t, r, s)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _read_accessor(gltf: "pygltflib.GLTF2", accessor_idx: int) -> np.ndarray:
|
| 52 |
+
"""Read a glTF accessor into a numpy array."""
|
| 53 |
+
acc = gltf.accessors[accessor_idx]
|
| 54 |
+
bv = gltf.bufferViews[acc.bufferView]
|
| 55 |
+
buf = gltf.buffers[bv.buffer]
|
| 56 |
+
|
| 57 |
+
# Inline base64 data URI
|
| 58 |
+
if buf.uri and buf.uri.startswith("data:"):
|
| 59 |
+
_, b64 = buf.uri.split(",", 1)
|
| 60 |
+
raw = base64.b64decode(b64)
|
| 61 |
+
elif buf.uri:
|
| 62 |
+
base_dir = Path(gltf._path).parent if hasattr(gltf, "_path") and gltf._path else Path(".")
|
| 63 |
+
raw = (base_dir / buf.uri).read_bytes()
|
| 64 |
+
else:
|
| 65 |
+
# Binary GLB — data stored in gltf.binary_blob
|
| 66 |
+
raw = bytes(gltf.binary_blob())
|
| 67 |
+
|
| 68 |
+
start = bv.byteOffset + (acc.byteOffset or 0)
|
| 69 |
+
count = acc.count
|
| 70 |
+
|
| 71 |
+
type_to_components = {
|
| 72 |
+
"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4,
|
| 73 |
+
"MAT2": 4, "MAT3": 9, "MAT4": 16,
|
| 74 |
+
}
|
| 75 |
+
component_type_to_fmt = {
|
| 76 |
+
5120: "b", 5121: "B", 5122: "h", 5123: "H",
|
| 77 |
+
5125: "I", 5126: "f",
|
| 78 |
+
}
|
| 79 |
+
n_comp = type_to_components[acc.type]
|
| 80 |
+
fmt = component_type_to_fmt[acc.componentType]
|
| 81 |
+
item_size = struct.calcsize(fmt) * n_comp
|
| 82 |
+
stride = bv.byteStride or item_size
|
| 83 |
+
|
| 84 |
+
items = []
|
| 85 |
+
for i in range(count):
|
| 86 |
+
offset = start + i * stride
|
| 87 |
+
vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
|
| 88 |
+
items.append(vals)
|
| 89 |
+
|
| 90 |
+
return np.array(items, dtype=float).squeeze()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Load skeleton from glTF
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
def load_gltf(filepath: str, skin_index: int = 0) -> Armature:
|
| 98 |
+
"""
|
| 99 |
+
Load the first (or specified) skin from a glTF/GLB file into an Armature.
|
| 100 |
+
|
| 101 |
+
The armature world_matrix is set to identity (typical for UniRig output).
|
| 102 |
+
"""
|
| 103 |
+
gltf = pygltflib.GLTF2().load(filepath)
|
| 104 |
+
gltf._path = filepath
|
| 105 |
+
|
| 106 |
+
if not gltf.skins:
|
| 107 |
+
raise ValueError(f"No skins found in '{filepath}'")
|
| 108 |
+
skin = gltf.skins[skin_index]
|
| 109 |
+
|
| 110 |
+
# Read inverse bind matrices
|
| 111 |
+
n_joints = len(skin.joints)
|
| 112 |
+
ibm_array: Optional[np.ndarray] = None
|
| 113 |
+
if skin.inverseBindMatrices is not None:
|
| 114 |
+
raw = _read_accessor(gltf, skin.inverseBindMatrices)
|
| 115 |
+
ibm_array = raw.reshape(n_joints, 4, 4)
|
| 116 |
+
|
| 117 |
+
# Compute bind-pose world matrices: world_bind = inv(ibm)
|
| 118 |
+
joint_world_bind: Dict[int, np.ndarray] = {}
|
| 119 |
+
for i, j_idx in enumerate(skin.joints):
|
| 120 |
+
if ibm_array is not None:
|
| 121 |
+
ibm = ibm_array[i].T # glTF column-major → numpy row-major
|
| 122 |
+
joint_world_bind[j_idx] = np.linalg.inv(ibm)
|
| 123 |
+
else:
|
| 124 |
+
# Fallback: compute from FK over node local matrices
|
| 125 |
+
joint_world_bind[j_idx] = np.eye(4)
|
| 126 |
+
|
| 127 |
+
# Build parent map for nodes
|
| 128 |
+
parent_of: Dict[int, Optional[int]] = {}
|
| 129 |
+
for ni, node in enumerate(gltf.nodes):
|
| 130 |
+
for child_idx in (node.children or []):
|
| 131 |
+
parent_of[child_idx] = ni
|
| 132 |
+
|
| 133 |
+
arm = Armature(skin.name or f"Skin_{skin_index}")
|
| 134 |
+
|
| 135 |
+
# Process joints in order (parent always before child in glTF spec)
|
| 136 |
+
joint_set = set(skin.joints)
|
| 137 |
+
processed: Dict[int, str] = {}
|
| 138 |
+
|
| 139 |
+
for i, j_idx in enumerate(skin.joints):
|
| 140 |
+
node = gltf.nodes[j_idx]
|
| 141 |
+
bone_name = node.name or f"joint_{i}"
|
| 142 |
+
|
| 143 |
+
# Find parent joint node
|
| 144 |
+
parent_node_idx = parent_of.get(j_idx)
|
| 145 |
+
parent_bone_name: Optional[str] = None
|
| 146 |
+
while parent_node_idx is not None:
|
| 147 |
+
if parent_node_idx in joint_set:
|
| 148 |
+
parent_bone_name = processed.get(parent_node_idx)
|
| 149 |
+
break
|
| 150 |
+
parent_node_idx = parent_of.get(parent_node_idx)
|
| 151 |
+
|
| 152 |
+
# rest_matrix_local in parent space
|
| 153 |
+
if parent_bone_name and parent_bone_name in [b for b in processed.values()]:
|
| 154 |
+
parent_world = joint_world_bind.get(
|
| 155 |
+
next(k for k, v in processed.items() if v == parent_bone_name),
|
| 156 |
+
np.eye(4)
|
| 157 |
+
)
|
| 158 |
+
rest_local = np.linalg.inv(parent_world) @ joint_world_bind[j_idx]
|
| 159 |
+
else:
|
| 160 |
+
rest_local = joint_world_bind[j_idx]
|
| 161 |
+
|
| 162 |
+
bone = PoseBone(bone_name, rest_local)
|
| 163 |
+
arm.add_bone(bone, parent_bone_name)
|
| 164 |
+
processed[j_idx] = bone_name
|
| 165 |
+
|
| 166 |
+
arm.update_fk()
|
| 167 |
+
return arm
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Write animation to glTF
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
def write_gltf_animation(
|
| 175 |
+
source_filepath: str,
|
| 176 |
+
dest_armature: Armature,
|
| 177 |
+
keyframes: List[Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]],
|
| 178 |
+
output_filepath: str,
|
| 179 |
+
fps: float = 30.0,
|
| 180 |
+
skin_index: int = 0,
|
| 181 |
+
) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Embed animation keyframes into a copy of source_filepath (the UniRig GLB).
|
| 184 |
+
|
| 185 |
+
keyframes: list of dicts, one per frame.
|
| 186 |
+
Each dict maps bone_name → (pose_location, pose_rotation_quat, pose_scale)
|
| 187 |
+
These are LOCAL values (relative to rest pose local matrix).
|
| 188 |
+
|
| 189 |
+
The function adds one glTF Animation with channels for each bone that has data.
|
| 190 |
+
"""
|
| 191 |
+
gltf = pygltflib.GLTF2().load(source_filepath)
|
| 192 |
+
gltf._path = source_filepath
|
| 193 |
+
|
| 194 |
+
if not gltf.skins:
|
| 195 |
+
raise ValueError("No skins in source file")
|
| 196 |
+
skin = gltf.skins[skin_index]
|
| 197 |
+
|
| 198 |
+
# Build node_name → node_index map for skin joints
|
| 199 |
+
joint_name_to_node: Dict[str, int] = {}
|
| 200 |
+
for j_idx in skin.joints:
|
| 201 |
+
node = gltf.nodes[j_idx]
|
| 202 |
+
name = node.name or f"joint_{j_idx}"
|
| 203 |
+
joint_name_to_node[name] = j_idx
|
| 204 |
+
|
| 205 |
+
n_frames = len(keyframes)
|
| 206 |
+
times = np.array([i / fps for i in range(n_frames)], dtype=np.float32)
|
| 207 |
+
|
| 208 |
+
# Gather binary data
|
| 209 |
+
binary_chunks: List[bytes] = []
|
| 210 |
+
accessors: List[dict] = []
|
| 211 |
+
buffer_views: List[dict] = []
|
| 212 |
+
|
| 213 |
+
def _add_data(data: np.ndarray, acc_type: str) -> int:
|
| 214 |
+
"""Append numpy array to binary, return accessor index."""
|
| 215 |
+
raw = data.astype(np.float32).tobytes()
|
| 216 |
+
bv_offset = sum(len(c) for c in binary_chunks)
|
| 217 |
+
binary_chunks.append(raw)
|
| 218 |
+
bv_idx = len(gltf.bufferViews)
|
| 219 |
+
gltf.bufferViews.append(pygltflib.BufferView(
|
| 220 |
+
buffer=0,
|
| 221 |
+
byteOffset=bv_offset,
|
| 222 |
+
byteLength=len(raw),
|
| 223 |
+
))
|
| 224 |
+
acc_idx = len(gltf.accessors)
|
| 225 |
+
gltf.accessors.append(pygltflib.Accessor(
|
| 226 |
+
bufferView=bv_idx,
|
| 227 |
+
componentType=pygltflib.FLOAT,
|
| 228 |
+
count=len(data),
|
| 229 |
+
type=acc_type,
|
| 230 |
+
max=data.max(axis=0).tolist() if data.ndim > 1 else [float(data.max())],
|
| 231 |
+
min=data.min(axis=0).tolist() if data.ndim > 1 else [float(data.min())],
|
| 232 |
+
))
|
| 233 |
+
return acc_idx
|
| 234 |
+
|
| 235 |
+
time_acc_idx = _add_data(times, "SCALAR")
|
| 236 |
+
|
| 237 |
+
channels: List[pygltflib.AnimationChannel] = []
|
| 238 |
+
samplers: List[pygltflib.AnimationSampler] = []
|
| 239 |
+
|
| 240 |
+
bone_names = set()
|
| 241 |
+
for frame in keyframes:
|
| 242 |
+
bone_names |= frame.keys()
|
| 243 |
+
|
| 244 |
+
for bone_name in sorted(bone_names):
|
| 245 |
+
if bone_name not in joint_name_to_node:
|
| 246 |
+
continue
|
| 247 |
+
node_idx = joint_name_to_node[bone_name]
|
| 248 |
+
node = gltf.nodes[node_idx]
|
| 249 |
+
|
| 250 |
+
# Collect TRS arrays across frames
|
| 251 |
+
rot_data = np.zeros((n_frames, 4), dtype=np.float32) # (x,y,z,w)
|
| 252 |
+
trans_data = np.zeros((n_frames, 3), dtype=np.float32)
|
| 253 |
+
scale_data = np.ones((n_frames, 3), dtype=np.float32)
|
| 254 |
+
|
| 255 |
+
rest_t, rest_r, rest_s = _node_local_trs(node)
|
| 256 |
+
|
| 257 |
+
for fi, frame in enumerate(keyframes):
|
| 258 |
+
if bone_name in frame:
|
| 259 |
+
pose_loc, pose_rot, pose_scale = frame[bone_name]
|
| 260 |
+
else:
|
| 261 |
+
pose_loc = vec3()
|
| 262 |
+
pose_rot = quat_identity()
|
| 263 |
+
pose_scale = np.ones(3)
|
| 264 |
+
|
| 265 |
+
# Final local = rest + delta (simple addition for translation, multiply for rotation)
|
| 266 |
+
from ..math3d import quat_mul, trs_to_matrix4
|
| 267 |
+
final_t = rest_t + pose_loc
|
| 268 |
+
final_r = quat_mul(rest_r, pose_rot) # (w,x,y,z)
|
| 269 |
+
final_s = rest_s * pose_scale
|
| 270 |
+
|
| 271 |
+
# Convert rotation to glTF (x,y,z,w)
|
| 272 |
+
w, x, y, z = final_r
|
| 273 |
+
rot_data[fi] = [x, y, z, w]
|
| 274 |
+
trans_data[fi] = final_t
|
| 275 |
+
scale_data[fi] = final_s
|
| 276 |
+
|
| 277 |
+
s_idx = len(samplers)
|
| 278 |
+
rot_acc = _add_data(rot_data, "VEC4")
|
| 279 |
+
samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=rot_acc, interpolation="LINEAR"))
|
| 280 |
+
channels.append(pygltflib.AnimationChannel(
|
| 281 |
+
sampler=s_idx,
|
| 282 |
+
target=pygltflib.AnimationChannelTarget(node=node_idx, path="rotation"),
|
| 283 |
+
))
|
| 284 |
+
|
| 285 |
+
s_idx = len(samplers)
|
| 286 |
+
trans_acc = _add_data(trans_data, "VEC3")
|
| 287 |
+
samplers.append(pygltflib.AnimationSampler(input=time_acc_idx, output=trans_acc, interpolation="LINEAR"))
|
| 288 |
+
channels.append(pygltflib.AnimationChannel(
|
| 289 |
+
sampler=s_idx,
|
| 290 |
+
target=pygltflib.AnimationChannelTarget(node=node_idx, path="translation"),
|
| 291 |
+
))
|
| 292 |
+
|
| 293 |
+
if not channels:
|
| 294 |
+
print("[gltf_io] Warning: no channels written — check bone name mapping.")
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
gltf.animations.append(pygltflib.Animation(
|
| 298 |
+
name="RetargetedAnimation",
|
| 299 |
+
samplers=samplers,
|
| 300 |
+
channels=channels,
|
| 301 |
+
))
|
| 302 |
+
|
| 303 |
+
# Patch buffer 0 size with our new data
|
| 304 |
+
new_blob = b"".join(binary_chunks)
|
| 305 |
+
existing_blob = bytes(gltf.binary_blob()) if gltf.binary_blob() else b""
|
| 306 |
+
full_blob = existing_blob + new_blob
|
| 307 |
+
|
| 308 |
+
# Update buffer 0 byteOffset of new views
|
| 309 |
+
for bv in gltf.bufferViews[-len(binary_chunks):]:
|
| 310 |
+
bv.byteOffset += len(existing_blob)
|
| 311 |
+
|
| 312 |
+
gltf.set_binary_blob(full_blob)
|
| 313 |
+
gltf.buffers[0].byteLength = len(full_blob)
|
| 314 |
+
|
| 315 |
+
gltf.save(output_filepath)
|
| 316 |
+
print(f"[gltf_io] Saved animated GLB -> {output_filepath}")
|
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
io/mapping.py
|
| 3 |
+
Load / save bone mapping JSON in the exact same format as KeeMap.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
from ..math3d import quat_identity, vec3
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BoneMappingItem:
|
| 15 |
+
name: str = ""
|
| 16 |
+
label: str = ""
|
| 17 |
+
description: str = ""
|
| 18 |
+
|
| 19 |
+
source_bone_name: str = ""
|
| 20 |
+
destination_bone_name: str = ""
|
| 21 |
+
|
| 22 |
+
keyframe_this_bone: bool = True
|
| 23 |
+
|
| 24 |
+
# Rotation correction (Euler, radians)
|
| 25 |
+
correction_factor: np.ndarray = field(default_factory=lambda: vec3())
|
| 26 |
+
|
| 27 |
+
# Quaternion correction
|
| 28 |
+
quat_correction_factor: np.ndarray = field(default_factory=quat_identity)
|
| 29 |
+
|
| 30 |
+
has_twist_bone: bool = False
|
| 31 |
+
twist_bone_name: str = ""
|
| 32 |
+
|
| 33 |
+
set_bone_position: bool = False
|
| 34 |
+
set_bone_rotation: bool = True
|
| 35 |
+
set_bone_scale: bool = False
|
| 36 |
+
|
| 37 |
+
# Rotation options
|
| 38 |
+
bone_rotation_application_axis: str = "XYZ" # X Y Z XY XZ YZ XYZ
|
| 39 |
+
bone_transpose_axis: str = "NONE" # NONE ZXY ZYX XZY YZX YXZ
|
| 40 |
+
|
| 41 |
+
# Position options
|
| 42 |
+
postion_type: str = "SINGLE_BONE_OFFSET" # SINGLE_BONE_OFFSET | POLE
|
| 43 |
+
position_correction_factor: np.ndarray = field(default_factory=lambda: vec3())
|
| 44 |
+
position_gain: float = 1.0
|
| 45 |
+
position_pole_distance: float = 0.3
|
| 46 |
+
|
| 47 |
+
# Scale options
|
| 48 |
+
scale_secondary_bone_name: str = ""
|
| 49 |
+
bone_scale_application_axis: str = "Y"
|
| 50 |
+
scale_gain: float = 1.0
|
| 51 |
+
scale_max: float = 1.0
|
| 52 |
+
scale_min: float = 0.5
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class KeeMapSettings:
|
| 57 |
+
source_rig_name: str = ""
|
| 58 |
+
destination_rig_name: str = ""
|
| 59 |
+
bone_mapping_file: str = ""
|
| 60 |
+
bone_rotation_mode: str = "EULER" # EULER | QUATERNION
|
| 61 |
+
start_frame_to_apply: int = 0
|
| 62 |
+
number_of_frames_to_apply: int = 100
|
| 63 |
+
keyframe_every_n_frames: int = 1
|
| 64 |
+
keyframe_test: bool = False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# Load
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
def load_mapping(filepath: str):
|
| 72 |
+
"""
|
| 73 |
+
Returns (KeeMapSettings, List[BoneMappingItem]).
|
| 74 |
+
Reads the exact same JSON that KeeMap writes.
|
| 75 |
+
"""
|
| 76 |
+
with open(filepath, "r") as f:
|
| 77 |
+
data = json.load(f)
|
| 78 |
+
|
| 79 |
+
settings = KeeMapSettings(
|
| 80 |
+
source_rig_name=data.get("source_rig_name", ""),
|
| 81 |
+
destination_rig_name=data.get("destination_rig_name", ""),
|
| 82 |
+
bone_mapping_file=data.get("bone_mapping_file", ""),
|
| 83 |
+
bone_rotation_mode=data.get("bone_rotation_mode", "EULER"),
|
| 84 |
+
start_frame_to_apply=data.get("start_frame_to_apply", 0),
|
| 85 |
+
number_of_frames_to_apply=data.get("number_of_frames_to_apply", 100),
|
| 86 |
+
keyframe_every_n_frames=data.get("keyframe_every_n_frames", 1),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
bones: List[BoneMappingItem] = []
|
| 90 |
+
for p in data.get("bones", []):
|
| 91 |
+
item = BoneMappingItem()
|
| 92 |
+
item.name = p.get("name", "")
|
| 93 |
+
item.label = p.get("label", "")
|
| 94 |
+
item.description = p.get("description", "")
|
| 95 |
+
item.source_bone_name = p.get("SourceBoneName", "")
|
| 96 |
+
item.destination_bone_name = p.get("DestinationBoneName", "")
|
| 97 |
+
item.keyframe_this_bone = p.get("keyframe_this_bone", True)
|
| 98 |
+
|
| 99 |
+
item.correction_factor = np.array([
|
| 100 |
+
p.get("CorrectionFactorX", 0.0),
|
| 101 |
+
p.get("CorrectionFactorY", 0.0),
|
| 102 |
+
p.get("CorrectionFactorZ", 0.0),
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
item.quat_correction_factor = np.array([
|
| 106 |
+
p.get("QuatCorrectionFactorw", 1.0),
|
| 107 |
+
p.get("QuatCorrectionFactorx", 0.0),
|
| 108 |
+
p.get("QuatCorrectionFactory", 0.0),
|
| 109 |
+
p.get("QuatCorrectionFactorz", 0.0),
|
| 110 |
+
])
|
| 111 |
+
|
| 112 |
+
item.has_twist_bone = p.get("has_twist_bone", False)
|
| 113 |
+
item.twist_bone_name = p.get("TwistBoneName", "")
|
| 114 |
+
item.set_bone_position = p.get("set_bone_position", False)
|
| 115 |
+
item.set_bone_rotation = p.get("set_bone_rotation", True)
|
| 116 |
+
item.set_bone_scale = p.get("set_bone_scale", False)
|
| 117 |
+
item.bone_rotation_application_axis = p.get("bone_rotation_application_axis", "XYZ")
|
| 118 |
+
item.bone_transpose_axis = p.get("bone_transpose_axis", "NONE")
|
| 119 |
+
item.postion_type = p.get("postion_type", "SINGLE_BONE_OFFSET")
|
| 120 |
+
|
| 121 |
+
item.position_correction_factor = np.array([
|
| 122 |
+
p.get("position_correction_factorX", 0.0),
|
| 123 |
+
p.get("position_correction_factorY", 0.0),
|
| 124 |
+
p.get("position_correction_factorZ", 0.0),
|
| 125 |
+
])
|
| 126 |
+
item.position_gain = p.get("position_gain", 1.0)
|
| 127 |
+
item.position_pole_distance = p.get("position_pole_distance", 0.3)
|
| 128 |
+
item.scale_secondary_bone_name = p.get("scale_secondary_bone_name", "")
|
| 129 |
+
item.bone_scale_application_axis = p.get("bone_scale_application_axis", "Y")
|
| 130 |
+
item.scale_gain = p.get("scale_gain", 1.0)
|
| 131 |
+
item.scale_max = p.get("scale_max", 1.0)
|
| 132 |
+
item.scale_min = p.get("scale_min", 0.5)
|
| 133 |
+
bones.append(item)
|
| 134 |
+
|
| 135 |
+
return settings, bones
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
# Save
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
def save_mapping(filepath: str, settings: KeeMapSettings, bones: List[BoneMappingItem]) -> None:
|
| 143 |
+
"""Write mapping JSON readable by KeeMap."""
|
| 144 |
+
root = {
|
| 145 |
+
"source_rig_name": settings.source_rig_name,
|
| 146 |
+
"destination_rig_name": settings.destination_rig_name,
|
| 147 |
+
"bone_mapping_file": settings.bone_mapping_file,
|
| 148 |
+
"bone_rotation_mode": settings.bone_rotation_mode,
|
| 149 |
+
"start_frame_to_apply": settings.start_frame_to_apply,
|
| 150 |
+
"number_of_frames_to_apply": settings.number_of_frames_to_apply,
|
| 151 |
+
"keyframe_every_n_frames": settings.keyframe_every_n_frames,
|
| 152 |
+
"bones": [],
|
| 153 |
+
}
|
| 154 |
+
for b in bones:
|
| 155 |
+
root["bones"].append({
|
| 156 |
+
"name": b.name,
|
| 157 |
+
"label": b.label,
|
| 158 |
+
"description": b.description,
|
| 159 |
+
"SourceBoneName": b.source_bone_name,
|
| 160 |
+
"DestinationBoneName": b.destination_bone_name,
|
| 161 |
+
"keyframe_this_bone": b.keyframe_this_bone,
|
| 162 |
+
"CorrectionFactorX": float(b.correction_factor[0]),
|
| 163 |
+
"CorrectionFactorY": float(b.correction_factor[1]),
|
| 164 |
+
"CorrectionFactorZ": float(b.correction_factor[2]),
|
| 165 |
+
"QuatCorrectionFactorw": float(b.quat_correction_factor[0]),
|
| 166 |
+
"QuatCorrectionFactorx": float(b.quat_correction_factor[1]),
|
| 167 |
+
"QuatCorrectionFactory": float(b.quat_correction_factor[2]),
|
| 168 |
+
"QuatCorrectionFactorz": float(b.quat_correction_factor[3]),
|
| 169 |
+
"has_twist_bone": b.has_twist_bone,
|
| 170 |
+
"TwistBoneName": b.twist_bone_name,
|
| 171 |
+
"set_bone_position": b.set_bone_position,
|
| 172 |
+
"set_bone_rotation": b.set_bone_rotation,
|
| 173 |
+
"set_bone_scale": b.set_bone_scale,
|
| 174 |
+
"bone_rotation_application_axis": b.bone_rotation_application_axis,
|
| 175 |
+
"bone_transpose_axis": b.bone_transpose_axis,
|
| 176 |
+
"postion_type": b.postion_type,
|
| 177 |
+
"position_correction_factorX": float(b.position_correction_factor[0]),
|
| 178 |
+
"position_correction_factorY": float(b.position_correction_factor[1]),
|
| 179 |
+
"position_correction_factorZ": float(b.position_correction_factor[2]),
|
| 180 |
+
"position_gain": b.position_gain,
|
| 181 |
+
"position_pole_distance": b.position_pole_distance,
|
| 182 |
+
"scale_secondary_bone_name": b.scale_secondary_bone_name,
|
| 183 |
+
"bone_scale_application_axis": b.bone_scale_application_axis,
|
| 184 |
+
"scale_gain": b.scale_gain,
|
| 185 |
+
"scale_max": b.scale_max,
|
| 186 |
+
"scale_min": b.scale_min,
|
| 187 |
+
})
|
| 188 |
+
with open(filepath, "w") as f:
|
| 189 |
+
json.dump(root, f, indent=2)
|
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
math3d.py
|
| 3 |
+
Pure numpy / scipy replacement for Blender's mathutils.
|
| 4 |
+
Quaternion convention throughout: (w, x, y, z)
|
| 5 |
+
Matrix convention: row-major, right-multiplied (Numpy default)
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial.transform import Rotation
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Quaternion helpers (w, x, y, z)
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
def quat_identity() -> np.ndarray:
|
| 16 |
+
return np.array([1.0, 0.0, 0.0, 0.0])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def quat_normalize(q: np.ndarray) -> np.ndarray:
|
| 20 |
+
n = np.linalg.norm(q)
|
| 21 |
+
return q / n if n > 1e-12 else quat_identity()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def quat_conjugate(q: np.ndarray) -> np.ndarray:
|
| 25 |
+
"""Conjugate == inverse for unit quaternion."""
|
| 26 |
+
return np.array([q[0], -q[1], -q[2], -q[3]])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
|
| 30 |
+
"""Quaternion multiplication (Blender @ operator)."""
|
| 31 |
+
w1, x1, y1, z1 = q1
|
| 32 |
+
w2, x2, y2, z2 = q2
|
| 33 |
+
return np.array([
|
| 34 |
+
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
| 35 |
+
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
| 36 |
+
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
| 37 |
+
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def quat_rotation_difference(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Rotation that takes q1 to q2.
|
| 44 |
+
r such that q1 @ r == q2
|
| 45 |
+
r = conj(q1) @ q2
|
| 46 |
+
Matches Blender's Quaternion.rotation_difference()
|
| 47 |
+
"""
|
| 48 |
+
return quat_normalize(quat_mul(quat_conjugate(q1), q2))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def quat_dot(q1: np.ndarray, q2: np.ndarray) -> float:
|
| 52 |
+
"""Dot product of two quaternions (used for scale retargeting)."""
|
| 53 |
+
return float(np.dot(q1, q2))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def quat_to_matrix4(q: np.ndarray) -> np.ndarray:
|
| 57 |
+
"""Unit quaternion (w,x,y,z) → 4×4 rotation matrix."""
|
| 58 |
+
w, x, y, z = q
|
| 59 |
+
m = np.array([
|
| 60 |
+
[1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w), 0],
|
| 61 |
+
[ 2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w), 0],
|
| 62 |
+
[ 2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y), 0],
|
| 63 |
+
[ 0, 0, 0, 1],
|
| 64 |
+
], dtype=float)
|
| 65 |
+
return m
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def matrix4_to_quat(m: np.ndarray) -> np.ndarray:
|
| 69 |
+
"""4×4 matrix → unit quaternion (w,x,y,z)."""
|
| 70 |
+
r = Rotation.from_matrix(m[:3, :3])
|
| 71 |
+
x, y, z, w = r.as_quat() # scipy uses (x,y,z,w)
|
| 72 |
+
q = np.array([w, x, y, z])
|
| 73 |
+
# Ensure positive w to match Blender convention
|
| 74 |
+
if q[0] < 0:
|
| 75 |
+
q = -q
|
| 76 |
+
return quat_normalize(q)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# Euler ↔ Quaternion
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def euler_to_quat(rx: float, ry: float, rz: float, order: str = "XYZ") -> np.ndarray:
|
| 84 |
+
"""Euler angles (radians) to quaternion (w,x,y,z)."""
|
| 85 |
+
r = Rotation.from_euler(order, [rx, ry, rz])
|
| 86 |
+
x, y, z, w = r.as_quat()
|
| 87 |
+
return quat_normalize(np.array([w, x, y, z]))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def quat_to_euler(q: np.ndarray, order: str = "XYZ") -> np.ndarray:
|
| 91 |
+
"""Quaternion (w,x,y,z) to Euler angles (radians)."""
|
| 92 |
+
w, x, y, z = q
|
| 93 |
+
r = Rotation.from_quat([x, y, z, w])
|
| 94 |
+
return r.as_euler(order)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Matrix constructors
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
def translation_matrix(v) -> np.ndarray:
|
| 102 |
+
m = np.eye(4)
|
| 103 |
+
m[0, 3] = v[0]
|
| 104 |
+
m[1, 3] = v[1]
|
| 105 |
+
m[2, 3] = v[2]
|
| 106 |
+
return m
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def scale_matrix(s) -> np.ndarray:
|
| 110 |
+
m = np.eye(4)
|
| 111 |
+
m[0, 0] = s[0]
|
| 112 |
+
m[1, 1] = s[1]
|
| 113 |
+
m[2, 2] = s[2]
|
| 114 |
+
return m
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def trs_to_matrix4(t, r_quat, s) -> np.ndarray:
|
| 118 |
+
"""Combine translation, rotation (w,x,y,z quat), scale into 4×4."""
|
| 119 |
+
T = translation_matrix(t)
|
| 120 |
+
R = quat_to_matrix4(r_quat)
|
| 121 |
+
S = scale_matrix(s)
|
| 122 |
+
return T @ R @ S
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def matrix4_to_trs(m: np.ndarray):
|
| 126 |
+
"""Decompose 4×4 into (translation[3], rotation_quat[4], scale[3])."""
|
| 127 |
+
t = m[:3, 3].copy()
|
| 128 |
+
sx = np.linalg.norm(m[:3, 0])
|
| 129 |
+
sy = np.linalg.norm(m[:3, 1])
|
| 130 |
+
sz = np.linalg.norm(m[:3, 2])
|
| 131 |
+
s = np.array([sx, sy, sz])
|
| 132 |
+
rot_m = m[:3, :3].copy()
|
| 133 |
+
if sx > 1e-12: rot_m[:, 0] /= sx
|
| 134 |
+
if sy > 1e-12: rot_m[:, 1] /= sy
|
| 135 |
+
if sz > 1e-12: rot_m[:, 2] /= sz
|
| 136 |
+
r = Rotation.from_matrix(rot_m)
|
| 137 |
+
x, y, z, w = r.as_quat()
|
| 138 |
+
q = np.array([w, x, y, z])
|
| 139 |
+
if q[0] < 0:
|
| 140 |
+
q = -q
|
| 141 |
+
return t, quat_normalize(q), s
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
# Vector helpers
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
def vec3(x=0.0, y=0.0, z=0.0) -> np.ndarray:
|
| 149 |
+
return np.array([x, y, z], dtype=float)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_point_on_vector(initial_pt: np.ndarray, terminal_pt: np.ndarray, distance: float) -> np.ndarray:
|
| 153 |
+
"""
|
| 154 |
+
Point at 'distance' from initial_pt along (initial_pt → terminal_pt).
|
| 155 |
+
Matches Blender's get_point_on_vector helper in KeeMapBoneOperators.
|
| 156 |
+
"""
|
| 157 |
+
n = initial_pt - terminal_pt
|
| 158 |
+
norm = np.linalg.norm(n)
|
| 159 |
+
if norm < 1e-12:
|
| 160 |
+
return initial_pt.copy()
|
| 161 |
+
n = n / norm
|
| 162 |
+
return initial_pt - distance * n
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def apply_rotation_matrix4(m: np.ndarray, v: np.ndarray) -> np.ndarray:
|
| 166 |
+
"""Apply only the rotation part of a 4×4 matrix to a 3-vector."""
|
| 167 |
+
return m[:3, :3] @ v
|
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
retarget.py
|
| 3 |
+
Pure-Python port of KeeMapBoneOperators.py core math.
|
| 4 |
+
|
| 5 |
+
Replaces bpy / mathutils with numpy. No Blender dependency.
|
| 6 |
+
Public API mirrors the Blender operator flow:
|
| 7 |
+
|
| 8 |
+
get_bone_position_ws(bone, arm) → np.ndarray(3)
|
| 9 |
+
get_bone_ws_quat(bone, arm) → np.ndarray(4) w,x,y,z
|
| 10 |
+
set_bone_position_ws(bone, arm, pos)
|
| 11 |
+
set_bone_rotation(...)
|
| 12 |
+
set_bone_position(...)
|
| 13 |
+
set_bone_position_pole(...)
|
| 14 |
+
set_bone_scale(...)
|
| 15 |
+
calc_rotation_offset(bone_item, src_arm, dst_arm, settings)
|
| 16 |
+
calc_location_offset(bone_item, src_arm, dst_arm)
|
| 17 |
+
transfer_frame(src_arm, dst_arm, bone_items, settings, do_keyframe)
|
| 18 |
+
→ Dict[bone_name → (pose_loc, pose_rot, pose_scale)]
|
| 19 |
+
transfer_animation(src_anim, dst_arm, bone_items, settings)
|
| 20 |
+
→ List[Dict[bone_name → (pose_loc, pose_rot, pose_scale)]]
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
import math
|
| 24 |
+
import sys
|
| 25 |
+
from typing import Dict, List, Optional, Tuple
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from .skeleton import Armature, PoseBone
|
| 29 |
+
from .math3d import (
|
| 30 |
+
quat_identity, quat_normalize, quat_mul, quat_conjugate,
|
| 31 |
+
quat_rotation_difference, quat_dot,
|
| 32 |
+
quat_to_matrix4, matrix4_to_quat,
|
| 33 |
+
euler_to_quat, quat_to_euler,
|
| 34 |
+
translation_matrix, vec3, get_point_on_vector,
|
| 35 |
+
)
|
| 36 |
+
from .io.mapping import BoneMappingItem, KeeMapSettings
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Progress bar (console)
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def _update_progress(job: str, progress: float) -> None:
|
| 44 |
+
length = 40
|
| 45 |
+
block = int(round(length * progress))
|
| 46 |
+
msg = f"\r{job}: [{'#'*block}{'-'*(length-block)}] {round(progress*100, 1)}%"
|
| 47 |
+
if progress >= 1:
|
| 48 |
+
msg += " DONE\r\n"
|
| 49 |
+
sys.stdout.write(msg)
|
| 50 |
+
sys.stdout.flush()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# World-space position / quaternion getters
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def get_bone_position_ws(bone: PoseBone, arm: Armature) -> np.ndarray:
|
| 58 |
+
"""
|
| 59 |
+
Return world-space position of bone head.
|
| 60 |
+
Equivalent to Blender's GetBonePositionWS().
|
| 61 |
+
"""
|
| 62 |
+
ws_matrix = arm.world_matrix @ bone.matrix_armature
|
| 63 |
+
return ws_matrix[:3, 3].copy()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_bone_ws_quat(bone: PoseBone, arm: Armature) -> np.ndarray:
|
| 67 |
+
"""
|
| 68 |
+
Return world-space rotation as quaternion (w,x,y,z).
|
| 69 |
+
Equivalent to Blender's GetBoneWSQuat().
|
| 70 |
+
"""
|
| 71 |
+
ws_matrix = arm.world_matrix @ bone.matrix_armature
|
| 72 |
+
return matrix4_to_quat(ws_matrix)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
# World-space position setter
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def set_bone_position_ws(bone: PoseBone, arm: Armature, position: np.ndarray) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Move bone so its world-space head = position.
|
| 82 |
+
Equivalent to Blender's SetBonePositionWS().
|
| 83 |
+
|
| 84 |
+
Strategy:
|
| 85 |
+
1. Build new armature-space matrix = old rotation + new translation
|
| 86 |
+
2. Strip parent transform to get new local translation
|
| 87 |
+
3. Update pose_location so FK matches
|
| 88 |
+
"""
|
| 89 |
+
# Current armature-space matrix (rotation/scale part preserved)
|
| 90 |
+
arm_mat = bone.matrix_armature.copy()
|
| 91 |
+
|
| 92 |
+
# Target armature-space position
|
| 93 |
+
arm_world_inv = np.linalg.inv(arm.world_matrix)
|
| 94 |
+
target_arm_pos = (arm_world_inv @ np.append(position, 1.0))[:3]
|
| 95 |
+
|
| 96 |
+
# New armature-space matrix with replaced translation
|
| 97 |
+
new_arm_mat = arm_mat.copy()
|
| 98 |
+
new_arm_mat[:3, 3] = target_arm_pos
|
| 99 |
+
|
| 100 |
+
# Convert to local (parent-relative) space
|
| 101 |
+
if bone.parent is not None:
|
| 102 |
+
parent_arm_mat = bone.parent.matrix_armature
|
| 103 |
+
new_local = np.linalg.inv(parent_arm_mat) @ new_arm_mat
|
| 104 |
+
else:
|
| 105 |
+
new_local = new_arm_mat
|
| 106 |
+
|
| 107 |
+
# Extract translation from new_local = rest_local @ T(pose_loc) @ ...
|
| 108 |
+
# Approximate: strip rest_local rotation contribution to isolate pose_location
|
| 109 |
+
rest_inv = np.linalg.inv(bone.rest_matrix_local)
|
| 110 |
+
pose_delta = rest_inv @ new_local
|
| 111 |
+
bone.pose_location = pose_delta[:3, 3].copy()
|
| 112 |
+
|
| 113 |
+
# Recompute FK for this bone and its subtree
|
| 114 |
+
if bone.parent is not None:
|
| 115 |
+
bone._fk(bone.parent.matrix_armature)
|
| 116 |
+
else:
|
| 117 |
+
bone._fk(np.eye(4))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
# Rotation setter (core retargeting math)
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
|
| 124 |
+
def set_bone_rotation(
|
| 125 |
+
src_arm: Armature, src_name: str,
|
| 126 |
+
dst_arm: Armature, dst_name: str,
|
| 127 |
+
dst_twist_name: str,
|
| 128 |
+
correction_quat: np.ndarray,
|
| 129 |
+
has_twist: bool,
|
| 130 |
+
xfer_axis: str,
|
| 131 |
+
transpose: str,
|
| 132 |
+
mode: str,
|
| 133 |
+
) -> None:
|
| 134 |
+
"""
|
| 135 |
+
Port of Blender's SetBoneRotation().
|
| 136 |
+
Drives dst bone rotation to match src bone world-space rotation.
|
| 137 |
+
|
| 138 |
+
mode: "EULER" | "QUATERNION"
|
| 139 |
+
xfer_axis: "X" "Y" "Z" "XY" "XZ" "YZ" "XYZ"
|
| 140 |
+
transpose: "NONE" "ZYX" "ZXY" "XZY" "YZX" "YXZ"
|
| 141 |
+
"""
|
| 142 |
+
src_bone = src_arm.get_bone(src_name)
|
| 143 |
+
dst_bone = dst_arm.get_bone(dst_name)
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------
|
| 146 |
+
# Get source and destination world-space quaternions (current pose)
|
| 147 |
+
# ------------------------------------------------------------------
|
| 148 |
+
src_ws_quat = get_bone_ws_quat(src_bone, src_arm)
|
| 149 |
+
dst_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
|
| 150 |
+
|
| 151 |
+
# Rotation difference: r such that dst_ws @ r ≈ src_ws
|
| 152 |
+
diff = quat_rotation_difference(dst_ws_quat, src_ws_quat)
|
| 153 |
+
|
| 154 |
+
# FinalQuat = dst_local_pose_delta @ diff @ correction
|
| 155 |
+
final_quat = quat_normalize(
|
| 156 |
+
quat_mul(quat_mul(dst_bone.pose_rotation_quat, diff), correction_quat)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# ------------------------------------------------------------------
|
| 160 |
+
# Apply axis masking / transpose (EULER mode)
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
if mode == "EULER":
|
| 163 |
+
euler = quat_to_euler(final_quat, order="XYZ")
|
| 164 |
+
|
| 165 |
+
# Transpose axes
|
| 166 |
+
if transpose == "ZYX":
|
| 167 |
+
euler = np.array([euler[2], euler[1], euler[0]])
|
| 168 |
+
elif transpose == "ZXY":
|
| 169 |
+
euler = np.array([euler[2], euler[0], euler[1]])
|
| 170 |
+
elif transpose == "XZY":
|
| 171 |
+
euler = np.array([euler[0], euler[2], euler[1]])
|
| 172 |
+
elif transpose == "YZX":
|
| 173 |
+
euler = np.array([euler[1], euler[2], euler[0]])
|
| 174 |
+
elif transpose == "YXZ":
|
| 175 |
+
euler = np.array([euler[1], euler[0], euler[2]])
|
| 176 |
+
# else NONE — no change
|
| 177 |
+
|
| 178 |
+
# Mask axes
|
| 179 |
+
if xfer_axis == "X":
|
| 180 |
+
euler[1] = 0.0; euler[2] = 0.0
|
| 181 |
+
elif xfer_axis == "Y":
|
| 182 |
+
euler[0] = 0.0; euler[2] = 0.0
|
| 183 |
+
elif xfer_axis == "Z":
|
| 184 |
+
euler[0] = 0.0; euler[1] = 0.0
|
| 185 |
+
elif xfer_axis == "XY":
|
| 186 |
+
euler[2] = 0.0
|
| 187 |
+
elif xfer_axis == "XZ":
|
| 188 |
+
euler[1] = 0.0
|
| 189 |
+
elif xfer_axis == "YZ":
|
| 190 |
+
euler[0] = 0.0
|
| 191 |
+
# XYZ → no masking
|
| 192 |
+
|
| 193 |
+
final_quat = euler_to_quat(euler[0], euler[1], euler[2], order="XYZ")
|
| 194 |
+
|
| 195 |
+
# Twist bone: peel Y rotation off to twist bone
|
| 196 |
+
if has_twist and dst_twist_name:
|
| 197 |
+
twist_bone = dst_arm.get_bone(dst_twist_name)
|
| 198 |
+
y_euler = quat_to_euler(final_quat, order="XYZ")[1]
|
| 199 |
+
# Remove Y from main bone
|
| 200 |
+
euler_no_y = quat_to_euler(final_quat, order="XYZ")
|
| 201 |
+
euler_no_y[1] = 0.0
|
| 202 |
+
final_quat = euler_to_quat(*euler_no_y, order="XYZ")
|
| 203 |
+
# Apply Y to twist bone
|
| 204 |
+
twist_euler = quat_to_euler(twist_bone.pose_rotation_quat, order="XYZ")
|
| 205 |
+
twist_euler[1] = math.degrees(y_euler)
|
| 206 |
+
twist_bone.pose_rotation_quat = euler_to_quat(*twist_euler, order="XYZ")
|
| 207 |
+
|
| 208 |
+
else: # QUATERNION
|
| 209 |
+
if final_quat[0] < 0:
|
| 210 |
+
final_quat = -final_quat
|
| 211 |
+
final_quat = quat_normalize(final_quat)
|
| 212 |
+
|
| 213 |
+
dst_bone.pose_rotation_quat = final_quat
|
| 214 |
+
|
| 215 |
+
# Recompute FK
|
| 216 |
+
parent = dst_bone.parent
|
| 217 |
+
dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ---------------------------------------------------------------------------
|
| 221 |
+
# Position setter
|
| 222 |
+
# ---------------------------------------------------------------------------
|
| 223 |
+
|
| 224 |
+
def set_bone_position(
|
| 225 |
+
src_arm: Armature, src_name: str,
|
| 226 |
+
dst_arm: Armature, dst_name: str,
|
| 227 |
+
dst_twist_name: str,
|
| 228 |
+
correction: np.ndarray,
|
| 229 |
+
gain: float,
|
| 230 |
+
) -> None:
|
| 231 |
+
"""
|
| 232 |
+
Port of Blender's SetBonePosition().
|
| 233 |
+
Moves dst bone to match src bone world-space position, with offset/gain.
|
| 234 |
+
"""
|
| 235 |
+
src_bone = src_arm.get_bone(src_name)
|
| 236 |
+
dst_bone = dst_arm.get_bone(dst_name)
|
| 237 |
+
|
| 238 |
+
target_ws = get_bone_position_ws(src_bone, src_arm)
|
| 239 |
+
set_bone_position_ws(dst_bone, dst_arm, target_ws)
|
| 240 |
+
|
| 241 |
+
# Apply correction and gain to pose_location
|
| 242 |
+
dst_bone.pose_location[0] = (dst_bone.pose_location[0] + correction[0]) * gain
|
| 243 |
+
dst_bone.pose_location[1] = (dst_bone.pose_location[1] + correction[1]) * gain
|
| 244 |
+
dst_bone.pose_location[2] = (dst_bone.pose_location[2] + correction[2]) * gain
|
| 245 |
+
|
| 246 |
+
parent = dst_bone.parent
|
| 247 |
+
dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
# Pole bone position setter
|
| 252 |
+
# ---------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def set_bone_position_pole(
|
| 255 |
+
src_arm: Armature, src_name: str,
|
| 256 |
+
dst_arm: Armature, dst_name: str,
|
| 257 |
+
dst_twist_name: str,
|
| 258 |
+
pole_distance: float,
|
| 259 |
+
) -> None:
|
| 260 |
+
"""
|
| 261 |
+
Port of Blender's SetBonePositionPole().
|
| 262 |
+
Positions an IK pole target relative to source limb geometry.
|
| 263 |
+
"""
|
| 264 |
+
src_bone = src_arm.get_bone(src_name)
|
| 265 |
+
dst_bone = dst_arm.get_bone(dst_name)
|
| 266 |
+
|
| 267 |
+
parent_src = src_bone.parent_recursive[0] if src_bone.parent_recursive else src_bone
|
| 268 |
+
|
| 269 |
+
base_parent_ws = get_bone_position_ws(parent_src, src_arm)
|
| 270 |
+
base_child_ws = get_bone_position_ws(src_bone, src_arm)
|
| 271 |
+
|
| 272 |
+
# Tail = head + Y-axis direction of bone in world space
|
| 273 |
+
src_ws_mat = src_arm.world_matrix @ src_bone.matrix_armature
|
| 274 |
+
tail_ws = src_ws_mat[:3, 3] + src_ws_mat[:3, :3] @ np.array([0.0, 1.0, 0.0])
|
| 275 |
+
|
| 276 |
+
length_parent = np.linalg.norm(base_child_ws - base_parent_ws)
|
| 277 |
+
length_child = np.linalg.norm(tail_ws - base_child_ws)
|
| 278 |
+
total = length_parent + length_child
|
| 279 |
+
|
| 280 |
+
c_p_ratio = length_parent / total if total > 1e-12 else 0.5
|
| 281 |
+
|
| 282 |
+
length_pp_to_tail = np.linalg.norm(base_parent_ws - tail_ws)
|
| 283 |
+
average_location = get_point_on_vector(base_parent_ws, tail_ws, length_pp_to_tail * c_p_ratio)
|
| 284 |
+
|
| 285 |
+
distance = np.linalg.norm(base_child_ws - average_location)
|
| 286 |
+
|
| 287 |
+
if distance > 0.001:
|
| 288 |
+
pole_pos = get_point_on_vector(base_child_ws, average_location, pole_distance)
|
| 289 |
+
set_bone_position_ws(dst_bone, dst_arm, pole_pos)
|
| 290 |
+
parent = dst_bone.parent
|
| 291 |
+
dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ---------------------------------------------------------------------------
|
| 295 |
+
# Scale setter
|
| 296 |
+
# ---------------------------------------------------------------------------
|
| 297 |
+
|
| 298 |
+
def set_bone_scale(
|
| 299 |
+
src_arm: Armature, src_name: str,
|
| 300 |
+
dst_arm: Armature, dst_name: str,
|
| 301 |
+
src_scale_bone_name: str,
|
| 302 |
+
gain: float,
|
| 303 |
+
axis: str,
|
| 304 |
+
max_scale: float,
|
| 305 |
+
min_scale: float,
|
| 306 |
+
) -> None:
|
| 307 |
+
"""
|
| 308 |
+
Port of Blender's SetBoneScale().
|
| 309 |
+
Scales dst bone based on dot product between two source bone quaternions.
|
| 310 |
+
"""
|
| 311 |
+
src_bone = src_arm.get_bone(src_name)
|
| 312 |
+
dst_bone = dst_arm.get_bone(dst_name)
|
| 313 |
+
secondary = src_arm.get_bone(src_scale_bone_name)
|
| 314 |
+
|
| 315 |
+
q1 = get_bone_ws_quat(src_bone, src_arm)
|
| 316 |
+
q2 = get_bone_ws_quat(secondary, src_arm)
|
| 317 |
+
amount = quat_dot(q1, q2) * gain
|
| 318 |
+
|
| 319 |
+
if amount < 0:
|
| 320 |
+
amount = -amount
|
| 321 |
+
amount = max(min_scale, min(max_scale, amount))
|
| 322 |
+
|
| 323 |
+
s = dst_bone.pose_scale
|
| 324 |
+
if axis == "X":
|
| 325 |
+
s[0] = amount
|
| 326 |
+
elif axis == "Y":
|
| 327 |
+
s[1] = amount
|
| 328 |
+
elif axis == "Z":
|
| 329 |
+
s[2] = amount
|
| 330 |
+
elif axis == "XY":
|
| 331 |
+
s[0] = s[1] = amount
|
| 332 |
+
elif axis == "XZ":
|
| 333 |
+
s[0] = s[2] = amount
|
| 334 |
+
elif axis == "YZ":
|
| 335 |
+
s[1] = s[2] = amount
|
| 336 |
+
else: # XYZ
|
| 337 |
+
s[:] = amount
|
| 338 |
+
|
| 339 |
+
parent = dst_bone.parent
|
| 340 |
+
dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
# Correction calculators
|
| 345 |
+
# ---------------------------------------------------------------------------
|
| 346 |
+
|
| 347 |
+
def calc_rotation_offset(
|
| 348 |
+
item: BoneMappingItem,
|
| 349 |
+
src_arm: Armature,
|
| 350 |
+
dst_arm: Armature,
|
| 351 |
+
settings: KeeMapSettings,
|
| 352 |
+
) -> None:
|
| 353 |
+
"""
|
| 354 |
+
Auto-compute the rotation correction factor for one bone mapping.
|
| 355 |
+
Port of Blender's CalcRotationOffset().
|
| 356 |
+
Modifies item.correction_factor and item.quat_correction_factor in-place.
|
| 357 |
+
"""
|
| 358 |
+
if not item.source_bone_name or not item.destination_bone_name:
|
| 359 |
+
return
|
| 360 |
+
if not src_arm.has_bone(item.source_bone_name):
|
| 361 |
+
return
|
| 362 |
+
if not dst_arm.has_bone(item.destination_bone_name):
|
| 363 |
+
return
|
| 364 |
+
|
| 365 |
+
dst_bone = dst_arm.get_bone(item.destination_bone_name)
|
| 366 |
+
|
| 367 |
+
# Snapshot destination bone state
|
| 368 |
+
snap_r = dst_bone.pose_rotation_quat.copy()
|
| 369 |
+
snap_t = dst_bone.pose_location.copy()
|
| 370 |
+
|
| 371 |
+
starting_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
|
| 372 |
+
|
| 373 |
+
# Apply with identity correction
|
| 374 |
+
set_bone_rotation(
|
| 375 |
+
src_arm, item.source_bone_name,
|
| 376 |
+
dst_arm, item.destination_bone_name,
|
| 377 |
+
item.twist_bone_name,
|
| 378 |
+
quat_identity(),
|
| 379 |
+
False,
|
| 380 |
+
item.bone_rotation_application_axis,
|
| 381 |
+
item.bone_transpose_axis,
|
| 382 |
+
settings.bone_rotation_mode,
|
| 383 |
+
)
|
| 384 |
+
dst_arm.update_fk()
|
| 385 |
+
|
| 386 |
+
modified_ws_quat = get_bone_ws_quat(dst_bone, dst_arm)
|
| 387 |
+
|
| 388 |
+
# Correction = rotation that takes modified_ws back to starting_ws
|
| 389 |
+
q_diff = quat_rotation_difference(modified_ws_quat, starting_ws_quat)
|
| 390 |
+
euler = quat_to_euler(q_diff, order="XYZ")
|
| 391 |
+
item.correction_factor = euler.copy()
|
| 392 |
+
item.quat_correction_factor = q_diff.copy()
|
| 393 |
+
|
| 394 |
+
# Restore
|
| 395 |
+
dst_bone.pose_rotation_quat = snap_r
|
| 396 |
+
dst_bone.pose_location = snap_t
|
| 397 |
+
parent = dst_bone.parent
|
| 398 |
+
dst_bone._fk(parent.matrix_armature if parent else np.eye(4))
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def calc_location_offset(
|
| 402 |
+
item: BoneMappingItem,
|
| 403 |
+
src_arm: Armature,
|
| 404 |
+
dst_arm: Armature,
|
| 405 |
+
) -> None:
|
| 406 |
+
"""
|
| 407 |
+
Auto-compute position correction for one bone mapping.
|
| 408 |
+
Port of Blender's CalcLocationOffset().
|
| 409 |
+
"""
|
| 410 |
+
if not item.source_bone_name or not item.destination_bone_name:
|
| 411 |
+
return
|
| 412 |
+
if not src_arm.has_bone(item.source_bone_name):
|
| 413 |
+
return
|
| 414 |
+
if not dst_arm.has_bone(item.destination_bone_name):
|
| 415 |
+
return
|
| 416 |
+
|
| 417 |
+
src_bone = src_arm.get_bone(item.source_bone_name)
|
| 418 |
+
dst_bone = dst_arm.get_bone(item.destination_bone_name)
|
| 419 |
+
|
| 420 |
+
source_ws_pos = get_bone_position_ws(src_bone, src_arm)
|
| 421 |
+
dest_ws_pos = get_bone_position_ws(dst_bone, dst_arm)
|
| 422 |
+
|
| 423 |
+
# Snapshot
|
| 424 |
+
snap_loc = dst_bone.pose_location.copy()
|
| 425 |
+
|
| 426 |
+
# Move dest to source position
|
| 427 |
+
set_bone_position_ws(dst_bone, dst_arm, source_ws_pos)
|
| 428 |
+
dst_arm.update_fk()
|
| 429 |
+
moved_pose_loc = dst_bone.pose_location.copy()
|
| 430 |
+
|
| 431 |
+
# Restore
|
| 432 |
+
set_bone_position_ws(dst_bone, dst_arm, dest_ws_pos)
|
| 433 |
+
dst_arm.update_fk()
|
| 434 |
+
|
| 435 |
+
delta = snap_loc - moved_pose_loc
|
| 436 |
+
item.position_correction_factor = delta.copy()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def calc_all_corrections(
|
| 440 |
+
bone_items: List[BoneMappingItem],
|
| 441 |
+
src_arm: Armature,
|
| 442 |
+
dst_arm: Armature,
|
| 443 |
+
settings: KeeMapSettings,
|
| 444 |
+
) -> None:
|
| 445 |
+
"""Auto-calculate rotation and position corrections for all mapped bones."""
|
| 446 |
+
for item in bone_items:
|
| 447 |
+
calc_rotation_offset(item, src_arm, dst_arm, settings)
|
| 448 |
+
if "pole" not in item.name.lower():
|
| 449 |
+
calc_location_offset(item, src_arm, dst_arm)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ---------------------------------------------------------------------------
|
| 453 |
+
# Single-frame transfer
|
| 454 |
+
# ---------------------------------------------------------------------------
|
| 455 |
+
|
| 456 |
+
def transfer_frame(
|
| 457 |
+
src_arm: Armature,
|
| 458 |
+
dst_arm: Armature,
|
| 459 |
+
bone_items: List[BoneMappingItem],
|
| 460 |
+
settings: KeeMapSettings,
|
| 461 |
+
) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
| 462 |
+
"""
|
| 463 |
+
Apply retargeting for all bone mappings at the current source frame.
|
| 464 |
+
src_arm must already have FK updated for the current frame.
|
| 465 |
+
|
| 466 |
+
Returns a dict of bone_name → (pose_location, pose_rotation_quat, pose_scale)
|
| 467 |
+
suitable for writing into a keyframe list.
|
| 468 |
+
"""
|
| 469 |
+
for item in bone_items:
|
| 470 |
+
if not item.source_bone_name or not item.destination_bone_name:
|
| 471 |
+
continue
|
| 472 |
+
if not src_arm.has_bone(item.source_bone_name):
|
| 473 |
+
continue
|
| 474 |
+
if not dst_arm.has_bone(item.destination_bone_name):
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
# Build correction quaternion
|
| 478 |
+
if settings.bone_rotation_mode == "EULER":
|
| 479 |
+
cf = item.correction_factor
|
| 480 |
+
correction_quat = euler_to_quat(cf[0], cf[1], cf[2], order="XYZ")
|
| 481 |
+
else:
|
| 482 |
+
correction_quat = quat_normalize(item.quat_correction_factor)
|
| 483 |
+
|
| 484 |
+
# Rotation
|
| 485 |
+
if item.set_bone_rotation:
|
| 486 |
+
set_bone_rotation(
|
| 487 |
+
src_arm, item.source_bone_name,
|
| 488 |
+
dst_arm, item.destination_bone_name,
|
| 489 |
+
item.twist_bone_name,
|
| 490 |
+
correction_quat,
|
| 491 |
+
item.has_twist_bone,
|
| 492 |
+
item.bone_rotation_application_axis,
|
| 493 |
+
item.bone_transpose_axis,
|
| 494 |
+
settings.bone_rotation_mode,
|
| 495 |
+
)
|
| 496 |
+
dst_arm.update_fk()
|
| 497 |
+
|
| 498 |
+
# Position
|
| 499 |
+
if item.set_bone_position:
|
| 500 |
+
if item.postion_type == "SINGLE_BONE_OFFSET":
|
| 501 |
+
set_bone_position(
|
| 502 |
+
src_arm, item.source_bone_name,
|
| 503 |
+
dst_arm, item.destination_bone_name,
|
| 504 |
+
item.twist_bone_name,
|
| 505 |
+
item.position_correction_factor,
|
| 506 |
+
item.position_gain,
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
set_bone_position_pole(
|
| 510 |
+
src_arm, item.source_bone_name,
|
| 511 |
+
dst_arm, item.destination_bone_name,
|
| 512 |
+
item.twist_bone_name,
|
| 513 |
+
-item.position_pole_distance,
|
| 514 |
+
)
|
| 515 |
+
dst_arm.update_fk()
|
| 516 |
+
|
| 517 |
+
# Scale
|
| 518 |
+
if item.set_bone_scale and item.scale_secondary_bone_name:
|
| 519 |
+
if src_arm.has_bone(item.scale_secondary_bone_name):
|
| 520 |
+
set_bone_scale(
|
| 521 |
+
src_arm, item.source_bone_name,
|
| 522 |
+
dst_arm, item.destination_bone_name,
|
| 523 |
+
item.scale_secondary_bone_name,
|
| 524 |
+
item.scale_gain,
|
| 525 |
+
item.bone_scale_application_axis,
|
| 526 |
+
item.scale_max,
|
| 527 |
+
item.scale_min,
|
| 528 |
+
)
|
| 529 |
+
dst_arm.update_fk()
|
| 530 |
+
|
| 531 |
+
# Snapshot destination bone state for this frame
|
| 532 |
+
result: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
|
| 533 |
+
for item in bone_items:
|
| 534 |
+
if not item.destination_bone_name:
|
| 535 |
+
continue
|
| 536 |
+
if not dst_arm.has_bone(item.destination_bone_name):
|
| 537 |
+
continue
|
| 538 |
+
dst_bone = dst_arm.get_bone(item.destination_bone_name)
|
| 539 |
+
result[item.destination_bone_name] = (
|
| 540 |
+
dst_bone.pose_location.copy(),
|
| 541 |
+
dst_bone.pose_rotation_quat.copy(),
|
| 542 |
+
dst_bone.pose_scale.copy(),
|
| 543 |
+
)
|
| 544 |
+
return result
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
# ---------------------------------------------------------------------------
|
| 548 |
+
# Full animation transfer
|
| 549 |
+
# ---------------------------------------------------------------------------
|
| 550 |
+
|
| 551 |
+
def transfer_animation(
|
| 552 |
+
src_anim, # BVHAnimation or any object with .armature + .apply_frame(i) + .num_frames
|
| 553 |
+
dst_arm: Armature,
|
| 554 |
+
bone_items: List[BoneMappingItem],
|
| 555 |
+
settings: KeeMapSettings,
|
| 556 |
+
) -> List[Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]]:
|
| 557 |
+
"""
|
| 558 |
+
Transfer all frames from src_anim to dst_arm.
|
| 559 |
+
Returns list of keyframe dicts (one per frame sampled).
|
| 560 |
+
|
| 561 |
+
Equivalent to Blender's PerformAnimationTransfer operator.
|
| 562 |
+
"""
|
| 563 |
+
keyframes: List[Dict] = []
|
| 564 |
+
step = max(1, settings.keyframe_every_n_frames)
|
| 565 |
+
start = settings.start_frame_to_apply
|
| 566 |
+
total = settings.number_of_frames_to_apply
|
| 567 |
+
end = start + total
|
| 568 |
+
|
| 569 |
+
src_arm = src_anim.armature
|
| 570 |
+
|
| 571 |
+
i = start
|
| 572 |
+
n_steps = len(range(start, end, step))
|
| 573 |
+
step_i = 0
|
| 574 |
+
while i < end and i < src_anim.num_frames:
|
| 575 |
+
src_anim.apply_frame(i) # updates src_arm FK
|
| 576 |
+
dst_arm.update_fk()
|
| 577 |
+
|
| 578 |
+
frame_data = transfer_frame(src_arm, dst_arm, bone_items, settings)
|
| 579 |
+
keyframes.append(frame_data)
|
| 580 |
+
|
| 581 |
+
step_i += 1
|
| 582 |
+
_update_progress("Retargeting", step_i / n_steps)
|
| 583 |
+
i += step
|
| 584 |
+
|
| 585 |
+
_update_progress("Retargeting", 1.0)
|
| 586 |
+
return keyframes
|
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
search.py
|
| 3 |
+
Stream TeoGchx/HumanML3D from HuggingFace and match motions by keyword.
|
| 4 |
+
|
| 5 |
+
Dataset: https://huggingface.co/datasets/TeoGchx/HumanML3D
|
| 6 |
+
Format: motion column is [T, 263] inline in parquet (standard HumanML3D)
|
| 7 |
+
Splits: train (23 384), val (1 460), test (4 384)
|
| 8 |
+
|
| 9 |
+
Usage
|
| 10 |
+
-----
|
| 11 |
+
from Retarget.search import search_motions
|
| 12 |
+
|
| 13 |
+
results = search_motions("a person walks forward", top_k=5)
|
| 14 |
+
for r in results:
|
| 15 |
+
print(r["caption"], r["frames"], "frames")
|
| 16 |
+
# r["motion"] → np.ndarray [T, 263]
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
import re
|
| 20 |
+
from typing import List, Optional
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 25 |
+
# Caption cleaning
|
| 26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
_SEP = re.compile(r'#|\|')
|
| 29 |
+
_POS_TAG = re.compile(r'^(?:[A-Z]{1,4}\s*)+$') # lines that look like POS tags
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _clean_caption(raw: str) -> str:
|
| 33 |
+
"""
|
| 34 |
+
HumanML3D captions are stored as multiple sentences joined by '#',
|
| 35 |
+
sometimes followed by POS tag strings. Return the first human-readable
|
| 36 |
+
sentence.
|
| 37 |
+
"""
|
| 38 |
+
parts = _SEP.split(raw)
|
| 39 |
+
for part in parts:
|
| 40 |
+
part = part.strip()
|
| 41 |
+
if not part:
|
| 42 |
+
continue
|
| 43 |
+
words = part.split()
|
| 44 |
+
# Skip if >50 % of tokens look like POS tags (all-caps, ≤4 chars)
|
| 45 |
+
pos_count = sum(1 for w in words if w.isupper() and len(w) <= 4)
|
| 46 |
+
if len(words) > 0 and pos_count / len(words) < 0.5:
|
| 47 |
+
return part
|
| 48 |
+
return parts[0].strip() if parts else raw.strip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 52 |
+
# Search
|
| 53 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 54 |
+
|
| 55 |
+
def search_motions(
|
| 56 |
+
query: str,
|
| 57 |
+
top_k: int = 8,
|
| 58 |
+
split: str = "test",
|
| 59 |
+
max_scan: int = 4384,
|
| 60 |
+
cached: bool = False,
|
| 61 |
+
) -> List[dict]:
|
| 62 |
+
"""
|
| 63 |
+
Stream TeoGchx/HumanML3D and return up to top_k motions matching query.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
query Natural-language description, e.g. "a person walks forward"
|
| 68 |
+
top_k Maximum number of results to return
|
| 69 |
+
split Dataset split — "test" (4 384 rows) is fastest to stream
|
| 70 |
+
max_scan Hard cap on rows examined before returning
|
| 71 |
+
|
| 72 |
+
Returns
|
| 73 |
+
-------
|
| 74 |
+
List of dicts, sorted by relevance score (descending):
|
| 75 |
+
caption str clean human-readable description
|
| 76 |
+
motion np.ndarray shape [T, 263], standard HumanML3D features
|
| 77 |
+
frames int number of frames (T)
|
| 78 |
+
duration float duration in seconds (at 20 fps)
|
| 79 |
+
name str original clip ID from dataset
|
| 80 |
+
score int keyword match score
|
| 81 |
+
"""
|
| 82 |
+
try:
|
| 83 |
+
from datasets import load_dataset
|
| 84 |
+
except ImportError:
|
| 85 |
+
raise ImportError(
|
| 86 |
+
"pip install datasets (HuggingFace datasets library required)"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if cached:
|
| 90 |
+
# Downloads the split once (~400MB) and caches to ~/.cache/huggingface.
|
| 91 |
+
# Subsequent calls are instant. Use for local dev / testing.
|
| 92 |
+
ds = load_dataset("TeoGchx/HumanML3D", split=split)
|
| 93 |
+
else:
|
| 94 |
+
# Streaming: no disk cache, re-downloads each run. Good for server use.
|
| 95 |
+
ds = load_dataset("TeoGchx/HumanML3D", split=split, streaming=True)
|
| 96 |
+
|
| 97 |
+
# Tokenise query; remove punctuation
|
| 98 |
+
query_words = re.sub(r"[^\w\s]", "", query.lower()).split()
|
| 99 |
+
if not query_words:
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
results: List[dict] = []
|
| 103 |
+
scanned = 0
|
| 104 |
+
|
| 105 |
+
for row in ds:
|
| 106 |
+
if scanned >= max_scan:
|
| 107 |
+
break
|
| 108 |
+
scanned += 1
|
| 109 |
+
|
| 110 |
+
caption_raw = row.get("caption", "") or ""
|
| 111 |
+
caption_clean = _clean_caption(caption_raw)
|
| 112 |
+
caption_lower = caption_clean.lower()
|
| 113 |
+
|
| 114 |
+
# Score: word-boundary matches count 2, substring matches count 1
|
| 115 |
+
score = 0
|
| 116 |
+
for kw in query_words:
|
| 117 |
+
if kw in caption_lower:
|
| 118 |
+
if re.search(r"\b" + re.escape(kw) + r"\b", caption_lower):
|
| 119 |
+
score += 2
|
| 120 |
+
else:
|
| 121 |
+
score += 1
|
| 122 |
+
|
| 123 |
+
if score == 0:
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
motion_raw = row.get("motion")
|
| 127 |
+
if motion_raw is None:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
motion = np.array(motion_raw, dtype=np.float32) # [T, 263]
|
| 131 |
+
meta = row.get("meta_data") or {}
|
| 132 |
+
|
| 133 |
+
T = motion.shape[0]
|
| 134 |
+
frames = int(meta.get("num_frames", T))
|
| 135 |
+
duration = float(meta.get("duration", T / 20.0))
|
| 136 |
+
|
| 137 |
+
results.append({
|
| 138 |
+
"caption": caption_clean,
|
| 139 |
+
"motion": motion,
|
| 140 |
+
"frames": frames,
|
| 141 |
+
"duration": duration,
|
| 142 |
+
"name": str(meta.get("name", "")),
|
| 143 |
+
"score": score,
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
# Stop as soon as we have top_k results
|
| 147 |
+
if len(results) >= top_k:
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
results.sort(key=lambda x: -x["score"])
|
| 151 |
+
return results[:top_k]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def format_choice_label(result: dict) -> str:
|
| 155 |
+
"""Short label for Gradio Radio component."""
|
| 156 |
+
caption = result["caption"]
|
| 157 |
+
if len(caption) > 72:
|
| 158 |
+
caption = caption[:72] + "…"
|
| 159 |
+
return f"{caption} ({result['frames']} frames, {result['duration']:.1f}s)"
|
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
skeleton.py
|
| 3 |
+
Pure-Python armature / pose-bone system.
|
| 4 |
+
|
| 5 |
+
Design matches Blender's pose-mode semantics:
|
| 6 |
+
- bone.rest_matrix_local = 4×4 rest pose in parent space (edit-mode)
|
| 7 |
+
- bone.pose_rotation_quat = local rotation DELTA from rest (≡ bone.rotation_quaternion)
|
| 8 |
+
- bone.pose_location = local translation DELTA from rest (≡ bone.location)
|
| 9 |
+
- bone.pose_scale = local scale (≡ bone.scale)
|
| 10 |
+
- bone.matrix_armature = FK-computed 4×4 in armature space (≡ bone.matrix in pose mode)
|
| 11 |
+
|
| 12 |
+
Armature.world_matrix corresponds to arm.matrix_world.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
import numpy as np
|
| 16 |
+
from typing import Dict, List, Optional, Tuple
|
| 17 |
+
from .math3d import (
|
| 18 |
+
quat_identity, quat_normalize, quat_mul,
|
| 19 |
+
quat_to_matrix4, matrix4_to_quat,
|
| 20 |
+
translation_matrix, scale_matrix, trs_to_matrix4, matrix4_to_trs,
|
| 21 |
+
vec3,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PoseBone:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
name: str,
|
| 29 |
+
rest_matrix_local: np.ndarray, # 4×4, in parent local space
|
| 30 |
+
parent: Optional["PoseBone"] = None,
|
| 31 |
+
):
|
| 32 |
+
self.name = name
|
| 33 |
+
self.parent: Optional[PoseBone] = parent
|
| 34 |
+
self.children: List[PoseBone] = []
|
| 35 |
+
self.rest_matrix_local: np.ndarray = rest_matrix_local.copy()
|
| 36 |
+
|
| 37 |
+
# Pose state — start at rest (delta = identity)
|
| 38 |
+
self.pose_rotation_quat: np.ndarray = quat_identity()
|
| 39 |
+
self.pose_location: np.ndarray = vec3()
|
| 40 |
+
self.pose_scale: np.ndarray = np.ones(3)
|
| 41 |
+
|
| 42 |
+
# Cached FK result — call armature.update_fk() to refresh
|
| 43 |
+
self._matrix_armature: np.ndarray = np.eye(4)
|
| 44 |
+
|
| 45 |
+
# -----------------------------------------------------------------------
|
| 46 |
+
# Properties
|
| 47 |
+
# -----------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def matrix_armature(self) -> np.ndarray:
|
| 51 |
+
"""4×4 FK result in armature space. Refresh with armature.update_fk()."""
|
| 52 |
+
return self._matrix_armature
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def head(self) -> np.ndarray:
|
| 56 |
+
"""Bone head position in armature space."""
|
| 57 |
+
return self._matrix_armature[:3, 3].copy()
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def tail(self) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Approximate tail position (Y-axis in bone space, length 1).
|
| 63 |
+
Works for Y-along-bone convention (Blender / BVH default).
|
| 64 |
+
"""
|
| 65 |
+
y_axis = self._matrix_armature[:3, :3] @ np.array([0.0, 1.0, 0.0])
|
| 66 |
+
return self._matrix_armature[:3, 3] + y_axis
|
| 67 |
+
|
| 68 |
+
# -----------------------------------------------------------------------
|
| 69 |
+
# FK
|
| 70 |
+
# -----------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
def _compute_local_matrix(self) -> np.ndarray:
|
| 73 |
+
"""rest_local @ T(pose_loc) @ R(pose_rot) @ S(pose_scale)."""
|
| 74 |
+
T = translation_matrix(self.pose_location)
|
| 75 |
+
R = quat_to_matrix4(self.pose_rotation_quat)
|
| 76 |
+
S = scale_matrix(self.pose_scale)
|
| 77 |
+
return self.rest_matrix_local @ T @ R @ S
|
| 78 |
+
|
| 79 |
+
def _fk(self, parent_matrix: np.ndarray) -> None:
|
| 80 |
+
self._matrix_armature = parent_matrix @ self._compute_local_matrix()
|
| 81 |
+
for child in self.children:
|
| 82 |
+
child._fk(self._matrix_armature)
|
| 83 |
+
|
| 84 |
+
# -----------------------------------------------------------------------
|
| 85 |
+
# Parent-chain helpers (Blender: bone.parent_recursive)
|
| 86 |
+
# -----------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def parent_recursive(self) -> List["PoseBone"]:
|
| 90 |
+
chain: List[PoseBone] = []
|
| 91 |
+
cur = self.parent
|
| 92 |
+
while cur is not None:
|
| 93 |
+
chain.append(cur)
|
| 94 |
+
cur = cur.parent
|
| 95 |
+
return chain
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Armature:
|
| 99 |
+
"""
|
| 100 |
+
Collection of PoseBones with a world transform.
|
| 101 |
+
Corresponds to a Blender armature object.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, name: str = "Armature"):
|
| 105 |
+
self.name = name
|
| 106 |
+
self.world_matrix: np.ndarray = np.eye(4) # arm.matrix_world
|
| 107 |
+
self._bones: Dict[str, PoseBone] = {}
|
| 108 |
+
self._roots: List[PoseBone] = []
|
| 109 |
+
|
| 110 |
+
# -----------------------------------------------------------------------
|
| 111 |
+
# Construction helpers
|
| 112 |
+
# -----------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
def add_bone(self, bone: PoseBone, parent_name: Optional[str] = None) -> PoseBone:
|
| 115 |
+
self._bones[bone.name] = bone
|
| 116 |
+
if parent_name and parent_name in self._bones:
|
| 117 |
+
parent = self._bones[parent_name]
|
| 118 |
+
bone.parent = parent
|
| 119 |
+
parent.children.append(bone)
|
| 120 |
+
elif bone.parent is None:
|
| 121 |
+
self._roots.append(bone)
|
| 122 |
+
return bone
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def pose_bones(self) -> Dict[str, PoseBone]:
|
| 126 |
+
return self._bones
|
| 127 |
+
|
| 128 |
+
def get_bone(self, name: str) -> PoseBone:
|
| 129 |
+
if name not in self._bones:
|
| 130 |
+
raise KeyError(f"Bone '{name}' not found in armature '{self.name}'")
|
| 131 |
+
return self._bones[name]
|
| 132 |
+
|
| 133 |
+
def has_bone(self, name: str) -> bool:
|
| 134 |
+
return name in self._bones
|
| 135 |
+
|
| 136 |
+
# -----------------------------------------------------------------------
|
| 137 |
+
# FK update
|
| 138 |
+
# -----------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def update_fk(self) -> None:
|
| 141 |
+
"""Recompute all bone armature-space matrices via FK."""
|
| 142 |
+
for root in self._roots:
|
| 143 |
+
root._fk(np.eye(4))
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------------------
|
| 146 |
+
# Snapshot / restore (for calc-correction passes)
|
| 147 |
+
# -----------------------------------------------------------------------
|
| 148 |
+
|
| 149 |
+
def snapshot(self) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
| 150 |
+
return {
|
| 151 |
+
name: (
|
| 152 |
+
bone.pose_rotation_quat.copy(),
|
| 153 |
+
bone.pose_location.copy(),
|
| 154 |
+
bone.pose_scale.copy(),
|
| 155 |
+
)
|
| 156 |
+
for name, bone in self._bones.items()
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
def restore(self, snap: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]) -> None:
|
| 160 |
+
for name, (r, t, s) in snap.items():
|
| 161 |
+
if name in self._bones:
|
| 162 |
+
self._bones[name].pose_rotation_quat = r.copy()
|
| 163 |
+
self._bones[name].pose_location = t.copy()
|
| 164 |
+
self._bones[name].pose_scale = s.copy()
|
| 165 |
+
self.update_fk()
|
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
smpl.py
|
| 3 |
+
───────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
Parse HumanML3D [T, 263] feature vectors into structured SMPL motion data.
|
| 5 |
+
|
| 6 |
+
HumanML3D 263-dim layout per frame
|
| 7 |
+
[0] root angular-velocity (Y-axis, rad/frame)
|
| 8 |
+
[1] root height Y (metres)
|
| 9 |
+
[2:4] root XZ velocity (local-frame, metres/frame)
|
| 10 |
+
[4:67] joint local positions joints 1-21 relative to root, 21×3 (unused here)
|
| 11 |
+
[67:193] 6D joint rotations joints 1-21, 21×6
|
| 12 |
+
[193:259] joint velocities joints 0-21, 22×3 (unused here)
|
| 13 |
+
[259:263] foot contact flags (unused here)
|
| 14 |
+
|
| 15 |
+
Root rotation = cumulative integral of dim[0] → Y-axis quaternion.
|
| 16 |
+
Root position = dim[1] (height) + integrated XZ velocity.
|
| 17 |
+
Joint 1-21 rot = dims 67:193 as 6D continuous rotation representation
|
| 18 |
+
[Zhou et al. 2019] → Gram-Schmidt → 3×3 rotation matrix → quaternion.
|
| 19 |
+
These are LOCAL rotations relative to the SMPL parent joint's rest
|
| 20 |
+
frame, where the canonical T-pose is the zero (identity) rotation.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
# 6D rotation helpers
|
| 28 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 29 |
+
|
| 30 |
+
def rot6d_to_matrix(r6d: np.ndarray) -> np.ndarray:
|
| 31 |
+
"""
|
| 32 |
+
[..., 6] → [..., 3, 3]
|
| 33 |
+
Reconstructs a rotation matrix from two columns using Gram-Schmidt.
|
| 34 |
+
The two columns are [a1 = r6d[..., 0:3], a2 = r6d[..., 3:6]].
|
| 35 |
+
"""
|
| 36 |
+
a1 = r6d[..., 0:3].astype(np.float64)
|
| 37 |
+
a2 = r6d[..., 3:6].astype(np.float64)
|
| 38 |
+
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-12)
|
| 39 |
+
b2 = a2 - (b1 * a2).sum(axis=-1, keepdims=True) * b1
|
| 40 |
+
b2 = b2 / (np.linalg.norm(b2, axis=-1, keepdims=True) + 1e-12)
|
| 41 |
+
b3 = np.cross(b1, b2)
|
| 42 |
+
return np.stack([b1, b2, b3], axis=-1) # columns → [..., 3, 3]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def matrix_to_quat(mat: np.ndarray) -> np.ndarray:
|
| 46 |
+
"""
|
| 47 |
+
[..., 3, 3] → [..., 4] WXYZ quaternion, positive-W convention.
|
| 48 |
+
Uses scipy for numerical stability.
|
| 49 |
+
"""
|
| 50 |
+
from scipy.spatial.transform import Rotation
|
| 51 |
+
shape = mat.shape[:-2]
|
| 52 |
+
flat = mat.reshape(-1, 3, 3).astype(np.float64)
|
| 53 |
+
xyzw = Rotation.from_matrix(flat).as_quat() # scipy → XYZW
|
| 54 |
+
wxyz = xyzw[:, [3, 0, 1, 2]].astype(np.float32)
|
| 55 |
+
wxyz[wxyz[:, 0] < 0] *= -1 # positive-W
|
| 56 |
+
return wxyz.reshape(*shape, 4)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rot6d_to_quat(r6d: np.ndarray) -> np.ndarray:
|
| 60 |
+
"""[..., 6] → [..., 4] WXYZ. Convenience: 6D → matrix → quaternion."""
|
| 61 |
+
return matrix_to_quat(rot6d_to_matrix(r6d))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 65 |
+
# Root motion recovery
|
| 66 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 67 |
+
|
| 68 |
+
def _qrot_vec(q: np.ndarray, v: np.ndarray) -> np.ndarray:
|
| 69 |
+
"""Rotate [N, 3] vectors by [N, 4] WXYZ quaternions (batch)."""
|
| 70 |
+
w, x, y, z = q[:, 0:1], q[:, 1:2], q[:, 2:3], q[:, 3:4]
|
| 71 |
+
vx, vy, vz = v[:, 0:1], v[:, 1:2], v[:, 2:3]
|
| 72 |
+
# Rodrigues-style: v + 2w*(q.xyz × v) + 2*(q.xyz × (q.xyz × v))
|
| 73 |
+
tx = 2 * (y * vz - z * vy)
|
| 74 |
+
ty = 2 * (z * vx - x * vz)
|
| 75 |
+
tz = 2 * (x * vy - y * vx)
|
| 76 |
+
return np.concatenate([
|
| 77 |
+
vx + w * tx + y * tz - z * ty,
|
| 78 |
+
vy + w * ty + z * tx - x * tz,
|
| 79 |
+
vz + w * tz + x * ty - y * tx,
|
| 80 |
+
], axis=-1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def recover_root_motion(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 84 |
+
"""
|
| 85 |
+
Recover root world-space position and rotation from [T, 263] features.
|
| 86 |
+
|
| 87 |
+
Returns
|
| 88 |
+
-------
|
| 89 |
+
root_pos : [T, 3] world-space root position (Y = height above ground)
|
| 90 |
+
root_rot : [T, 4] WXYZ quaternion — Y-axis only (global facing direction)
|
| 91 |
+
"""
|
| 92 |
+
T = data.shape[0]
|
| 93 |
+
|
| 94 |
+
# Facing direction: integrate Y-axis angular velocity
|
| 95 |
+
theta = np.cumsum(data[:, 0].astype(np.float32))
|
| 96 |
+
half = theta * 0.5
|
| 97 |
+
root_rot = np.zeros((T, 4), dtype=np.float32)
|
| 98 |
+
root_rot[:, 0] = np.cos(half)
|
| 99 |
+
root_rot[:, 2] = np.sin(half)
|
| 100 |
+
|
| 101 |
+
# XZ velocity encoded in root-local frame → world frame
|
| 102 |
+
vel_local = np.stack([
|
| 103 |
+
data[:, 2].astype(np.float32),
|
| 104 |
+
np.zeros(T, dtype=np.float32),
|
| 105 |
+
data[:, 3].astype(np.float32),
|
| 106 |
+
], axis=-1)
|
| 107 |
+
vel_world = _qrot_vec(root_rot, vel_local)
|
| 108 |
+
|
| 109 |
+
root_pos = np.zeros((T, 3), dtype=np.float32)
|
| 110 |
+
root_pos[:, 0] = np.cumsum(vel_world[:, 0])
|
| 111 |
+
root_pos[:, 1] = data[:, 1]
|
| 112 |
+
root_pos[:, 2] = np.cumsum(vel_world[:, 2])
|
| 113 |
+
|
| 114 |
+
return root_pos, root_rot
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 118 |
+
# SMPLMotion container
|
| 119 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 120 |
+
|
| 121 |
+
class SMPLMotion:
|
| 122 |
+
"""
|
| 123 |
+
Structured SMPL motion data parsed from a single HumanML3D clip.
|
| 124 |
+
|
| 125 |
+
Attributes
|
| 126 |
+
----------
|
| 127 |
+
root_pos : [T, 3] world-space root position (metres)
|
| 128 |
+
root_rot : [T, 4] WXYZ root Y-axis rotation (global facing)
|
| 129 |
+
local_rot : [T, 21, 4] WXYZ local quaternions for joints 1-21
|
| 130 |
+
T-pose = identity; relative to SMPL parent frame
|
| 131 |
+
fps : float capture frame rate (20 for HumanML3D)
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
root_pos: np.ndarray,
|
| 137 |
+
root_rot: np.ndarray,
|
| 138 |
+
local_rot: np.ndarray,
|
| 139 |
+
fps: float = 20.0,
|
| 140 |
+
):
|
| 141 |
+
self.root_pos = np.asarray(root_pos, dtype=np.float32)
|
| 142 |
+
self.root_rot = np.asarray(root_rot, dtype=np.float32)
|
| 143 |
+
self.local_rot = np.asarray(local_rot, dtype=np.float32)
|
| 144 |
+
self.fps = float(fps)
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def num_frames(self) -> int:
|
| 148 |
+
return self.root_pos.shape[0]
|
| 149 |
+
|
| 150 |
+
def slice(self, start: int = 0, end: int = -1) -> "SMPLMotion":
|
| 151 |
+
e = end if end > 0 else self.num_frames
|
| 152 |
+
return SMPLMotion(
|
| 153 |
+
self.root_pos[start:e],
|
| 154 |
+
self.root_rot[start:e],
|
| 155 |
+
self.local_rot[start:e],
|
| 156 |
+
self.fps,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def hml3d_to_smpl_motion(data: np.ndarray, fps: float = 20.0) -> SMPLMotion:
|
| 161 |
+
"""
|
| 162 |
+
Convert HumanML3D [T, 263] feature array to a SMPLMotion.
|
| 163 |
+
|
| 164 |
+
Uses the actual 6D rotation data (dims 67:193) — NOT position-derived
|
| 165 |
+
rotations. This preserves twist and gives physically correct limb poses.
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
data : [T, 263] raw HumanML3D features (e.g. from MoMask or dataset row)
|
| 170 |
+
fps : float frame rate (default 20 = HumanML3D native)
|
| 171 |
+
"""
|
| 172 |
+
data = np.asarray(data, dtype=np.float32)
|
| 173 |
+
if data.ndim != 2 or data.shape[1] < 193:
|
| 174 |
+
raise ValueError(f"Expected [T, >=193] but got {data.shape}")
|
| 175 |
+
|
| 176 |
+
T = data.shape[0]
|
| 177 |
+
|
| 178 |
+
root_pos, root_rot = recover_root_motion(data)
|
| 179 |
+
|
| 180 |
+
# 6D rotations for joints 1-21: dims [67:193] → [T, 21, 6]
|
| 181 |
+
r6d = data[:, 67:193].reshape(T, 21, 6)
|
| 182 |
+
local_rot = rot6d_to_quat(r6d) # [T, 21, 4] WXYZ
|
| 183 |
+
|
| 184 |
+
return SMPLMotion(root_pos, root_rot, local_rot, fps)
|
|
@@ -6,6 +6,7 @@ import shutil
|
|
| 6 |
import traceback
|
| 7 |
import json
|
| 8 |
import random
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
# ── ZeroGPU: install packages that can't be built at Docker build time ─────────
|
|
@@ -130,8 +131,13 @@ _triposg_pipe = None
|
|
| 130 |
_rmbg_net = None
|
| 131 |
_rmbg_version = None
|
| 132 |
_last_glb_path = None
|
|
|
|
|
|
|
|
|
|
| 133 |
_init_seed = random.randint(0, 2**31 - 1)
|
| 134 |
|
|
|
|
|
|
|
| 135 |
ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
| 136 |
[41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
|
| 137 |
* (256 / 112) + (256 - 112 * (256 / 112)) / 2)
|
|
@@ -167,6 +173,104 @@ def _ensure_ckpts():
|
|
| 167 |
|
| 168 |
# ── Model loaders ─────────────────────────────────────────────────────────────
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def load_triposg():
|
| 171 |
global _triposg_pipe, _rmbg_net, _rmbg_version
|
| 172 |
if _triposg_pipe is not None:
|
|
@@ -309,6 +413,197 @@ def load_triposg():
|
|
| 309 |
return _triposg_pipe, _rmbg_net
|
| 310 |
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
# ── Background removal helper ─────────────────────────────────────────────────
|
| 313 |
|
| 314 |
def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
|
|
@@ -343,8 +638,8 @@ def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
|
|
| 343 |
|
| 344 |
rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
|
| 345 |
alpha = mask[:, :, np.newaxis]
|
| 346 |
-
comp = (rgb * alpha + 0.5 * (1.0 - alpha) * 255
|
| 347 |
-
return Image.fromarray(comp)
|
| 348 |
|
| 349 |
|
| 350 |
def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
|
|
@@ -357,6 +652,188 @@ def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
|
|
| 357 |
return input_image
|
| 358 |
|
| 359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
# ── Stage 1: Shape generation ─────────────────────────────────────────────────
|
| 361 |
|
| 362 |
@spaces.GPU(duration=180)
|
|
@@ -365,6 +842,28 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
|
|
| 365 |
if input_image is None:
|
| 366 |
return None, "Please upload an image."
|
| 367 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
progress(0.1, desc="Loading TripoSG...")
|
| 369 |
pipe, rmbg_net = load_triposg()
|
| 370 |
|
|
@@ -373,16 +872,17 @@ def generate_shape(input_image, remove_background, num_steps, guidance_scale,
|
|
| 373 |
img.save(img_path)
|
| 374 |
|
| 375 |
progress(0.5, desc="Generating shape (SDF diffusion)...")
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
| 386 |
|
| 387 |
out_path = "/tmp/triposg_shape.glb"
|
| 388 |
mesh.export(out_path)
|
|
@@ -609,7 +1109,7 @@ def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames,
|
|
| 609 |
animated = mdm_result.get("animated_glb")
|
| 610 |
|
| 611 |
parts = ["Rigged: " + os.path.basename(rigged)]
|
| 612 |
-
if fbx:
|
| 613 |
if animated: parts.append("Animation: " + os.path.basename(animated))
|
| 614 |
|
| 615 |
torch.cuda.empty_cache()
|
|
@@ -633,6 +1133,7 @@ def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength,
|
|
| 633 |
from pipeline.enhance_surface import (
|
| 634 |
run_stable_normal, run_depth_anything,
|
| 635 |
bake_normal_into_glb, bake_depth_as_occlusion,
|
|
|
|
| 636 |
)
|
| 637 |
import pipeline.enhance_surface as _enh_mod
|
| 638 |
|
|
@@ -704,18 +1205,271 @@ def render_views(glb_file):
|
|
| 704 |
return []
|
| 705 |
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
# ── Full pipeline ─────────────────────────────────────────────────────────────
|
| 708 |
|
| 709 |
-
def run_full_pipeline(input_image, num_steps, guidance, seed, face_count,
|
| 710 |
-
variant, tex_seed, enhance_face,
|
| 711 |
export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
|
| 712 |
progress(0.0, desc="Stage 1/3: Generating shape...")
|
| 713 |
-
glb, status = generate_shape(input_image,
|
| 714 |
if not glb:
|
| 715 |
return None, None, None, None, None, None, status
|
| 716 |
|
| 717 |
progress(0.33, desc="Stage 2/3: Applying texture...")
|
| 718 |
-
glb, mv_img, status = apply_texture(glb, input_image,
|
|
|
|
| 719 |
if not glb:
|
| 720 |
return None, None, None, None, None, None, status
|
| 721 |
|
|
@@ -727,17 +1481,61 @@ def run_full_pipeline(input_image, num_steps, guidance, seed, face_count,
|
|
| 727 |
|
| 728 |
|
| 729 |
# ── UI ────────────────────────────────────────────────────────────────────────
|
| 730 |
-
with gr.Blocks(title="Image2Model") as demo:
|
| 731 |
gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh")
|
| 732 |
-
glb_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
|
| 734 |
-
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
# ════════════════════════════════════════════════════════════════════
|
| 737 |
-
with gr.Tab("Generate"):
|
| 738 |
with gr.Row():
|
| 739 |
with gr.Column(scale=1):
|
| 740 |
-
input_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
with gr.Accordion("Shape Settings", open=True):
|
| 743 |
num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
|
|
@@ -756,9 +1554,11 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 756 |
shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
|
| 757 |
texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
|
| 758 |
render_btn = gr.Button("Render Views", variant="secondary", scale=1)
|
| 759 |
-
run_all_btn = gr.Button("▶ Run Full Pipeline", variant="primary", interactive=False)
|
| 760 |
|
| 761 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
| 762 |
status = gr.Textbox(label="Status", lines=3, interactive=False)
|
| 763 |
model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
|
| 764 |
download_file = gr.File(label="Download GLB")
|
|
@@ -766,6 +1566,8 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 766 |
|
| 767 |
render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
|
| 768 |
|
|
|
|
|
|
|
| 769 |
_pipeline_btns = [shape_btn, run_all_btn]
|
| 770 |
|
| 771 |
input_image.upload(
|
|
@@ -777,9 +1579,14 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 777 |
inputs=[], outputs=_pipeline_btns,
|
| 778 |
)
|
| 779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
shape_btn.click(
|
| 781 |
-
fn=
|
| 782 |
-
inputs=[input_image, num_steps, guidance, seed, face_count],
|
| 783 |
outputs=[glb_state, status],
|
| 784 |
).then(
|
| 785 |
fn=lambda p: (p, p) if p else (None, None),
|
|
@@ -787,8 +1594,9 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 787 |
)
|
| 788 |
|
| 789 |
texture_btn.click(
|
| 790 |
-
fn=
|
| 791 |
-
inputs=[glb_state, input_image, variant, tex_seed,
|
|
|
|
| 792 |
outputs=[glb_state, multiview_img, status],
|
| 793 |
).then(
|
| 794 |
fn=lambda p: (p, p) if p else (None, None),
|
|
@@ -797,6 +1605,29 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 797 |
|
| 798 |
render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
|
| 799 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
# ════════════════════════════════════════════════════════════════════
|
| 801 |
with gr.Tab("Rig & Export"):
|
| 802 |
with gr.Row():
|
|
@@ -844,6 +1675,9 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 844 |
inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
|
| 845 |
outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
|
| 846 |
rig_model_3d, rigged_base_state, skel_glb_state],
|
|
|
|
|
|
|
|
|
|
| 847 |
)
|
| 848 |
|
| 849 |
show_skel_check.change(
|
|
@@ -852,6 +1686,103 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 852 |
outputs=[rig_model_3d],
|
| 853 |
)
|
| 854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
# ════════════════════════════════════════════════════════════════════
|
| 856 |
with gr.Tab("Enhancement"):
|
| 857 |
gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.")
|
|
@@ -868,6 +1799,7 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 868 |
displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
|
| 869 |
|
| 870 |
enhance_btn = gr.Button("Run Enhancement", variant="primary")
|
|
|
|
| 871 |
|
| 872 |
with gr.Column(scale=2):
|
| 873 |
enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
|
|
@@ -886,12 +1818,111 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 886 |
enhanced_glb_dl, enhanced_model_3d, enhance_status],
|
| 887 |
)
|
| 888 |
|
| 889 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
run_all_btn.click(
|
| 891 |
fn=run_full_pipeline,
|
| 892 |
inputs=[
|
| 893 |
-
input_image, num_steps, guidance, seed, face_count,
|
| 894 |
-
variant, tex_seed, enhance_face_check,
|
| 895 |
export_fbx_check, mdm_prompt_box, mdm_frames_slider,
|
| 896 |
],
|
| 897 |
outputs=[glb_state, download_file, multiview_img,
|
|
@@ -901,6 +1932,23 @@ with gr.Blocks(title="Image2Model") as demo:
|
|
| 901 |
inputs=[glb_state], outputs=[model_3d, download_file],
|
| 902 |
)
|
| 903 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
if __name__ == "__main__":
|
| 906 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft()
|
|
|
|
|
|
| 6 |
import traceback
|
| 7 |
import json
|
| 8 |
import random
|
| 9 |
+
import threading
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
# ── ZeroGPU: install packages that can't be built at Docker build time ─────────
|
|
|
|
| 131 |
_rmbg_net = None
|
| 132 |
_rmbg_version = None
|
| 133 |
_last_glb_path = None
|
| 134 |
+
_hyperswap_sess = None
|
| 135 |
+
_gfpgan_restorer = None
|
| 136 |
+
_firered_pipe = None
|
| 137 |
_init_seed = random.randint(0, 2**31 - 1)
|
| 138 |
|
| 139 |
+
_model_load_lock = threading.Lock()
|
| 140 |
+
|
| 141 |
ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
| 142 |
[41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
|
| 143 |
* (256 / 112) + (256 - 112 * (256 / 112)) / 2)
|
|
|
|
| 173 |
|
| 174 |
# ── Model loaders ─────────────────────────────────────────────────────────────
|
| 175 |
|
| 176 |
+
def _load_rmbg():
|
| 177 |
+
"""Load RMBG-2.0 from 1038lab mirror."""
|
| 178 |
+
global _rmbg_net, _rmbg_version
|
| 179 |
+
if _rmbg_net is not None:
|
| 180 |
+
return
|
| 181 |
+
try:
|
| 182 |
+
from transformers import AutoModelForImageSegmentation
|
| 183 |
+
from torch.overrides import TorchFunctionMode
|
| 184 |
+
|
| 185 |
+
class _NoMetaMode(TorchFunctionMode):
|
| 186 |
+
"""Intercept device='meta' tensor construction and redirect to CPU.
|
| 187 |
+
|
| 188 |
+
init_empty_weights() inside from_pretrained pushes a meta DeviceContext
|
| 189 |
+
ON TOP of any torch.device("cpu") wrapper, so meta wins. This mode is
|
| 190 |
+
pushed BELOW it; when meta DeviceContext adds device='meta' and chains
|
| 191 |
+
down the stack, we see it here and flip it back to 'cpu'.
|
| 192 |
+
"""
|
| 193 |
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
| 194 |
+
if kwargs is None:
|
| 195 |
+
kwargs = {}
|
| 196 |
+
dev = kwargs.get("device")
|
| 197 |
+
if dev is not None:
|
| 198 |
+
dev_str = dev.type if isinstance(dev, torch.device) else str(dev)
|
| 199 |
+
if dev_str == "meta":
|
| 200 |
+
kwargs["device"] = "cpu"
|
| 201 |
+
return func(*args, **kwargs)
|
| 202 |
+
|
| 203 |
+
# transformers 5.x _finalize_model_loading calls mark_tied_weights_as_initialized
|
| 204 |
+
# which accesses all_tied_weights_keys. BiRefNetConfig inherits from the old
|
| 205 |
+
# PretrainedConfig alias which skips the new PreTrainedModel.__init__ section
|
| 206 |
+
# that sets this attribute. Patch the method to be safe.
|
| 207 |
+
from transformers import PreTrainedModel as _PTM
|
| 208 |
+
_orig_mark_tied = _PTM.mark_tied_weights_as_initialized
|
| 209 |
+
def _safe_mark_tied(self, loading_info):
|
| 210 |
+
if not hasattr(self, "all_tied_weights_keys"):
|
| 211 |
+
self.all_tied_weights_keys = {}
|
| 212 |
+
return _orig_mark_tied(self, loading_info)
|
| 213 |
+
_PTM.mark_tied_weights_as_initialized = _safe_mark_tied
|
| 214 |
+
try:
|
| 215 |
+
with _NoMetaMode():
|
| 216 |
+
_rmbg_net = AutoModelForImageSegmentation.from_pretrained(
|
| 217 |
+
"1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False,
|
| 218 |
+
)
|
| 219 |
+
finally:
|
| 220 |
+
_PTM.mark_tied_weights_as_initialized = _orig_mark_tied
|
| 221 |
+
_rmbg_net.to(DEVICE).eval()
|
| 222 |
+
_rmbg_version = "2.0"
|
| 223 |
+
print("RMBG-2.0 loaded.")
|
| 224 |
+
except Exception as e:
|
| 225 |
+
_rmbg_net = None
|
| 226 |
+
_rmbg_version = None
|
| 227 |
+
print(f"RMBG-2.0 failed: {e} — background removal disabled.")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def load_rmbg_only():
|
| 231 |
+
"""Load RMBG standalone without loading TripoSG."""
|
| 232 |
+
_load_rmbg()
|
| 233 |
+
return _rmbg_net
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_gfpgan():
|
| 237 |
+
global _gfpgan_restorer
|
| 238 |
+
if _gfpgan_restorer is not None:
|
| 239 |
+
return _gfpgan_restorer
|
| 240 |
+
try:
|
| 241 |
+
from gfpgan import GFPGANer
|
| 242 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 243 |
+
from realesrgan import RealESRGANer
|
| 244 |
+
|
| 245 |
+
model_path = str(CKPT_DIR / "GFPGANv1.4.pth")
|
| 246 |
+
if not os.path.exists(model_path):
|
| 247 |
+
print(f"[GFPGAN] Not found at {model_path}")
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
realesrgan_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
|
| 251 |
+
bg_upsampler = None
|
| 252 |
+
if os.path.exists(realesrgan_path):
|
| 253 |
+
bg_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
| 254 |
+
num_block=23, num_grow_ch=32, scale=2)
|
| 255 |
+
bg_upsampler = RealESRGANer(
|
| 256 |
+
scale=2, model_path=realesrgan_path, model=bg_model,
|
| 257 |
+
tile=400, tile_pad=10, pre_pad=0, half=True,
|
| 258 |
+
)
|
| 259 |
+
print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded")
|
| 260 |
+
else:
|
| 261 |
+
print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler")
|
| 262 |
+
|
| 263 |
+
_gfpgan_restorer = GFPGANer(
|
| 264 |
+
model_path=model_path, upscale=2, arch="clean",
|
| 265 |
+
channel_multiplier=2, bg_upsampler=bg_upsampler,
|
| 266 |
+
)
|
| 267 |
+
print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)")
|
| 268 |
+
return _gfpgan_restorer
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"[GFPGAN] Load failed: {e}")
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
|
| 274 |
def load_triposg():
|
| 275 |
global _triposg_pipe, _rmbg_net, _rmbg_version
|
| 276 |
if _triposg_pipe is not None:
|
|
|
|
| 413 |
return _triposg_pipe, _rmbg_net
|
| 414 |
|
| 415 |
|
| 416 |
+
def load_firered():
|
| 417 |
+
"""Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer.
|
| 418 |
+
|
| 419 |
+
Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk).
|
| 420 |
+
Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model).
|
| 421 |
+
Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails.
|
| 422 |
+
|
| 423 |
+
text_encoder: 4-bit NF4 on GPU (~5.6 GB).
|
| 424 |
+
GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget.
|
| 425 |
+
Lightning scheduler: 4 steps, CFG 1.0 → ~1-2 min per inference.
|
| 426 |
+
|
| 427 |
+
GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE ≈ 24 GB.
|
| 428 |
+
"""
|
| 429 |
+
global _firered_pipe
|
| 430 |
+
if _firered_pipe is not None:
|
| 431 |
+
return _firered_pipe
|
| 432 |
+
|
| 433 |
+
import math as _math
|
| 434 |
+
from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
|
| 435 |
+
from diffusers.models import QwenImageTransformer2DModel
|
| 436 |
+
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
|
| 437 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
| 438 |
+
from huggingface_hub import hf_hub_download
|
| 439 |
+
|
| 440 |
+
# Patch SDPA to cast K/V to match Q dtype.
|
| 441 |
+
import torch.nn.functional as _F
|
| 442 |
+
_orig_sdpa = _F.scaled_dot_product_attention
|
| 443 |
+
def _dtype_safe_sdpa(query, key, value, *a, **kw):
|
| 444 |
+
if key.dtype != query.dtype: key = key.to(query.dtype)
|
| 445 |
+
if value.dtype != query.dtype: value = value.to(query.dtype)
|
| 446 |
+
return _orig_sdpa(query, key, value, *a, **kw)
|
| 447 |
+
_F.scaled_dot_product_attention = _dtype_safe_sdpa
|
| 448 |
+
|
| 449 |
+
torch.cuda.empty_cache()
|
| 450 |
+
|
| 451 |
+
# Load RMBG NOW — before dispatch_model creates meta tensors that poison later loads
|
| 452 |
+
_load_rmbg()
|
| 453 |
+
|
| 454 |
+
gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16)
|
| 455 |
+
|
| 456 |
+
# ── Transformer: GGUF Q4_K_M — try fine-tuned Rapid-AIO first, fall back to base ──
|
| 457 |
+
transformer = None
|
| 458 |
+
|
| 459 |
+
# Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB)
|
| 460 |
+
try:
|
| 461 |
+
print("[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)...")
|
| 462 |
+
gguf_path = hf_hub_download(
|
| 463 |
+
repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF",
|
| 464 |
+
filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf",
|
| 465 |
+
)
|
| 466 |
+
print("[FireRed] Loading Rapid-AIO transformer from GGUF...")
|
| 467 |
+
transformer = QwenImageTransformer2DModel.from_single_file(
|
| 468 |
+
gguf_path,
|
| 469 |
+
quantization_config=gguf_config,
|
| 470 |
+
torch_dtype=torch.bfloat16,
|
| 471 |
+
config="Qwen/Qwen-Image-Edit-2511",
|
| 472 |
+
subfolder="transformer",
|
| 473 |
+
)
|
| 474 |
+
print("[FireRed] Rapid-AIO GGUF transformer loaded OK.")
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF...")
|
| 477 |
+
transformer = None
|
| 478 |
+
|
| 479 |
+
# Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB)
|
| 480 |
+
if transformer is None:
|
| 481 |
+
print("[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)...")
|
| 482 |
+
gguf_path = hf_hub_download(
|
| 483 |
+
repo_id="unsloth/Qwen-Image-Edit-2511-GGUF",
|
| 484 |
+
filename="qwen-image-edit-2511-Q4_K_M.gguf",
|
| 485 |
+
)
|
| 486 |
+
print("[FireRed] Loading base transformer from GGUF...")
|
| 487 |
+
transformer = QwenImageTransformer2DModel.from_single_file(
|
| 488 |
+
gguf_path,
|
| 489 |
+
quantization_config=gguf_config,
|
| 490 |
+
torch_dtype=torch.bfloat16,
|
| 491 |
+
config="Qwen/Qwen-Image-Edit-2511",
|
| 492 |
+
subfolder="transformer",
|
| 493 |
+
)
|
| 494 |
+
print("[FireRed] Base GGUF transformer loaded OK.")
|
| 495 |
+
|
| 496 |
+
print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...")
|
| 497 |
+
device_map = infer_auto_device_map(
|
| 498 |
+
transformer,
|
| 499 |
+
max_memory={0: "18GiB", "cpu": "90GiB"},
|
| 500 |
+
dtype=torch.bfloat16,
|
| 501 |
+
)
|
| 502 |
+
n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0"))
|
| 503 |
+
n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu")
|
| 504 |
+
print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU")
|
| 505 |
+
transformer = dispatch_model(transformer, device_map=device_map)
|
| 506 |
+
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
|
| 507 |
+
print(f"[FireRed] Transformer dispatched — VRAM: {used_mb} MB")
|
| 508 |
+
|
| 509 |
+
# ── text_encoder: 4-bit NF4 on GPU (~5.6 GB) ──────────────────────────────
|
| 510 |
+
bnb_enc = BitsAndBytesConfig(
|
| 511 |
+
load_in_4bit=True,
|
| 512 |
+
bnb_4bit_quant_type="nf4",
|
| 513 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 514 |
+
bnb_4bit_use_double_quant=True,
|
| 515 |
+
)
|
| 516 |
+
print("[FireRed] Loading text_encoder (4-bit NF4)...")
|
| 517 |
+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 518 |
+
"Qwen/Qwen-Image-Edit-2511",
|
| 519 |
+
subfolder="text_encoder",
|
| 520 |
+
quantization_config=bnb_enc,
|
| 521 |
+
device_map="auto",
|
| 522 |
+
)
|
| 523 |
+
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
|
| 524 |
+
print(f"[FireRed] Text encoder loaded — VRAM: {used_mb} MB")
|
| 525 |
+
|
| 526 |
+
# ── Pipeline: VAE + scheduler + processor + tokenizer ─────────────────────
|
| 527 |
+
print("[FireRed] Loading pipeline...")
|
| 528 |
+
_firered_pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 529 |
+
"Qwen/Qwen-Image-Edit-2511",
|
| 530 |
+
transformer=transformer,
|
| 531 |
+
text_encoder=text_encoder,
|
| 532 |
+
torch_dtype=torch.bfloat16,
|
| 533 |
+
)
|
| 534 |
+
_firered_pipe.vae.to(DEVICE)
|
| 535 |
+
|
| 536 |
+
# Lightning scheduler — 4 steps, use_dynamic_shifting, matches reference space config
|
| 537 |
+
_firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config({
|
| 538 |
+
"base_image_seq_len": 256,
|
| 539 |
+
"base_shift": _math.log(3),
|
| 540 |
+
"max_image_seq_len": 8192,
|
| 541 |
+
"max_shift": _math.log(3),
|
| 542 |
+
"num_train_timesteps": 1000,
|
| 543 |
+
"shift": 1.0,
|
| 544 |
+
"time_shift_type": "exponential",
|
| 545 |
+
"use_dynamic_shifting": True,
|
| 546 |
+
})
|
| 547 |
+
|
| 548 |
+
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
|
| 549 |
+
print(f"[FireRed] Pipeline ready — total VRAM: {used_mb} MB")
|
| 550 |
+
return _firered_pipe
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _gallery_to_pil_list(gallery_value):
|
| 554 |
+
"""Convert a Gradio Gallery value (list of various formats) to a list of PIL Images."""
|
| 555 |
+
pil_images = []
|
| 556 |
+
if not gallery_value:
|
| 557 |
+
return pil_images
|
| 558 |
+
for item in gallery_value:
|
| 559 |
+
try:
|
| 560 |
+
if isinstance(item, np.ndarray):
|
| 561 |
+
pil_images.append(Image.fromarray(item).convert("RGB"))
|
| 562 |
+
continue
|
| 563 |
+
if isinstance(item, Image.Image):
|
| 564 |
+
pil_images.append(item.convert("RGB"))
|
| 565 |
+
continue
|
| 566 |
+
# Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...}
|
| 567 |
+
if isinstance(item, dict):
|
| 568 |
+
img_data = item.get("image") or item
|
| 569 |
+
if isinstance(img_data, dict):
|
| 570 |
+
path = img_data.get("path") or img_data.get("url") or img_data.get("name")
|
| 571 |
+
else:
|
| 572 |
+
path = img_data
|
| 573 |
+
elif isinstance(item, (list, tuple)):
|
| 574 |
+
path = item[0]
|
| 575 |
+
else:
|
| 576 |
+
path = item
|
| 577 |
+
if path and os.path.exists(str(path)):
|
| 578 |
+
pil_images.append(Image.open(str(path)).convert("RGB"))
|
| 579 |
+
except Exception as e:
|
| 580 |
+
print(f"[FireRed] Could not load gallery image: {e}")
|
| 581 |
+
return pil_images
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def _firered_resize(img):
|
| 585 |
+
"""Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8."""
|
| 586 |
+
w, h = img.size
|
| 587 |
+
if max(w, h) > 1024:
|
| 588 |
+
if w > h:
|
| 589 |
+
nw, nh = 1024, int(1024 * h / w)
|
| 590 |
+
else:
|
| 591 |
+
nw, nh = int(1024 * w / h), 1024
|
| 592 |
+
else:
|
| 593 |
+
nw, nh = w, h
|
| 594 |
+
nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8)
|
| 595 |
+
if (nw, nh) != (w, h):
|
| 596 |
+
img = img.resize((nw, nh), Image.LANCZOS)
|
| 597 |
+
return img
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
_FIRERED_NEGATIVE = (
|
| 601 |
+
"worst quality, low quality, bad anatomy, bad hands, text, error, "
|
| 602 |
+
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
|
| 603 |
+
"signature, watermark, username, blurry"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
# ── Background removal helper ─────────────────────────────────────────────────
|
| 608 |
|
| 609 |
def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
|
|
|
|
| 638 |
|
| 639 |
rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
|
| 640 |
alpha = mask[:, :, np.newaxis]
|
| 641 |
+
comp = (rgb * alpha + 0.5 * (1.0 - alpha)) * 255
|
| 642 |
+
return Image.fromarray(comp.clip(0, 255).astype(np.uint8))
|
| 643 |
|
| 644 |
|
| 645 |
def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
|
|
|
|
| 652 |
return input_image
|
| 653 |
|
| 654 |
|
| 655 |
+
# ── RealESRGAN helpers ─────────────────────────────────────────────────────────
|
| 656 |
+
|
| 657 |
+
def _load_realesrgan(scale: int = 4):
|
| 658 |
+
"""Load RealESRGAN upsampler. Returns RealESRGANer or None."""
|
| 659 |
+
try:
|
| 660 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 661 |
+
from realesrgan import RealESRGANer
|
| 662 |
+
if scale == 4:
|
| 663 |
+
model_path = str(CKPT_DIR / "RealESRGAN_x4plus.pth")
|
| 664 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
| 665 |
+
else:
|
| 666 |
+
model_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
|
| 667 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
| 668 |
+
if not os.path.exists(model_path):
|
| 669 |
+
print(f"[RealESRGAN] {model_path} not found")
|
| 670 |
+
return None
|
| 671 |
+
upsampler = RealESRGANer(
|
| 672 |
+
scale=scale, model_path=model_path, model=model,
|
| 673 |
+
tile=512, tile_pad=32, pre_pad=0, half=True,
|
| 674 |
+
)
|
| 675 |
+
print(f"[RealESRGAN] Loaded x{scale}plus")
|
| 676 |
+
return upsampler
|
| 677 |
+
except Exception as e:
|
| 678 |
+
print(f"[RealESRGAN] Load failed: {e}")
|
| 679 |
+
return None
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def _enhance_glb_texture(glb_path: str) -> bool:
|
| 683 |
+
"""
|
| 684 |
+
Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4,
|
| 685 |
+
downscale back to original resolution (sharper detail), then repack in-place.
|
| 686 |
+
Returns True if enhancement was applied.
|
| 687 |
+
"""
|
| 688 |
+
import pygltflib
|
| 689 |
+
|
| 690 |
+
upsampler = _load_realesrgan(scale=4)
|
| 691 |
+
if upsampler is None:
|
| 692 |
+
upsampler = _load_realesrgan(scale=2)
|
| 693 |
+
if upsampler is None:
|
| 694 |
+
print("[enhance_glb] No RealESRGAN checkpoint available")
|
| 695 |
+
return False
|
| 696 |
+
|
| 697 |
+
glb = pygltflib.GLTF2().load(glb_path)
|
| 698 |
+
blob = bytearray(glb.binary_blob() or b"")
|
| 699 |
+
|
| 700 |
+
for mat in glb.materials:
|
| 701 |
+
bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None)
|
| 702 |
+
if bct is None:
|
| 703 |
+
continue
|
| 704 |
+
tex = glb.textures[bct.index]
|
| 705 |
+
if tex.source is None:
|
| 706 |
+
continue
|
| 707 |
+
img_obj = glb.images[tex.source]
|
| 708 |
+
if img_obj.bufferView is None:
|
| 709 |
+
continue
|
| 710 |
+
bv = glb.bufferViews[img_obj.bufferView]
|
| 711 |
+
offset, length = bv.byteOffset or 0, bv.byteLength
|
| 712 |
+
|
| 713 |
+
img_arr = np.frombuffer(blob[offset:offset + length], dtype=np.uint8)
|
| 714 |
+
atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
| 715 |
+
if atlas_bgr is None:
|
| 716 |
+
continue
|
| 717 |
+
|
| 718 |
+
orig_h, orig_w = atlas_bgr.shape[:2]
|
| 719 |
+
print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGAN…")
|
| 720 |
+
|
| 721 |
+
try:
|
| 722 |
+
upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4)
|
| 723 |
+
except Exception as e:
|
| 724 |
+
print(f"[enhance_glb] RealESRGAN enhance failed: {e}")
|
| 725 |
+
continue
|
| 726 |
+
|
| 727 |
+
restored = cv2.resize(upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
|
| 728 |
+
|
| 729 |
+
ok, new_bytes = cv2.imencode(".png", restored)
|
| 730 |
+
if not ok:
|
| 731 |
+
continue
|
| 732 |
+
new_bytes = new_bytes.tobytes()
|
| 733 |
+
new_len = len(new_bytes)
|
| 734 |
+
|
| 735 |
+
if new_len > length:
|
| 736 |
+
before = bytes(blob[:offset])
|
| 737 |
+
after = bytes(blob[offset + length:])
|
| 738 |
+
blob = bytearray(before + new_bytes + after)
|
| 739 |
+
delta = new_len - length
|
| 740 |
+
bv.byteLength = new_len
|
| 741 |
+
for other_bv in glb.bufferViews:
|
| 742 |
+
if (other_bv.byteOffset or 0) > offset:
|
| 743 |
+
other_bv.byteOffset += delta
|
| 744 |
+
glb.buffers[0].byteLength += delta
|
| 745 |
+
else:
|
| 746 |
+
blob[offset:offset + new_len] = new_bytes
|
| 747 |
+
bv.byteLength = new_len
|
| 748 |
+
|
| 749 |
+
glb.set_binary_blob(bytes(blob))
|
| 750 |
+
glb.save(glb_path)
|
| 751 |
+
print(f"[enhance_glb] GLB texture enhanced OK (was {length}B → {new_len}B)")
|
| 752 |
+
return True
|
| 753 |
+
|
| 754 |
+
print("[enhance_glb] No base-color texture found in GLB")
|
| 755 |
+
return False
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
# ── FireRed GPU functions ──────────────────────────────────────────────────────
|
| 759 |
+
|
| 760 |
+
@spaces.GPU(duration=600)
|
| 761 |
+
def firered_generate(gallery_images, prompt, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
|
| 762 |
+
"""Run FireRed image-edit inference on one or more reference images (max 3 natively)."""
|
| 763 |
+
pil_images = _gallery_to_pil_list(gallery_images)
|
| 764 |
+
if not pil_images:
|
| 765 |
+
return None, int(seed), "Please upload at least one image."
|
| 766 |
+
if not prompt or not prompt.strip():
|
| 767 |
+
return None, int(seed), "Please enter an edit prompt."
|
| 768 |
+
try:
|
| 769 |
+
import gc
|
| 770 |
+
progress(0.05, desc="Loading FireRed pipeline...")
|
| 771 |
+
pipe = load_firered()
|
| 772 |
+
|
| 773 |
+
if randomize_seed:
|
| 774 |
+
seed = random.randint(0, 2**31 - 1)
|
| 775 |
+
|
| 776 |
+
# FireRed natively handles 1-3 images; cap silently and warn
|
| 777 |
+
if len(pil_images) > 3:
|
| 778 |
+
print(f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit).")
|
| 779 |
+
pil_images = pil_images[:3]
|
| 780 |
+
|
| 781 |
+
# Resize to max 1024px and align to multiple of 8 (prevents padding bars)
|
| 782 |
+
pil_images = [_firered_resize(img) for img in pil_images]
|
| 783 |
+
height, width = pil_images[0].height, pil_images[0].width
|
| 784 |
+
print(f"[FireRed] Input size after resize: {width}x{height}")
|
| 785 |
+
|
| 786 |
+
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
|
| 787 |
+
|
| 788 |
+
progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...")
|
| 789 |
+
with torch.inference_mode():
|
| 790 |
+
result = pipe(
|
| 791 |
+
image=pil_images,
|
| 792 |
+
prompt=prompt.strip(),
|
| 793 |
+
negative_prompt=_FIRERED_NEGATIVE,
|
| 794 |
+
num_inference_steps=int(steps),
|
| 795 |
+
generator=generator,
|
| 796 |
+
true_cfg_scale=float(guidance_scale),
|
| 797 |
+
num_images_per_prompt=1,
|
| 798 |
+
height=height,
|
| 799 |
+
width=width,
|
| 800 |
+
).images[0]
|
| 801 |
+
|
| 802 |
+
gc.collect()
|
| 803 |
+
torch.cuda.empty_cache()
|
| 804 |
+
progress(1.0, desc="Done!")
|
| 805 |
+
n = len(pil_images)
|
| 806 |
+
note = " (truncated to 3)" if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3 else ""
|
| 807 |
+
return np.array(result), int(seed), f"Preview ready — {n} image(s) used{note}."
|
| 808 |
+
except Exception:
|
| 809 |
+
return None, int(seed), f"FireRed error:\n{traceback.format_exc()}"
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
@spaces.GPU(duration=60)
|
| 813 |
+
def firered_load_into_pipeline(firered_output, threshold, erode_px, progress=gr.Progress()):
|
| 814 |
+
"""Load a FireRed output into the main pipeline with automatic background removal."""
|
| 815 |
+
if firered_output is None:
|
| 816 |
+
return None, None, "No FireRed output — generate an image first."
|
| 817 |
+
try:
|
| 818 |
+
progress(0.1, desc="Loading RMBG model...")
|
| 819 |
+
load_rmbg_only()
|
| 820 |
+
|
| 821 |
+
img = Image.fromarray(firered_output).convert("RGB")
|
| 822 |
+
if _rmbg_net is not None:
|
| 823 |
+
progress(0.5, desc="Removing background...")
|
| 824 |
+
composited = _remove_bg_rmbg(img, threshold=float(threshold), erode_px=int(erode_px))
|
| 825 |
+
result = np.array(composited)
|
| 826 |
+
msg = "Loaded into pipeline — background removed."
|
| 827 |
+
else:
|
| 828 |
+
result = firered_output
|
| 829 |
+
msg = "Loaded into pipeline (RMBG unavailable — background not removed)."
|
| 830 |
+
|
| 831 |
+
progress(1.0, desc="Done!")
|
| 832 |
+
return result, result, msg
|
| 833 |
+
except Exception:
|
| 834 |
+
return None, None, f"Error:\n{traceback.format_exc()}"
|
| 835 |
+
|
| 836 |
+
|
| 837 |
# ── Stage 1: Shape generation ─────────────────────────────────────────────────
|
| 838 |
|
| 839 |
@spaces.GPU(duration=180)
|
|
|
|
| 842 |
if input_image is None:
|
| 843 |
return None, "Please upload an image."
|
| 844 |
try:
|
| 845 |
+
progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...")
|
| 846 |
+
global _firered_pipe
|
| 847 |
+
if _firered_pipe is not None:
|
| 848 |
+
# dispatch_model attaches accelerate hooks — remove them before .to("cpu")
|
| 849 |
+
try:
|
| 850 |
+
from accelerate.hooks import remove_hook_from_submodules
|
| 851 |
+
remove_hook_from_submodules(_firered_pipe.transformer)
|
| 852 |
+
_firered_pipe.transformer.to("cpu")
|
| 853 |
+
except Exception as _e:
|
| 854 |
+
print(f"[TripoSG] Transformer CPU offload: {_e}")
|
| 855 |
+
try:
|
| 856 |
+
_firered_pipe.text_encoder.to("cpu")
|
| 857 |
+
except Exception:
|
| 858 |
+
pass
|
| 859 |
+
try:
|
| 860 |
+
_firered_pipe.vae.to("cpu")
|
| 861 |
+
except Exception:
|
| 862 |
+
pass
|
| 863 |
+
_firered_pipe = None
|
| 864 |
+
torch.cuda.empty_cache()
|
| 865 |
+
print("[TripoSG] FireRed offloaded — VRAM freed for shape generation.")
|
| 866 |
+
|
| 867 |
progress(0.1, desc="Loading TripoSG...")
|
| 868 |
pipe, rmbg_net = load_triposg()
|
| 869 |
|
|
|
|
| 872 |
img.save(img_path)
|
| 873 |
|
| 874 |
progress(0.5, desc="Generating shape (SDF diffusion)...")
|
| 875 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 876 |
+
from scripts.inference_triposg import run_triposg
|
| 877 |
+
mesh = run_triposg(
|
| 878 |
+
pipe=pipe,
|
| 879 |
+
image_input=img_path,
|
| 880 |
+
rmbg_net=rmbg_net if remove_background else None,
|
| 881 |
+
seed=int(seed),
|
| 882 |
+
num_inference_steps=int(num_steps),
|
| 883 |
+
guidance_scale=float(guidance_scale),
|
| 884 |
+
faces=int(face_count) if int(face_count) > 0 else -1,
|
| 885 |
+
)
|
| 886 |
|
| 887 |
out_path = "/tmp/triposg_shape.glb"
|
| 888 |
mesh.export(out_path)
|
|
|
|
| 1109 |
animated = mdm_result.get("animated_glb")
|
| 1110 |
|
| 1111 |
parts = ["Rigged: " + os.path.basename(rigged)]
|
| 1112 |
+
if fbx: parts.append("FBX: " + os.path.basename(fbx))
|
| 1113 |
if animated: parts.append("Animation: " + os.path.basename(animated))
|
| 1114 |
|
| 1115 |
torch.cuda.empty_cache()
|
|
|
|
| 1133 |
from pipeline.enhance_surface import (
|
| 1134 |
run_stable_normal, run_depth_anything,
|
| 1135 |
bake_normal_into_glb, bake_depth_as_occlusion,
|
| 1136 |
+
unload_models,
|
| 1137 |
)
|
| 1138 |
import pipeline.enhance_surface as _enh_mod
|
| 1139 |
|
|
|
|
| 1205 |
return []
|
| 1206 |
|
| 1207 |
|
| 1208 |
+
# ── HyperSwap views ───────────────────────────────────────────────────────────
|
| 1209 |
+
|
| 1210 |
+
@spaces.GPU(duration=120)
|
| 1211 |
+
def hyperswap_views(embedding_json: str):
|
| 1212 |
+
"""
|
| 1213 |
+
Stage 6 — run HyperSwap on the last rendered views.
|
| 1214 |
+
embedding_json: JSON string of the 512-d ArcFace embedding list.
|
| 1215 |
+
Returns a gallery of (swapped_image_path, view_name) tuples.
|
| 1216 |
+
"""
|
| 1217 |
+
global _hyperswap_sess
|
| 1218 |
+
try:
|
| 1219 |
+
import onnxruntime as ort
|
| 1220 |
+
from insightface.app import FaceAnalysis
|
| 1221 |
+
|
| 1222 |
+
embedding = np.array(json.loads(embedding_json), dtype=np.float32)
|
| 1223 |
+
embedding /= np.linalg.norm(embedding)
|
| 1224 |
+
|
| 1225 |
+
# Load HyperSwap once
|
| 1226 |
+
if _hyperswap_sess is None:
|
| 1227 |
+
hs_path = str(CKPT_DIR / "hyperswap_1a_256.onnx")
|
| 1228 |
+
_hyperswap_sess = ort.InferenceSession(hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
| 1229 |
+
print(f"[hyperswap_views] Loaded {hs_path}")
|
| 1230 |
+
|
| 1231 |
+
app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
|
| 1232 |
+
app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1)
|
| 1233 |
+
|
| 1234 |
+
results = []
|
| 1235 |
+
for view_path, name in zip(VIEW_PATHS, VIEW_NAMES):
|
| 1236 |
+
if not os.path.exists(view_path):
|
| 1237 |
+
print(f"[hyperswap_views] Missing {view_path}, skipping")
|
| 1238 |
+
continue
|
| 1239 |
+
|
| 1240 |
+
bgr = cv2.imread(view_path)
|
| 1241 |
+
faces = app.get(bgr)
|
| 1242 |
+
if not faces:
|
| 1243 |
+
print(f"[hyperswap_views] {name}: no face detected")
|
| 1244 |
+
out_path = view_path # return original
|
| 1245 |
+
else:
|
| 1246 |
+
face = faces[0]
|
| 1247 |
+
M, _ = cv2.estimateAffinePartial2D(face.kps, ARCFACE_256,
|
| 1248 |
+
method=cv2.RANSAC, ransacReprojThreshold=100)
|
| 1249 |
+
H, W = bgr.shape[:2]
|
| 1250 |
+
aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR)
|
| 1251 |
+
t = ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1].copy().transpose(2, 0, 1)[None]
|
| 1252 |
+
out, mask = _hyperswap_sess.run(None, {
|
| 1253 |
+
"source": embedding.reshape(1, -1),
|
| 1254 |
+
"target": t,
|
| 1255 |
+
})
|
| 1256 |
+
out_bgr = (((out[0].transpose(1, 2, 0) + 1) / 2 * 255)
|
| 1257 |
+
.clip(0, 255).astype(np.uint8))[:, :, ::-1].copy()
|
| 1258 |
+
m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8)
|
| 1259 |
+
Mi = cv2.invertAffineTransform(M)
|
| 1260 |
+
of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR)
|
| 1261 |
+
mf = cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(np.float32)[:, :, None] / 255
|
| 1262 |
+
swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8)
|
| 1263 |
+
|
| 1264 |
+
# GFPGAN face restoration
|
| 1265 |
+
restorer = load_gfpgan()
|
| 1266 |
+
if restorer is not None:
|
| 1267 |
+
b = face.bbox.astype(int)
|
| 1268 |
+
h2, w2 = swapped.shape[:2]
|
| 1269 |
+
pad = 0.35
|
| 1270 |
+
bw2, bh2 = b[2]-b[0], b[3]-b[1]
|
| 1271 |
+
cx1 = max(0, b[0]-int(bw2*pad)); cy1 = max(0, b[1]-int(bh2*pad))
|
| 1272 |
+
cx2 = min(w2, b[2]+int(bw2*pad)); cy2 = min(h2, b[3]+int(bh2*pad))
|
| 1273 |
+
crop = swapped[cy1:cy2, cx1:cx2]
|
| 1274 |
+
try:
|
| 1275 |
+
_, _, rest = restorer.enhance(
|
| 1276 |
+
crop, has_aligned=False, only_center_face=True,
|
| 1277 |
+
paste_back=True, weight=0.5)
|
| 1278 |
+
if rest is not None:
|
| 1279 |
+
ch, cw = cy2 - cy1, cx2 - cx1
|
| 1280 |
+
if rest.shape[:2] != (ch, cw):
|
| 1281 |
+
rest = cv2.resize(rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4)
|
| 1282 |
+
swapped[cy1:cy2, cx1:cx2] = rest
|
| 1283 |
+
except Exception as _ge:
|
| 1284 |
+
print(f"[hyperswap_views] GFPGAN failed: {_ge}")
|
| 1285 |
+
|
| 1286 |
+
out_path = view_path.replace("render_", "swapped_")
|
| 1287 |
+
cv2.imwrite(out_path, swapped)
|
| 1288 |
+
print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}")
|
| 1289 |
+
|
| 1290 |
+
results.append((out_path, name))
|
| 1291 |
+
|
| 1292 |
+
return results
|
| 1293 |
+
except Exception:
|
| 1294 |
+
err = traceback.format_exc()
|
| 1295 |
+
print(f"hyperswap_views FAILED:\n{err}")
|
| 1296 |
+
return []
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
# ── Animate tab functions ─────────────────────────────────────────────────────
|
| 1300 |
+
|
| 1301 |
+
def gradio_search_motions(query: str, progress=gr.Progress()):
|
| 1302 |
+
"""Stream TeoGchx/HumanML3D and return matching motions as radio choices."""
|
| 1303 |
+
if not query.strip():
|
| 1304 |
+
return (
|
| 1305 |
+
gr.update(choices=[], visible=False),
|
| 1306 |
+
[],
|
| 1307 |
+
"Enter a motion description and click Search.",
|
| 1308 |
+
)
|
| 1309 |
+
try:
|
| 1310 |
+
progress(0.1, desc="Connecting to HumanML3D dataset…")
|
| 1311 |
+
sys.path.insert(0, str(HERE))
|
| 1312 |
+
from Retarget.search import search_motions, format_choice_label
|
| 1313 |
+
progress(0.3, desc="Streaming dataset…")
|
| 1314 |
+
results = search_motions(query, top_k=8)
|
| 1315 |
+
progress(1.0)
|
| 1316 |
+
if not results:
|
| 1317 |
+
return (
|
| 1318 |
+
gr.update(choices=["No matches — try different keywords"], visible=True),
|
| 1319 |
+
[],
|
| 1320 |
+
f"No motions matched '{query}'. Try broader terms.",
|
| 1321 |
+
)
|
| 1322 |
+
choices = [format_choice_label(r) for r in results]
|
| 1323 |
+
status = f"Found {len(results)} motions matching '{query}'"
|
| 1324 |
+
return (
|
| 1325 |
+
gr.update(choices=choices, value=choices[0], visible=True),
|
| 1326 |
+
results,
|
| 1327 |
+
status,
|
| 1328 |
+
)
|
| 1329 |
+
except Exception:
|
| 1330 |
+
return (
|
| 1331 |
+
gr.update(choices=[], visible=False),
|
| 1332 |
+
[],
|
| 1333 |
+
f"Search error:\n{traceback.format_exc()}",
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
@spaces.GPU(duration=180)
|
| 1338 |
+
def gradio_animate(
|
| 1339 |
+
rigged_glb_path,
|
| 1340 |
+
selected_label: str,
|
| 1341 |
+
motion_results: list,
|
| 1342 |
+
fps: int,
|
| 1343 |
+
max_frames: int,
|
| 1344 |
+
progress=gr.Progress(),
|
| 1345 |
+
):
|
| 1346 |
+
"""Bake selected HumanML3D motion onto the rigged GLB."""
|
| 1347 |
+
try:
|
| 1348 |
+
glb = rigged_glb_path or "/tmp/rig_out/rigged.glb"
|
| 1349 |
+
if not os.path.exists(glb):
|
| 1350 |
+
return None, "No rigged GLB — run the Rig step first.", None
|
| 1351 |
+
|
| 1352 |
+
if not motion_results or not selected_label:
|
| 1353 |
+
return None, "No motion selected — run Search first.", None
|
| 1354 |
+
|
| 1355 |
+
# Resolve which result was selected
|
| 1356 |
+
sys.path.insert(0, str(HERE))
|
| 1357 |
+
from Retarget.search import format_choice_label
|
| 1358 |
+
idx = 0
|
| 1359 |
+
for i, r in enumerate(motion_results):
|
| 1360 |
+
if format_choice_label(r) == selected_label:
|
| 1361 |
+
idx = i
|
| 1362 |
+
break
|
| 1363 |
+
|
| 1364 |
+
chosen = motion_results[idx]
|
| 1365 |
+
motion = chosen["motion"] # np.ndarray [T, 263]
|
| 1366 |
+
caption = chosen["caption"]
|
| 1367 |
+
T_total = motion.shape[0]
|
| 1368 |
+
n_frames = min(max_frames, T_total) if max_frames > 0 else T_total
|
| 1369 |
+
|
| 1370 |
+
progress(0.2, desc="Parsing skeleton…")
|
| 1371 |
+
from Retarget.animate import animate_glb_from_hml3d
|
| 1372 |
+
|
| 1373 |
+
out_path = "/tmp/animated_out/animated.glb"
|
| 1374 |
+
os.makedirs("/tmp/animated_out", exist_ok=True)
|
| 1375 |
+
|
| 1376 |
+
progress(0.4, desc="Mapping bones to SMPL joints…")
|
| 1377 |
+
animated = animate_glb_from_hml3d(
|
| 1378 |
+
motion=motion,
|
| 1379 |
+
rigged_glb=glb,
|
| 1380 |
+
output_glb=out_path,
|
| 1381 |
+
fps=int(fps),
|
| 1382 |
+
num_frames=int(n_frames),
|
| 1383 |
+
)
|
| 1384 |
+
progress(1.0, desc="Done!")
|
| 1385 |
+
status = (
|
| 1386 |
+
f"Animated: {n_frames} frames @ {fps} fps\n"
|
| 1387 |
+
f"Motion: {caption[:120]}"
|
| 1388 |
+
)
|
| 1389 |
+
return animated, status, animated
|
| 1390 |
+
|
| 1391 |
+
except Exception:
|
| 1392 |
+
return None, f"Error:\n{traceback.format_exc()}", None
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
+
# ── PSHuman Face Transplant ────────────────────────────────────────────────────
|
| 1396 |
+
|
| 1397 |
+
def gradio_pshuman_face(
|
| 1398 |
+
input_image,
|
| 1399 |
+
rigged_glb_path,
|
| 1400 |
+
weight_threshold: float,
|
| 1401 |
+
retract_mm: float,
|
| 1402 |
+
pshuman_url: str,
|
| 1403 |
+
progress=gr.Progress(),
|
| 1404 |
+
):
|
| 1405 |
+
"""
|
| 1406 |
+
Full PSHuman face transplant pipeline:
|
| 1407 |
+
1. Run PSHuman on input_image → colored OBJ face mesh
|
| 1408 |
+
2. Run face_transplant.py → stitch face into rigged GLB
|
| 1409 |
+
3. Return the combined GLB
|
| 1410 |
+
|
| 1411 |
+
PSHuman runs as a remote service (pshuman_url). On ZeroGPU the service_url
|
| 1412 |
+
must point to an externally-deployed PSHuman endpoint (PSHUMAN_URL env var
|
| 1413 |
+
or user-provided URL in the UI). Local localhost will not work on ZeroGPU.
|
| 1414 |
+
"""
|
| 1415 |
+
try:
|
| 1416 |
+
if input_image is None:
|
| 1417 |
+
return None, "Upload a portrait image first.", None
|
| 1418 |
+
rigged = rigged_glb_path
|
| 1419 |
+
if not rigged or not os.path.exists(str(rigged)):
|
| 1420 |
+
return None, "No rigged GLB found — run the Rig step first.", None
|
| 1421 |
+
|
| 1422 |
+
work_dir = tempfile.mkdtemp(prefix="pshuman_transplant_")
|
| 1423 |
+
img_path = os.path.join(work_dir, "portrait.png")
|
| 1424 |
+
if isinstance(input_image, np.ndarray):
|
| 1425 |
+
Image.fromarray(input_image).save(img_path)
|
| 1426 |
+
else:
|
| 1427 |
+
input_image.save(img_path)
|
| 1428 |
+
|
| 1429 |
+
# pipeline/ is already in sys.path via PIPELINE_DIR insertion at startup
|
| 1430 |
+
# ── Step 1: PSHuman inference ──────────────────────────────────────────
|
| 1431 |
+
progress(0.05, desc="Step 1/2: Running PSHuman (generates multi-view face)...")
|
| 1432 |
+
from pipeline.pshuman_client import generate_pshuman_mesh
|
| 1433 |
+
face_obj = os.path.join(work_dir, "pshuman_face.obj")
|
| 1434 |
+
generate_pshuman_mesh(
|
| 1435 |
+
image_path = img_path,
|
| 1436 |
+
output_path = face_obj,
|
| 1437 |
+
service_url = pshuman_url.strip() or "http://localhost:7862",
|
| 1438 |
+
)
|
| 1439 |
+
|
| 1440 |
+
# ── Step 2: Face transplant ────────────────────────────────────────────
|
| 1441 |
+
progress(0.7, desc="Step 2/2: Stitching PSHuman face into rigged GLB...")
|
| 1442 |
+
out_glb = os.path.join(work_dir, "rigged_pshuman_face.glb")
|
| 1443 |
+
|
| 1444 |
+
from pipeline.face_transplant import transplant_face
|
| 1445 |
+
transplant_face(
|
| 1446 |
+
body_glb_path = str(rigged),
|
| 1447 |
+
pshuman_mesh_path = face_obj,
|
| 1448 |
+
output_path = out_glb,
|
| 1449 |
+
weight_threshold = float(weight_threshold),
|
| 1450 |
+
retract_amount = float(retract_mm) / 1000.0, # mm → metres
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
progress(1.0, desc="Done!")
|
| 1454 |
+
return out_glb, "PSHuman face transplant complete.", out_glb
|
| 1455 |
+
|
| 1456 |
+
except Exception:
|
| 1457 |
+
return None, f"Error:\n{traceback.format_exc()}", None
|
| 1458 |
+
|
| 1459 |
+
|
| 1460 |
# ── Full pipeline ─────────────────────────────────────────────────────────────
|
| 1461 |
|
| 1462 |
+
def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count,
|
| 1463 |
+
variant, tex_seed, enhance_face, rembg_threshold, rembg_erode,
|
| 1464 |
export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
|
| 1465 |
progress(0.0, desc="Stage 1/3: Generating shape...")
|
| 1466 |
+
glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count)
|
| 1467 |
if not glb:
|
| 1468 |
return None, None, None, None, None, None, status
|
| 1469 |
|
| 1470 |
progress(0.33, desc="Stage 2/3: Applying texture...")
|
| 1471 |
+
glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed,
|
| 1472 |
+
enhance_face, rembg_threshold, rembg_erode)
|
| 1473 |
if not glb:
|
| 1474 |
return None, None, None, None, None, None, status
|
| 1475 |
|
|
|
|
| 1481 |
|
| 1482 |
|
| 1483 |
# ── UI ────────────────────────────────────────────────────────────────────────
|
| 1484 |
+
with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
|
| 1485 |
gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh")
|
| 1486 |
+
glb_state = gr.State(None)
|
| 1487 |
+
rigged_glb_state = gr.State(None) # persists rigged GLB for Animate + PSHuman tabs
|
| 1488 |
+
|
| 1489 |
+
with gr.Tabs() as tabs:
|
| 1490 |
+
|
| 1491 |
+
# ════════════════════════════════════════════════════════════════════
|
| 1492 |
+
with gr.Tab("Edit", id=0):
|
| 1493 |
+
gr.Markdown(
|
| 1494 |
+
"### Image Edit — FireRed\n"
|
| 1495 |
+
"Upload one or more reference images, write an edit prompt, preview the result, "
|
| 1496 |
+
"then click **Load to Generate** to send it to the 3D pipeline."
|
| 1497 |
+
)
|
| 1498 |
+
with gr.Row():
|
| 1499 |
+
with gr.Column(scale=1):
|
| 1500 |
+
firered_gallery = gr.Gallery(
|
| 1501 |
+
label="Reference Images (1–3 images, drag & drop)",
|
| 1502 |
+
interactive=True,
|
| 1503 |
+
columns=3,
|
| 1504 |
+
height=220,
|
| 1505 |
+
object_fit="contain",
|
| 1506 |
+
)
|
| 1507 |
+
firered_prompt = gr.Textbox(
|
| 1508 |
+
label="Edit Prompt",
|
| 1509 |
+
placeholder="make the person wear a red jacket",
|
| 1510 |
+
lines=2,
|
| 1511 |
+
)
|
| 1512 |
+
with gr.Row():
|
| 1513 |
+
firered_seed = gr.Number(value=_init_seed, label="Seed", precision=0)
|
| 1514 |
+
firered_rand = gr.Checkbox(label="Random Seed", value=True)
|
| 1515 |
+
with gr.Row():
|
| 1516 |
+
firered_guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5,
|
| 1517 |
+
label="Guidance Scale")
|
| 1518 |
+
firered_steps = gr.Slider(1, 40, value=4, step=1,
|
| 1519 |
+
label="Inference Steps")
|
| 1520 |
+
firered_btn = gr.Button("Generate Preview", variant="secondary")
|
| 1521 |
+
firered_status = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 1522 |
|
| 1523 |
+
with gr.Column(scale=1):
|
| 1524 |
+
firered_output_img = gr.Image(label="FireRed Output", type="numpy",
|
| 1525 |
+
interactive=False)
|
| 1526 |
+
load_to_generate_btn = gr.Button("Load to Generate", variant="primary")
|
| 1527 |
|
| 1528 |
# ════════════════════════════════════════════════════════════════════
|
| 1529 |
+
with gr.Tab("Generate", id=1):
|
| 1530 |
with gr.Row():
|
| 1531 |
with gr.Column(scale=1):
|
| 1532 |
+
input_image = gr.Image(label="Input Image", type="numpy")
|
| 1533 |
+
remove_bg_check = gr.Checkbox(label="Remove Background", value=True)
|
| 1534 |
+
with gr.Row():
|
| 1535 |
+
rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
|
| 1536 |
+
label="BG Threshold (higher = stricter)")
|
| 1537 |
+
rembg_erode = gr.Slider(0, 8, value=2, step=1,
|
| 1538 |
+
label="Edge Erode (px)")
|
| 1539 |
|
| 1540 |
with gr.Accordion("Shape Settings", open=True):
|
| 1541 |
num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
|
|
|
|
| 1554 |
shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
|
| 1555 |
texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
|
| 1556 |
render_btn = gr.Button("Render Views", variant="secondary", scale=1)
|
| 1557 |
+
run_all_btn = gr.Button("▶ Run Full Pipeline (Shape + Texture + Rig)", variant="primary", interactive=False)
|
| 1558 |
|
| 1559 |
with gr.Column(scale=1):
|
| 1560 |
+
rembg_preview = gr.Image(label="BG Removed Preview", type="numpy",
|
| 1561 |
+
interactive=False)
|
| 1562 |
status = gr.Textbox(label="Status", lines=3, interactive=False)
|
| 1563 |
model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
|
| 1564 |
download_file = gr.File(label="Download GLB")
|
|
|
|
| 1566 |
|
| 1567 |
render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
|
| 1568 |
|
| 1569 |
+
# ── wiring: Generate tab ──────────────────────────────────────────
|
| 1570 |
+
_rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode]
|
| 1571 |
_pipeline_btns = [shape_btn, run_all_btn]
|
| 1572 |
|
| 1573 |
input_image.upload(
|
|
|
|
| 1579 |
inputs=[], outputs=_pipeline_btns,
|
| 1580 |
)
|
| 1581 |
|
| 1582 |
+
input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
|
| 1583 |
+
remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
|
| 1584 |
+
rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
|
| 1585 |
+
rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
|
| 1586 |
+
|
| 1587 |
shape_btn.click(
|
| 1588 |
+
fn=generate_shape,
|
| 1589 |
+
inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count],
|
| 1590 |
outputs=[glb_state, status],
|
| 1591 |
).then(
|
| 1592 |
fn=lambda p: (p, p) if p else (None, None),
|
|
|
|
| 1594 |
)
|
| 1595 |
|
| 1596 |
texture_btn.click(
|
| 1597 |
+
fn=apply_texture,
|
| 1598 |
+
inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed,
|
| 1599 |
+
enhance_face_check, rembg_threshold, rembg_erode],
|
| 1600 |
outputs=[glb_state, multiview_img, status],
|
| 1601 |
).then(
|
| 1602 |
fn=lambda p: (p, p) if p else (None, None),
|
|
|
|
| 1605 |
|
| 1606 |
render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
|
| 1607 |
|
| 1608 |
+
# ── Edit tab wiring (after Generate so all components are defined) ──
|
| 1609 |
+
firered_btn.click(
|
| 1610 |
+
fn=firered_generate,
|
| 1611 |
+
inputs=[firered_gallery, firered_prompt, firered_seed, firered_rand,
|
| 1612 |
+
firered_guidance, firered_steps],
|
| 1613 |
+
outputs=[firered_output_img, firered_seed, firered_status],
|
| 1614 |
+
api_name="firered_generate",
|
| 1615 |
+
)
|
| 1616 |
+
|
| 1617 |
+
load_to_generate_btn.click(
|
| 1618 |
+
fn=firered_load_into_pipeline,
|
| 1619 |
+
inputs=[firered_output_img, rembg_threshold, rembg_erode],
|
| 1620 |
+
outputs=[input_image, rembg_preview, firered_status],
|
| 1621 |
+
).then(
|
| 1622 |
+
fn=lambda img: (
|
| 1623 |
+
gr.update(interactive=img is not None),
|
| 1624 |
+
gr.update(interactive=img is not None),
|
| 1625 |
+
gr.update(selected=1),
|
| 1626 |
+
),
|
| 1627 |
+
inputs=[input_image],
|
| 1628 |
+
outputs=[shape_btn, run_all_btn, tabs],
|
| 1629 |
+
)
|
| 1630 |
+
|
| 1631 |
# ════════════════════════════════════════════════════════════════════
|
| 1632 |
with gr.Tab("Rig & Export"):
|
| 1633 |
with gr.Row():
|
|
|
|
| 1675 |
inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
|
| 1676 |
outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
|
| 1677 |
rig_model_3d, rigged_base_state, skel_glb_state],
|
| 1678 |
+
).then(
|
| 1679 |
+
fn=lambda p: p,
|
| 1680 |
+
inputs=[rigged_base_state], outputs=[rigged_glb_state],
|
| 1681 |
)
|
| 1682 |
|
| 1683 |
show_skel_check.change(
|
|
|
|
| 1686 |
outputs=[rig_model_3d],
|
| 1687 |
)
|
| 1688 |
|
| 1689 |
+
# ════════════════════════════════════════════════════════════════════
|
| 1690 |
+
with gr.Tab("Animate"):
|
| 1691 |
+
gr.Markdown(
|
| 1692 |
+
"### Motion Search & Animate\n"
|
| 1693 |
+
"Search the HumanML3D dataset for motions matching a description, "
|
| 1694 |
+
"then bake the selected motion onto your rigged GLB."
|
| 1695 |
+
)
|
| 1696 |
+
with gr.Row():
|
| 1697 |
+
with gr.Column(scale=1):
|
| 1698 |
+
motion_query = gr.Textbox(
|
| 1699 |
+
label="Motion Description",
|
| 1700 |
+
placeholder="a person walks forward slowly",
|
| 1701 |
+
lines=2,
|
| 1702 |
+
)
|
| 1703 |
+
search_btn = gr.Button("Search Motions", variant="secondary")
|
| 1704 |
+
motion_radio = gr.Radio(
|
| 1705 |
+
label="Select Motion", choices=[], visible=False,
|
| 1706 |
+
)
|
| 1707 |
+
motion_results_state = gr.State([])
|
| 1708 |
+
|
| 1709 |
+
gr.Markdown("### Animate Settings")
|
| 1710 |
+
animate_fps = gr.Slider(10, 60, value=30, step=5, label="FPS")
|
| 1711 |
+
animate_frames = gr.Slider(0, 600, value=0, step=30,
|
| 1712 |
+
label="Max Frames (0 = full motion)")
|
| 1713 |
+
animate_btn = gr.Button("Animate", variant="primary")
|
| 1714 |
+
|
| 1715 |
+
with gr.Column(scale=2):
|
| 1716 |
+
animate_status = gr.Textbox(label="Status", lines=4, interactive=False)
|
| 1717 |
+
animate_model_3d = gr.Model3D(label="Animated Preview",
|
| 1718 |
+
clear_color=[0.9, 0.9, 0.9, 1.0])
|
| 1719 |
+
animate_dl = gr.File(label="Download Animated GLB")
|
| 1720 |
+
|
| 1721 |
+
search_btn.click(
|
| 1722 |
+
fn=gradio_search_motions,
|
| 1723 |
+
inputs=[motion_query],
|
| 1724 |
+
outputs=[motion_radio, motion_results_state, animate_status],
|
| 1725 |
+
)
|
| 1726 |
+
|
| 1727 |
+
animate_btn.click(
|
| 1728 |
+
fn=gradio_animate,
|
| 1729 |
+
inputs=[rigged_glb_state, motion_radio, motion_results_state,
|
| 1730 |
+
animate_fps, animate_frames],
|
| 1731 |
+
outputs=[animate_dl, animate_status, animate_model_3d],
|
| 1732 |
+
)
|
| 1733 |
+
|
| 1734 |
+
# ════════════════════════════════════════════════════════════════════
|
| 1735 |
+
with gr.Tab("PSHuman Face"):
|
| 1736 |
+
gr.Markdown(
|
| 1737 |
+
"### PSHuman Face Transplant\n"
|
| 1738 |
+
"Generates a high-detail face mesh via PSHuman (multi-view diffusion), "
|
| 1739 |
+
"then transplants it into the rigged GLB.\n\n"
|
| 1740 |
+
"**Pipeline:** portrait → PSHuman (remote service) → colored OBJ → face_transplant → rigged GLB with HD face\n\n"
|
| 1741 |
+
"**Note:** On ZeroGPU, PSHuman must run as a remote service. "
|
| 1742 |
+
"Set `PSHUMAN_URL` environment variable or enter the URL below."
|
| 1743 |
+
)
|
| 1744 |
+
with gr.Row():
|
| 1745 |
+
with gr.Column(scale=1):
|
| 1746 |
+
pshuman_img_input = gr.Image(
|
| 1747 |
+
label="Portrait image (same as used for Generate)",
|
| 1748 |
+
type="pil",
|
| 1749 |
+
)
|
| 1750 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 1751 |
+
pshuman_weight_thresh = gr.Slider(
|
| 1752 |
+
minimum=0.1, maximum=0.9, value=0.35, step=0.05,
|
| 1753 |
+
label="Head bone weight threshold",
|
| 1754 |
+
info="Vertices with head-bone weight above this get replaced",
|
| 1755 |
+
)
|
| 1756 |
+
pshuman_retract_mm = gr.Slider(
|
| 1757 |
+
minimum=0.0, maximum=20.0, value=4.0, step=0.5,
|
| 1758 |
+
label="Face retract (mm)",
|
| 1759 |
+
info="How far to push original face verts inward to avoid z-fighting",
|
| 1760 |
+
)
|
| 1761 |
+
pshuman_service_url = gr.Textbox(
|
| 1762 |
+
label="PSHuman service URL",
|
| 1763 |
+
value=os.environ.get("PSHUMAN_URL", "http://localhost:7862"),
|
| 1764 |
+
info="pshuman_app.py Gradio endpoint (deployed separately)",
|
| 1765 |
+
)
|
| 1766 |
+
pshuman_btn = gr.Button("Generate HD Face", variant="primary")
|
| 1767 |
+
|
| 1768 |
+
with gr.Column(scale=2):
|
| 1769 |
+
pshuman_status = gr.Textbox(label="Status", lines=4, interactive=False)
|
| 1770 |
+
pshuman_model_3d = gr.Model3D(
|
| 1771 |
+
label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
|
| 1772 |
+
pshuman_glb_dl = gr.File(label="Download GLB (with PSHuman face)")
|
| 1773 |
+
|
| 1774 |
+
pshuman_btn.click(
|
| 1775 |
+
fn=gradio_pshuman_face,
|
| 1776 |
+
inputs=[
|
| 1777 |
+
pshuman_img_input,
|
| 1778 |
+
rigged_glb_state,
|
| 1779 |
+
pshuman_weight_thresh,
|
| 1780 |
+
pshuman_retract_mm,
|
| 1781 |
+
pshuman_service_url,
|
| 1782 |
+
],
|
| 1783 |
+
outputs=[pshuman_glb_dl, pshuman_status, pshuman_model_3d],
|
| 1784 |
+
)
|
| 1785 |
+
|
| 1786 |
# ════════════════════════════════════════════════════════════════════
|
| 1787 |
with gr.Tab("Enhancement"):
|
| 1788 |
gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.")
|
|
|
|
| 1799 |
displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
|
| 1800 |
|
| 1801 |
enhance_btn = gr.Button("Run Enhancement", variant="primary")
|
| 1802 |
+
unload_btn = gr.Button("Unload Models (free VRAM)", variant="secondary")
|
| 1803 |
|
| 1804 |
with gr.Column(scale=2):
|
| 1805 |
enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
|
|
|
|
| 1818 |
enhanced_glb_dl, enhanced_model_3d, enhance_status],
|
| 1819 |
)
|
| 1820 |
|
| 1821 |
+
def _unload_enhancement_models():
|
| 1822 |
+
try:
|
| 1823 |
+
from pipeline.enhance_surface import unload_models
|
| 1824 |
+
unload_models()
|
| 1825 |
+
return "Enhancement models unloaded — VRAM freed."
|
| 1826 |
+
except Exception as e:
|
| 1827 |
+
return f"Unload failed: {e}"
|
| 1828 |
+
|
| 1829 |
+
unload_btn.click(
|
| 1830 |
+
fn=_unload_enhancement_models,
|
| 1831 |
+
inputs=[], outputs=[enhance_status],
|
| 1832 |
+
)
|
| 1833 |
+
|
| 1834 |
+
# ════════════════════════════════════════════════════════════════════
|
| 1835 |
+
with gr.Tab("Settings"):
|
| 1836 |
+
|
| 1837 |
+
def get_vram_status():
|
| 1838 |
+
lines = []
|
| 1839 |
+
if torch.cuda.is_available():
|
| 1840 |
+
alloc = torch.cuda.memory_allocated() / 1024**3
|
| 1841 |
+
reserv = torch.cuda.memory_reserved() / 1024**3
|
| 1842 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 1843 |
+
free = total - reserv
|
| 1844 |
+
lines.append(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 1845 |
+
lines.append(f"VRAM total: {total:.1f} GB")
|
| 1846 |
+
lines.append(f"VRAM allocated: {alloc:.1f} GB")
|
| 1847 |
+
lines.append(f"VRAM reserved: {reserv:.1f} GB")
|
| 1848 |
+
lines.append(f"VRAM free: {free:.1f} GB")
|
| 1849 |
+
else:
|
| 1850 |
+
lines.append("No CUDA device available.")
|
| 1851 |
+
lines.append("")
|
| 1852 |
+
lines.append("Loaded models:")
|
| 1853 |
+
lines.append(f" TripoSG pipeline: {'loaded' if _triposg_pipe is not None else 'not loaded'}")
|
| 1854 |
+
lines.append(f" RMBG-{_rmbg_version or '?'}: {'loaded' if _rmbg_net is not None else 'not loaded'}")
|
| 1855 |
+
lines.append(f" FireRed: {'loaded' if _firered_pipe is not None else 'not loaded'}")
|
| 1856 |
+
try:
|
| 1857 |
+
import pipeline.enhance_surface as _enh_mod
|
| 1858 |
+
lines.append(f" StableNormal: {'loaded' if _enh_mod._normal_pipe is not None else 'not loaded'}")
|
| 1859 |
+
lines.append(f" Depth-Anything: {'loaded' if _enh_mod._depth_pipe is not None else 'not loaded'}")
|
| 1860 |
+
except Exception:
|
| 1861 |
+
lines.append(" StableNormal / Depth-Anything: (status unavailable)")
|
| 1862 |
+
return "\n".join(lines)
|
| 1863 |
+
|
| 1864 |
+
def _preload_triposg():
|
| 1865 |
+
try:
|
| 1866 |
+
load_triposg()
|
| 1867 |
+
return get_vram_status()
|
| 1868 |
+
except Exception:
|
| 1869 |
+
return f"Preload failed:\n{traceback.format_exc()}"
|
| 1870 |
+
|
| 1871 |
+
def _unload_triposg():
|
| 1872 |
+
global _triposg_pipe, _rmbg_net
|
| 1873 |
+
with _model_load_lock:
|
| 1874 |
+
if _triposg_pipe is not None:
|
| 1875 |
+
_triposg_pipe.to("cpu")
|
| 1876 |
+
del _triposg_pipe
|
| 1877 |
+
_triposg_pipe = None
|
| 1878 |
+
if _rmbg_net is not None:
|
| 1879 |
+
_rmbg_net.to("cpu")
|
| 1880 |
+
del _rmbg_net
|
| 1881 |
+
_rmbg_net = None
|
| 1882 |
+
torch.cuda.empty_cache()
|
| 1883 |
+
return get_vram_status()
|
| 1884 |
+
|
| 1885 |
+
def _unload_enhancement():
|
| 1886 |
+
try:
|
| 1887 |
+
from pipeline.enhance_surface import unload_models
|
| 1888 |
+
unload_models()
|
| 1889 |
+
except Exception:
|
| 1890 |
+
pass
|
| 1891 |
+
return get_vram_status()
|
| 1892 |
+
|
| 1893 |
+
def _unload_all():
|
| 1894 |
+
_unload_triposg()
|
| 1895 |
+
_unload_enhancement()
|
| 1896 |
+
return get_vram_status()
|
| 1897 |
+
|
| 1898 |
+
with gr.Row():
|
| 1899 |
+
with gr.Column(scale=1):
|
| 1900 |
+
gr.Markdown("### VRAM Management")
|
| 1901 |
+
preload_btn = gr.Button("Preload TripoSG + RMBG to VRAM", variant="primary")
|
| 1902 |
+
unload_triposg_btn = gr.Button("Unload TripoSG / RMBG")
|
| 1903 |
+
unload_enh_btn = gr.Button("Unload Enhancement Models (StableNormal / Depth)")
|
| 1904 |
+
unload_all_btn = gr.Button("Unload All Models", variant="stop")
|
| 1905 |
+
refresh_btn = gr.Button("Refresh Status")
|
| 1906 |
+
|
| 1907 |
+
with gr.Column(scale=1):
|
| 1908 |
+
gr.Markdown("### GPU Status")
|
| 1909 |
+
vram_status = gr.Textbox(
|
| 1910 |
+
label="", lines=12, interactive=False,
|
| 1911 |
+
value="Click Refresh to check VRAM status.",
|
| 1912 |
+
)
|
| 1913 |
+
|
| 1914 |
+
preload_btn.click(fn=_preload_triposg, inputs=[], outputs=[vram_status])
|
| 1915 |
+
unload_triposg_btn.click(fn=_unload_triposg, inputs=[], outputs=[vram_status])
|
| 1916 |
+
unload_enh_btn.click(fn=_unload_enhancement, inputs=[], outputs=[vram_status])
|
| 1917 |
+
unload_all_btn.click(fn=_unload_all, inputs=[], outputs=[vram_status])
|
| 1918 |
+
refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status])
|
| 1919 |
+
|
| 1920 |
+
# ── Run All wiring (after all tabs so components are defined) ────────
|
| 1921 |
run_all_btn.click(
|
| 1922 |
fn=run_full_pipeline,
|
| 1923 |
inputs=[
|
| 1924 |
+
input_image, remove_bg_check, num_steps, guidance, seed, face_count,
|
| 1925 |
+
variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode,
|
| 1926 |
export_fbx_check, mdm_prompt_box, mdm_frames_slider,
|
| 1927 |
],
|
| 1928 |
outputs=[glb_state, download_file, multiview_img,
|
|
|
|
| 1932 |
inputs=[glb_state], outputs=[model_3d, download_file],
|
| 1933 |
)
|
| 1934 |
|
| 1935 |
+
# ── Hidden API endpoints ──────────────────────────────────────────────────
|
| 1936 |
+
_api_render_gallery = gr.Gallery(visible=False)
|
| 1937 |
+
_api_swap_gallery = gr.Gallery(visible=False)
|
| 1938 |
+
|
| 1939 |
+
def _render_last():
|
| 1940 |
+
path = _last_glb_path or "/tmp/triposg_textured.glb"
|
| 1941 |
+
return render_views(path)
|
| 1942 |
+
|
| 1943 |
+
_hs_emb_input = gr.Textbox(visible=False)
|
| 1944 |
+
|
| 1945 |
+
gr.Button(visible=False).click(
|
| 1946 |
+
fn=_render_last, inputs=[], outputs=[_api_render_gallery], api_name="render_last")
|
| 1947 |
+
gr.Button(visible=False).click(
|
| 1948 |
+
fn=hyperswap_views, inputs=[_hs_emb_input], outputs=[_api_swap_gallery],
|
| 1949 |
+
api_name="hyperswap_views")
|
| 1950 |
+
|
| 1951 |
|
| 1952 |
if __name__ == "__main__":
|
| 1953 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(),
|
| 1954 |
+
show_error=True, allowed_paths=["/tmp"])
|
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
face_inswap_bake.py — Proper face swap on rendered views, then UV-bake.
|
| 3 |
+
|
| 4 |
+
Pipeline:
|
| 5 |
+
1. Render the mesh from multiple views (front + L/R 3-quarter)
|
| 6 |
+
2. Run inswapper_128 to swap reference face onto each rendered view
|
| 7 |
+
3. uv_render_attr() bakes each swapped render directly into UV texture
|
| 8 |
+
(render-space coords shared with UV lookup — no coordinate transforms)
|
| 9 |
+
4. Composite multiple views (front takes priority, sides fill gaps)
|
| 10 |
+
5. Save updated GLB
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python face_inswap_bake.py \
|
| 14 |
+
--body /tmp/triposg_textured.glb \
|
| 15 |
+
--face /tmp/triposg_face_ref.png \
|
| 16 |
+
--out /tmp/face_swapped.glb \
|
| 17 |
+
[--uv_size 4096] [--debug_dir /tmp]
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os, sys, argparse, warnings
|
| 21 |
+
warnings.filterwarnings('ignore')
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import cv2
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import trimesh
|
| 29 |
+
from trimesh.visual.texture import TextureVisuals
|
| 30 |
+
from trimesh.visual.material import PBRMaterial
|
| 31 |
+
|
| 32 |
+
sys.path.insert(0, '/root/MV-Adapter')
|
| 33 |
+
from mvadapter.utils.mesh_utils import (
|
| 34 |
+
NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
|
| 35 |
+
)
|
| 36 |
+
from mvadapter.utils.mesh_utils.uv import (
|
| 37 |
+
uv_precompute, uv_render_geometry, uv_render_attr,
|
| 38 |
+
)
|
| 39 |
+
from insightface.app import FaceAnalysis
|
| 40 |
+
import insightface
|
| 41 |
+
from gfpgan import GFPGANer
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
GFPGAN_PATH = '/root/MV-Adapter/checkpoints/GFPGANv1.4.pth'
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ── helpers ───────────────────────────────────────────────────────────────────
|
| 48 |
+
|
| 49 |
+
def _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac=0.76):
|
| 50 |
+
"""UV-space mask covering only front-facing head triangles (no back-of-head)."""
|
| 51 |
+
verts = np.array(mesh_t.vertices, dtype=np.float64)
|
| 52 |
+
faces = np.array(mesh_t.faces, dtype=np.int32)
|
| 53 |
+
uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
|
| 54 |
+
|
| 55 |
+
y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
|
| 56 |
+
neck_y = float(y_min + (y_max - y_min) * neck_frac)
|
| 57 |
+
head_idx = np.where(verts[:, 1] > neck_y)[0]
|
| 58 |
+
hv = verts[head_idx]
|
| 59 |
+
|
| 60 |
+
z_thresh = float(np.percentile(hv[:, 2], 40))
|
| 61 |
+
front = hv[:, 2] >= z_thresh
|
| 62 |
+
if front.sum() < 30:
|
| 63 |
+
front = np.ones(len(hv), bool)
|
| 64 |
+
|
| 65 |
+
face_vert_idx = head_idx[front]
|
| 66 |
+
face_vert_mask = np.zeros(len(verts), bool)
|
| 67 |
+
face_vert_mask[face_vert_idx] = True
|
| 68 |
+
face_tri_mask = face_vert_mask[faces].all(axis=1)
|
| 69 |
+
face_tris = faces[face_tri_mask]
|
| 70 |
+
print(f' Geometry mask: {face_tri_mask.sum()} front-face triangles '
|
| 71 |
+
f'(neck_y={neck_y:.3f}, z_thresh={z_thresh:.3f})')
|
| 72 |
+
|
| 73 |
+
geom_mask = np.zeros((tex_H, tex_W), dtype=np.float32)
|
| 74 |
+
pts_list = []
|
| 75 |
+
for tri in face_tris:
|
| 76 |
+
uv = uvs[tri]
|
| 77 |
+
px = uv[:, 0] * tex_W
|
| 78 |
+
py = (1.0 - uv[:, 1]) * tex_H
|
| 79 |
+
pts_list.append(np.column_stack([px, py]).astype(np.int32))
|
| 80 |
+
if pts_list:
|
| 81 |
+
cv2.fillPoly(geom_mask, pts_list, 1.0)
|
| 82 |
+
|
| 83 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 84 |
+
geom_mask = cv2.dilate(geom_mask, kernel, iterations=2)
|
| 85 |
+
geom_mask = cv2.erode(geom_mask, kernel, iterations=1)
|
| 86 |
+
geom_mask = cv2.GaussianBlur(geom_mask, (31, 31), 8)
|
| 87 |
+
return geom_mask
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _detect_largest_face(img_bgr, app):
|
| 91 |
+
faces = app.get(img_bgr)
|
| 92 |
+
if not faces:
|
| 93 |
+
return None
|
| 94 |
+
return max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _render_view(ctx, mesh_mv, uv_pre, azimuth_deg, H, W, device):
|
| 98 |
+
"""Render the mesh from a given azimuth; return (camera, uv_geom)."""
|
| 99 |
+
camera = get_orthogonal_camera(
|
| 100 |
+
elevation_deg=[0], distance=[1.8],
|
| 101 |
+
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
|
| 102 |
+
azimuth_deg=[azimuth_deg], device=device,
|
| 103 |
+
)
|
| 104 |
+
uv_geom = uv_render_geometry(
|
| 105 |
+
ctx, mesh_mv, camera,
|
| 106 |
+
view_height=H, view_width=W,
|
| 107 |
+
uv_precompute_output=uv_pre,
|
| 108 |
+
compute_depth_grad=False,
|
| 109 |
+
)
|
| 110 |
+
return camera, uv_geom
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def face_inswap_bake(body_glb, face_img_path, out_glb,
|
| 114 |
+
uv_size=4096, debug_dir=None):
|
| 115 |
+
|
| 116 |
+
device = 'cuda'
|
| 117 |
+
INSWAPPER_PATH = '/root/MV-Adapter/checkpoints/inswapper_128.onnx'
|
| 118 |
+
|
| 119 |
+
# ── Load GFPGAN enhancer ──────────────────────────────────────────────────
|
| 120 |
+
print('[fib] Loading GFPGANv1.4 ...')
|
| 121 |
+
enhancer = GFPGANer(
|
| 122 |
+
model_path=GFPGAN_PATH,
|
| 123 |
+
upscale=1,
|
| 124 |
+
arch='clean',
|
| 125 |
+
channel_multiplier=2,
|
| 126 |
+
bg_upsampler=None,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# ── Load mesh ─────────────────────────────────────────────────────────────
|
| 130 |
+
print(f'[fib] Loading mesh: {body_glb}')
|
| 131 |
+
ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
|
| 132 |
+
mesh_mv = load_mesh(body_glb, rescale=True, device=device)
|
| 133 |
+
|
| 134 |
+
scene_t = trimesh.load(body_glb)
|
| 135 |
+
if isinstance(scene_t, trimesh.Scene):
|
| 136 |
+
geom_name = list(scene_t.geometry.keys())[0]
|
| 137 |
+
mesh_t = scene_t.geometry[geom_name]
|
| 138 |
+
else:
|
| 139 |
+
mesh_t = scene_t; geom_name = None
|
| 140 |
+
|
| 141 |
+
orig_tex_np = np.array(mesh_t.visual.material.baseColorTexture, dtype=np.float32) / 255.0
|
| 142 |
+
uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
|
| 143 |
+
tex_H, tex_W = orig_tex_np.shape[:2]
|
| 144 |
+
print(f' Texture: {tex_W}×{tex_H}')
|
| 145 |
+
|
| 146 |
+
# Build geometry mask (front-face head triangles only) at UV resolution
|
| 147 |
+
print('[fib] Building front-face geometry UV mask ...')
|
| 148 |
+
geom_uv_mask = _build_front_face_uv_mask(mesh_t, uv_size, uv_size)
|
| 149 |
+
|
| 150 |
+
# Render dimensions (match triposg_app.py)
|
| 151 |
+
H_r, W_r = 1024, 768
|
| 152 |
+
|
| 153 |
+
# ── Precompute UV geometry ─────────────────────────────────────────────────
|
| 154 |
+
print(f'[fib] Precomputing UV geometry ({uv_size}×{uv_size}) ...')
|
| 155 |
+
uv_pre = uv_precompute(ctx, mesh_mv, height=uv_size, width=uv_size)
|
| 156 |
+
|
| 157 |
+
# ── Load face swap model + face detector ──────────────────────────────────
|
| 158 |
+
print('[fib] Loading inswapper_128 ...')
|
| 159 |
+
swapper = insightface.model_zoo.get_model(
|
| 160 |
+
INSWAPPER_PATH, download=False,
|
| 161 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
app = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 165 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
| 166 |
+
|
| 167 |
+
ref_bgr = cv2.imread(face_img_path)
|
| 168 |
+
ref_face = _detect_largest_face(ref_bgr, app)
|
| 169 |
+
if ref_face is None:
|
| 170 |
+
raise RuntimeError(f'No face detected in reference: {face_img_path}')
|
| 171 |
+
print(f' Reference face detected: bbox={ref_face.bbox.astype(int).tolist()}')
|
| 172 |
+
|
| 173 |
+
# ── Process each view ─────────────────────────────────────────────────────
|
| 174 |
+
# Views: front (azimuth=-90), slight left (-60), slight right (-120)
|
| 175 |
+
# Azimuth convention from MV-Adapter: -90 = front-facing
|
| 176 |
+
views = [
|
| 177 |
+
('front', -90, 1.0), # (name, azimuth_deg, priority_weight)
|
| 178 |
+
('threequarter_r', -60, 0.7),
|
| 179 |
+
('threequarter_l', -120, 0.7),
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
# Accumulators for weighted UV compositing
|
| 183 |
+
uv_colour_acc = np.zeros((uv_size, uv_size, 3), dtype=np.float32)
|
| 184 |
+
uv_weight_acc = np.zeros((uv_size, uv_size), dtype=np.float32)
|
| 185 |
+
|
| 186 |
+
for view_name, azimuth, weight in views:
|
| 187 |
+
print(f'\n[fib] View: {view_name} (azimuth={azimuth}°)')
|
| 188 |
+
|
| 189 |
+
# Create camera + UV geometry for this view
|
| 190 |
+
camera, uv_geom = _render_view(ctx, mesh_mv, uv_pre, azimuth, H_r, W_r, device)
|
| 191 |
+
|
| 192 |
+
# Render textured mesh from this view
|
| 193 |
+
render_out = render(ctx, mesh_mv, camera, height=H_r, width=W_r,
|
| 194 |
+
render_attr=True, render_depth=False, render_normal=False,
|
| 195 |
+
attr_background=0.0)
|
| 196 |
+
# render_out.attr: (1, H, W, 3) float in [0,1]
|
| 197 |
+
rendered_np = (render_out.attr[0].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 198 |
+
rendered_bgr = cv2.cvtColor(rendered_np, cv2.COLOR_RGB2BGR)
|
| 199 |
+
|
| 200 |
+
if debug_dir:
|
| 201 |
+
cv2.imwrite(os.path.join(debug_dir, f'fib_render_{view_name}.png'), rendered_bgr)
|
| 202 |
+
|
| 203 |
+
# Detect face in this rendered view
|
| 204 |
+
tgt_face = _detect_largest_face(rendered_bgr, app)
|
| 205 |
+
if tgt_face is None:
|
| 206 |
+
print(f' No face in {view_name} render — skipping')
|
| 207 |
+
continue
|
| 208 |
+
print(f' Target face: bbox={tgt_face.bbox.astype(int).tolist()}')
|
| 209 |
+
|
| 210 |
+
# Swap face
|
| 211 |
+
swapped_bgr = swapper.get(rendered_bgr.copy(), tgt_face, ref_face, paste_back=True)
|
| 212 |
+
|
| 213 |
+
# Enhance face detail with GFPGAN
|
| 214 |
+
_, _, enhanced_bgr = enhancer.enhance(
|
| 215 |
+
swapped_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
| 216 |
+
if enhanced_bgr is not None:
|
| 217 |
+
swapped_bgr = enhanced_bgr
|
| 218 |
+
print(f' GFPGAN enhanced')
|
| 219 |
+
|
| 220 |
+
if debug_dir:
|
| 221 |
+
cv2.imwrite(os.path.join(debug_dir, f'fib_swapped_{view_name}.png'), swapped_bgr)
|
| 222 |
+
|
| 223 |
+
swapped_rgb = cv2.cvtColor(swapped_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 224 |
+
|
| 225 |
+
# Build render-space face hull mask
|
| 226 |
+
kps = tgt_face.kps
|
| 227 |
+
hull_pts = cv2.convexHull(kps.astype(np.float32)).squeeze(1)
|
| 228 |
+
hull_cx, hull_cy = hull_pts.mean(axis=0)
|
| 229 |
+
hull_exp = (hull_pts - [hull_cx, hull_cy]) * 3.5 + [hull_cx, hull_cy]
|
| 230 |
+
face_mask = np.zeros((H_r, W_r), dtype=np.float32)
|
| 231 |
+
cv2.fillPoly(face_mask, [hull_exp.astype(np.int32)], 1.0)
|
| 232 |
+
face_mask = cv2.GaussianBlur(face_mask, (61, 61), 20)
|
| 233 |
+
|
| 234 |
+
# Bake swapped render into UV space
|
| 235 |
+
swapped_t = torch.tensor(swapped_rgb, device=device).unsqueeze(0) # (1,H,W,3)
|
| 236 |
+
mask_t = torch.tensor(face_mask[None], device=device)
|
| 237 |
+
|
| 238 |
+
uv_out = uv_render_attr(
|
| 239 |
+
images=swapped_t,
|
| 240 |
+
masks=mask_t,
|
| 241 |
+
uv_render_geometry_output=uv_geom,
|
| 242 |
+
)
|
| 243 |
+
uv_img = uv_out.uv_attr_proj[0].cpu().numpy() # (uv, uv, 3)
|
| 244 |
+
uv_mask = uv_out.uv_mask_proj[0].cpu().numpy() # (uv, uv)
|
| 245 |
+
|
| 246 |
+
# Kill back-of-head UV islands
|
| 247 |
+
uv_mask = uv_mask * geom_uv_mask
|
| 248 |
+
|
| 249 |
+
# Weighted accumulate
|
| 250 |
+
w = uv_mask * weight
|
| 251 |
+
uv_colour_acc += uv_img * w[..., None]
|
| 252 |
+
uv_weight_acc += w
|
| 253 |
+
print(f' Painted texels: {(uv_mask > 0.05).sum()}')
|
| 254 |
+
|
| 255 |
+
# ── Composite ──────────────────────────────────────────────────────────────
|
| 256 |
+
print('\n[fib] Compositing views ...')
|
| 257 |
+
valid = uv_weight_acc > 0.01
|
| 258 |
+
uv_final = np.where(valid[..., None],
|
| 259 |
+
uv_colour_acc / np.maximum(uv_weight_acc[..., None], 1e-6),
|
| 260 |
+
orig_tex_np[:uv_size, :uv_size] if uv_size <= tex_H else orig_tex_np)
|
| 261 |
+
|
| 262 |
+
# Resize to texture resolution if needed
|
| 263 |
+
if uv_size != tex_H or uv_size != tex_W:
|
| 264 |
+
uv_final_rs = cv2.resize(uv_final, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
|
| 265 |
+
weight_rs = cv2.resize(uv_weight_acc, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
|
| 266 |
+
else:
|
| 267 |
+
uv_final_rs = uv_final
|
| 268 |
+
weight_rs = uv_weight_acc
|
| 269 |
+
|
| 270 |
+
# Blend with original texture: use face-swap result where painted, orig elsewhere
|
| 271 |
+
alpha = np.clip(weight_rs, 0, 1)[..., None]
|
| 272 |
+
new_tex = uv_final_rs * alpha + orig_tex_np * (1.0 - alpha)
|
| 273 |
+
print(f' Total painted texels (tex res): {(weight_rs > 0.05).sum()}')
|
| 274 |
+
|
| 275 |
+
if debug_dir:
|
| 276 |
+
Image.fromarray((uv_final_rs * 255).clip(0,255).astype(np.uint8)).save(
|
| 277 |
+
os.path.join(debug_dir, 'fib_uv_composite.png'))
|
| 278 |
+
|
| 279 |
+
# ── Save GLB ──────────────────────────────────────────────────────────────
|
| 280 |
+
new_pil = Image.fromarray((new_tex * 255).clip(0, 255).astype(np.uint8))
|
| 281 |
+
mesh_t.visual = TextureVisuals(uv=uvs, material=PBRMaterial(baseColorTexture=new_pil))
|
| 282 |
+
|
| 283 |
+
if geom_name and isinstance(scene_t, trimesh.Scene):
|
| 284 |
+
scene_t.geometry[geom_name] = mesh_t
|
| 285 |
+
scene_t.export(out_glb)
|
| 286 |
+
else:
|
| 287 |
+
mesh_t.export(out_glb)
|
| 288 |
+
|
| 289 |
+
print(f'[fib] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
|
| 290 |
+
return out_glb
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == '__main__':
|
| 294 |
+
ap = argparse.ArgumentParser()
|
| 295 |
+
ap.add_argument('--body', required=True)
|
| 296 |
+
ap.add_argument('--face', required=True)
|
| 297 |
+
ap.add_argument('--out', required=True)
|
| 298 |
+
ap.add_argument('--uv_size', type=int, default=4096)
|
| 299 |
+
ap.add_argument('--debug_dir', default=None)
|
| 300 |
+
args = ap.parse_args()
|
| 301 |
+
face_inswap_bake(args.body, args.face, args.out,
|
| 302 |
+
uv_size=args.uv_size, debug_dir=args.debug_dir)
|
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
face_project.py — Project reference face image onto TripoSG mesh UV texture.
|
| 3 |
+
|
| 4 |
+
Keeps geometry 100% intact. Paints the face-region UV triangles using
|
| 5 |
+
barycentric rasterization — never interpolates across UV island boundaries.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python face_project.py --body /tmp/triposg_textured.glb \
|
| 9 |
+
--face /tmp/triposg_face_ref.png \
|
| 10 |
+
--out /tmp/face_projected.glb \
|
| 11 |
+
[--blend 0.9] [--neck_frac 0.84] [--debug_tex /tmp/tex.png]
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os, argparse, warnings
|
| 15 |
+
warnings.filterwarnings('ignore')
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import cv2
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import trimesh
|
| 21 |
+
from trimesh.visual.texture import TextureVisuals
|
| 22 |
+
from trimesh.visual.material import PBRMaterial
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Face alignment ─────────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
def _aligned_face_bgr(face_img_bgr, target_size=512):
|
| 28 |
+
"""Detect + align face via InsightFace 5-pt warp; falls back to square crop."""
|
| 29 |
+
try:
|
| 30 |
+
from insightface.app import FaceAnalysis
|
| 31 |
+
from insightface.utils import face_align
|
| 32 |
+
app = FaceAnalysis(providers=['CPUExecutionProvider'])
|
| 33 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
| 34 |
+
faces = app.get(face_img_bgr)
|
| 35 |
+
if faces:
|
| 36 |
+
faces.sort(
|
| 37 |
+
key=lambda f: (f.bbox[2]-f.bbox[0]) * (f.bbox[3]-f.bbox[1]),
|
| 38 |
+
reverse=True)
|
| 39 |
+
aligned = face_align.norm_crop(face_img_bgr, faces[0].kps,
|
| 40 |
+
image_size=target_size)
|
| 41 |
+
print(f' InsightFace aligned: {aligned.shape}')
|
| 42 |
+
return aligned
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f' InsightFace unavailable ({e}), using centre-crop')
|
| 45 |
+
h, w = face_img_bgr.shape[:2]
|
| 46 |
+
side = min(h, w)
|
| 47 |
+
y0, x0 = (h - side) // 2, (w - side) // 2
|
| 48 |
+
return cv2.resize(face_img_bgr[y0:y0+side, x0:x0+side], (target_size, target_size))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ── Triangle rasterizer ────────────────────────────────────────────────────────
|
| 52 |
+
|
| 53 |
+
def _rasterize_triangles(face_tri_uvs_px, face_tri_img_xy,
|
| 54 |
+
face_img_rgb, tex, blend,
|
| 55 |
+
max_uv_span=300):
|
| 56 |
+
"""
|
| 57 |
+
Paint face_img_rgb colour into tex at UV locations, triangle by triangle.
|
| 58 |
+
|
| 59 |
+
face_tri_uvs_px : (M, 3, 2) UV pixel coords of M triangles
|
| 60 |
+
face_tri_img_xy : (M, 3, 2) projected image coords of M triangles
|
| 61 |
+
face_img_rgb : (H, W, 3) reference face image
|
| 62 |
+
tex : (texH, texW, 3) float32 texture (modified in-place)
|
| 63 |
+
blend : float 0–1
|
| 64 |
+
max_uv_span : skip triangles whose UV bounding box exceeds this (UV seams)
|
| 65 |
+
"""
|
| 66 |
+
H_f, W_f = face_img_rgb.shape[:2]
|
| 67 |
+
tex_H, tex_W = tex.shape[:2]
|
| 68 |
+
painted = 0
|
| 69 |
+
|
| 70 |
+
for fi in range(len(face_tri_uvs_px)):
|
| 71 |
+
uv = face_tri_uvs_px[fi] # (3, 2) in texture pixel coords
|
| 72 |
+
img = face_tri_img_xy[fi] # (3, 2) in face-image pixel coords
|
| 73 |
+
|
| 74 |
+
# Skip UV-seam triangles (vertices far apart in UV space)
|
| 75 |
+
if (uv[:, 0].max() - uv[:, 0].min() > max_uv_span or
|
| 76 |
+
uv[:, 1].max() - uv[:, 1].min() > max_uv_span):
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# Bounding box in texture space
|
| 80 |
+
u_lo = max(0, int(uv[:, 0].min()))
|
| 81 |
+
u_hi = min(tex_W, int(uv[:, 0].max()) + 2)
|
| 82 |
+
v_lo = max(0, int(uv[:, 1].min()))
|
| 83 |
+
v_hi = min(tex_H, int(uv[:, 1].max()) + 2)
|
| 84 |
+
if u_hi <= u_lo or v_hi <= v_lo:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
# Grid of texel centres in this bounding box
|
| 88 |
+
gu, gv = np.meshgrid(np.arange(u_lo, u_hi), np.arange(v_lo, v_hi))
|
| 89 |
+
pts = np.column_stack([gu.ravel().astype(np.float32),
|
| 90 |
+
gv.ravel().astype(np.float32)]) # (K, 2)
|
| 91 |
+
|
| 92 |
+
# Barycentric coordinates (in UV pixel space)
|
| 93 |
+
A = uv[0].astype(np.float64)
|
| 94 |
+
AB = (uv[1] - uv[0]).astype(np.float64)
|
| 95 |
+
AC = (uv[2] - uv[0]).astype(np.float64)
|
| 96 |
+
denom = AB[0] * AC[1] - AB[1] * AC[0]
|
| 97 |
+
if abs(denom) < 0.5:
|
| 98 |
+
continue
|
| 99 |
+
P = pts.astype(np.float64) - A
|
| 100 |
+
b1 = (P[:, 0] * AC[1] - P[:, 1] * AC[0]) / denom
|
| 101 |
+
b2 = (P[:, 1] * AB[0] - P[:, 0] * AB[1]) / denom
|
| 102 |
+
b0 = 1.0 - b1 - b2
|
| 103 |
+
|
| 104 |
+
inside = (b0 >= 0) & (b1 >= 0) & (b2 >= 0)
|
| 105 |
+
if not inside.any():
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# Interpolate reference-face image coordinates
|
| 109 |
+
ix_f = (b0[inside] * img[0, 0] +
|
| 110 |
+
b1[inside] * img[1, 0] +
|
| 111 |
+
b2[inside] * img[2, 0])
|
| 112 |
+
iy_f = (b0[inside] * img[0, 1] +
|
| 113 |
+
b1[inside] * img[1, 1] +
|
| 114 |
+
b2[inside] * img[2, 1])
|
| 115 |
+
|
| 116 |
+
valid = ((ix_f >= 0) & (ix_f < W_f) & (iy_f >= 0) & (iy_f < H_f))
|
| 117 |
+
if not valid.any():
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
ix = np.clip(ix_f[valid].astype(int), 0, W_f - 1)
|
| 121 |
+
iy = np.clip(iy_f[valid].astype(int), 0, H_f - 1)
|
| 122 |
+
colours = face_img_rgb[iy, ix].astype(np.float32) # (P, 3)
|
| 123 |
+
|
| 124 |
+
tu = pts[inside][valid, 0].astype(int)
|
| 125 |
+
tv = pts[inside][valid, 1].astype(int)
|
| 126 |
+
in_tex = (tu >= 0) & (tu < tex_W) & (tv >= 0) & (tv < tex_H)
|
| 127 |
+
|
| 128 |
+
tex[tv[in_tex], tu[in_tex]] = (
|
| 129 |
+
blend * colours[in_tex] +
|
| 130 |
+
(1.0 - blend) * tex[tv[in_tex], tu[in_tex]]
|
| 131 |
+
)
|
| 132 |
+
painted += int(in_tex.sum())
|
| 133 |
+
|
| 134 |
+
return painted
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ── Main ───────────────────────────────────────────────────────────────────────
|
| 138 |
+
|
| 139 |
+
def project_face(body_glb, face_img_path, out_glb,
|
| 140 |
+
blend=0.90, neck_frac=0.84, debug_tex=None):
|
| 141 |
+
"""
|
| 142 |
+
Project reference face onto TripoSG UV texture via per-triangle rasterization.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# ── Load mesh ─────────────────────────────────────────────────────────────
|
| 146 |
+
print(f'[face_project] Loading {body_glb}')
|
| 147 |
+
scene = trimesh.load(body_glb)
|
| 148 |
+
if isinstance(scene, trimesh.Scene):
|
| 149 |
+
geom_name = list(scene.geometry.keys())[0]
|
| 150 |
+
mesh = scene.geometry[geom_name]
|
| 151 |
+
else:
|
| 152 |
+
mesh = scene
|
| 153 |
+
geom_name = None
|
| 154 |
+
|
| 155 |
+
verts = np.array(mesh.vertices, dtype=np.float64) # (N, 3)
|
| 156 |
+
faces = np.array(mesh.faces, dtype=np.int32) # (F, 3)
|
| 157 |
+
uvs = np.array(mesh.visual.uv, dtype=np.float64) # (N, 2)
|
| 158 |
+
mat = mesh.visual.material
|
| 159 |
+
orig_tex = np.array(mat.baseColorTexture, dtype=np.float32) # (H, W, 3) RGB
|
| 160 |
+
tex_H, tex_W = orig_tex.shape[:2]
|
| 161 |
+
print(f' {len(verts)} verts | {len(faces)} faces | texture {orig_tex.shape}')
|
| 162 |
+
|
| 163 |
+
# ── Identify face-region vertices ─────────────────────────────────────────
|
| 164 |
+
y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
|
| 165 |
+
neck_y = float(y_min + (y_max - y_min) * neck_frac)
|
| 166 |
+
|
| 167 |
+
head_mask = verts[:, 1] > neck_y
|
| 168 |
+
head_idx = np.where(head_mask)[0]
|
| 169 |
+
hv = verts[head_idx]
|
| 170 |
+
|
| 171 |
+
# Front half only (z >= median — face faces +Z)
|
| 172 |
+
z_med = float(np.median(hv[:, 2]))
|
| 173 |
+
front = hv[:, 2] >= z_med
|
| 174 |
+
if front.sum() < 30:
|
| 175 |
+
front = np.ones(len(hv), bool)
|
| 176 |
+
|
| 177 |
+
face_vert_idx = head_idx[front] # indices into the full vertex array
|
| 178 |
+
|
| 179 |
+
# Build boolean mask for fast triangle selection
|
| 180 |
+
face_vert_mask = np.zeros(len(verts), bool)
|
| 181 |
+
face_vert_mask[face_vert_idx] = True
|
| 182 |
+
|
| 183 |
+
# Select triangles where ALL 3 vertices are in the face region
|
| 184 |
+
face_tri_mask = face_vert_mask[faces].all(axis=1)
|
| 185 |
+
face_tris = faces[face_tri_mask] # (M, 3)
|
| 186 |
+
print(f' neck_y={neck_y:.4f} | head={len(head_idx)} '
|
| 187 |
+
f'| face-front={front.sum()} | face triangles={len(face_tris)}')
|
| 188 |
+
|
| 189 |
+
# ── Load and align reference face ─────────────────────────────────────────
|
| 190 |
+
print(f'[face_project] Reference face: {face_img_path}')
|
| 191 |
+
raw_bgr = cv2.imread(face_img_path)
|
| 192 |
+
aligned_bgr = _aligned_face_bgr(raw_bgr, target_size=512)
|
| 193 |
+
aligned_rgb = cv2.cvtColor(aligned_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 194 |
+
H_f, W_f = aligned_rgb.shape[:2]
|
| 195 |
+
|
| 196 |
+
# ── Compute face projection axes from actual face normal ─────────────────
|
| 197 |
+
fv = verts[face_vert_idx]
|
| 198 |
+
|
| 199 |
+
# Average normal of the front-facing face triangles defines projection dir
|
| 200 |
+
face_tri_normals = np.array(mesh.face_normals)[face_tri_mask]
|
| 201 |
+
face_fwd = face_tri_normals.mean(axis=0)
|
| 202 |
+
face_fwd /= np.linalg.norm(face_fwd)
|
| 203 |
+
|
| 204 |
+
# Build orthonormal right/up axes in the face plane
|
| 205 |
+
world_up = np.array([0., 1., 0.])
|
| 206 |
+
face_right = np.cross(face_fwd, world_up)
|
| 207 |
+
face_right /= np.linalg.norm(face_right)
|
| 208 |
+
face_up = np.cross(face_right, face_fwd)
|
| 209 |
+
face_up /= np.linalg.norm(face_up)
|
| 210 |
+
print(f' Face normal: {face_fwd.round(3)}')
|
| 211 |
+
|
| 212 |
+
# Project face vertices onto local (right, up) plane
|
| 213 |
+
fv_centroid = fv.mean(axis=0)
|
| 214 |
+
fv_c = fv - fv_centroid
|
| 215 |
+
lx = fv_c @ face_right
|
| 216 |
+
ly = fv_c @ face_up
|
| 217 |
+
x_span = float(lx.max() - lx.min())
|
| 218 |
+
y_span = float(ly.max() - ly.min())
|
| 219 |
+
|
| 220 |
+
# InsightFace norm_crop places eyes at ~37% from top of the 512px image.
|
| 221 |
+
# In 3D the eyes are ~78% up from neck → 28% above centroid.
|
| 222 |
+
# Shift the vertical origin up by 0.112*y_span so eye level → 37% in image.
|
| 223 |
+
cy_shift = 0.112 * y_span
|
| 224 |
+
pad = 0.10 # tighter crop so face features fill more of the image
|
| 225 |
+
|
| 226 |
+
def vert_to_img(v):
|
| 227 |
+
"""Project 3D vertex to reference-face image using the face normal."""
|
| 228 |
+
c = v - fv_centroid # (N, 3)
|
| 229 |
+
lx = c @ face_right
|
| 230 |
+
ly = c @ face_up
|
| 231 |
+
pu = lx / (x_span * (1 + 2*pad)) + 0.5
|
| 232 |
+
pv = -(ly - cy_shift) / (y_span * (1 + 2*pad)) + 0.5
|
| 233 |
+
return np.column_stack([pu * W_f, pv * H_f]) # (N, 2)
|
| 234 |
+
|
| 235 |
+
def vert_to_uv_px(v_idx):
|
| 236 |
+
"""Convert vertex UV coords to texture pixel coordinates."""
|
| 237 |
+
uv = uvs[v_idx]
|
| 238 |
+
# trimesh loads GLB UV with (0,0)=bottom-left; flip V for image row
|
| 239 |
+
col = uv[:, 0] * tex_W
|
| 240 |
+
row = (1.0 - uv[:, 1]) * tex_H
|
| 241 |
+
return np.column_stack([col, row]) # (N, 2)
|
| 242 |
+
|
| 243 |
+
# Pre-compute image + UV pixel coords for every vertex
|
| 244 |
+
all_img_px = vert_to_img(verts) # (N, 2)
|
| 245 |
+
all_uv_px = vert_to_uv_px(np.arange(len(verts))) # (N, 2)
|
| 246 |
+
|
| 247 |
+
# Gather per-triangle arrays
|
| 248 |
+
face_tri_uvs_px = all_uv_px[face_tris] # (M, 3, 2)
|
| 249 |
+
face_tri_img_xy = all_img_px[face_tris] # (M, 3, 2)
|
| 250 |
+
|
| 251 |
+
print(f' UV pixel range: u={face_tri_uvs_px[:,:,0].min():.0f}→'
|
| 252 |
+
f'{face_tri_uvs_px[:,:,0].max():.0f} '
|
| 253 |
+
f'v={face_tri_uvs_px[:,:,1].min():.0f}→'
|
| 254 |
+
f'{face_tri_uvs_px[:,:,1].max():.0f}')
|
| 255 |
+
print(f' Image coord range: x={face_tri_img_xy[:,:,0].min():.1f}→'
|
| 256 |
+
f'{face_tri_img_xy[:,:,0].max():.1f} '
|
| 257 |
+
f'y={face_tri_img_xy[:,:,1].min():.1f}→'
|
| 258 |
+
f'{face_tri_img_xy[:,:,1].max():.1f}')
|
| 259 |
+
|
| 260 |
+
# ── Rasterize face triangles into UV texture ──────────────────────────────
|
| 261 |
+
print(f'[face_project] Rasterizing {len(face_tris)} triangles into texture...')
|
| 262 |
+
new_tex = orig_tex.copy()
|
| 263 |
+
painted = _rasterize_triangles(
|
| 264 |
+
face_tri_uvs_px, face_tri_img_xy,
|
| 265 |
+
aligned_rgb, new_tex, blend,
|
| 266 |
+
max_uv_span=300
|
| 267 |
+
)
|
| 268 |
+
print(f' Painted {painted} texels across {len(face_tris)} triangles')
|
| 269 |
+
|
| 270 |
+
# ── Save debug texture if requested ──────────────────────────────────────
|
| 271 |
+
if debug_tex:
|
| 272 |
+
dbg = np.clip(new_tex, 0, 255).astype(np.uint8)
|
| 273 |
+
Image.fromarray(dbg).save(debug_tex)
|
| 274 |
+
print(f' Debug texture: {debug_tex}')
|
| 275 |
+
|
| 276 |
+
# ── Write modified texture back to mesh ───────────────────────────────────
|
| 277 |
+
new_pil = Image.fromarray(np.clip(new_tex, 0, 255).astype(np.uint8))
|
| 278 |
+
new_mat = PBRMaterial(baseColorTexture=new_pil)
|
| 279 |
+
mesh.visual = TextureVisuals(uv=uvs, material=new_mat)
|
| 280 |
+
|
| 281 |
+
os.makedirs(os.path.dirname(os.path.abspath(out_glb)), exist_ok=True)
|
| 282 |
+
if geom_name and isinstance(scene, trimesh.Scene):
|
| 283 |
+
scene.geometry[geom_name] = mesh
|
| 284 |
+
scene.export(out_glb)
|
| 285 |
+
else:
|
| 286 |
+
mesh.export(out_glb)
|
| 287 |
+
|
| 288 |
+
print(f'[face_project] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
|
| 289 |
+
return out_glb
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ── CLI ────────────────────────────────────────────────────────────────────────
|
| 293 |
+
|
| 294 |
+
if __name__ == '__main__':
|
| 295 |
+
ap = argparse.ArgumentParser()
|
| 296 |
+
ap.add_argument('--body', required=True)
|
| 297 |
+
ap.add_argument('--face', required=True)
|
| 298 |
+
ap.add_argument('--out', required=True)
|
| 299 |
+
ap.add_argument('--blend', type=float, default=0.90)
|
| 300 |
+
ap.add_argument('--neck_frac', type=float, default=0.84)
|
| 301 |
+
ap.add_argument('--debug_tex', default=None)
|
| 302 |
+
args = ap.parse_args()
|
| 303 |
+
project_face(args.body, args.face, args.out,
|
| 304 |
+
blend=args.blend, neck_frac=args.neck_frac,
|
| 305 |
+
debug_tex=args.debug_tex)
|
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
face_swap_render.py — Paint reference face onto TripoSG UV texture using
|
| 3 |
+
MV-Adapter's UV-baking pipeline.
|
| 4 |
+
|
| 5 |
+
Pipeline:
|
| 6 |
+
1. Load mesh with same params as triposg_app.py render stage
|
| 7 |
+
2. Create orthographic camera matching render_front.png (azimuth=-90)
|
| 8 |
+
3. Detect face landmarks in render_front.png + reference photo via InsightFace
|
| 9 |
+
4. norm_crop reference → canonical 512×512 frontal face
|
| 10 |
+
5. Estimate 4-DOF similarity (canonical → render) and warpAffine
|
| 11 |
+
→ produces face_on_render.png: reference face at correct render-space coords
|
| 12 |
+
6. uv_render_attr(images=face_on_render) → projects render image into UV space
|
| 13 |
+
No inverse transform, no scale mismatch — the render-space coordinate system
|
| 14 |
+
is shared between the camera projection and the UV lookup.
|
| 15 |
+
7. Blend projected face into original texture with geometry mask guard.
|
| 16 |
+
8. Save updated GLB
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python face_swap_render.py \
|
| 20 |
+
--body /tmp/triposg_textured.glb \
|
| 21 |
+
--face /tmp/triposg_face_ref.png \
|
| 22 |
+
--render /tmp/render_front.png \
|
| 23 |
+
--out /tmp/face_swapped.glb \
|
| 24 |
+
[--blend 0.93] [--uv_size 4096] [--debug_dir /tmp]
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import os, sys, argparse, warnings
|
| 28 |
+
warnings.filterwarnings('ignore')
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import cv2
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from PIL import Image
|
| 35 |
+
import trimesh
|
| 36 |
+
from trimesh.visual.texture import TextureVisuals
|
| 37 |
+
from trimesh.visual.material import PBRMaterial
|
| 38 |
+
from insightface.utils import face_align as insightface_align
|
| 39 |
+
|
| 40 |
+
sys.path.insert(0, '/root/MV-Adapter')
|
| 41 |
+
from mvadapter.utils.mesh_utils import (
|
| 42 |
+
NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera,
|
| 43 |
+
)
|
| 44 |
+
from mvadapter.utils.mesh_utils.uv import (
|
| 45 |
+
uv_precompute, uv_render_geometry, uv_render_attr,
|
| 46 |
+
)
|
| 47 |
+
from insightface.app import FaceAnalysis
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _detect_largest_face(img_bgr, app):
|
| 51 |
+
faces = app.get(img_bgr)
|
| 52 |
+
if not faces:
|
| 53 |
+
return None
|
| 54 |
+
faces.sort(key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]), reverse=True)
|
| 55 |
+
return faces[0]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac=0.84):
|
| 59 |
+
"""
|
| 60 |
+
Build a UV-space mask covering only the front-facing face triangles.
|
| 61 |
+
Excludes back-of-head, hair, and ears (lateral vertices).
|
| 62 |
+
"""
|
| 63 |
+
verts = np.array(mesh_t.vertices, dtype=np.float64)
|
| 64 |
+
faces = np.array(mesh_t.faces, dtype=np.int32)
|
| 65 |
+
uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
|
| 66 |
+
|
| 67 |
+
# Head vertices above neck
|
| 68 |
+
y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
|
| 69 |
+
neck_y = float(y_min + (y_max - y_min) * neck_frac)
|
| 70 |
+
head_idx = np.where(verts[:, 1] > neck_y)[0]
|
| 71 |
+
hv = verts[head_idx]
|
| 72 |
+
|
| 73 |
+
# Front half: z >= 40th percentile — generous to include jaw/cheek toward ears
|
| 74 |
+
# No lateral exclusion — it splits UV islands through the eyes/mouth → duplicates
|
| 75 |
+
z_thresh = float(np.percentile(hv[:, 2], 40))
|
| 76 |
+
front = hv[:, 2] >= z_thresh
|
| 77 |
+
if front.sum() < 30:
|
| 78 |
+
front = np.ones(len(hv), bool)
|
| 79 |
+
|
| 80 |
+
face_vert_idx = head_idx[front]
|
| 81 |
+
face_vert_mask = np.zeros(len(verts), bool)
|
| 82 |
+
face_vert_mask[face_vert_idx] = True
|
| 83 |
+
|
| 84 |
+
face_tri_mask = face_vert_mask[faces].all(axis=1)
|
| 85 |
+
face_tris = faces[face_tri_mask]
|
| 86 |
+
print(f' Geometry mask: {face_tri_mask.sum()} front-face triangles selected '
|
| 87 |
+
f'(neck_y={neck_y:.3f}, z_thresh={z_thresh:.3f})')
|
| 88 |
+
|
| 89 |
+
# Rasterize into UV-space mask (trimesh UV: y=0 is bottom-left → flip V)
|
| 90 |
+
geom_mask = np.zeros((tex_H, tex_W), dtype=np.float32)
|
| 91 |
+
pts_list = []
|
| 92 |
+
for tri in face_tris:
|
| 93 |
+
uv = uvs[tri] # (3, 2)
|
| 94 |
+
px = uv[:, 0] * tex_W
|
| 95 |
+
py = (1.0 - uv[:, 1]) * tex_H
|
| 96 |
+
pts_list.append(np.column_stack([px, py]).astype(np.int32))
|
| 97 |
+
if pts_list:
|
| 98 |
+
cv2.fillPoly(geom_mask, pts_list, 1.0)
|
| 99 |
+
|
| 100 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 101 |
+
geom_mask = cv2.dilate(geom_mask, kernel, iterations=2) # close intra-tri gaps
|
| 102 |
+
geom_mask = cv2.erode(geom_mask, kernel, iterations=1) # retreat from island edges
|
| 103 |
+
geom_mask = cv2.GaussianBlur(geom_mask, (31, 31), 8) # soft transition
|
| 104 |
+
return geom_mask
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def face_swap_render(body_glb, face_img_path, render_img_path, out_glb,
|
| 108 |
+
blend=0.93, uv_size=4096, neck_frac=0.76, debug_dir=None):
|
| 109 |
+
|
| 110 |
+
device = 'cuda'
|
| 111 |
+
|
| 112 |
+
# ── Step 1: Load mesh ─────────────────────────────────────────────────────
|
| 113 |
+
print(f'[fsr] Loading mesh: {body_glb}')
|
| 114 |
+
ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
|
| 115 |
+
mesh_mv = load_mesh(body_glb, rescale=True, device=device)
|
| 116 |
+
|
| 117 |
+
scene_t = trimesh.load(body_glb)
|
| 118 |
+
if isinstance(scene_t, trimesh.Scene):
|
| 119 |
+
geom_name = list(scene_t.geometry.keys())[0]
|
| 120 |
+
mesh_t = scene_t.geometry[geom_name]
|
| 121 |
+
else:
|
| 122 |
+
mesh_t = scene_t; geom_name = None
|
| 123 |
+
|
| 124 |
+
orig_tex = np.array(mesh_t.visual.material.baseColorTexture, dtype=np.float32) / 255.0
|
| 125 |
+
uvs = np.array(mesh_t.visual.uv, dtype=np.float64)
|
| 126 |
+
tex_H, tex_W = orig_tex.shape[:2]
|
| 127 |
+
print(f' UV size: {tex_W}×{tex_H}')
|
| 128 |
+
|
| 129 |
+
# ── Step 1b: Geometry mask (front-face UV islands only) ───────────────────
|
| 130 |
+
print('[fsr] Building geometry front-face UV mask ...')
|
| 131 |
+
geom_uv_mask = _build_front_face_uv_mask(mesh_t, tex_H, tex_W, neck_frac)
|
| 132 |
+
|
| 133 |
+
# ── Step 2: Orthographic camera matching render_front.png ─────────────────
|
| 134 |
+
render_img = cv2.imread(render_img_path)
|
| 135 |
+
H_r, W_r = render_img.shape[:2]
|
| 136 |
+
print(f' Render size: {W_r}×{H_r}')
|
| 137 |
+
|
| 138 |
+
camera = get_orthogonal_camera(
|
| 139 |
+
elevation_deg=[0], distance=[1.8],
|
| 140 |
+
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
|
| 141 |
+
azimuth_deg=[-90], device=device,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
print(f'[fsr] Precomputing UV geometry ({uv_size}×{uv_size}) ...')
|
| 145 |
+
uv_pre = uv_precompute(ctx, mesh_mv, height=uv_size, width=uv_size)
|
| 146 |
+
uv_geom = uv_render_geometry(
|
| 147 |
+
ctx, mesh_mv, camera,
|
| 148 |
+
view_height=H_r, view_width=W_r,
|
| 149 |
+
uv_precompute_output=uv_pre,
|
| 150 |
+
compute_depth_grad=False,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# ── Step 3: Face landmark detection ───────────────────────────────────────
|
| 154 |
+
print('[fsr] Detecting face landmarks ...')
|
| 155 |
+
app = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 156 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
| 157 |
+
|
| 158 |
+
ref_bgr = cv2.imread(face_img_path)
|
| 159 |
+
render_face = _detect_largest_face(render_img, app)
|
| 160 |
+
if render_face is None:
|
| 161 |
+
raise RuntimeError(f'No face detected in render: {render_img_path}')
|
| 162 |
+
ref_face = _detect_largest_face(ref_bgr, app)
|
| 163 |
+
if ref_face is None:
|
| 164 |
+
raise RuntimeError(f'No face detected in reference: {face_img_path}')
|
| 165 |
+
|
| 166 |
+
render_kps = render_face.kps # (5, 2)
|
| 167 |
+
ref_kps = ref_face.kps
|
| 168 |
+
print(f' render kps: x={render_kps[:,0].min():.0f}-{render_kps[:,0].max():.0f}'
|
| 169 |
+
f' y={render_kps[:,1].min():.0f}-{render_kps[:,1].max():.0f}')
|
| 170 |
+
|
| 171 |
+
# ── Step 4: norm_crop → canonical 512×512 frontal face ───────────────────
|
| 172 |
+
CANONICAL_SIZE = 512
|
| 173 |
+
aligned_bgr = insightface_align.norm_crop(ref_bgr, ref_kps, image_size=CANONICAL_SIZE)
|
| 174 |
+
|
| 175 |
+
# Fixed ARCFACE 5-point positions scaled to CANONICAL_SIZE
|
| 176 |
+
ARCFACE_112 = np.array([
|
| 177 |
+
[38.2946, 51.6963],
|
| 178 |
+
[73.5318, 51.5014],
|
| 179 |
+
[56.0252, 71.7366],
|
| 180 |
+
[41.5493, 92.3655],
|
| 181 |
+
[70.7299, 92.2041],
|
| 182 |
+
], dtype=np.float32)
|
| 183 |
+
canonical_kps = ARCFACE_112 * (CANONICAL_SIZE / 112.0)
|
| 184 |
+
|
| 185 |
+
# ── Step 5: Forward warp: canonical → render space ────────────────────────
|
| 186 |
+
# 4-DOF similarity (scale + rotation + translation) with all 5 kps.
|
| 187 |
+
# FORWARD direction: canonical_kps → render_kps so that warpAffine places
|
| 188 |
+
# the face at exactly the render-space coordinates, downsampling cleanly.
|
| 189 |
+
fwd_M, inliers = cv2.estimateAffinePartial2D(
|
| 190 |
+
canonical_kps.astype(np.float32),
|
| 191 |
+
render_kps.astype(np.float32),
|
| 192 |
+
method=cv2.LMEDS,
|
| 193 |
+
)
|
| 194 |
+
print(f' Forward warp M:\n{fwd_M}')
|
| 195 |
+
|
| 196 |
+
face_on_render_bgr = cv2.warpAffine(
|
| 197 |
+
aligned_bgr, fwd_M, (W_r, H_r),
|
| 198 |
+
flags=cv2.INTER_LANCZOS4,
|
| 199 |
+
borderMode=cv2.BORDER_CONSTANT, borderValue=0,
|
| 200 |
+
)
|
| 201 |
+
face_on_render_rgb = cv2.cvtColor(face_on_render_bgr,
|
| 202 |
+
cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 203 |
+
|
| 204 |
+
# ── Step 6: Render-space face hull mask ───────────────────────────────────
|
| 205 |
+
# Only paint UV texels that correspond to pixels inside the face region.
|
| 206 |
+
hull_pts = cv2.convexHull(render_kps.astype(np.float32)).squeeze(1)
|
| 207 |
+
hull_cx, hull_cy = hull_pts.mean(axis=0)
|
| 208 |
+
hull_expanded = (hull_pts - [hull_cx, hull_cy]) * 4.0 + [hull_cx, hull_cy]
|
| 209 |
+
face_mask_render = np.zeros((H_r, W_r), dtype=np.float32)
|
| 210 |
+
cv2.fillPoly(face_mask_render, [hull_expanded.astype(np.int32)], 1.0)
|
| 211 |
+
# Restrict to where the warped face actually has content
|
| 212 |
+
face_content = (face_on_render_bgr.mean(axis=2) > 3.0 / 255.0).astype(np.float32)
|
| 213 |
+
face_mask_render = face_mask_render * face_content
|
| 214 |
+
face_mask_render = cv2.GaussianBlur(face_mask_render, (51, 51), 15)
|
| 215 |
+
|
| 216 |
+
# ── Step 7: Project face-on-render into UV space ──────────────────────────
|
| 217 |
+
# uv_render_attr uses uv_pos_ndc as a lookup: for each UV texel, sample the
|
| 218 |
+
# render-space image at that texel's render NDC position.
|
| 219 |
+
# Since face_on_render is already in render-space coords, this is exact.
|
| 220 |
+
print('[fsr] Projecting face into UV space via uv_render_attr ...')
|
| 221 |
+
face_t = torch.tensor(face_on_render_rgb, device=device).unsqueeze(0) # (1,H,W,3)
|
| 222 |
+
mask_t = torch.tensor(face_mask_render[None], device=device)
|
| 223 |
+
|
| 224 |
+
uv_attr_out = uv_render_attr(
|
| 225 |
+
images=face_t,
|
| 226 |
+
masks=mask_t,
|
| 227 |
+
uv_render_geometry_output=uv_geom,
|
| 228 |
+
)
|
| 229 |
+
uv_face_img = uv_attr_out.uv_attr_proj[0].cpu().numpy() # (uv, uv, 3)
|
| 230 |
+
uv_face_mask = uv_attr_out.uv_mask_proj[0].cpu().numpy() # (uv, uv)
|
| 231 |
+
|
| 232 |
+
# Rescale to tex resolution if needed
|
| 233 |
+
if uv_size != tex_H or uv_size != tex_W:
|
| 234 |
+
uv_face_img_rs = cv2.resize(uv_face_img, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
|
| 235 |
+
uv_face_mask_rs = cv2.resize(uv_face_mask, (tex_W, tex_H), interpolation=cv2.INTER_LINEAR)
|
| 236 |
+
else:
|
| 237 |
+
uv_face_img_rs = uv_face_img
|
| 238 |
+
uv_face_mask_rs = uv_face_mask
|
| 239 |
+
|
| 240 |
+
# ── Step 7b: Apply geometry mask — kill back-of-head / ear UV islands ────
|
| 241 |
+
uv_face_mask_rs = uv_face_mask_rs * geom_uv_mask
|
| 242 |
+
|
| 243 |
+
# Final blend alpha — use full blend=1.0 inside the face region so no
|
| 244 |
+
# original texture leaks through and creates duplicate features
|
| 245 |
+
alpha = np.clip(uv_face_mask_rs, 0, 1)[..., None]
|
| 246 |
+
painted_px = int((alpha[..., 0] > 0.01).sum())
|
| 247 |
+
print(f' Painted texels: {painted_px}')
|
| 248 |
+
|
| 249 |
+
if debug_dir:
|
| 250 |
+
cv2.imwrite(os.path.join(debug_dir, 'fsr_aligned_ref.png'), aligned_bgr)
|
| 251 |
+
cv2.imwrite(os.path.join(debug_dir, 'fsr_face_on_render.png'), face_on_render_bgr)
|
| 252 |
+
cv2.imwrite(os.path.join(debug_dir, 'fsr_face_mask_render.png'),
|
| 253 |
+
(face_mask_render * 255).astype(np.uint8))
|
| 254 |
+
cv2.imwrite(os.path.join(debug_dir, 'fsr_geom_mask.png'),
|
| 255 |
+
(geom_uv_mask * 255).astype(np.uint8))
|
| 256 |
+
cv2.imwrite(os.path.join(debug_dir, 'fsr_uv_mask.png'),
|
| 257 |
+
(uv_face_mask_rs * 255).astype(np.uint8))
|
| 258 |
+
Image.fromarray((uv_face_img_rs * 255).clip(0, 255).astype(np.uint8)).save(
|
| 259 |
+
os.path.join(debug_dir, 'fsr_uv_face.png'))
|
| 260 |
+
print(f' Debug files saved to {debug_dir}')
|
| 261 |
+
|
| 262 |
+
# ── Step 8: Blend into original texture ───────────────────────────────────
|
| 263 |
+
print(f'[fsr] Blending (blend={blend}) ...')
|
| 264 |
+
new_tex = uv_face_img_rs * alpha + orig_tex * (1.0 - alpha)
|
| 265 |
+
|
| 266 |
+
# ── Step 9: Save GLB ──────────────────────────────────────────────────────
|
| 267 |
+
new_pil = Image.fromarray((new_tex * 255).clip(0, 255).astype(np.uint8))
|
| 268 |
+
mesh_t.visual = TextureVisuals(uv=uvs, material=PBRMaterial(baseColorTexture=new_pil))
|
| 269 |
+
|
| 270 |
+
if geom_name and isinstance(scene_t, trimesh.Scene):
|
| 271 |
+
scene_t.geometry[geom_name] = mesh_t
|
| 272 |
+
scene_t.export(out_glb)
|
| 273 |
+
else:
|
| 274 |
+
mesh_t.export(out_glb)
|
| 275 |
+
|
| 276 |
+
print(f'[fsr] Saved: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
|
| 277 |
+
return out_glb
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == '__main__':
|
| 281 |
+
ap = argparse.ArgumentParser()
|
| 282 |
+
ap.add_argument('--body', required=True)
|
| 283 |
+
ap.add_argument('--face', required=True)
|
| 284 |
+
ap.add_argument('--render', required=True, help='Front render (e.g. render_front.png)')
|
| 285 |
+
ap.add_argument('--out', required=True)
|
| 286 |
+
ap.add_argument('--blend', type=float, default=0.93)
|
| 287 |
+
ap.add_argument('--uv_size', type=int, default=4096)
|
| 288 |
+
ap.add_argument('--neck_frac', type=float, default=0.76)
|
| 289 |
+
ap.add_argument('--debug_dir', default=None)
|
| 290 |
+
args = ap.parse_args()
|
| 291 |
+
face_swap_render(args.body, args.face, args.render, args.out,
|
| 292 |
+
blend=args.blend, uv_size=args.uv_size,
|
| 293 |
+
neck_frac=args.neck_frac, debug_dir=args.debug_dir)
|
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
face_transplant.py
|
| 3 |
+
==================
|
| 4 |
+
Replace the face/head region of a rigged UniRig GLB with a higher-detail
|
| 5 |
+
PSHuman mesh, while preserving the skeleton, rig, and skinning weights.
|
| 6 |
+
|
| 7 |
+
Algorithm
|
| 8 |
+
---------
|
| 9 |
+
1. Parse rigged GLB → vertices, faces, UVs, JOINTS_0, WEIGHTS_0, bone list
|
| 10 |
+
2. Identify head vertices → any vert whose dominant bone is in HEAD_BONES
|
| 11 |
+
3. Load PSHuman mesh (OBJ or GLB, no rig)
|
| 12 |
+
4. Align PSHuman head to UniRig head bounding box (scale + translate)
|
| 13 |
+
5. Transfer skinning weights to PSHuman verts via K-nearest-neighbour from
|
| 14 |
+
UniRig head verts (scipy KDTree, weighted average)
|
| 15 |
+
6. Retract UniRig face verts slightly inward so PSHuman sits on top cleanly
|
| 16 |
+
7. Rebuild the GLB with two mesh primitives:
|
| 17 |
+
- Primitive 0 : UniRig body (face verts retracted)
|
| 18 |
+
- Primitive 1 : PSHuman face (new, with transferred weights)
|
| 19 |
+
8. Write output GLB
|
| 20 |
+
|
| 21 |
+
Usage
|
| 22 |
+
-----
|
| 23 |
+
python -m pipeline.face_transplant \\
|
| 24 |
+
--body rigged_body.glb \\
|
| 25 |
+
--face pshuman_output.obj \\
|
| 26 |
+
--output rigged_body_with_pshuman_face.glb
|
| 27 |
+
|
| 28 |
+
Optionally supply --head-bones as comma-separated bone-name substrings
|
| 29 |
+
(default: head,Head,skull). Any bone whose name contains one of these
|
| 30 |
+
substrings is treated as a head bone.
|
| 31 |
+
|
| 32 |
+
Requires: pygltflib numpy scipy trimesh (pip install each)
|
| 33 |
+
"""
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import base64
|
| 38 |
+
import struct
|
| 39 |
+
import json
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Dict, List, Optional, Tuple
|
| 42 |
+
import numpy as np
|
| 43 |
+
from scipy.spatial import KDTree
|
| 44 |
+
import trimesh
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# GLB low-level helpers (subset of Retarget/io/gltf_io.py re-used here)
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
import pygltflib
|
| 53 |
+
except ImportError:
|
| 54 |
+
raise ImportError("pip install pygltflib")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _read_accessor_raw(gltf: pygltflib.GLTF2, accessor_idx: int) -> np.ndarray:
|
| 58 |
+
acc = gltf.accessors[accessor_idx]
|
| 59 |
+
bv = gltf.bufferViews[acc.bufferView]
|
| 60 |
+
buf = gltf.buffers[bv.buffer]
|
| 61 |
+
|
| 62 |
+
if buf.uri and buf.uri.startswith("data:"):
|
| 63 |
+
_, b64 = buf.uri.split(",", 1)
|
| 64 |
+
raw = base64.b64decode(b64)
|
| 65 |
+
elif buf.uri:
|
| 66 |
+
base_dir = Path(gltf._path).parent if getattr(gltf, "_path", None) else Path(".")
|
| 67 |
+
raw = (base_dir / buf.uri).read_bytes()
|
| 68 |
+
else:
|
| 69 |
+
raw = bytes(gltf.binary_blob())
|
| 70 |
+
|
| 71 |
+
type_nc = {"SCALAR": 1, "VEC2": 2, "VEC3": 3, "VEC4": 4, "MAT4": 16}
|
| 72 |
+
fmt_map = {5120: "b", 5121: "B", 5122: "h", 5123: "H", 5125: "I", 5126: "f"}
|
| 73 |
+
|
| 74 |
+
n_comp = type_nc[acc.type]
|
| 75 |
+
fmt = fmt_map[acc.componentType]
|
| 76 |
+
item_sz = struct.calcsize(fmt) * n_comp
|
| 77 |
+
stride = bv.byteStride or item_sz
|
| 78 |
+
start = bv.byteOffset + (acc.byteOffset or 0)
|
| 79 |
+
|
| 80 |
+
items = []
|
| 81 |
+
for i in range(acc.count):
|
| 82 |
+
offset = start + i * stride
|
| 83 |
+
vals = struct.unpack_from(f"{n_comp}{fmt}", raw, offset)
|
| 84 |
+
items.append(vals)
|
| 85 |
+
|
| 86 |
+
arr = np.array(items)
|
| 87 |
+
if arr.ndim == 2 and arr.shape[1] == 1:
|
| 88 |
+
arr = arr[:, 0]
|
| 89 |
+
return arr
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _accessor_dtype(gltf: pygltflib.GLTF2, accessor_idx: int):
|
| 93 |
+
fmt_map = {5120: np.int8, 5121: np.uint8, 5122: np.int16, 5123: np.uint16,
|
| 94 |
+
5125: np.uint32, 5126: np.float32}
|
| 95 |
+
return fmt_map[gltf.accessors[accessor_idx].componentType]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
# Mesh extraction
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
class GLBMesh:
|
| 103 |
+
"""
|
| 104 |
+
All data from the first skin's first mesh primitive in a GLB.
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self, path: str):
|
| 107 |
+
self.path = path
|
| 108 |
+
gltf = pygltflib.GLTF2().load(path)
|
| 109 |
+
gltf._path = path
|
| 110 |
+
self.gltf = gltf
|
| 111 |
+
|
| 112 |
+
if not gltf.skins:
|
| 113 |
+
raise ValueError("No skin found in GLB — is this a rigged file?")
|
| 114 |
+
self.skin = gltf.skins[0]
|
| 115 |
+
self.joint_names: List[str] = [gltf.nodes[j].name or f"joint_{k}"
|
| 116 |
+
for k, j in enumerate(self.skin.joints)]
|
| 117 |
+
|
| 118 |
+
# Find a mesh node that uses this skin
|
| 119 |
+
self.mesh_prim, self.mesh_node_idx = self._find_skinned_prim()
|
| 120 |
+
|
| 121 |
+
attrs = self.mesh_prim.attributes
|
| 122 |
+
self.verts = _read_accessor_raw(gltf, attrs.POSITION).astype(np.float32)
|
| 123 |
+
self.normals = (_read_accessor_raw(gltf, attrs.NORMAL).astype(np.float32)
|
| 124 |
+
if attrs.NORMAL is not None else None)
|
| 125 |
+
self.uvs = (_read_accessor_raw(gltf, attrs.TEXCOORD_0).astype(np.float32)
|
| 126 |
+
if attrs.TEXCOORD_0 is not None else None)
|
| 127 |
+
self.faces = _read_accessor_raw(gltf, self.mesh_prim.indices).astype(np.int32).reshape(-1, 3)
|
| 128 |
+
|
| 129 |
+
# Skinning — may be JOINTS_0 / WEIGHTS_0 (uint8/uint16 + float)
|
| 130 |
+
self.joints4 = None
|
| 131 |
+
self.weights4 = None
|
| 132 |
+
if attrs.JOINTS_0 is not None:
|
| 133 |
+
self.joints4 = _read_accessor_raw(gltf, attrs.JOINTS_0).astype(np.int32)
|
| 134 |
+
self.weights4 = _read_accessor_raw(gltf, attrs.WEIGHTS_0).astype(np.float32)
|
| 135 |
+
|
| 136 |
+
# Material index (carry over to output)
|
| 137 |
+
self.material_idx = self.mesh_prim.material
|
| 138 |
+
|
| 139 |
+
def _find_skinned_prim(self):
|
| 140 |
+
skin_node_indices = set(self.skin.joints)
|
| 141 |
+
# find mesh node that references this skin
|
| 142 |
+
for ni, node in enumerate(self.gltf.nodes):
|
| 143 |
+
if node.skin == 0 and node.mesh is not None:
|
| 144 |
+
mesh = self.gltf.meshes[node.mesh]
|
| 145 |
+
return mesh.primitives[0], ni
|
| 146 |
+
# fallback: first mesh node
|
| 147 |
+
for ni, node in enumerate(self.gltf.nodes):
|
| 148 |
+
if node.mesh is not None:
|
| 149 |
+
mesh = self.gltf.meshes[node.mesh]
|
| 150 |
+
return mesh.primitives[0], ni
|
| 151 |
+
raise ValueError("No mesh primitive found")
|
| 152 |
+
|
| 153 |
+
def head_bone_indices(self, substrings=("head", "Head", "skull", "Skull", "neck", "Neck")) -> List[int]:
|
| 154 |
+
"""Return joint indices (into self.joint_names) matching any substring.
|
| 155 |
+
Falls back to positional heuristic (highest-Y dominant bone) when no
|
| 156 |
+
bone names match (e.g. generic bone_0/bone_1 naming from UniRig)."""
|
| 157 |
+
result = []
|
| 158 |
+
for i, name in enumerate(self.joint_names):
|
| 159 |
+
if any(s in name for s in substrings):
|
| 160 |
+
result.append(i)
|
| 161 |
+
if not result and self.joints4 is not None and self.weights4 is not None:
|
| 162 |
+
# Positional fallback: pick bone whose dominant vertices have highest avg Y.
|
| 163 |
+
n_bones = len(self.joint_names)
|
| 164 |
+
bone_y_sum = np.zeros(n_bones)
|
| 165 |
+
bone_y_cnt = np.zeros(n_bones, dtype=np.int32)
|
| 166 |
+
for vi in range(len(self.verts)):
|
| 167 |
+
dom = int(self.joints4[vi, np.argmax(self.weights4[vi])])
|
| 168 |
+
bone_y_sum[dom] += self.verts[vi, 1]
|
| 169 |
+
bone_y_cnt[dom] += 1
|
| 170 |
+
with np.errstate(invalid='ignore'):
|
| 171 |
+
bone_y_avg = np.where(bone_y_cnt > 0, bone_y_sum / bone_y_cnt, -np.inf)
|
| 172 |
+
top = int(np.argmax(bone_y_avg))
|
| 173 |
+
print(f"[face_transplant] No named head bones; positional fallback: "
|
| 174 |
+
f"bone {top} ({self.joint_names[top]}, avg_y={bone_y_avg[top]:.3f})")
|
| 175 |
+
result = [top]
|
| 176 |
+
return result
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
# Face-region identification
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
def find_face_verts(glb_mesh: GLBMesh, head_joint_indices: List[int],
|
| 184 |
+
weight_threshold: float = 0.35) -> np.ndarray:
|
| 185 |
+
"""
|
| 186 |
+
Return boolean mask of face/head vertices:
|
| 187 |
+
any vert whose total weight on head joints exceeds weight_threshold.
|
| 188 |
+
"""
|
| 189 |
+
if glb_mesh.joints4 is None:
|
| 190 |
+
raise ValueError("Mesh has no skinning weights — cannot identify face region")
|
| 191 |
+
|
| 192 |
+
n = len(glb_mesh.verts)
|
| 193 |
+
mask = np.zeros(n, dtype=bool)
|
| 194 |
+
head_set = set(head_joint_indices)
|
| 195 |
+
|
| 196 |
+
for vi in range(n):
|
| 197 |
+
total_head_w = 0.0
|
| 198 |
+
for c in range(4):
|
| 199 |
+
j = glb_mesh.joints4[vi, c]
|
| 200 |
+
w = glb_mesh.weights4[vi, c]
|
| 201 |
+
if j in head_set:
|
| 202 |
+
total_head_w += w
|
| 203 |
+
if total_head_w >= weight_threshold:
|
| 204 |
+
mask[vi] = True
|
| 205 |
+
|
| 206 |
+
return mask
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
# PSHuman mesh loading + alignment
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
def _crop_to_head(mesh: trimesh.Trimesh, head_fraction: float = 0.22) -> trimesh.Trimesh:
|
| 214 |
+
"""
|
| 215 |
+
Keep only the top head_fraction of the PSHuman body mesh by Y coordinate.
|
| 216 |
+
PSHuman produces a full-body mesh; we only want the head/face portion.
|
| 217 |
+
"""
|
| 218 |
+
y = mesh.vertices[:, 1]
|
| 219 |
+
threshold = y.max() - (y.max() - y.min()) * head_fraction
|
| 220 |
+
vert_keep = y >= threshold
|
| 221 |
+
face_keep = vert_keep[mesh.faces].all(axis=1)
|
| 222 |
+
kept_faces = mesh.faces[face_keep]
|
| 223 |
+
used = np.unique(kept_faces)
|
| 224 |
+
remap = np.full(len(mesh.vertices), -1, dtype=np.int32)
|
| 225 |
+
remap[used] = np.arange(len(used))
|
| 226 |
+
new_verts = mesh.vertices[used].astype(np.float32)
|
| 227 |
+
new_faces = remap[kept_faces]
|
| 228 |
+
result = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
|
| 229 |
+
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
|
| 230 |
+
result.visual = trimesh.visual.TextureVisuals(uv=np.array(mesh.visual.uv)[used])
|
| 231 |
+
print(f"[face_transplant] PSHuman head crop ({head_fraction*100:.0f}%): "
|
| 232 |
+
f"{len(mesh.vertices)} → {len(new_verts)} verts (Y ≥ {threshold:.3f})")
|
| 233 |
+
return result
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_and_align_pshuman(pshuman_path: str, target_verts: np.ndarray) -> trimesh.Trimesh:
|
| 237 |
+
"""
|
| 238 |
+
Load PSHuman mesh (OBJ/GLB/PLY), crop to head region, then scale+translate
|
| 239 |
+
to fit the bounding box of target_verts (UniRig head verts).
|
| 240 |
+
"""
|
| 241 |
+
mesh: trimesh.Trimesh = trimesh.load(pshuman_path, force="mesh", process=False)
|
| 242 |
+
print(f"[face_transplant] PSHuman mesh: {len(mesh.vertices)} verts, {len(mesh.faces)} faces")
|
| 243 |
+
|
| 244 |
+
# PSHuman is full-body — crop to just the head before aligning
|
| 245 |
+
mesh = _crop_to_head(mesh)
|
| 246 |
+
|
| 247 |
+
# Target bbox from UniRig head region
|
| 248 |
+
tgt_min = target_verts.min(axis=0)
|
| 249 |
+
tgt_max = target_verts.max(axis=0)
|
| 250 |
+
tgt_ctr = (tgt_min + tgt_max) * 0.5
|
| 251 |
+
tgt_ext = (tgt_max - tgt_min)
|
| 252 |
+
|
| 253 |
+
src_min = mesh.vertices.min(axis=0).astype(np.float32)
|
| 254 |
+
src_max = mesh.vertices.max(axis=0).astype(np.float32)
|
| 255 |
+
src_ctr = (src_min + src_max) * 0.5
|
| 256 |
+
src_ext = (src_max - src_min)
|
| 257 |
+
|
| 258 |
+
# Uniform scale: match the largest axis of the target
|
| 259 |
+
dominant = np.argmax(tgt_ext)
|
| 260 |
+
scale = float(tgt_ext[dominant]) / float(src_ext[dominant] + 1e-9)
|
| 261 |
+
|
| 262 |
+
verts = mesh.vertices.astype(np.float32).copy()
|
| 263 |
+
verts = (verts - src_ctr) * scale + tgt_ctr
|
| 264 |
+
|
| 265 |
+
mesh.vertices = verts
|
| 266 |
+
print(f"[face_transplant] PSHuman aligned: scale={scale:.4f}, center={tgt_ctr}")
|
| 267 |
+
return mesh
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
# Weight transfer via KDTree
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
|
| 274 |
+
def transfer_weights(
|
| 275 |
+
donor_verts: np.ndarray, # (M, 3) UniRig face verts
|
| 276 |
+
donor_joints: np.ndarray, # (M, 4) uint16
|
| 277 |
+
donor_weights: np.ndarray, # (M, 4) float32
|
| 278 |
+
recipient_verts: np.ndarray, # (N, 3) PSHuman face verts
|
| 279 |
+
k: int = 5,
|
| 280 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 281 |
+
"""
|
| 282 |
+
K-nearest-neighbour weight transfer.
|
| 283 |
+
Returns (joints4, weights4) for recipient_verts.
|
| 284 |
+
"""
|
| 285 |
+
tree = KDTree(donor_verts)
|
| 286 |
+
dists, idxs = tree.query(recipient_verts, k=k) # (N, k)
|
| 287 |
+
|
| 288 |
+
N = len(recipient_verts)
|
| 289 |
+
n_joints_total = int(donor_joints.max()) + 1
|
| 290 |
+
|
| 291 |
+
# Build dense per-recipient weight vector
|
| 292 |
+
dense = np.zeros((N, n_joints_total), dtype=np.float64)
|
| 293 |
+
for ki in range(k):
|
| 294 |
+
w_dist = 1.0 / (dists[:, ki] + 1e-8) # inverse-distance
|
| 295 |
+
for vi in range(N):
|
| 296 |
+
di = idxs[vi, ki]
|
| 297 |
+
for c in range(4):
|
| 298 |
+
j = donor_joints[di, c]
|
| 299 |
+
w = donor_weights[di, c]
|
| 300 |
+
dense[vi, j] += w * w_dist[vi]
|
| 301 |
+
|
| 302 |
+
# Re-normalise rows
|
| 303 |
+
row_sum = dense.sum(axis=1, keepdims=True) + 1e-12
|
| 304 |
+
dense /= row_sum
|
| 305 |
+
|
| 306 |
+
# Pack back into 4-bone format (top-4 by weight)
|
| 307 |
+
out_joints = np.zeros((N, 4), dtype=np.uint16)
|
| 308 |
+
out_weights = np.zeros((N, 4), dtype=np.float32)
|
| 309 |
+
for vi in range(N):
|
| 310 |
+
top4 = np.argsort(dense[vi])[-4:][::-1]
|
| 311 |
+
total = dense[vi, top4].sum() + 1e-12
|
| 312 |
+
for c, j in enumerate(top4):
|
| 313 |
+
out_joints[vi, c] = j
|
| 314 |
+
out_weights[vi, c] = dense[vi, j] / total
|
| 315 |
+
|
| 316 |
+
return out_joints, out_weights
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
# GLB rebuild
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
def _pack_buffer_view(data_bytes: bytes, target: list, byte_offset: int,
|
| 324 |
+
byte_stride: Optional[int] = None) -> Tuple[int, int]:
|
| 325 |
+
"""
|
| 326 |
+
Append data_bytes to target buffer, return (buffer_view_index, new_offset).
|
| 327 |
+
"""
|
| 328 |
+
bv = pygltflib.BufferView(
|
| 329 |
+
buffer=0,
|
| 330 |
+
byteOffset=byte_offset,
|
| 331 |
+
byteLength=len(data_bytes),
|
| 332 |
+
)
|
| 333 |
+
if byte_stride:
|
| 334 |
+
bv.byteStride = byte_stride
|
| 335 |
+
return bv, byte_offset + len(data_bytes)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _make_accessor(component_type: int, type_str: str, count: int,
|
| 339 |
+
bv_idx: int, min_vals=None, max_vals=None) -> pygltflib.Accessor:
|
| 340 |
+
acc = pygltflib.Accessor(
|
| 341 |
+
bufferView=bv_idx,
|
| 342 |
+
byteOffset=0,
|
| 343 |
+
componentType=component_type,
|
| 344 |
+
count=count,
|
| 345 |
+
type=type_str,
|
| 346 |
+
)
|
| 347 |
+
if min_vals is not None:
|
| 348 |
+
acc.min = [float(v) for v in min_vals]
|
| 349 |
+
if max_vals is not None:
|
| 350 |
+
acc.max = [float(v) for v in max_vals]
|
| 351 |
+
return acc
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
FLOAT32 = pygltflib.FLOAT # 5126
|
| 355 |
+
UINT16 = pygltflib.UNSIGNED_SHORT # 5123
|
| 356 |
+
UINT32 = pygltflib.UNSIGNED_INT # 5125
|
| 357 |
+
UBYTE = pygltflib.UNSIGNED_BYTE # 5121
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def transplant_face(
|
| 361 |
+
body_glb_path: str,
|
| 362 |
+
pshuman_mesh_path: str,
|
| 363 |
+
output_path: str,
|
| 364 |
+
head_bone_substrings: Tuple[str, ...] = ("head", "Head", "skull", "Skull"),
|
| 365 |
+
weight_threshold: float = 0.35,
|
| 366 |
+
retract_amount: float = 0.004, # metres — how far to push face verts inward
|
| 367 |
+
knn: int = 5,
|
| 368 |
+
):
|
| 369 |
+
"""
|
| 370 |
+
Main entry point.
|
| 371 |
+
|
| 372 |
+
Parameters
|
| 373 |
+
----------
|
| 374 |
+
body_glb_path : rigged UniRig GLB
|
| 375 |
+
pshuman_mesh_path : PSHuman output mesh (OBJ / GLB / PLY)
|
| 376 |
+
output_path : result GLB path
|
| 377 |
+
head_bone_substrings : bone name fragments that identify head joints
|
| 378 |
+
weight_threshold : head-weight sum above which a vertex is "face"
|
| 379 |
+
retract_amount : metres to push face verts inward to avoid z-fight
|
| 380 |
+
knn : neighbours for weight transfer
|
| 381 |
+
"""
|
| 382 |
+
print(f"[face_transplant] Loading rigged GLB: {body_glb_path}")
|
| 383 |
+
glb = GLBMesh(body_glb_path)
|
| 384 |
+
print(f" Verts: {len(glb.verts)} Faces: {len(glb.faces)}")
|
| 385 |
+
print(f" Bones ({len(glb.joint_names)}): {', '.join(glb.joint_names[:8])} ...")
|
| 386 |
+
|
| 387 |
+
# 1. Identify head joints
|
| 388 |
+
head_ji = glb.head_bone_indices(substrings=head_bone_substrings)
|
| 389 |
+
if not head_ji:
|
| 390 |
+
raise RuntimeError(
|
| 391 |
+
f"No head bones found with substrings {head_bone_substrings}.\n"
|
| 392 |
+
f"Available bones: {glb.joint_names}"
|
| 393 |
+
)
|
| 394 |
+
print(f" Head joints ({len(head_ji)}): {[glb.joint_names[i] for i in head_ji]}")
|
| 395 |
+
|
| 396 |
+
# 2. Find face/head vertices
|
| 397 |
+
face_mask = find_face_verts(glb, head_ji, weight_threshold=weight_threshold)
|
| 398 |
+
print(f" Face verts: {face_mask.sum()} / {len(glb.verts)}")
|
| 399 |
+
|
| 400 |
+
min_face_verts = max(3, min(10, len(glb.verts) // 4))
|
| 401 |
+
if face_mask.sum() < min_face_verts:
|
| 402 |
+
raise RuntimeError(
|
| 403 |
+
f"Only {face_mask.sum()} face vertices found (need >= {min_face_verts}) — "
|
| 404 |
+
f"try lowering --weight-threshold (current: {weight_threshold})"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# 3. Load + align PSHuman mesh
|
| 408 |
+
face_verts_ur = glb.verts[face_mask]
|
| 409 |
+
ps_mesh = load_and_align_pshuman(pshuman_mesh_path, face_verts_ur)
|
| 410 |
+
ps_verts = np.array(ps_mesh.vertices, dtype=np.float32)
|
| 411 |
+
ps_faces = np.array(ps_mesh.faces, dtype=np.int32)
|
| 412 |
+
ps_uvs = None
|
| 413 |
+
if hasattr(ps_mesh.visual, "uv") and ps_mesh.visual.uv is not None:
|
| 414 |
+
ps_uvs = np.array(ps_mesh.visual.uv, dtype=np.float32)
|
| 415 |
+
|
| 416 |
+
# 4. Transfer weights: donor = UniRig face verts, recipient = PSHuman verts
|
| 417 |
+
print("[face_transplant] Transferring skinning weights via KNN ...")
|
| 418 |
+
ps_joints, ps_weights = transfer_weights(
|
| 419 |
+
donor_verts = glb.verts[face_mask].astype(np.float64),
|
| 420 |
+
donor_joints = glb.joints4[face_mask],
|
| 421 |
+
donor_weights = glb.weights4[face_mask],
|
| 422 |
+
recipient_verts = ps_verts.astype(np.float64),
|
| 423 |
+
k = knn,
|
| 424 |
+
)
|
| 425 |
+
print(f" Done. Head joint coverage in PSHuman: "
|
| 426 |
+
f"{(np.isin(ps_joints[:, 0], head_ji)).mean() * 100:.1f}% primary bone is head")
|
| 427 |
+
|
| 428 |
+
# 5. Retract UniRig face verts inward (push along −normal)
|
| 429 |
+
body_verts = glb.verts.copy()
|
| 430 |
+
if glb.normals is not None:
|
| 431 |
+
body_verts[face_mask] -= glb.normals[face_mask] * retract_amount
|
| 432 |
+
else:
|
| 433 |
+
# push toward centroid
|
| 434 |
+
centroid = body_verts[face_mask].mean(axis=0)
|
| 435 |
+
dirs = centroid - body_verts[face_mask]
|
| 436 |
+
norms = np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-9
|
| 437 |
+
body_verts[face_mask] += (dirs / norms) * retract_amount
|
| 438 |
+
|
| 439 |
+
# 6. Rebuild GLB
|
| 440 |
+
print("[face_transplant] Rebuilding GLB ...")
|
| 441 |
+
_write_transplanted_glb(
|
| 442 |
+
source_gltf = glb,
|
| 443 |
+
body_verts = body_verts,
|
| 444 |
+
ps_verts = ps_verts,
|
| 445 |
+
ps_faces = ps_faces,
|
| 446 |
+
ps_uvs = ps_uvs,
|
| 447 |
+
ps_joints = ps_joints,
|
| 448 |
+
ps_weights = ps_weights,
|
| 449 |
+
output_path = output_path,
|
| 450 |
+
)
|
| 451 |
+
print(f"[face_transplant] Saved -> {output_path}")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# ---------------------------------------------------------------------------
|
| 455 |
+
# GLB writer
|
| 456 |
+
# ---------------------------------------------------------------------------
|
| 457 |
+
|
| 458 |
+
def _write_transplanted_glb(
|
| 459 |
+
source_gltf: GLBMesh,
|
| 460 |
+
body_verts: np.ndarray,
|
| 461 |
+
ps_verts: np.ndarray,
|
| 462 |
+
ps_faces: np.ndarray,
|
| 463 |
+
ps_uvs: Optional[np.ndarray],
|
| 464 |
+
ps_joints: np.ndarray,
|
| 465 |
+
ps_weights: np.ndarray,
|
| 466 |
+
output_path: str,
|
| 467 |
+
):
|
| 468 |
+
"""
|
| 469 |
+
Copy the source GLB structure, replace mesh primitive 0 vertex data,
|
| 470 |
+
and append a new primitive for the PSHuman face.
|
| 471 |
+
"""
|
| 472 |
+
import copy
|
| 473 |
+
gltf = pygltflib.GLTF2().load(source_gltf.path)
|
| 474 |
+
gltf._path = source_gltf.path
|
| 475 |
+
|
| 476 |
+
# ------------------------------------------------------------------
|
| 477 |
+
# Preserve embedded images as data URIs BEFORE we wipe buffer views.
|
| 478 |
+
# The binary blob rebuild below only contains geometry; any image data
|
| 479 |
+
# referenced via bufferView would otherwise be lost.
|
| 480 |
+
# ------------------------------------------------------------------
|
| 481 |
+
try:
|
| 482 |
+
blob = bytes(gltf.binary_blob())
|
| 483 |
+
except Exception:
|
| 484 |
+
blob = b""
|
| 485 |
+
for img in gltf.images:
|
| 486 |
+
if img.bufferView is not None and img.uri is None and blob:
|
| 487 |
+
bv = gltf.bufferViews[img.bufferView]
|
| 488 |
+
img_bytes = blob[bv.byteOffset: bv.byteOffset + bv.byteLength]
|
| 489 |
+
mime = img.mimeType or "image/png"
|
| 490 |
+
img.uri = "data:{};base64,{}".format(mime, base64.b64encode(img_bytes).decode())
|
| 491 |
+
img.bufferView = None
|
| 492 |
+
|
| 493 |
+
# ------------------------------------------------------------------
|
| 494 |
+
# We will rebuild the entire binary buffer from scratch.
|
| 495 |
+
# Collect all data chunks; track buffer views + accessors.
|
| 496 |
+
# ------------------------------------------------------------------
|
| 497 |
+
chunks: List[bytes] = []
|
| 498 |
+
bviews: List[pygltflib.BufferView] = []
|
| 499 |
+
accors: List[pygltflib.Accessor] = []
|
| 500 |
+
byte_offset = 0
|
| 501 |
+
|
| 502 |
+
def add_chunk(data: bytes, component_type: int, type_str: str, count: int,
|
| 503 |
+
min_v=None, max_v=None, stride: int = None) -> int:
|
| 504 |
+
"""Append data, create buffer view + accessor, return accessor index."""
|
| 505 |
+
nonlocal byte_offset
|
| 506 |
+
bv = pygltflib.BufferView(buffer=0, byteOffset=byte_offset, byteLength=len(data))
|
| 507 |
+
if stride:
|
| 508 |
+
bv.byteStride = stride
|
| 509 |
+
bviews.append(bv)
|
| 510 |
+
bv_idx = len(bviews) - 1
|
| 511 |
+
|
| 512 |
+
acc = pygltflib.Accessor(
|
| 513 |
+
bufferView=bv_idx,
|
| 514 |
+
byteOffset=0,
|
| 515 |
+
componentType=component_type,
|
| 516 |
+
count=count,
|
| 517 |
+
type=type_str,
|
| 518 |
+
)
|
| 519 |
+
if min_v is not None:
|
| 520 |
+
acc.min = [float(x) for x in np.atleast_1d(min_v)]
|
| 521 |
+
if max_v is not None:
|
| 522 |
+
acc.max = [float(x) for x in np.atleast_1d(max_v)]
|
| 523 |
+
accors.append(acc)
|
| 524 |
+
acc_idx = len(accors) - 1
|
| 525 |
+
|
| 526 |
+
chunks.append(data)
|
| 527 |
+
byte_offset += len(data)
|
| 528 |
+
return acc_idx
|
| 529 |
+
|
| 530 |
+
# ------------------------------------------------------------------
|
| 531 |
+
# Primitive 0 — UniRig body (retracted face verts)
|
| 532 |
+
# ------------------------------------------------------------------
|
| 533 |
+
body_v = body_verts.astype(np.float32)
|
| 534 |
+
body_i = source_gltf.faces.astype(np.uint32).flatten()
|
| 535 |
+
body_n = (source_gltf.normals.astype(np.float32)
|
| 536 |
+
if source_gltf.normals is not None else None)
|
| 537 |
+
body_uv = (source_gltf.uvs.astype(np.float32)
|
| 538 |
+
if source_gltf.uvs is not None else None)
|
| 539 |
+
body_j = source_gltf.joints4.astype(np.uint16)
|
| 540 |
+
body_w = source_gltf.weights4.astype(np.float32)
|
| 541 |
+
|
| 542 |
+
# indices
|
| 543 |
+
bi_idx = add_chunk(body_i.tobytes(), UINT32, "SCALAR", len(body_i),
|
| 544 |
+
min_v=[int(body_i.min())], max_v=[int(body_i.max())])
|
| 545 |
+
# positions
|
| 546 |
+
bv_idx = add_chunk(body_v.tobytes(), FLOAT32, "VEC3", len(body_v),
|
| 547 |
+
min_v=body_v.min(axis=0), max_v=body_v.max(axis=0))
|
| 548 |
+
body_attrs = pygltflib.Attributes(POSITION=bv_idx)
|
| 549 |
+
if body_n is not None:
|
| 550 |
+
body_attrs.NORMAL = add_chunk(body_n.tobytes(), FLOAT32, "VEC3", len(body_n))
|
| 551 |
+
if body_uv is not None:
|
| 552 |
+
body_attrs.TEXCOORD_0 = add_chunk(body_uv.tobytes(), FLOAT32, "VEC2", len(body_uv))
|
| 553 |
+
if body_j is not None:
|
| 554 |
+
body_attrs.JOINTS_0 = add_chunk(body_j.tobytes(), UINT16, "VEC4", len(body_j))
|
| 555 |
+
body_attrs.WEIGHTS_0 = add_chunk(body_w.tobytes(), FLOAT32, "VEC4", len(body_w))
|
| 556 |
+
|
| 557 |
+
prim0 = pygltflib.Primitive(
|
| 558 |
+
attributes=body_attrs,
|
| 559 |
+
indices=bi_idx,
|
| 560 |
+
material=source_gltf.material_idx,
|
| 561 |
+
mode=4, # TRIANGLES
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# ------------------------------------------------------------------
|
| 565 |
+
# Primitive 1 — PSHuman face
|
| 566 |
+
# ------------------------------------------------------------------
|
| 567 |
+
ps_v = ps_verts.astype(np.float32)
|
| 568 |
+
ps_i = ps_faces.astype(np.uint32).flatten()
|
| 569 |
+
ps_j4 = ps_joints.astype(np.uint16)
|
| 570 |
+
ps_w4 = ps_weights.astype(np.float32)
|
| 571 |
+
|
| 572 |
+
# PSHuman material — reuse body material for now (same texture look)
|
| 573 |
+
# If PSHuman has its own texture, you'd add a new material here.
|
| 574 |
+
face_mat_idx = source_gltf.material_idx
|
| 575 |
+
|
| 576 |
+
fi_idx = add_chunk(ps_i.tobytes(), UINT32, "SCALAR", len(ps_i),
|
| 577 |
+
min_v=[int(ps_i.min())], max_v=[int(ps_i.max())])
|
| 578 |
+
fv_idx = add_chunk(ps_v.tobytes(), FLOAT32, "VEC3", len(ps_v),
|
| 579 |
+
min_v=ps_v.min(axis=0), max_v=ps_v.max(axis=0))
|
| 580 |
+
face_attrs = pygltflib.Attributes(POSITION=fv_idx)
|
| 581 |
+
if ps_uvs is not None:
|
| 582 |
+
face_attrs.TEXCOORD_0 = add_chunk(ps_uvs.tobytes(), FLOAT32, "VEC2", len(ps_uvs))
|
| 583 |
+
face_attrs.JOINTS_0 = add_chunk(ps_j4.tobytes(), UINT16, "VEC4", len(ps_j4))
|
| 584 |
+
face_attrs.WEIGHTS_0 = add_chunk(ps_w4.tobytes(), FLOAT32, "VEC4", len(ps_w4))
|
| 585 |
+
|
| 586 |
+
prim1 = pygltflib.Primitive(
|
| 587 |
+
attributes=face_attrs,
|
| 588 |
+
indices=fi_idx,
|
| 589 |
+
material=face_mat_idx,
|
| 590 |
+
mode=4,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# ------------------------------------------------------------------
|
| 594 |
+
# Patch gltf structure
|
| 595 |
+
# ------------------------------------------------------------------
|
| 596 |
+
# Find or create the mesh that uses our skin
|
| 597 |
+
mesh_node = gltf.nodes[source_gltf.mesh_node_idx]
|
| 598 |
+
old_mesh_idx = mesh_node.mesh
|
| 599 |
+
|
| 600 |
+
new_mesh = pygltflib.Mesh(
|
| 601 |
+
name="body_with_pshuman_face",
|
| 602 |
+
primitives=[prim0, prim1],
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Replace or append
|
| 606 |
+
if old_mesh_idx is not None and old_mesh_idx < len(gltf.meshes):
|
| 607 |
+
gltf.meshes[old_mesh_idx] = new_mesh
|
| 608 |
+
target_mesh_idx = old_mesh_idx
|
| 609 |
+
else:
|
| 610 |
+
gltf.meshes.append(new_mesh)
|
| 611 |
+
target_mesh_idx = len(gltf.meshes) - 1
|
| 612 |
+
|
| 613 |
+
mesh_node.mesh = target_mesh_idx
|
| 614 |
+
|
| 615 |
+
# Replace buffer views and accessors
|
| 616 |
+
gltf.bufferViews = bviews
|
| 617 |
+
gltf.accessors = accors
|
| 618 |
+
|
| 619 |
+
# Rewrite buffer
|
| 620 |
+
combined = b"".join(chunks)
|
| 621 |
+
# Pad to 4-byte alignment
|
| 622 |
+
if len(combined) % 4:
|
| 623 |
+
combined += b"\x00" * (4 - len(combined) % 4)
|
| 624 |
+
|
| 625 |
+
gltf.buffers = [pygltflib.Buffer(byteLength=len(combined))]
|
| 626 |
+
gltf.set_binary_blob(combined)
|
| 627 |
+
|
| 628 |
+
# Drop stale animation (it referenced old accessor indices)
|
| 629 |
+
# The user can re-add animation later if needed.
|
| 630 |
+
gltf.animations = []
|
| 631 |
+
|
| 632 |
+
gltf.save(output_path)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
# ---------------------------------------------------------------------------
|
| 636 |
+
# CLI
|
| 637 |
+
# ---------------------------------------------------------------------------
|
| 638 |
+
|
| 639 |
+
def main():
|
| 640 |
+
parser = argparse.ArgumentParser(description="Transplant PSHuman face into UniRig GLB")
|
| 641 |
+
parser.add_argument("--body", required=True, help="Rigged UniRig GLB")
|
| 642 |
+
parser.add_argument("--face", required=True, help="PSHuman mesh (OBJ/GLB/PLY)")
|
| 643 |
+
parser.add_argument("--output", required=True, help="Output GLB path")
|
| 644 |
+
parser.add_argument("--head-bones", default="head,Head,skull,Skull",
|
| 645 |
+
help="Comma-separated bone name substrings for head detection")
|
| 646 |
+
parser.add_argument("--weight-threshold", type=float, default=0.35,
|
| 647 |
+
help="Minimum head-bone weight sum to classify a vert as face")
|
| 648 |
+
parser.add_argument("--retract", type=float, default=0.004,
|
| 649 |
+
help="Metres to retract UniRig face verts inward (default 0.004)")
|
| 650 |
+
parser.add_argument("--knn", type=int, default=5,
|
| 651 |
+
help="K nearest neighbours for weight transfer")
|
| 652 |
+
args = parser.parse_args()
|
| 653 |
+
|
| 654 |
+
subs = tuple(s.strip() for s in args.head_bones.split(","))
|
| 655 |
+
transplant_face(
|
| 656 |
+
body_glb_path = args.body,
|
| 657 |
+
pshuman_mesh_path = args.face,
|
| 658 |
+
output_path = args.output,
|
| 659 |
+
head_bone_substrings = subs,
|
| 660 |
+
weight_threshold = args.weight_threshold,
|
| 661 |
+
retract_amount = args.retract,
|
| 662 |
+
knn = args.knn,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
if __name__ == "__main__":
|
| 667 |
+
main()
|
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
head_replace.py — Replace TripoSG head with DECA-reconstructed head at mesh level.
|
| 3 |
+
|
| 4 |
+
Requires: trimesh, numpy, scipy, cv2, torch (+ face-alignment via DECA deps)
|
| 5 |
+
Optional: pymeshlab (for mesh clean-up)
|
| 6 |
+
|
| 7 |
+
Usage (standalone):
|
| 8 |
+
python head_replace.py --body /tmp/triposg_textured.glb \
|
| 9 |
+
--face /path/to/face.jpg \
|
| 10 |
+
--out /tmp/head_replaced.glb
|
| 11 |
+
|
| 12 |
+
Returns combined GLB with DECA head geometry + TripoSG body.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os, sys, argparse, warnings
|
| 16 |
+
warnings.filterwarnings('ignore')
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import cv2
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# ──────────────────────────────────────────────────────────────────
|
| 23 |
+
# Patch DECA before importing it to avoid pytorch3d dependency
|
| 24 |
+
# ──────────────────────────────────────────────────────────────────
|
| 25 |
+
DECA_ROOT = '/root/DECA'
|
| 26 |
+
sys.path.insert(0, DECA_ROOT)
|
| 27 |
+
|
| 28 |
+
# Stub out the rasterizer so DECA doesn't try to import pytorch3d
|
| 29 |
+
import importlib, types
|
| 30 |
+
_fake_renderer = types.ModuleType('decalib.utils.renderer')
|
| 31 |
+
_fake_renderer.set_rasterizer = lambda t='pytorch3d': None
|
| 32 |
+
|
| 33 |
+
class _FakeRender:
|
| 34 |
+
"""No-op renderer — we only need the mesh, not rendered images."""
|
| 35 |
+
def __init__(self, *a, **kw): pass
|
| 36 |
+
def to(self, *a, **kw): return self
|
| 37 |
+
def __call__(self, *a, **kw): return {'images': None, 'alpha_images': None,
|
| 38 |
+
'normal_images': None, 'grid': None,
|
| 39 |
+
'transformed_normals': None, 'normals': None}
|
| 40 |
+
def render_shape(self, *a, **kw): return None, None, None, None
|
| 41 |
+
def world2uv(self, *a, **kw): return None
|
| 42 |
+
def add_SHlight(self, *a, **kw): return None
|
| 43 |
+
|
| 44 |
+
_fake_renderer.SRenderY = _FakeRender
|
| 45 |
+
sys.modules['decalib.utils.renderer'] = _fake_renderer
|
| 46 |
+
|
| 47 |
+
# Patch deca.py: make _setup_renderer a no-op when renderer not available
|
| 48 |
+
from decalib import deca as _deca_mod
|
| 49 |
+
_orig_setup = _deca_mod.DECA._setup_renderer
|
| 50 |
+
|
| 51 |
+
def _patched_setup(self, model_cfg):
|
| 52 |
+
try:
|
| 53 |
+
_orig_setup(self, model_cfg)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f'[head_replace] Renderer disabled ({e})')
|
| 56 |
+
self.render = _FakeRender()
|
| 57 |
+
# Still load mask / displacement data we need for UV baking
|
| 58 |
+
from skimage.io import imread
|
| 59 |
+
import torch, torch.nn.functional as F
|
| 60 |
+
try:
|
| 61 |
+
mask = imread(model_cfg.face_eye_mask_path).astype(np.float32) / 255.
|
| 62 |
+
mask = torch.from_numpy(mask[:, :, 0])[None, None, :, :].contiguous()
|
| 63 |
+
self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size])
|
| 64 |
+
mask2 = imread(model_cfg.face_mask_path).astype(np.float32) / 255.
|
| 65 |
+
mask2 = torch.from_numpy(mask2[:, :, 0])[None, None, :, :].contiguous()
|
| 66 |
+
self.uv_face_mask = F.interpolate(mask2, [model_cfg.uv_size, model_cfg.uv_size])
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
try:
|
| 70 |
+
fixed_dis = np.load(model_cfg.fixed_displacement_path)
|
| 71 |
+
self.fixed_uv_dis = torch.tensor(fixed_dis).float()
|
| 72 |
+
except Exception:
|
| 73 |
+
pass
|
| 74 |
+
try:
|
| 75 |
+
mean_tex_np = imread(model_cfg.mean_tex_path).astype(np.float32) / 255.
|
| 76 |
+
mean_tex = torch.from_numpy(mean_tex_np.transpose(2, 0, 1))[None]
|
| 77 |
+
self.mean_texture = F.interpolate(mean_tex, [model_cfg.uv_size, model_cfg.uv_size])
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
try:
|
| 81 |
+
self.dense_template = np.load(model_cfg.dense_template_path,
|
| 82 |
+
allow_pickle=True, encoding='latin1').item()
|
| 83 |
+
except Exception:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
_deca_mod.DECA._setup_renderer = _patched_setup
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ──────────────────────────────────────────────────────────────────
|
| 90 |
+
# FLAME mesh: parse head_template.obj for UV map
|
| 91 |
+
# ──────────────────────────────────────────────────────────────────
|
| 92 |
+
def _load_flame_template(obj_path=os.path.join(DECA_ROOT, 'data', 'head_template.obj')):
|
| 93 |
+
"""Return (verts, faces, uv_verts, uv_faces) from head_template.obj."""
|
| 94 |
+
verts, uv_verts = [], []
|
| 95 |
+
faces_v, faces_uv = [], []
|
| 96 |
+
for line in open(obj_path):
|
| 97 |
+
t = line.split()
|
| 98 |
+
if not t:
|
| 99 |
+
continue
|
| 100 |
+
if t[0] == 'v':
|
| 101 |
+
verts.append([float(t[1]), float(t[2]), float(t[3])])
|
| 102 |
+
elif t[0] == 'vt':
|
| 103 |
+
uv_verts.append([float(t[1]), float(t[2])])
|
| 104 |
+
elif t[0] == 'f':
|
| 105 |
+
vi, uvi = [], []
|
| 106 |
+
for tok in t[1:]:
|
| 107 |
+
parts = tok.split('/')
|
| 108 |
+
vi.append(int(parts[0]) - 1)
|
| 109 |
+
uvi.append(int(parts[1]) - 1 if len(parts) > 1 and parts[1] else 0)
|
| 110 |
+
if len(vi) == 3:
|
| 111 |
+
faces_v.append(vi)
|
| 112 |
+
faces_uv.append(uvi)
|
| 113 |
+
return (np.array(verts, dtype=np.float32),
|
| 114 |
+
np.array(faces_v, dtype=np.int32),
|
| 115 |
+
np.array(uv_verts, dtype=np.float32),
|
| 116 |
+
np.array(faces_uv, dtype=np.int32))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ──────────────────────────────────────────────────────────────────
|
| 120 |
+
# UV texture baking (software rasteriser, no pytorch3d needed)
|
| 121 |
+
# ──────────────────────────────────────────────────────────────────
|
| 122 |
+
def _bake_uv_texture(verts3d, faces_v, uv_verts, faces_uv, cam, face_img_bgr, tex_size=256):
|
| 123 |
+
"""
|
| 124 |
+
Project face_img_bgr onto the FLAME UV map using orthographic camera.
|
| 125 |
+
verts3d : (N,3) FLAME vertices in world space
|
| 126 |
+
cam : (3,) = [scale, tx, ty] orthographic camera
|
| 127 |
+
Returns : (tex_size, tex_size, 3) uint8 texture (BGR)
|
| 128 |
+
"""
|
| 129 |
+
H, W = face_img_bgr.shape[:2]
|
| 130 |
+
scale, tx, ty = float(cam[0]), float(cam[1]), float(cam[2])
|
| 131 |
+
|
| 132 |
+
# Orthographic project: DECA formula = (vert_2D + [tx,ty]) * scale, then flip y
|
| 133 |
+
proj = np.zeros((len(verts3d), 2), dtype=np.float32)
|
| 134 |
+
proj[:, 0] = (verts3d[:, 0] + tx) * scale
|
| 135 |
+
proj[:, 1] = -((verts3d[:, 1] + ty) * scale) # y-flip matches DECA convention
|
| 136 |
+
|
| 137 |
+
# Map to pixel coords: image spans proj ∈ [-1,1] → pixel [0, WH]
|
| 138 |
+
img_pts = (proj + 1.0) * 0.5 * np.array([W, H], dtype=np.float32) # (N, 2)
|
| 139 |
+
|
| 140 |
+
# UV pixel coords
|
| 141 |
+
uv_px = uv_verts * tex_size # (K, 2)
|
| 142 |
+
|
| 143 |
+
# Output buffers
|
| 144 |
+
tex_acc = np.zeros((tex_size, tex_size, 3), dtype=np.float64)
|
| 145 |
+
tex_cnt = np.zeros((tex_size, tex_size), dtype=np.float64)
|
| 146 |
+
z_buf = np.full((tex_size, tex_size), -1e9, dtype=np.float64)
|
| 147 |
+
|
| 148 |
+
# Vectorised rasteriser in UV space:
|
| 149 |
+
# For each face, scatter samples from img_pts into uv_px coords.
|
| 150 |
+
# Use scipy.interpolate.griddata as a fast splat.
|
| 151 |
+
from scipy.interpolate import griddata
|
| 152 |
+
|
| 153 |
+
# Front-facing mask (z > threshold) — only bake visible faces
|
| 154 |
+
z_face = verts3d[faces_v, 2].mean(axis=1) # (M,) mean z per face
|
| 155 |
+
front_mask = z_face >= -0.02 # keep front and side faces
|
| 156 |
+
|
| 157 |
+
# For each face corner, record (uv_px, img_pts) sample
|
| 158 |
+
corners_uv = uv_px[faces_uv[front_mask]] # (K, 3, 2)
|
| 159 |
+
corners_img = img_pts[faces_v[front_mask]] # (K, 3, 2)
|
| 160 |
+
|
| 161 |
+
# Flatten to (K*3, 2)
|
| 162 |
+
src_uv = corners_uv.reshape(-1, 2) # UV pixel destination
|
| 163 |
+
src_img = corners_img.reshape(-1, 2) # image pixel source
|
| 164 |
+
|
| 165 |
+
# Remove out-of-bounds image samples
|
| 166 |
+
valid = ((src_img[:, 0] >= 0) & (src_img[:, 0] < W) &
|
| 167 |
+
(src_img[:, 1] >= 0) & (src_img[:, 1] < H))
|
| 168 |
+
src_uv = src_uv[valid]
|
| 169 |
+
src_img = src_img[valid]
|
| 170 |
+
|
| 171 |
+
# Sample face image at src_img positions
|
| 172 |
+
ix = np.clip(src_img[:, 0].astype(int), 0, W - 1)
|
| 173 |
+
iy = np.clip(src_img[:, 1].astype(int), 0, H - 1)
|
| 174 |
+
colours = face_img_bgr[iy, ix].astype(np.float32) # (P, 3)
|
| 175 |
+
|
| 176 |
+
# Clip UV destinations to texture bounds
|
| 177 |
+
uv_dest = np.clip(src_uv, 0, tex_size - 1 - 1e-6).astype(np.float32)
|
| 178 |
+
|
| 179 |
+
# Build query grid for griddata interpolation
|
| 180 |
+
grid_u, grid_v = np.meshgrid(np.arange(tex_size), np.arange(tex_size))
|
| 181 |
+
grid_pts = np.column_stack([grid_u.ravel(), grid_v.ravel()])
|
| 182 |
+
|
| 183 |
+
# Interpolate each colour channel
|
| 184 |
+
tex_baked = np.zeros((tex_size * tex_size, 3), dtype=np.float32)
|
| 185 |
+
for ch in range(3):
|
| 186 |
+
ch_vals = griddata(uv_dest, colours[:, ch], grid_pts,
|
| 187 |
+
method='linear', fill_value=np.nan)
|
| 188 |
+
tex_baked[:, ch] = ch_vals
|
| 189 |
+
tex_baked = tex_baked.reshape(tex_size, tex_size, 3)
|
| 190 |
+
face_baked_mask = ~np.isnan(tex_baked[:, :, 0])
|
| 191 |
+
|
| 192 |
+
# Base texture: mean_texture (skin tone fallback for unsampled regions)
|
| 193 |
+
mean_tex_path = os.path.join(DECA_ROOT, 'data', 'mean_texture.jpg')
|
| 194 |
+
if os.path.exists(mean_tex_path):
|
| 195 |
+
mt = cv2.resize(cv2.imread(mean_tex_path), (tex_size, tex_size)).astype(np.float32)
|
| 196 |
+
else:
|
| 197 |
+
mt = np.full((tex_size, tex_size, 3), 180.0, dtype=np.float32)
|
| 198 |
+
|
| 199 |
+
# Blend: baked face over mean texture
|
| 200 |
+
result = mt.copy()
|
| 201 |
+
result[face_baked_mask] = np.clip(tex_baked[face_baked_mask], 0, 255)
|
| 202 |
+
return result.astype(np.uint8)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ──────────────────────────────────────────────────────────────────
|
| 206 |
+
# DECA inference
|
| 207 |
+
# ──────────────────────────────────────────────────────────────────
|
| 208 |
+
def run_deca(face_img_path, device='cuda'):
|
| 209 |
+
"""
|
| 210 |
+
Run DECA on face_img_path.
|
| 211 |
+
Returns (verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_img_bgr)
|
| 212 |
+
"""
|
| 213 |
+
import torch
|
| 214 |
+
from decalib.deca import DECA
|
| 215 |
+
from decalib.utils import config as cfg_module
|
| 216 |
+
from decalib.datasets import datasets
|
| 217 |
+
|
| 218 |
+
cfg = cfg_module.get_cfg_defaults()
|
| 219 |
+
cfg.model.use_tex = False
|
| 220 |
+
|
| 221 |
+
print('[DECA] Loading model...')
|
| 222 |
+
deca = DECA(config=cfg, device=device)
|
| 223 |
+
deca.eval()
|
| 224 |
+
|
| 225 |
+
print('[DECA] Preprocessing image...')
|
| 226 |
+
testdata = datasets.TestData(face_img_path)
|
| 227 |
+
img_tensor = testdata[0]['image'].to(device)[None, ...]
|
| 228 |
+
|
| 229 |
+
print('[DECA] Encoding...')
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
codedict = deca.encode(img_tensor, use_detail=False)
|
| 232 |
+
verts, _, _ = deca.flame(
|
| 233 |
+
shape_params=codedict['shape'],
|
| 234 |
+
expression_params=codedict['exp'],
|
| 235 |
+
pose_params=codedict['pose']
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
verts_np = verts[0].cpu().numpy() # (5023, 3)
|
| 239 |
+
cam_np = codedict['cam'][0].cpu().numpy() # (3,)
|
| 240 |
+
print(f'[DECA] Mesh: {verts_np.shape}, cam={cam_np}')
|
| 241 |
+
|
| 242 |
+
# Load FLAME UV map
|
| 243 |
+
_, faces_v, uv_verts, faces_uv = _load_flame_template()
|
| 244 |
+
|
| 245 |
+
# Get face image for texture baking (use the cropped/aligned 224x224)
|
| 246 |
+
img_np = (img_tensor[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 247 |
+
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 248 |
+
|
| 249 |
+
print('[DECA] Baking UV texture...')
|
| 250 |
+
tex_bgr = _bake_uv_texture(verts_np, faces_v, uv_verts, faces_uv, cam_np, img_bgr, tex_size=256)
|
| 251 |
+
|
| 252 |
+
return verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# ──────────────────────────────────────────────────────────────────
|
| 256 |
+
# Mesh helpers
|
| 257 |
+
# ──────────────────────────────────────────────────────────────────
|
| 258 |
+
def _find_neck_height(mesh):
|
| 259 |
+
"""
|
| 260 |
+
Find the best neck cut height in a body mesh.
|
| 261 |
+
Strategy: in the top 40% of the mesh, find the local minimum of
|
| 262 |
+
cross-sectional area (the neck is narrower than the head).
|
| 263 |
+
Returns the y-value of the cut plane.
|
| 264 |
+
"""
|
| 265 |
+
verts = mesh.vertices
|
| 266 |
+
y_min, y_max = verts[:, 1].min(), verts[:, 1].max()
|
| 267 |
+
y_range = y_max - y_min
|
| 268 |
+
|
| 269 |
+
# Scan [80%, 87%] to find the neck-base narrowing below the face.
|
| 270 |
+
# The range [83%, 91%] was picking the crown taper instead of the neck.
|
| 271 |
+
y_start = y_min + y_range * 0.80
|
| 272 |
+
y_end = y_min + y_range * 0.87
|
| 273 |
+
steps = 20
|
| 274 |
+
ys = np.linspace(y_start, y_end, steps)
|
| 275 |
+
band = y_range * 0.015
|
| 276 |
+
|
| 277 |
+
r10_vals = []
|
| 278 |
+
for y in ys:
|
| 279 |
+
pts = verts[(verts[:, 1] >= y - band) & (verts[:, 1] <= y + band)]
|
| 280 |
+
if len(pts) < 6:
|
| 281 |
+
r10_vals.append(1.0); continue
|
| 282 |
+
xz = pts[:, [0, 2]]
|
| 283 |
+
cx, cz = xz.mean(0)
|
| 284 |
+
radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
|
| 285 |
+
r10_vals.append(float(np.percentile(radii, 10)))
|
| 286 |
+
|
| 287 |
+
from scipy.ndimage import uniform_filter1d
|
| 288 |
+
r10 = uniform_filter1d(np.array(r10_vals), size=3)
|
| 289 |
+
neck_idx = int(np.argmin(r10[2:-2])) + 2
|
| 290 |
+
neck_y = float(ys[neck_idx])
|
| 291 |
+
frac = (neck_y - y_min) / y_range
|
| 292 |
+
print(f'[neck] Cut height: {neck_y:.4f} (y_range {y_min:.3f}–{y_max:.3f}, frac={frac:.2f})')
|
| 293 |
+
return neck_y
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _weld_mesh(mesh):
|
| 297 |
+
"""
|
| 298 |
+
Merge duplicate vertices (UV-split mesh → geometric mesh).
|
| 299 |
+
Returns a new trimesh with welded vertices.
|
| 300 |
+
"""
|
| 301 |
+
import trimesh
|
| 302 |
+
from scipy.spatial import cKDTree
|
| 303 |
+
verts = mesh.vertices
|
| 304 |
+
tree = cKDTree(verts)
|
| 305 |
+
# Build mapping: each vertex → canonical representative
|
| 306 |
+
N = len(verts)
|
| 307 |
+
mapping = np.arange(N, dtype=np.int64)
|
| 308 |
+
pairs = tree.query_pairs(r=1e-5)
|
| 309 |
+
for a, b in pairs:
|
| 310 |
+
root_a = int(mapping[a])
|
| 311 |
+
root_b = int(mapping[b])
|
| 312 |
+
while mapping[root_a] != root_a:
|
| 313 |
+
root_a = int(mapping[root_a])
|
| 314 |
+
while mapping[root_b] != root_b:
|
| 315 |
+
root_b = int(mapping[root_b])
|
| 316 |
+
if root_a != root_b:
|
| 317 |
+
mapping[root_b] = root_a
|
| 318 |
+
# Flatten chains
|
| 319 |
+
for i in range(N):
|
| 320 |
+
root = int(mapping[i])
|
| 321 |
+
while mapping[root] != root:
|
| 322 |
+
root = int(mapping[root])
|
| 323 |
+
mapping[i] = root
|
| 324 |
+
# Compact the mapping
|
| 325 |
+
unique_ids = np.unique(mapping)
|
| 326 |
+
compact = np.full(N, -1, dtype=np.int64)
|
| 327 |
+
compact[unique_ids] = np.arange(len(unique_ids))
|
| 328 |
+
new_faces = compact[mapping[mesh.faces]]
|
| 329 |
+
new_verts = verts[unique_ids]
|
| 330 |
+
return trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _cut_mesh_below(mesh, y_cut):
|
| 334 |
+
"""Keep only faces where all vertices are at or below y_cut. Preserves UV/texture."""
|
| 335 |
+
import trimesh
|
| 336 |
+
from trimesh.visual.texture import TextureVisuals
|
| 337 |
+
v_mask = mesh.vertices[:, 1] <= y_cut
|
| 338 |
+
f_keep = np.all(v_mask[mesh.faces], axis=1)
|
| 339 |
+
faces_kept = mesh.faces[f_keep]
|
| 340 |
+
used_verts = np.unique(faces_kept)
|
| 341 |
+
old_to_new = np.full(len(mesh.vertices), -1, dtype=np.int64)
|
| 342 |
+
old_to_new[used_verts] = np.arange(len(used_verts))
|
| 343 |
+
new_faces = old_to_new[faces_kept]
|
| 344 |
+
new_verts = mesh.vertices[used_verts]
|
| 345 |
+
new_mesh = trimesh.Trimesh(vertices=new_verts, faces=new_faces, process=False)
|
| 346 |
+
# Preserve UV + texture if present
|
| 347 |
+
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
|
| 348 |
+
new_mesh.visual = TextureVisuals(
|
| 349 |
+
uv=mesh.visual.uv[used_verts],
|
| 350 |
+
material=mesh.visual.material)
|
| 351 |
+
return new_mesh
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _extract_neck_ring_geometric(mesh, neck_y, n_pts=64, band_frac=0.02):
|
| 355 |
+
"""
|
| 356 |
+
Extract a neck ring using topological boundary edges near neck_y.
|
| 357 |
+
Falls back to angle-sorted vertices if topology is non-manifold.
|
| 358 |
+
Works on welded (geometric) meshes.
|
| 359 |
+
"""
|
| 360 |
+
verts = mesh.vertices
|
| 361 |
+
y_range = verts[:, 1].max() - verts[:, 1].min()
|
| 362 |
+
band = y_range * band_frac
|
| 363 |
+
|
| 364 |
+
# --- Try topological boundary near neck_y first ---
|
| 365 |
+
edges = np.sort(mesh.edges, axis=1)
|
| 366 |
+
u, c2 = np.unique(edges, axis=0, return_counts=True)
|
| 367 |
+
be = u[c2 == 1] # boundary edges
|
| 368 |
+
|
| 369 |
+
# Keep boundary edges where BOTH endpoints are near neck_y
|
| 370 |
+
v_near = np.abs(verts[:, 1] - neck_y) <= band * 2
|
| 371 |
+
neck_be = be[v_near[be[:, 0]] & v_near[be[:, 1]]]
|
| 372 |
+
|
| 373 |
+
if len(neck_be) >= 8:
|
| 374 |
+
# Build adjacency and walk loop
|
| 375 |
+
adj = {}
|
| 376 |
+
for e in neck_be:
|
| 377 |
+
adj.setdefault(int(e[0]), []).append(int(e[1]))
|
| 378 |
+
adj.setdefault(int(e[1]), []).append(int(e[0]))
|
| 379 |
+
# Find the largest connected loop
|
| 380 |
+
visited = set()
|
| 381 |
+
loops = []
|
| 382 |
+
for start in adj:
|
| 383 |
+
if start in visited: continue
|
| 384 |
+
loop = [start]; visited.add(start); prev = -1; cur = start
|
| 385 |
+
for _ in range(len(neck_be) + 1):
|
| 386 |
+
nbrs = [v for v in adj.get(cur, []) if v != prev]
|
| 387 |
+
if not nbrs: break
|
| 388 |
+
nxt = nbrs[0]
|
| 389 |
+
if nxt == start: break
|
| 390 |
+
if nxt in visited: break
|
| 391 |
+
visited.add(nxt); prev = cur; cur = nxt; loop.append(cur)
|
| 392 |
+
loops.append(loop)
|
| 393 |
+
if loops:
|
| 394 |
+
best = max(loops, key=len)
|
| 395 |
+
if len(best) >= 8:
|
| 396 |
+
ring_pts = verts[best]
|
| 397 |
+
# Snap all ring points to neck_y (smooth the cut plane)
|
| 398 |
+
ring_pts = ring_pts.copy()
|
| 399 |
+
ring_pts[:, 1] = neck_y
|
| 400 |
+
return _resample_loop(ring_pts, n_pts)
|
| 401 |
+
|
| 402 |
+
# --- Fallback: use inner-cluster (neck column) vertices in the band ---
|
| 403 |
+
mask = (verts[:, 1] >= neck_y - band) & (verts[:, 1] <= neck_y + band)
|
| 404 |
+
pts = verts[mask]
|
| 405 |
+
if len(pts) < 8:
|
| 406 |
+
raise ValueError(f'Too few vertices near neck_y={neck_y:.4f}: {len(pts)}')
|
| 407 |
+
|
| 408 |
+
# Keep only inner-ring vertices (below 35th percentile radius from centroid)
|
| 409 |
+
# This excludes the outer face/head surface and keeps only the neck column
|
| 410 |
+
xz = pts[:, [0, 2]]
|
| 411 |
+
cx, cz = xz.mean(0)
|
| 412 |
+
radii = np.sqrt((xz[:, 0] - cx)**2 + (xz[:, 1] - cz)**2)
|
| 413 |
+
thresh = np.percentile(radii, 35)
|
| 414 |
+
inner_mask = radii <= thresh
|
| 415 |
+
if inner_mask.sum() >= 8:
|
| 416 |
+
pts = pts[inner_mask]
|
| 417 |
+
# Recompute centroid on inner pts
|
| 418 |
+
cx, cz = pts[:, [0, 2]].mean(0)
|
| 419 |
+
|
| 420 |
+
# Sort by angle in XZ plane
|
| 421 |
+
angles = np.arctan2(pts[:, 2] - cz, pts[:, 0] - cx)
|
| 422 |
+
pts_sorted = pts[np.argsort(angles)]
|
| 423 |
+
pts_sorted = pts_sorted.copy()
|
| 424 |
+
pts_sorted[:, 1] = neck_y # snap to cut plane
|
| 425 |
+
return _resample_loop(pts_sorted, n_pts)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _extract_boundary_loop(mesh):
|
| 429 |
+
"""
|
| 430 |
+
Extract the boundary edge loop (ordered) from a welded mesh.
|
| 431 |
+
Returns (N, 3) ordered vertex positions.
|
| 432 |
+
"""
|
| 433 |
+
# Find boundary edges (edges used by exactly one face)
|
| 434 |
+
edges = np.sort(mesh.edges, axis=1)
|
| 435 |
+
unique, counts = np.unique(edges, axis=0, return_counts=True)
|
| 436 |
+
boundary_edges = unique[counts == 1]
|
| 437 |
+
|
| 438 |
+
if len(boundary_edges) == 0:
|
| 439 |
+
raise ValueError('No boundary edges found — mesh may be closed')
|
| 440 |
+
|
| 441 |
+
# Build adjacency for boundary edges
|
| 442 |
+
adj = {}
|
| 443 |
+
for e in boundary_edges:
|
| 444 |
+
adj.setdefault(int(e[0]), []).append(int(e[1]))
|
| 445 |
+
adj.setdefault(int(e[1]), []).append(int(e[0]))
|
| 446 |
+
|
| 447 |
+
# Walk the longest connected loop
|
| 448 |
+
# Find all loops
|
| 449 |
+
visited = set()
|
| 450 |
+
loops = []
|
| 451 |
+
for start_v in adj:
|
| 452 |
+
if start_v in visited:
|
| 453 |
+
continue
|
| 454 |
+
loop = [start_v]
|
| 455 |
+
visited.add(start_v)
|
| 456 |
+
prev = -1
|
| 457 |
+
cur = start_v
|
| 458 |
+
for _ in range(len(boundary_edges) + 1):
|
| 459 |
+
nbrs = [v for v in adj.get(cur, []) if v != prev]
|
| 460 |
+
if not nbrs:
|
| 461 |
+
break
|
| 462 |
+
nxt = nbrs[0]
|
| 463 |
+
if nxt == start_v:
|
| 464 |
+
break
|
| 465 |
+
if nxt in visited:
|
| 466 |
+
break
|
| 467 |
+
visited.add(nxt)
|
| 468 |
+
prev = cur
|
| 469 |
+
cur = nxt
|
| 470 |
+
loop.append(cur)
|
| 471 |
+
loops.append(loop)
|
| 472 |
+
|
| 473 |
+
# Use the longest loop
|
| 474 |
+
best = max(loops, key=len)
|
| 475 |
+
return mesh.vertices[best]
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _resample_loop(loop_pts, N):
|
| 479 |
+
"""Resample an ordered set of 3D points to exactly N evenly-spaced points."""
|
| 480 |
+
from scipy.interpolate import interp1d
|
| 481 |
+
# Arc-length parameterisation
|
| 482 |
+
diffs = np.diff(loop_pts, axis=0, prepend=loop_pts[-1:])
|
| 483 |
+
seg_lens = np.linalg.norm(diffs, axis=1)
|
| 484 |
+
t = np.cumsum(seg_lens)
|
| 485 |
+
t = np.insert(t, 0, 0)
|
| 486 |
+
t /= t[-1]
|
| 487 |
+
# Close the loop
|
| 488 |
+
t[-1] = 1.0
|
| 489 |
+
loop_closed = np.vstack([loop_pts, loop_pts[0]])
|
| 490 |
+
interp = interp1d(t, loop_closed, axis=0)
|
| 491 |
+
t_new = np.linspace(0, 1, N, endpoint=False)
|
| 492 |
+
return interp(t_new)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _bridge_loops(loop_a, loop_b):
|
| 496 |
+
"""
|
| 497 |
+
Create a triangle strip bridging two ordered loops of equal length N.
|
| 498 |
+
loop_a, loop_b: (N, 3) vertex positions
|
| 499 |
+
Returns (verts, faces) — just the bridge strip as a trimesh-ready array.
|
| 500 |
+
"""
|
| 501 |
+
N = len(loop_a)
|
| 502 |
+
verts = np.vstack([loop_a, loop_b]) # (2N, 3) — a:0..N-1, b:N..2N-1
|
| 503 |
+
faces = []
|
| 504 |
+
for i in range(N):
|
| 505 |
+
j = (i + 1) % N
|
| 506 |
+
ai, aj = i, j
|
| 507 |
+
bi, bj = i + N, j + N
|
| 508 |
+
faces.append([ai, aj, bi])
|
| 509 |
+
faces.append([aj, bj, bi])
|
| 510 |
+
return verts, np.array(faces, dtype=np.int32)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# ──────────────────────────────────────────────────────────────────
|
| 514 |
+
# DECA head → trimesh
|
| 515 |
+
# ──────────────────────────────────────────────────────────────────
|
| 516 |
+
def deca_to_trimesh(verts_np, faces_v, uv_verts, faces_uv, tex_bgr):
|
| 517 |
+
"""
|
| 518 |
+
Assemble a trimesh.Trimesh from DECA outputs with UV texture.
|
| 519 |
+
Uses per-vertex UV (averaged over face corners sharing each vertex).
|
| 520 |
+
"""
|
| 521 |
+
import trimesh
|
| 522 |
+
from trimesh.visual.texture import TextureVisuals
|
| 523 |
+
from trimesh.visual.material import PBRMaterial
|
| 524 |
+
|
| 525 |
+
# Average face-corner UVs per vertex
|
| 526 |
+
N = len(verts_np)
|
| 527 |
+
uv_sum = np.zeros((N, 2), dtype=np.float64)
|
| 528 |
+
uv_cnt = np.zeros(N, dtype=np.int32)
|
| 529 |
+
for fi in range(len(faces_v)):
|
| 530 |
+
for ci in range(3):
|
| 531 |
+
vi = faces_v[fi, ci]
|
| 532 |
+
uvi = faces_uv[fi, ci]
|
| 533 |
+
uv_sum[vi] += uv_verts[uvi]
|
| 534 |
+
uv_cnt[vi] += 1
|
| 535 |
+
uv_cnt = np.maximum(uv_cnt, 1)
|
| 536 |
+
uv_per_vert = (uv_sum / uv_cnt[:, None]).astype(np.float32)
|
| 537 |
+
|
| 538 |
+
mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_v, process=False)
|
| 539 |
+
|
| 540 |
+
tex_rgb = cv2.cvtColor(tex_bgr, cv2.COLOR_BGR2RGB)
|
| 541 |
+
tex_pil = Image.fromarray(tex_rgb)
|
| 542 |
+
|
| 543 |
+
try:
|
| 544 |
+
mat = PBRMaterial(baseColorTexture=tex_pil)
|
| 545 |
+
mesh.visual = TextureVisuals(uv=uv_per_vert, material=mat)
|
| 546 |
+
print(f'[deca_to_trimesh] UV attached: {uv_per_vert.shape}, tex={tex_rgb.shape}')
|
| 547 |
+
except Exception as e:
|
| 548 |
+
print(f'[deca_to_trimesh] UV attach failed ({e}) — using vertex colours')
|
| 549 |
+
mesh.visual.vertex_colors = np.tile([200, 175, 155, 255], (len(verts_np), 1))
|
| 550 |
+
|
| 551 |
+
return mesh
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# ──────────────────────────────────────────────────────────────────
|
| 555 |
+
# Main head-replacement function
|
| 556 |
+
# ──────────────────────────────────────────────────────────────────
|
| 557 |
+
def replace_head(body_glb: str, face_img_path: str, out_glb: str,
|
| 558 |
+
device: str = 'cuda', bridge_n: int = 64):
|
| 559 |
+
"""
|
| 560 |
+
Main entry point.
|
| 561 |
+
body_glb : path to TripoSG textured GLB
|
| 562 |
+
face_img_path : path to reference face image
|
| 563 |
+
out_glb : output path for combined GLB
|
| 564 |
+
bridge_n : number of vertices in the stitching ring
|
| 565 |
+
"""
|
| 566 |
+
import trimesh
|
| 567 |
+
import torch
|
| 568 |
+
|
| 569 |
+
# ── 1. Load body GLB ──────────────────────────────────────────
|
| 570 |
+
print('[replace_head] Loading body GLB...')
|
| 571 |
+
scene = trimesh.load(body_glb)
|
| 572 |
+
if isinstance(scene, trimesh.Scene):
|
| 573 |
+
body_mesh = trimesh.util.concatenate(
|
| 574 |
+
[g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)]
|
| 575 |
+
)
|
| 576 |
+
else:
|
| 577 |
+
body_mesh = scene
|
| 578 |
+
|
| 579 |
+
print(f' Body: {len(body_mesh.vertices)} verts, {len(body_mesh.faces)} faces')
|
| 580 |
+
|
| 581 |
+
# ── 1b. Weld body mesh (UV-split → geometric) ─────────────────
|
| 582 |
+
print('[replace_head] Welding mesh vertices...')
|
| 583 |
+
body_welded = _weld_mesh(body_mesh)
|
| 584 |
+
print(f' Welded: {len(body_welded.vertices)} verts (was {len(body_mesh.vertices)})')
|
| 585 |
+
|
| 586 |
+
# ── 2. Find neck cut height ───────────────────────────────────
|
| 587 |
+
neck_y = _find_neck_height(body_welded)
|
| 588 |
+
|
| 589 |
+
# ── 3. Cut body at neck ───────────────────────────────────────
|
| 590 |
+
print('[replace_head] Cutting body at neck...')
|
| 591 |
+
# Work on welded mesh for topology; keep original mesh for geometry export
|
| 592 |
+
body_lower_welded = _cut_mesh_below(body_welded, neck_y)
|
| 593 |
+
body_lower = _cut_mesh_below(body_mesh, neck_y) # keeps original UV/texture
|
| 594 |
+
print(f' Body lower: {len(body_lower.vertices)} verts')
|
| 595 |
+
|
| 596 |
+
# Extract neck ring geometrically (robust for non-manifold UV-split meshes)
|
| 597 |
+
body_neck_loop = _extract_neck_ring_geometric(body_welded, neck_y, n_pts=bridge_n)
|
| 598 |
+
print(f' Body neck ring: {len(body_neck_loop)} pts (geometric)')
|
| 599 |
+
|
| 600 |
+
# ── 4. Run DECA ───────────────────────────────────────────────
|
| 601 |
+
print('[replace_head] Running DECA...')
|
| 602 |
+
verts_np, cam_np, faces_v, uv_verts, faces_uv, tex_bgr = run_deca(face_img_path, device=device)
|
| 603 |
+
|
| 604 |
+
# ── 5. Align DECA head to body coordinate system ─────────────
|
| 605 |
+
# TripoSG body is roughly in [-1,1] world space (y-up)
|
| 606 |
+
# DECA/FLAME space: head centered around origin, scale ≈ 1.5-2.5 units for full head
|
| 607 |
+
# We need to:
|
| 608 |
+
# a) Scale the FLAME head to match body scale
|
| 609 |
+
# b) Position the FLAME head so its neck base aligns with body neck ring
|
| 610 |
+
|
| 611 |
+
# Get the bottom of the FLAME head (neck area)
|
| 612 |
+
# FLAME template: bottom vertices are the neck boundary ring
|
| 613 |
+
flame_mesh_tmp = __import__('trimesh').Trimesh(vertices=verts_np, faces=faces_v, process=False)
|
| 614 |
+
try:
|
| 615 |
+
flame_neck_loop = _extract_boundary_loop(flame_mesh_tmp)
|
| 616 |
+
print(f' FLAME neck ring (topology): {len(flame_neck_loop)} verts')
|
| 617 |
+
except Exception as e:
|
| 618 |
+
print(f' FLAME boundary loop failed ({e}), using geometric extraction')
|
| 619 |
+
# Geometric fallback: bottom 5% of head vertices
|
| 620 |
+
flame_neck_y = verts_np[:, 1].min() + (verts_np[:, 1].max() - verts_np[:, 1].min()) * 0.08
|
| 621 |
+
flame_neck_loop = _extract_neck_ring_geometric(flame_mesh_tmp, flame_neck_y, n_pts=bridge_n)
|
| 622 |
+
print(f' FLAME neck ring (geometric): {len(flame_neck_loop)} pts')
|
| 623 |
+
|
| 624 |
+
# ── 5b. Compute head position using NECK RING centroid ───────────────
|
| 625 |
+
# Directly align FLAME neck ring center → body neck ring center in all 3 axes.
|
| 626 |
+
# This is robust regardless of body pose or tilt.
|
| 627 |
+
body_neck_center = body_neck_loop.mean(axis=0)
|
| 628 |
+
|
| 629 |
+
# Estimate head height from WELDED mesh crown (more reliable than UV-split mesh)
|
| 630 |
+
welded_y_max = float(body_welded.vertices[:, 1].max())
|
| 631 |
+
body_head_height = welded_y_max - neck_y
|
| 632 |
+
|
| 633 |
+
flame_neck_center_unscaled = flame_neck_loop.mean(axis=0)
|
| 634 |
+
flame_y_min = verts_np[:, 1].min()
|
| 635 |
+
flame_y_max = verts_np[:, 1].max()
|
| 636 |
+
flame_head_height = flame_y_max - flame_y_min
|
| 637 |
+
|
| 638 |
+
print(f' Body neck center: {body_neck_center.round(4)}')
|
| 639 |
+
print(f' Body head space: {body_head_height:.4f} (neck_y={neck_y:.4f}, crown_y={welded_y_max:.4f})')
|
| 640 |
+
print(f' FLAME head height (unscaled): {flame_head_height:.4f}')
|
| 641 |
+
print(f' FLAME neck center (unscaled): {flame_neck_center_unscaled.round(4)}')
|
| 642 |
+
|
| 643 |
+
# Scale FLAME head to match body head height
|
| 644 |
+
if flame_head_height > 1e-5:
|
| 645 |
+
head_scale = body_head_height / flame_head_height
|
| 646 |
+
else:
|
| 647 |
+
head_scale = 1.0
|
| 648 |
+
print(f' Head scale: {head_scale:.4f}')
|
| 649 |
+
|
| 650 |
+
# Translate: FLAME neck ring center → body neck ring center in XZ,
|
| 651 |
+
# FLAME mesh bottom (flame_y_min) → neck_y in Y.
|
| 652 |
+
# This ensures the head fills the full space from neck_y to body crown.
|
| 653 |
+
translate = np.array([
|
| 654 |
+
body_neck_center[0] - flame_neck_center_unscaled[0] * head_scale,
|
| 655 |
+
neck_y - flame_y_min * head_scale,
|
| 656 |
+
body_neck_center[2] - flame_neck_center_unscaled[2] * head_scale,
|
| 657 |
+
])
|
| 658 |
+
print(f' Translate: {translate.round(4)}')
|
| 659 |
+
verts_aligned = verts_np * head_scale + translate
|
| 660 |
+
print(f' FLAME aligned y={verts_aligned[:,1].min():.4f}→{verts_aligned[:,1].max():.4f}'
|
| 661 |
+
f' x={verts_aligned[:,0].min():.4f}→{verts_aligned[:,0].max():.4f}'
|
| 662 |
+
f' z={verts_aligned[:,2].min():.4f}→{verts_aligned[:,2].max():.4f}')
|
| 663 |
+
|
| 664 |
+
# Extract FLAME neck loop after alignment (at the cut plane y=neck_y)
|
| 665 |
+
flame_verts_aligned = verts_aligned
|
| 666 |
+
flame_mesh_aligned = __import__('trimesh').Trimesh(
|
| 667 |
+
vertices=flame_verts_aligned, faces=faces_v, process=False)
|
| 668 |
+
try:
|
| 669 |
+
flame_neck_loop_aligned = _extract_boundary_loop(flame_mesh_aligned)
|
| 670 |
+
print(f' FLAME neck ring (topology): {len(flame_neck_loop_aligned)} verts')
|
| 671 |
+
except Exception:
|
| 672 |
+
flame_neck_y_aligned = flame_verts_aligned[:, 1].min() + (
|
| 673 |
+
flame_verts_aligned[:, 1].max() - flame_verts_aligned[:, 1].min()) * 0.05
|
| 674 |
+
flame_neck_loop_aligned = _extract_neck_ring_geometric(
|
| 675 |
+
flame_mesh_aligned, flame_neck_y_aligned, n_pts=bridge_n)
|
| 676 |
+
print(f' FLAME neck ring (geometric): {len(flame_neck_loop_aligned)} pts')
|
| 677 |
+
|
| 678 |
+
flame_neck_r = np.linalg.norm(flame_neck_loop_aligned - flame_neck_loop_aligned.mean(0), axis=1).mean()
|
| 679 |
+
body_neck_r = np.linalg.norm(body_neck_loop - body_neck_loop.mean(0), axis=1).mean()
|
| 680 |
+
print(f' Body neck radius: {body_neck_r:.4f} FLAME neck radius (scaled): {flame_neck_r:.4f}')
|
| 681 |
+
|
| 682 |
+
# ── 6. Resample both neck loops to bridge_n points ────────────
|
| 683 |
+
body_loop_r = _resample_loop(body_neck_loop, bridge_n)
|
| 684 |
+
flame_loop_r = _resample_loop(flame_neck_loop_aligned, bridge_n)
|
| 685 |
+
|
| 686 |
+
# Ensure loops are oriented consistently (both CW or both CCW)
|
| 687 |
+
# Compute signed area to check orientation
|
| 688 |
+
def _loop_orientation(loop):
|
| 689 |
+
c = loop.mean(0)
|
| 690 |
+
t = loop - c
|
| 691 |
+
cross = np.cross(t[:-1], t[1:])
|
| 692 |
+
return float(np.sum(cross[:, 1])) # y-component
|
| 693 |
+
|
| 694 |
+
o_body = _loop_orientation(body_loop_r)
|
| 695 |
+
o_flame = _loop_orientation(flame_loop_r)
|
| 696 |
+
if (o_body > 0) != (o_flame > 0):
|
| 697 |
+
flame_loop_r = flame_loop_r[::-1]
|
| 698 |
+
|
| 699 |
+
# ── 7. Align loop starting points (minimise bridge twist) ─────
|
| 700 |
+
# Match starting vertex: find flame loop point closest to body loop start
|
| 701 |
+
dists = np.linalg.norm(flame_loop_r - body_loop_r[0], axis=1)
|
| 702 |
+
best_offset = int(np.argmin(dists))
|
| 703 |
+
flame_loop_r = np.roll(flame_loop_r, -best_offset, axis=0)
|
| 704 |
+
|
| 705 |
+
# ── 8. Build bridge strip ─────────────────────────────────────
|
| 706 |
+
bridge_verts, bridge_faces = _bridge_loops(body_loop_r, flame_loop_r)
|
| 707 |
+
bridge_mesh = __import__('trimesh').Trimesh(vertices=bridge_verts, faces=bridge_faces, process=False)
|
| 708 |
+
|
| 709 |
+
# ── 9. Combine: body_lower + bridge + FLAME head ──────────────
|
| 710 |
+
# Build FLAME head mesh with texture
|
| 711 |
+
head_mesh = deca_to_trimesh(flame_verts_aligned, faces_v, uv_verts, faces_uv, tex_bgr)
|
| 712 |
+
|
| 713 |
+
# Combine all parts
|
| 714 |
+
combined = __import__('trimesh').util.concatenate([body_lower, bridge_mesh, head_mesh])
|
| 715 |
+
combined = __import__('trimesh').Trimesh(
|
| 716 |
+
vertices=combined.vertices,
|
| 717 |
+
faces=combined.faces,
|
| 718 |
+
process=False
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# Try to copy body texture to combined if available
|
| 722 |
+
try:
|
| 723 |
+
if hasattr(body_lower.visual, 'material'):
|
| 724 |
+
pass # Keep per-mesh materials — export as scene
|
| 725 |
+
except Exception:
|
| 726 |
+
pass
|
| 727 |
+
|
| 728 |
+
# ── 10. Export ────────────────────────────────────────────────
|
| 729 |
+
print(f'[replace_head] Exporting combined mesh: {len(combined.vertices)} verts...')
|
| 730 |
+
os.makedirs(os.path.dirname(out_glb) or '.', exist_ok=True)
|
| 731 |
+
|
| 732 |
+
# Export as GLB scene with separate submeshes (preserves textures)
|
| 733 |
+
try:
|
| 734 |
+
import trimesh
|
| 735 |
+
scene_out = trimesh.Scene()
|
| 736 |
+
scene_out.add_geometry(body_lower, geom_name='body')
|
| 737 |
+
scene_out.add_geometry(bridge_mesh, geom_name='bridge')
|
| 738 |
+
scene_out.add_geometry(head_mesh, geom_name='head')
|
| 739 |
+
scene_out.export(out_glb)
|
| 740 |
+
print(f'[replace_head] Saved scene GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
|
| 741 |
+
except Exception as e:
|
| 742 |
+
print(f'[replace_head] Scene export failed ({e}), trying single mesh...')
|
| 743 |
+
combined.export(out_glb)
|
| 744 |
+
print(f'[replace_head] Saved GLB: {out_glb} ({os.path.getsize(out_glb)//1024} KB)')
|
| 745 |
+
|
| 746 |
+
return out_glb
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# ──────────────────────────────────────────────────────────────────
|
| 750 |
+
# CLI
|
| 751 |
+
# ──────────────────────────────────────────────────────────────────
|
| 752 |
+
if __name__ == '__main__':
|
| 753 |
+
ap = argparse.ArgumentParser()
|
| 754 |
+
ap.add_argument('--body', required=True, help='TripoSG body GLB path')
|
| 755 |
+
ap.add_argument('--face', required=True, help='Reference face image path')
|
| 756 |
+
ap.add_argument('--out', required=True, help='Output GLB path')
|
| 757 |
+
ap.add_argument('--bridge', type=int, default=64, help='Bridge ring vertex count')
|
| 758 |
+
ap.add_argument('--cpu', action='store_true', help='Use CPU instead of CUDA')
|
| 759 |
+
args = ap.parse_args()
|
| 760 |
+
|
| 761 |
+
device = 'cpu' if args.cpu else ('cuda' if __import__('torch').cuda.is_available() else 'cpu')
|
| 762 |
+
replace_head(args.body, args.face, args.out, device=device, bridge_n=args.bridge)
|
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pshuman_client.py
|
| 3 |
+
=================
|
| 4 |
+
Call PSHuman to generate a high-detail 3D face mesh from a portrait image.
|
| 5 |
+
|
| 6 |
+
Two modes:
|
| 7 |
+
- Direct (default when service_url is localhost): runs PSHuman inference.py
|
| 8 |
+
as a subprocess without going through Gradio HTTP. Avoids the gradio_client
|
| 9 |
+
API-info bug that affects the pshuman Gradio env.
|
| 10 |
+
- Remote: uses gradio_client to call a running pshuman_app.py service.
|
| 11 |
+
|
| 12 |
+
Usage (standalone)
|
| 13 |
+
------------------
|
| 14 |
+
python -m pipeline.pshuman_client \\
|
| 15 |
+
--image /path/to/portrait.png \\
|
| 16 |
+
--output /tmp/pshuman_face.obj \\
|
| 17 |
+
[--url http://remote-host:7862] # omit for direct/local mode
|
| 18 |
+
|
| 19 |
+
Requires: gradio-client (remote mode only)
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import glob
|
| 25 |
+
import os
|
| 26 |
+
import shutil
|
| 27 |
+
import subprocess
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
# Default: assume running on the same instance (local)
|
| 32 |
+
_DEFAULT_URL = os.environ.get("PSHUMAN_URL", "http://localhost:7862")
|
| 33 |
+
|
| 34 |
+
# ── Paths (on the Vast instance) ──────────────────────────────────────────────
|
| 35 |
+
PSHUMAN_DIR = "/root/PSHuman"
|
| 36 |
+
CONDA_PYTHON = "/root/miniconda/envs/pshuman/bin/python"
|
| 37 |
+
CONFIG = f"{PSHUMAN_DIR}/configs/inference-768-6view.yaml"
|
| 38 |
+
HF_MODEL_DIR = f"{PSHUMAN_DIR}/checkpoints/PSHuman_Unclip_768_6views"
|
| 39 |
+
HF_MODEL_HUB = "pengHTYX/PSHuman_Unclip_768_6views"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _run_pshuman_direct(image_path: str, work_dir: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Run PSHuman inference.py directly as a subprocess.
|
| 45 |
+
Returns path to the colored OBJ mesh.
|
| 46 |
+
"""
|
| 47 |
+
img_dir = os.path.join(work_dir, "input")
|
| 48 |
+
out_dir = os.path.join(work_dir, "out")
|
| 49 |
+
os.makedirs(img_dir, exist_ok=True)
|
| 50 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
scene = "face"
|
| 53 |
+
dst = os.path.join(img_dir, f"{scene}.png")
|
| 54 |
+
shutil.copy(image_path, dst)
|
| 55 |
+
|
| 56 |
+
hf_model = HF_MODEL_DIR if Path(HF_MODEL_DIR).exists() else HF_MODEL_HUB
|
| 57 |
+
|
| 58 |
+
cmd = [
|
| 59 |
+
CONDA_PYTHON, f"{PSHUMAN_DIR}/inference.py",
|
| 60 |
+
"--config", CONFIG,
|
| 61 |
+
f"pretrained_model_name_or_path={hf_model}",
|
| 62 |
+
f"validation_dataset.root_dir={img_dir}",
|
| 63 |
+
f"save_dir={out_dir}",
|
| 64 |
+
"validation_dataset.crop_size=740",
|
| 65 |
+
"with_smpl=false",
|
| 66 |
+
"num_views=7",
|
| 67 |
+
"save_mode=rgb",
|
| 68 |
+
"seed=42",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
print(f"[pshuman] Running direct inference: {' '.join(cmd[:4])} ...")
|
| 72 |
+
t0 = time.time()
|
| 73 |
+
|
| 74 |
+
# Set CUDA_HOME + extra include dirs so nvdiffrast/torch JIT can compile.
|
| 75 |
+
# On Vast.ai, triposg conda env ships nvcc at bin/nvcc and CUDA headers
|
| 76 |
+
# scattered across site-packages/nvidia/{pkg}/include/ directories.
|
| 77 |
+
env = os.environ.copy()
|
| 78 |
+
if "CUDA_HOME" not in env:
|
| 79 |
+
_triposg = "/root/miniconda/envs/triposg"
|
| 80 |
+
_targets = os.path.join(_triposg, "targets", "x86_64-linux")
|
| 81 |
+
_nvcc_bin = os.path.join(_triposg, "bin")
|
| 82 |
+
_cuda_home = _targets # has include/cuda_runtime_api.h
|
| 83 |
+
|
| 84 |
+
_nvvm_bin = os.path.join(_triposg, "nvvm", "bin") # contains cicc
|
| 85 |
+
_nvcc_real = os.path.join(_targets, "bin") # contains nvcc (real one)
|
| 86 |
+
|
| 87 |
+
if (os.path.exists(os.path.join(_cuda_home, "include", "cuda_runtime_api.h"))
|
| 88 |
+
and (os.path.exists(os.path.join(_nvcc_bin, "nvcc"))
|
| 89 |
+
or os.path.exists(os.path.join(_nvcc_real, "nvcc")))):
|
| 90 |
+
env["CUDA_HOME"] = _cuda_home
|
| 91 |
+
# Build PATH: nvvm/bin (cicc) + targets/.../bin (nvcc real) + conda bin (nvcc wrapper)
|
| 92 |
+
path_parts = []
|
| 93 |
+
if os.path.isdir(_nvvm_bin):
|
| 94 |
+
path_parts.append(_nvvm_bin)
|
| 95 |
+
if os.path.isdir(_nvcc_real):
|
| 96 |
+
path_parts.append(_nvcc_real)
|
| 97 |
+
path_parts.append(_nvcc_bin)
|
| 98 |
+
env["PATH"] = ":".join(path_parts) + ":" + env.get("PATH", "")
|
| 99 |
+
|
| 100 |
+
# Collect all nvidia sub-package include dirs (cusparse, cublas, etc.)
|
| 101 |
+
_nvidia_site = os.path.join(_triposg, "lib", "python3.10",
|
| 102 |
+
"site-packages", "nvidia")
|
| 103 |
+
_extra_incs = []
|
| 104 |
+
if os.path.isdir(_nvidia_site):
|
| 105 |
+
import glob as _glob
|
| 106 |
+
for _inc in _glob.glob(os.path.join(_nvidia_site, "*/include")):
|
| 107 |
+
if os.path.isdir(_inc):
|
| 108 |
+
_extra_incs.append(_inc)
|
| 109 |
+
if _extra_incs:
|
| 110 |
+
_sep = ":"
|
| 111 |
+
_existing = env.get("CPATH", "")
|
| 112 |
+
env["CPATH"] = _sep.join(_extra_incs) + (_sep + _existing if _existing else "")
|
| 113 |
+
print(f"[pshuman] CUDA_HOME={_cuda_home}, {len(_extra_incs)} nvidia include dirs added")
|
| 114 |
+
|
| 115 |
+
proc = subprocess.run(
|
| 116 |
+
cmd, cwd=PSHUMAN_DIR,
|
| 117 |
+
capture_output=False,
|
| 118 |
+
text=True,
|
| 119 |
+
timeout=600,
|
| 120 |
+
env=env,
|
| 121 |
+
)
|
| 122 |
+
elapsed = time.time() - t0
|
| 123 |
+
print(f"[pshuman] Inference done in {elapsed:.1f}s (exit={proc.returncode})")
|
| 124 |
+
|
| 125 |
+
if proc.returncode != 0:
|
| 126 |
+
raise RuntimeError(f"PSHuman inference failed (exit {proc.returncode})")
|
| 127 |
+
|
| 128 |
+
# Locate output OBJ — PSHuman may save relative to its CWD (/root/PSHuman/out/)
|
| 129 |
+
# rather than to the specified save_dir, so check both locations.
|
| 130 |
+
cwd_out_dir = os.path.join(PSHUMAN_DIR, "out", scene)
|
| 131 |
+
patterns = [
|
| 132 |
+
f"{out_dir}/{scene}/result_clr_scale4_{scene}.obj",
|
| 133 |
+
f"{out_dir}/{scene}/result_clr_scale*_{scene}.obj",
|
| 134 |
+
f"{out_dir}/**/*.obj",
|
| 135 |
+
f"{cwd_out_dir}/result_clr_scale*_{scene}.obj",
|
| 136 |
+
f"{cwd_out_dir}/*.obj",
|
| 137 |
+
f"{PSHUMAN_DIR}/out/**/*.obj",
|
| 138 |
+
]
|
| 139 |
+
obj_path = None
|
| 140 |
+
for pat in patterns:
|
| 141 |
+
hits = sorted(glob.glob(pat, recursive=True))
|
| 142 |
+
if hits:
|
| 143 |
+
colored = [h for h in hits if "clr" in h]
|
| 144 |
+
obj_path = (colored or hits)[-1]
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
if not obj_path:
|
| 148 |
+
all_files = list(Path(out_dir).rglob("*"))
|
| 149 |
+
objs = [str(f) for f in all_files if f.suffix in (".obj", ".ply", ".glb")]
|
| 150 |
+
if objs:
|
| 151 |
+
obj_path = objs[-1]
|
| 152 |
+
if not obj_path and Path(cwd_out_dir).exists():
|
| 153 |
+
for f in Path(cwd_out_dir).rglob("*.obj"):
|
| 154 |
+
obj_path = str(f)
|
| 155 |
+
break
|
| 156 |
+
if not obj_path:
|
| 157 |
+
raise FileNotFoundError(
|
| 158 |
+
f"No mesh output found in {out_dir}. "
|
| 159 |
+
f"Files: {[str(f) for f in all_files[:20]]}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
print(f"[pshuman] Output mesh: {obj_path}")
|
| 163 |
+
return obj_path
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def generate_pshuman_mesh(
|
| 167 |
+
image_path: str,
|
| 168 |
+
output_path: str,
|
| 169 |
+
service_url: str = _DEFAULT_URL,
|
| 170 |
+
timeout: float = 600.0,
|
| 171 |
+
) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Generate a PSHuman face mesh and save it to *output_path*.
|
| 174 |
+
|
| 175 |
+
When service_url points to localhost, PSHuman inference.py is run directly
|
| 176 |
+
(no Gradio HTTP, avoids gradio_client API-info bug).
|
| 177 |
+
For remote URLs, gradio_client is used.
|
| 178 |
+
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
image_path : local PNG/JPG path of the portrait
|
| 182 |
+
output_path : where to save the downloaded OBJ
|
| 183 |
+
service_url : base URL of pshuman_app.py, or "direct" to skip HTTP
|
| 184 |
+
timeout : seconds to wait for inference (used in remote mode)
|
| 185 |
+
|
| 186 |
+
Returns
|
| 187 |
+
-------
|
| 188 |
+
output_path (convenience)
|
| 189 |
+
"""
|
| 190 |
+
import tempfile
|
| 191 |
+
|
| 192 |
+
output_path = str(output_path)
|
| 193 |
+
os.makedirs(Path(output_path).parent, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
is_local = (
|
| 196 |
+
"localhost" in service_url
|
| 197 |
+
or "127.0.0.1" in service_url
|
| 198 |
+
or service_url.strip().lower() == "direct"
|
| 199 |
+
or not service_url.strip()
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if is_local:
|
| 203 |
+
# ── Direct mode: run subprocess ───────────────────────────────────────
|
| 204 |
+
print(f"[pshuman] Direct mode (no HTTP) — running inference on {image_path}")
|
| 205 |
+
work_dir = tempfile.mkdtemp(prefix="pshuman_direct_")
|
| 206 |
+
obj_tmp = _run_pshuman_direct(image_path, work_dir)
|
| 207 |
+
else:
|
| 208 |
+
# ── Remote mode: call Gradio service ──────────────────────────────────
|
| 209 |
+
try:
|
| 210 |
+
from gradio_client import Client
|
| 211 |
+
except ImportError:
|
| 212 |
+
raise ImportError("pip install gradio-client")
|
| 213 |
+
|
| 214 |
+
print(f"[pshuman] Connecting to {service_url}")
|
| 215 |
+
client = Client(service_url)
|
| 216 |
+
|
| 217 |
+
print(f"[pshuman] Submitting: {image_path}")
|
| 218 |
+
result = client.predict(
|
| 219 |
+
image=image_path,
|
| 220 |
+
api_name="/gradio_generate_face",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if isinstance(result, (list, tuple)):
|
| 224 |
+
obj_tmp = result[0]
|
| 225 |
+
status = result[1] if len(result) > 1 else "ok"
|
| 226 |
+
elif isinstance(result, dict):
|
| 227 |
+
obj_tmp = result.get("obj_path") or result.get("value")
|
| 228 |
+
status = result.get("status", "ok")
|
| 229 |
+
else:
|
| 230 |
+
obj_tmp = result
|
| 231 |
+
status = "ok"
|
| 232 |
+
|
| 233 |
+
if not obj_tmp or "Error" in str(status):
|
| 234 |
+
raise RuntimeError(f"PSHuman service error: {status}")
|
| 235 |
+
|
| 236 |
+
if isinstance(obj_tmp, dict):
|
| 237 |
+
obj_tmp = obj_tmp.get("path") or obj_tmp.get("name") or str(obj_tmp)
|
| 238 |
+
|
| 239 |
+
work_dir = str(Path(str(obj_tmp)).parent)
|
| 240 |
+
|
| 241 |
+
# ── Copy OBJ + companions to output location ───────────────────────────
|
| 242 |
+
shutil.copy(str(obj_tmp), output_path)
|
| 243 |
+
print(f"[pshuman] Saved OBJ -> {output_path}")
|
| 244 |
+
|
| 245 |
+
src_dir = Path(str(obj_tmp)).parent
|
| 246 |
+
out_dir = Path(output_path).parent
|
| 247 |
+
for ext in ("*.mtl", "*.png", "*.jpg"):
|
| 248 |
+
for f in src_dir.glob(ext):
|
| 249 |
+
dest = out_dir / f.name
|
| 250 |
+
if not dest.exists():
|
| 251 |
+
shutil.copy(str(f), str(dest))
|
| 252 |
+
|
| 253 |
+
return output_path
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
# CLI
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
parser = argparse.ArgumentParser(
|
| 262 |
+
description="Generate PSHuman face mesh from portrait image"
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument("--image", required=True, help="Portrait image path")
|
| 265 |
+
parser.add_argument("--output", required=True, help="Output OBJ path")
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--url", default=_DEFAULT_URL,
|
| 268 |
+
help="PSHuman service URL, or 'direct' to run inference locally "
|
| 269 |
+
"(default: http://localhost:7862 → auto-selects direct mode)",
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument("--timeout", type=float, default=600.0)
|
| 272 |
+
args = parser.parse_args()
|
| 273 |
+
|
| 274 |
+
generate_pshuman_mesh(
|
| 275 |
+
image_path = args.image,
|
| 276 |
+
output_path = args.output,
|
| 277 |
+
service_url = args.url,
|
| 278 |
+
timeout = args.timeout,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, cv2
|
| 2 |
+
sys.path.insert(0, '/root/MV-Adapter')
|
| 3 |
+
import numpy as np, torch
|
| 4 |
+
from mvadapter.utils.mesh_utils import NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render
|
| 5 |
+
|
| 6 |
+
glb = sys.argv[1]
|
| 7 |
+
out = sys.argv[2]
|
| 8 |
+
device = 'cuda'
|
| 9 |
+
ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
|
| 10 |
+
mesh = load_mesh(glb, rescale=True, device=device)
|
| 11 |
+
|
| 12 |
+
views = [('front',-90),('right',-180),('back',-270),('left',0)]
|
| 13 |
+
imgs = []
|
| 14 |
+
for name, az in views:
|
| 15 |
+
cam = get_orthogonal_camera(elevation_deg=[0], distance=[1.8],
|
| 16 |
+
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
|
| 17 |
+
azimuth_deg=[az], device=device)
|
| 18 |
+
r = render(ctx, mesh, cam, height=512, width=384,
|
| 19 |
+
render_attr=True, render_depth=False, render_normal=False, attr_background=0.15)
|
| 20 |
+
img = (r.attr[0].cpu().numpy()*255).clip(0,255).astype('uint8')
|
| 21 |
+
imgs.append(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 22 |
+
|
| 23 |
+
grid = np.concatenate(imgs, axis=1)
|
| 24 |
+
cv2.imwrite(out, grid)
|
| 25 |
+
print(f'Saved {grid.shape[1]}x{grid.shape[0]} grid to {out}')
|
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tpose.py — T-pose a humanoid GLB using YOLO pose estimation.
|
| 3 |
+
|
| 4 |
+
Pipeline:
|
| 5 |
+
1. Render the mesh from front view (azimuth=-90)
|
| 6 |
+
2. Run YOLOv8-pose to get 17 COCO keypoints in render-space
|
| 7 |
+
3. Unproject keypoints through the orthographic camera to 3D
|
| 8 |
+
4. Build Blender armature with bones at detected 3D joint positions (current pose)
|
| 9 |
+
5. Auto-weight skin the mesh to this armature
|
| 10 |
+
6. Rotate arm/leg bones to T-pose, apply deformation, export
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
blender --background --python tpose.py -- <input.glb> <output.glb>
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import bpy, sys, math, mathutils, os, tempfile
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# ── Args ─────────────────────────────────────────────────────────────────────
|
| 20 |
+
argv = sys.argv
|
| 21 |
+
argv = argv[argv.index("--") + 1:] if "--" in argv else []
|
| 22 |
+
if len(argv) < 2:
|
| 23 |
+
print("Usage: blender --background --python tpose.py -- input.glb output.glb")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
input_glb = argv[0]
|
| 26 |
+
output_glb = argv[1]
|
| 27 |
+
|
| 28 |
+
# ── Step 1: Render front view using nvdiffrast (outside Blender) ───────────────
|
| 29 |
+
# We do this via a subprocess call before Blender scene setup,
|
| 30 |
+
# using the triposg Python env which has MV-Adapter + nvdiffrast.
|
| 31 |
+
import subprocess, json
|
| 32 |
+
|
| 33 |
+
TRIPOSG_PYTHON = '/root/miniconda/envs/triposg/bin/python'
|
| 34 |
+
RENDER_SCRIPT = '/tmp/_tpose_render.py'
|
| 35 |
+
RENDER_OUT = '/tmp/_tpose_front.png'
|
| 36 |
+
KP_OUT = '/tmp/_tpose_kp.json'
|
| 37 |
+
|
| 38 |
+
render_code = r"""
|
| 39 |
+
import sys, json
|
| 40 |
+
sys.path.insert(0, '/root/MV-Adapter')
|
| 41 |
+
import numpy as np, cv2, torch
|
| 42 |
+
from mvadapter.utils.mesh_utils import (
|
| 43 |
+
NVDiffRastContextWrapper, load_mesh, get_orthogonal_camera, render,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
body_glb = sys.argv[1]
|
| 47 |
+
out_png = sys.argv[2]
|
| 48 |
+
|
| 49 |
+
device = 'cuda'
|
| 50 |
+
ctx = NVDiffRastContextWrapper(device=device, context_type='cuda')
|
| 51 |
+
mesh_mv = load_mesh(body_glb, rescale=True, device=device)
|
| 52 |
+
camera = get_orthogonal_camera(
|
| 53 |
+
elevation_deg=[0], distance=[1.8],
|
| 54 |
+
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
|
| 55 |
+
azimuth_deg=[-90], device=device,
|
| 56 |
+
)
|
| 57 |
+
out = render(ctx, mesh_mv, camera, height=1024, width=768,
|
| 58 |
+
render_attr=True, render_depth=False, render_normal=False,
|
| 59 |
+
attr_background=0.5)
|
| 60 |
+
img_np = (out.attr[0].cpu().numpy() * 255).clip(0,255).astype('uint8')
|
| 61 |
+
cv2.imwrite(out_png, cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
|
| 62 |
+
print(f"Rendered to {out_png}")
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
with open(RENDER_SCRIPT, 'w') as f:
|
| 66 |
+
f.write(render_code)
|
| 67 |
+
|
| 68 |
+
print("[tpose] Rendering front view ...")
|
| 69 |
+
r = subprocess.run([TRIPOSG_PYTHON, RENDER_SCRIPT, input_glb, RENDER_OUT],
|
| 70 |
+
capture_output=True, text=True)
|
| 71 |
+
print(r.stdout.strip()); print(r.stderr[-500:] if r.stderr else '')
|
| 72 |
+
|
| 73 |
+
# ── Step 2: YOLO pose estimation ──────────────────────────────────────────────
|
| 74 |
+
YOLO_SCRIPT = '/tmp/_tpose_yolo.py'
|
| 75 |
+
yolo_code = r"""
|
| 76 |
+
import sys, json
|
| 77 |
+
import cv2
|
| 78 |
+
from ultralytics import YOLO
|
| 79 |
+
import numpy as np
|
| 80 |
+
|
| 81 |
+
img_path = sys.argv[1]
|
| 82 |
+
kp_path = sys.argv[2]
|
| 83 |
+
|
| 84 |
+
model = YOLO('yolov8n-pose.pt')
|
| 85 |
+
img = cv2.imread(img_path)
|
| 86 |
+
H, W = img.shape[:2]
|
| 87 |
+
|
| 88 |
+
results = model(img, verbose=False)
|
| 89 |
+
if not results or results[0].keypoints is None:
|
| 90 |
+
print("ERROR: no person detected"); sys.exit(1)
|
| 91 |
+
|
| 92 |
+
# Pick detection with highest confidence
|
| 93 |
+
kps_all = results[0].keypoints.data.cpu().numpy() # (N, 17, 3)
|
| 94 |
+
confs = kps_all[:, :, 2].mean(axis=1)
|
| 95 |
+
best = kps_all[confs.argmax()] # (17, 3): x, y, conf
|
| 96 |
+
|
| 97 |
+
# COCO 17 keypoints:
|
| 98 |
+
# 0=nose 1=left_eye 2=right_eye 3=left_ear 4=right_ear
|
| 99 |
+
# 5=left_shoulder 6=right_shoulder 7=left_elbow 8=right_elbow
|
| 100 |
+
# 9=left_wrist 10=right_wrist 11=left_hip 12=right_hip
|
| 101 |
+
# 13=left_knee 14=right_knee 15=left_ankle 16=right_ankle
|
| 102 |
+
|
| 103 |
+
names = ['nose','left_eye','right_eye','left_ear','right_ear',
|
| 104 |
+
'left_shoulder','right_shoulder','left_elbow','right_elbow',
|
| 105 |
+
'left_wrist','right_wrist','left_hip','right_hip',
|
| 106 |
+
'left_knee','right_knee','left_ankle','right_ankle']
|
| 107 |
+
|
| 108 |
+
kp_dict = {}
|
| 109 |
+
for i, name in enumerate(names):
|
| 110 |
+
x, y, c = best[i]
|
| 111 |
+
kp_dict[name] = {'x': float(x)/W, 'y': float(y)/H, 'conf': float(c)}
|
| 112 |
+
print(f" {name}: ({x:.1f},{y:.1f}) conf={c:.2f}")
|
| 113 |
+
|
| 114 |
+
kp_dict['img_hw'] = [int(H), int(W)]
|
| 115 |
+
with open(kp_path, 'w') as f:
|
| 116 |
+
json.dump(kp_dict, f)
|
| 117 |
+
print(f"Keypoints saved to {kp_path}")
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
with open(YOLO_SCRIPT, 'w') as f:
|
| 121 |
+
f.write(yolo_code)
|
| 122 |
+
|
| 123 |
+
print("[tpose] Running YOLO pose estimation ...")
|
| 124 |
+
r2 = subprocess.run([TRIPOSG_PYTHON, YOLO_SCRIPT, RENDER_OUT, KP_OUT],
|
| 125 |
+
capture_output=True, text=True)
|
| 126 |
+
print(r2.stdout.strip()); print(r2.stderr[-300:] if r2.stderr else '')
|
| 127 |
+
|
| 128 |
+
if not os.path.exists(KP_OUT):
|
| 129 |
+
print("ERROR: YOLO failed — falling back to heuristic")
|
| 130 |
+
kp_data = None
|
| 131 |
+
else:
|
| 132 |
+
with open(KP_OUT) as f:
|
| 133 |
+
kp_data = json.load(f)
|
| 134 |
+
|
| 135 |
+
# ── Step 3: Unproject render-space keypoints to 3D ────────────────────────────
|
| 136 |
+
# Orthographic camera: left=-0.55, right=0.55, bottom=-0.55, top=0.55
|
| 137 |
+
# Render: 768×1024. NDC x = 2*(px/W)-1, ndc y = 1-2*(py/H)
|
| 138 |
+
# World X = ndc_x * 0.55, World Y (mesh up) = ndc_y * 0.55
|
| 139 |
+
# We need 3D positions in the ORIGINAL mesh coordinate space.
|
| 140 |
+
# After Blender GLB import, original mesh Y → Blender Z, original Z → Blender -Y
|
| 141 |
+
|
| 142 |
+
def kp_to_3d(name, z_default=0.0):
|
| 143 |
+
"""Convert YOLO keypoint (image fraction) → Blender 3D coords."""
|
| 144 |
+
if kp_data is None or name not in kp_data:
|
| 145 |
+
return None
|
| 146 |
+
k = kp_data[name]
|
| 147 |
+
if k['conf'] < 0.3:
|
| 148 |
+
return None
|
| 149 |
+
# Image coords (fractions) → NDC
|
| 150 |
+
ndc_x = 2 * k['x'] - 1.0 # left→right = mesh X
|
| 151 |
+
ndc_y = -(2 * k['y'] - 1.0) # top→bottom = mesh Y (up)
|
| 152 |
+
# Orthographic: frustum ±0.55
|
| 153 |
+
mesh_x = ndc_x * 0.55
|
| 154 |
+
mesh_y = ndc_y * 0.55 # this is mesh-space Y (vertical)
|
| 155 |
+
# After GLB import: mesh Y → Blender Z, mesh Z → Blender -Y
|
| 156 |
+
bl_x = mesh_x
|
| 157 |
+
bl_z = mesh_y # height
|
| 158 |
+
bl_y = z_default # depth (not observable from front view)
|
| 159 |
+
return (bl_x, bl_y, bl_z)
|
| 160 |
+
|
| 161 |
+
# Key joint positions in Blender space
|
| 162 |
+
J = {}
|
| 163 |
+
for name in ['nose','left_shoulder','right_shoulder','left_elbow','right_elbow',
|
| 164 |
+
'left_wrist','right_wrist','left_hip','right_hip',
|
| 165 |
+
'left_knee','right_knee','left_ankle','right_ankle']:
|
| 166 |
+
p = kp_to_3d(name)
|
| 167 |
+
if p: J[name] = p
|
| 168 |
+
|
| 169 |
+
print(f"[tpose] Detected joints: {list(J.keys())}")
|
| 170 |
+
|
| 171 |
+
# ── Step 4: Set up Blender scene ──────────────────────────────────────────────
|
| 172 |
+
bpy.ops.wm.read_factory_settings(use_empty=True)
|
| 173 |
+
bpy.ops.import_scene.gltf(filepath=input_glb)
|
| 174 |
+
bpy.context.view_layer.update()
|
| 175 |
+
|
| 176 |
+
mesh_obj = next((o for o in bpy.data.objects if o.type == 'MESH'), None)
|
| 177 |
+
if not mesh_obj:
|
| 178 |
+
print("ERROR: no mesh"); sys.exit(1)
|
| 179 |
+
|
| 180 |
+
verts_w = np.array([mesh_obj.matrix_world @ v.co for v in mesh_obj.data.vertices])
|
| 181 |
+
z_min, z_max = verts_w[:,2].min(), verts_w[:,2].max()
|
| 182 |
+
x_c = (verts_w[:,0].min() + verts_w[:,0].max()) / 2
|
| 183 |
+
y_c = (verts_w[:,1].min() + verts_w[:,1].max()) / 2
|
| 184 |
+
H_mesh = z_max - z_min
|
| 185 |
+
|
| 186 |
+
def zh(frac): return z_min + frac * H_mesh
|
| 187 |
+
def jv(name, fallback_frac=None, fallback_x=0.0):
|
| 188 |
+
"""Get joint position from YOLO or use fallback."""
|
| 189 |
+
if name in J:
|
| 190 |
+
x, y, z = J[name]
|
| 191 |
+
return (x, y_c, z) # use mesh y_c for depth
|
| 192 |
+
if fallback_frac is not None:
|
| 193 |
+
return (x_c + fallback_x, y_c, zh(fallback_frac))
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
# ── Step 5: Build armature in CURRENT pose ────────────────────────────────────
|
| 197 |
+
bpy.ops.object.armature_add(location=(x_c, y_c, zh(0.5)))
|
| 198 |
+
arm_obj = bpy.context.object
|
| 199 |
+
arm_obj.name = 'PoseRig'
|
| 200 |
+
arm = arm_obj.data
|
| 201 |
+
|
| 202 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
| 203 |
+
eb = arm.edit_bones
|
| 204 |
+
|
| 205 |
+
def V(xyz): return mathutils.Vector(xyz)
|
| 206 |
+
|
| 207 |
+
def add_bone(name, head, tail, parent=None, connect=False):
|
| 208 |
+
b = eb.new(name)
|
| 209 |
+
b.head = V(head)
|
| 210 |
+
b.tail = V(tail)
|
| 211 |
+
if parent and parent in eb:
|
| 212 |
+
b.parent = eb[parent]
|
| 213 |
+
b.use_connect = connect
|
| 214 |
+
return b
|
| 215 |
+
|
| 216 |
+
# Helper: midpoint
|
| 217 |
+
def mid(a, b): return tuple((a[i]+b[i])/2 for i in range(3))
|
| 218 |
+
def offset(p, dx=0, dy=0, dz=0): return (p[0]+dx, p[1]+dy, p[2]+dz)
|
| 219 |
+
|
| 220 |
+
# ── Spine / hips ─────────────────────────────────────────────────────────────
|
| 221 |
+
hip_L = jv('left_hip', 0.48, -0.07)
|
| 222 |
+
hip_R = jv('right_hip', 0.48, 0.07)
|
| 223 |
+
sh_L = jv('left_shoulder', 0.77, -0.20)
|
| 224 |
+
sh_R = jv('right_shoulder', 0.77, 0.20)
|
| 225 |
+
nose = jv('nose', 0.92)
|
| 226 |
+
|
| 227 |
+
hips_c = mid(hip_L, hip_R) if (hip_L and hip_R) else (x_c, y_c, zh(0.48))
|
| 228 |
+
sh_c = mid(sh_L, sh_R) if (sh_L and sh_R) else (x_c, y_c, zh(0.77))
|
| 229 |
+
|
| 230 |
+
add_bone('Hips', hips_c, offset(hips_c, dz=H_mesh*0.08))
|
| 231 |
+
add_bone('Spine', hips_c, offset(hips_c, dz=(sh_c[2]-hips_c[2])*0.5), 'Hips')
|
| 232 |
+
add_bone('Chest', offset(hips_c, dz=(sh_c[2]-hips_c[2])*0.5), sh_c, 'Spine', True)
|
| 233 |
+
if nose:
|
| 234 |
+
neck_z = sh_c[2] + (nose[2]-sh_c[2])*0.35
|
| 235 |
+
head_z = sh_c[2] + (nose[2]-sh_c[2])*0.65
|
| 236 |
+
add_bone('Neck', (x_c, y_c, neck_z), (x_c, y_c, head_z), 'Chest')
|
| 237 |
+
add_bone('Head', (x_c, y_c, head_z), (x_c, y_c, nose[2]+H_mesh*0.05), 'Neck', True)
|
| 238 |
+
else:
|
| 239 |
+
add_bone('Neck', sh_c, offset(sh_c, dz=H_mesh*0.06), 'Chest')
|
| 240 |
+
add_bone('Head', offset(sh_c, dz=H_mesh*0.06), offset(sh_c, dz=H_mesh*0.14), 'Neck', True)
|
| 241 |
+
|
| 242 |
+
# ── Arms (placed at DETECTED current pose positions) ─────────────────────────
|
| 243 |
+
el_L = jv('left_elbow', 0.60, -0.30)
|
| 244 |
+
el_R = jv('right_elbow', 0.60, 0.30)
|
| 245 |
+
wr_L = jv('left_wrist', 0.45, -0.25)
|
| 246 |
+
wr_R = jv('right_wrist', 0.45, 0.25)
|
| 247 |
+
|
| 248 |
+
for side, sh, el, wr in (('L', sh_L, el_L, wr_L), ('R', sh_R, el_R, wr_R)):
|
| 249 |
+
if not sh: continue
|
| 250 |
+
el_pos = el if el else offset(sh, dz=-H_mesh*0.15)
|
| 251 |
+
wr_pos = wr if wr else offset(el_pos, dz=-H_mesh*0.15)
|
| 252 |
+
hand = offset(wr_pos, dz=(wr_pos[2]-el_pos[2])*0.4)
|
| 253 |
+
add_bone(f'UpperArm.{side}', sh, el_pos, 'Chest')
|
| 254 |
+
add_bone(f'ForeArm.{side}', el_pos, wr_pos, f'UpperArm.{side}', True)
|
| 255 |
+
add_bone(f'Hand.{side}', wr_pos, hand, f'ForeArm.{side}', True)
|
| 256 |
+
|
| 257 |
+
# ── Legs ─────────────────────────────────────────────────────────────────────
|
| 258 |
+
kn_L = jv('left_knee', 0.25, -0.07)
|
| 259 |
+
kn_R = jv('right_knee', 0.25, 0.07)
|
| 260 |
+
an_L = jv('left_ankle', 0.04, -0.06)
|
| 261 |
+
an_R = jv('right_ankle', 0.04, 0.06)
|
| 262 |
+
|
| 263 |
+
for side, hp, kn, an in (('L', hip_L, kn_L, an_L), ('R', hip_R, kn_R, an_R)):
|
| 264 |
+
if not hp: continue
|
| 265 |
+
kn_pos = kn if kn else offset(hp, dz=-H_mesh*0.23)
|
| 266 |
+
an_pos = an if an else offset(kn_pos, dz=-H_mesh*0.22)
|
| 267 |
+
toe = offset(an_pos, dy=-H_mesh*0.06, dz=-H_mesh*0.02)
|
| 268 |
+
add_bone(f'UpperLeg.{side}', hp, kn_pos, 'Hips')
|
| 269 |
+
add_bone(f'LowerLeg.{side}', kn_pos, an_pos, f'UpperLeg.{side}', True)
|
| 270 |
+
add_bone(f'Foot.{side}', an_pos, toe, f'LowerLeg.{side}', True)
|
| 271 |
+
|
| 272 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
| 273 |
+
|
| 274 |
+
# ── Step 6: Skin mesh to armature ────────────────────────────────────────────
|
| 275 |
+
bpy.context.view_layer.objects.active = arm_obj
|
| 276 |
+
mesh_obj.select_set(True)
|
| 277 |
+
arm_obj.select_set(True)
|
| 278 |
+
bpy.ops.object.parent_set(type='ARMATURE_AUTO')
|
| 279 |
+
print("[tpose] Auto-weights applied")
|
| 280 |
+
|
| 281 |
+
# ── Step 7: Pose arms to T-pose ───────────────────────────────────────────────
|
| 282 |
+
# Compute per-arm rotation: from (current elbow - shoulder) direction → horizontal ±X
|
| 283 |
+
bpy.context.view_layer.objects.active = arm_obj
|
| 284 |
+
bpy.ops.object.mode_set(mode='POSE')
|
| 285 |
+
|
| 286 |
+
pb = arm_obj.pose.bones
|
| 287 |
+
|
| 288 |
+
def set_tpose_arm(side, sh_pos, el_pos):
|
| 289 |
+
if not sh_pos or not el_pos:
|
| 290 |
+
return
|
| 291 |
+
if f'UpperArm.{side}' not in pb:
|
| 292 |
+
return
|
| 293 |
+
# Current upper-arm direction in armature local space
|
| 294 |
+
sx = -1 if side == 'L' else 1
|
| 295 |
+
# T-pose direction: ±X horizontal
|
| 296 |
+
tpose_dir = mathutils.Vector((sx, 0, 0))
|
| 297 |
+
# Current bone direction (head→tail) in world space
|
| 298 |
+
bone = arm_obj.data.bones[f'UpperArm.{side}']
|
| 299 |
+
cur_dir = (bone.tail_local - bone.head_local).normalized()
|
| 300 |
+
# Rotation needed in bone's local space
|
| 301 |
+
rot_quat = cur_dir.rotation_difference(tpose_dir)
|
| 302 |
+
pb[f'UpperArm.{side}'].rotation_mode = 'QUATERNION'
|
| 303 |
+
pb[f'UpperArm.{side}'].rotation_quaternion = rot_quat
|
| 304 |
+
|
| 305 |
+
# Straighten forearm along the same axis
|
| 306 |
+
if f'ForeArm.{side}' in pb:
|
| 307 |
+
pb[f'ForeArm.{side}'].rotation_mode = 'QUATERNION'
|
| 308 |
+
pb[f'ForeArm.{side}'].rotation_quaternion = mathutils.Quaternion((1,0,0,0))
|
| 309 |
+
|
| 310 |
+
set_tpose_arm('L', sh_L, el_L)
|
| 311 |
+
set_tpose_arm('R', sh_R, el_R)
|
| 312 |
+
|
| 313 |
+
bpy.context.view_layer.update()
|
| 314 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
| 315 |
+
|
| 316 |
+
# ── Step 8: Apply armature modifier ──────────────────────────────────────────
|
| 317 |
+
bpy.context.view_layer.objects.active = mesh_obj
|
| 318 |
+
mesh_obj.select_set(True)
|
| 319 |
+
for mod in mesh_obj.modifiers:
|
| 320 |
+
if mod.type == 'ARMATURE':
|
| 321 |
+
bpy.ops.object.modifier_apply(modifier=mod.name)
|
| 322 |
+
print(f"[tpose] Applied modifier: {mod.name}")
|
| 323 |
+
break
|
| 324 |
+
|
| 325 |
+
bpy.data.objects.remove(arm_obj, do_unlink=True)
|
| 326 |
+
|
| 327 |
+
# ── Step 9: Export ────────────────────────────────────────────────────────────
|
| 328 |
+
bpy.ops.export_scene.gltf(
|
| 329 |
+
filepath=output_glb, export_format='GLB',
|
| 330 |
+
export_texcoords=True, export_normals=True,
|
| 331 |
+
export_materials='EXPORT', use_selection=False)
|
| 332 |
+
print(f"[tpose] Done → {output_glb}")
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# HuggingFace ZeroGPU Space — Gradio SDK [cache-bust:
|
| 2 |
spaces
|
| 3 |
numpy>=2
|
| 4 |
|
|
@@ -79,3 +79,15 @@ typeguard
|
|
| 79 |
sentencepiece
|
| 80 |
spandrel
|
| 81 |
imageio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace ZeroGPU Space — Gradio SDK [cache-bust: 3]
|
| 2 |
spaces
|
| 3 |
numpy>=2
|
| 4 |
|
|
|
|
| 79 |
sentencepiece
|
| 80 |
spandrel
|
| 81 |
imageio
|
| 82 |
+
gradio_client
|
| 83 |
+
|
| 84 |
+
# FireRed / GGUF quantization
|
| 85 |
+
bitsandbytes
|
| 86 |
+
|
| 87 |
+
# Motion search + retargeting
|
| 88 |
+
filterpy
|
| 89 |
+
pytorch-lightning
|
| 90 |
+
lightning-utilities
|
| 91 |
+
webdataset
|
| 92 |
+
hydra-core
|
| 93 |
+
matplotlib
|
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pytorch3d_minimal.py
|
| 3 |
+
====================
|
| 4 |
+
Drop-in replacement for the pytorch3d subset used by PSHuman's project_mesh.py
|
| 5 |
+
and mesh_utils.py. Uses nvdiffrast for GPU rasterization.
|
| 6 |
+
|
| 7 |
+
Implements:
|
| 8 |
+
- Meshes / TexturesVertex
|
| 9 |
+
- look_at_view_transform
|
| 10 |
+
- FoVOrthographicCameras / OrthographicCameras (orthographic projection only)
|
| 11 |
+
- RasterizationSettings / MeshRasterizer (via nvdiffrast)
|
| 12 |
+
- render_pix2faces_py3d (compatibility shim)
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Texture / Mesh containers
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
class TexturesVertex:
|
| 26 |
+
def __init__(self, verts_features):
|
| 27 |
+
# verts_features: list of [N, C] tensors (one per mesh in batch)
|
| 28 |
+
self._feats = verts_features
|
| 29 |
+
|
| 30 |
+
def verts_features_packed(self):
|
| 31 |
+
return self._feats[0]
|
| 32 |
+
|
| 33 |
+
def clone(self):
|
| 34 |
+
return TexturesVertex([f.clone() for f in self._feats])
|
| 35 |
+
|
| 36 |
+
def detach(self):
|
| 37 |
+
return TexturesVertex([f.detach() for f in self._feats])
|
| 38 |
+
|
| 39 |
+
def to(self, device):
|
| 40 |
+
self._feats = [f.to(device) for f in self._feats]
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Meshes:
|
| 45 |
+
def __init__(self, verts, faces, textures=None):
|
| 46 |
+
self._verts = verts # list of [N,3] float tensors
|
| 47 |
+
self._faces = faces # list of [F,3] long tensors
|
| 48 |
+
self.textures = textures
|
| 49 |
+
|
| 50 |
+
# ---- accessors --------------------------------------------------------
|
| 51 |
+
def verts_padded(self): return torch.stack(self._verts)
|
| 52 |
+
def faces_padded(self): return torch.stack(self._faces)
|
| 53 |
+
def verts_packed(self): return self._verts[0]
|
| 54 |
+
def faces_packed(self): return self._faces[0]
|
| 55 |
+
def verts_list(self): return self._verts
|
| 56 |
+
def faces_list(self): return self._faces
|
| 57 |
+
|
| 58 |
+
def verts_normals_packed(self):
|
| 59 |
+
v, f = self._verts[0], self._faces[0]
|
| 60 |
+
v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
|
| 61 |
+
fn = torch.cross(v1 - v0, v2 - v0, dim=1)
|
| 62 |
+
fn = F.normalize(fn, dim=1)
|
| 63 |
+
vn = torch.zeros_like(v)
|
| 64 |
+
for k in range(3):
|
| 65 |
+
vn.scatter_add_(0, f[:, k:k+1].expand(-1, 3), fn)
|
| 66 |
+
return F.normalize(vn, dim=1)
|
| 67 |
+
|
| 68 |
+
# ---- device / copy ----------------------------------------------------
|
| 69 |
+
def to(self, device):
|
| 70 |
+
self._verts = [v.to(device) for v in self._verts]
|
| 71 |
+
self._faces = [f.to(device) for f in self._faces]
|
| 72 |
+
if self.textures is not None:
|
| 73 |
+
self.textures.to(device)
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
def clone(self):
|
| 77 |
+
m = Meshes([v.clone() for v in self._verts],
|
| 78 |
+
[f.clone() for f in self._faces])
|
| 79 |
+
if self.textures is not None:
|
| 80 |
+
m.textures = self.textures.clone()
|
| 81 |
+
return m
|
| 82 |
+
|
| 83 |
+
def detach(self):
|
| 84 |
+
m = Meshes([v.detach() for v in self._verts],
|
| 85 |
+
[f.detach() for f in self._faces])
|
| 86 |
+
if self.textures is not None:
|
| 87 |
+
m.textures = self.textures.detach()
|
| 88 |
+
return m
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
# Camera math (mirrors pytorch3d look_at_view_transform + Orthographic)
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
def _look_at_rotation(camera_pos: torch.Tensor,
|
| 96 |
+
at: torch.Tensor,
|
| 97 |
+
up: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
"""Return (3,3) rotation matrix: world → camera."""
|
| 99 |
+
z = F.normalize(camera_pos - at, dim=-1) # cam looks along -Z
|
| 100 |
+
x = F.normalize(torch.cross(up, z, dim=-1), dim=-1)
|
| 101 |
+
y = torch.cross(z, x, dim=-1)
|
| 102 |
+
R = torch.stack([x, y, z], dim=-1) # columns = cam axes
|
| 103 |
+
return R # shape (3,3)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def look_at_view_transform(dist=1.0, elev=0.0, azim=0.0,
|
| 107 |
+
degrees=True, device="cpu"):
|
| 108 |
+
"""Matches pytorch3d convention exactly."""
|
| 109 |
+
if degrees:
|
| 110 |
+
elev = math.radians(float(elev))
|
| 111 |
+
azim = math.radians(float(azim))
|
| 112 |
+
|
| 113 |
+
# camera position in world
|
| 114 |
+
cx = dist * math.cos(elev) * math.sin(azim)
|
| 115 |
+
cy = dist * math.sin(elev)
|
| 116 |
+
cz = dist * math.cos(elev) * math.cos(azim)
|
| 117 |
+
eye = torch.tensor([[cx, cy, cz]], dtype=torch.float32, device=device)
|
| 118 |
+
at = torch.zeros(1, 3, device=device)
|
| 119 |
+
up = torch.tensor([[0, 1, 0]], dtype=torch.float32, device=device)
|
| 120 |
+
|
| 121 |
+
# pytorch3d stores R transposed (row = cam axis in world space)
|
| 122 |
+
R = _look_at_rotation(eye[0], at[0], up[0]).T.unsqueeze(0) # (1,3,3)
|
| 123 |
+
|
| 124 |
+
# T = camera position expressed in camera space
|
| 125 |
+
T = torch.bmm(-R, eye.unsqueeze(-1)).squeeze(-1) # (1,3)
|
| 126 |
+
return R, T
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class _OrthoCamera:
|
| 130 |
+
"""Minimal orthographic camera, matches FoVOrthographicCameras API."""
|
| 131 |
+
def __init__(self, R, T, focal_length=1.0, device="cpu"):
|
| 132 |
+
self.R = R.to(device) # (B,3,3)
|
| 133 |
+
self.T = T.to(device) # (B,3)
|
| 134 |
+
self.focal = float(focal_length)
|
| 135 |
+
self.device = device
|
| 136 |
+
|
| 137 |
+
def to(self, device):
|
| 138 |
+
self.R = self.R.to(device)
|
| 139 |
+
self.T = self.T.to(device)
|
| 140 |
+
self.device = device
|
| 141 |
+
return self
|
| 142 |
+
|
| 143 |
+
def get_znear(self):
|
| 144 |
+
return torch.tensor(0.01, device=self.device)
|
| 145 |
+
|
| 146 |
+
def is_perspective(self):
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
def transform_points_ndc(self, points):
|
| 150 |
+
"""
|
| 151 |
+
points: (B, N, 3) world coords
|
| 152 |
+
returns: (B, N, 3) NDC coords (X,Y in [-1,1], Z = depth)
|
| 153 |
+
"""
|
| 154 |
+
# world → camera
|
| 155 |
+
pts_cam = torch.bmm(points, self.R) + self.T.unsqueeze(1) # (B,N,3)
|
| 156 |
+
# orthographic NDC: scale by focal, flip Y to match image convention
|
| 157 |
+
ndc_x = pts_cam[..., 0] * self.focal
|
| 158 |
+
ndc_y = -pts_cam[..., 1] * self.focal # pytorch3d flips Y
|
| 159 |
+
ndc_z = pts_cam[..., 2]
|
| 160 |
+
return torch.stack([ndc_x, ndc_y, ndc_z], dim=-1)
|
| 161 |
+
|
| 162 |
+
def _world_to_clip(self, verts: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""verts: (N,3) → clip (N,4) for nvdiffrast."""
|
| 164 |
+
pts_cam = (verts @ self.R[0].T) + self.T[0] # (N,3)
|
| 165 |
+
cx = pts_cam[:, 0] * self.focal
|
| 166 |
+
cy = -pts_cam[:, 1] * self.focal # flip Y
|
| 167 |
+
cz = pts_cam[:, 2]
|
| 168 |
+
w = torch.ones_like(cz)
|
| 169 |
+
return torch.stack([cx, cy, cz, w], dim=1) # (N,4)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Aliases used in project_mesh.py
|
| 173 |
+
def FoVOrthographicCameras(device="cpu", R=None, T=None,
|
| 174 |
+
min_x=-1, max_x=1, min_y=-1, max_y=1,
|
| 175 |
+
focal_length=None, **kwargs):
|
| 176 |
+
fl = focal_length if focal_length is not None else 1.0 / (max_x + 1e-9)
|
| 177 |
+
return _OrthoCamera(R, T, focal_length=fl, device=device)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def FoVPerspectiveCameras(device="cpu", R=None, T=None, fov=60, degrees=True, **kwargs):
|
| 181 |
+
# Fallback: treat as orthographic at fov-derived scale (good enough for PSHuman)
|
| 182 |
+
fl = 1.0 / math.tan(math.radians(fov / 2)) if degrees else 1.0 / math.tan(fov / 2)
|
| 183 |
+
return _OrthoCamera(R, T, focal_length=fl, device=device)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
OrthographicCameras = FoVOrthographicCameras
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# Rasterizer (nvdiffrast-based)
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
class RasterizationSettings:
|
| 194 |
+
def __init__(self, image_size=512, blur_radius=0.0, faces_per_pixel=1):
|
| 195 |
+
if isinstance(image_size, (list, tuple)):
|
| 196 |
+
self.H, self.W = image_size[0], image_size[1]
|
| 197 |
+
else:
|
| 198 |
+
self.H = self.W = int(image_size)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class _Fragments:
|
| 202 |
+
def __init__(self, pix_to_face):
|
| 203 |
+
self.pix_to_face = pix_to_face.unsqueeze(-1) # (1,H,W,1)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class MeshRasterizer:
|
| 207 |
+
def __init__(self, cameras=None, raster_settings=None):
|
| 208 |
+
self.cameras = cameras
|
| 209 |
+
self.settings = raster_settings
|
| 210 |
+
self._glctx = None
|
| 211 |
+
|
| 212 |
+
def _get_ctx(self, device):
|
| 213 |
+
if self._glctx is None:
|
| 214 |
+
import nvdiffrast.torch as dr
|
| 215 |
+
self._glctx = dr.RasterizeCudaContext(device=device)
|
| 216 |
+
return self._glctx
|
| 217 |
+
|
| 218 |
+
def __call__(self, meshes: Meshes, cameras=None):
|
| 219 |
+
cam = cameras or self.cameras
|
| 220 |
+
H, W = self.settings.H, self.settings.W
|
| 221 |
+
device = meshes.verts_packed().device
|
| 222 |
+
import nvdiffrast.torch as dr
|
| 223 |
+
glctx = self._get_ctx(str(device))
|
| 224 |
+
|
| 225 |
+
verts = meshes.verts_packed().to(device)
|
| 226 |
+
faces = meshes.faces_packed().to(torch.int32).to(device)
|
| 227 |
+
clip = cam._world_to_clip(verts).unsqueeze(0) # (1,N,4)
|
| 228 |
+
rast, _ = dr.rasterize(glctx, clip, faces, resolution=(H, W))
|
| 229 |
+
pix_to_face = rast[0, :, :, -1].to(torch.int32) - 1 # -1 = background
|
| 230 |
+
return _Fragments(pix_to_face.unsqueeze(0))
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ---------------------------------------------------------------------------
|
| 234 |
+
# render_pix2faces_py3d shim (used in get_visible_faces)
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
|
| 237 |
+
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, **kwargs):
|
| 238 |
+
"""Returns {'pix_to_face': (1,H,W)} integer tensor of face indices (-1=bg)."""
|
| 239 |
+
settings = RasterizationSettings(image_size=(H, W))
|
| 240 |
+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=settings)
|
| 241 |
+
frags = rasterizer(meshes)
|
| 242 |
+
return {"pix_to_face": frags.pix_to_face[..., 0]} # (1,H,W)
|