Spaces:
Running on Zero
Running on Zero
Commit ·
9d7cf7f
0
Parent(s):
Public release: SkinTokens · TokenRig demo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +35 -0
- .gitignore +67 -0
- LICENSE +21 -0
- README.md +14 -0
- bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl +3 -0
- bpy_server.py +7 -0
- configs/skeleton/mixamo.yaml +59 -0
- configs/skeleton/vroid.yaml +59 -0
- demo.py +764 -0
- download.py +72 -0
- requirements.txt +37 -0
- runtime.txt +1 -0
- src/__init__.py +0 -0
- src/data/augment.py +706 -0
- src/data/datapath.py +344 -0
- src/data/dataset.py +319 -0
- src/data/order.py +132 -0
- src/data/sampler.py +189 -0
- src/data/spec.py +16 -0
- src/data/transform.py +70 -0
- src/data/vertex_group.py +257 -0
- src/model/__init__.py +0 -0
- src/model/michelangelo/__init__.py +1 -0
- src/model/michelangelo/get_model.py +30 -0
- src/model/michelangelo/models/__init__.py +1 -0
- src/model/michelangelo/models/modules/__init__.py +3 -0
- src/model/michelangelo/models/modules/checkpoint.py +69 -0
- src/model/michelangelo/models/modules/distributions.py +100 -0
- src/model/michelangelo/models/modules/embedder.py +213 -0
- src/model/michelangelo/models/modules/transformer_blocks.py +327 -0
- src/model/michelangelo/models/tsal/__init__.py +1 -0
- src/model/michelangelo/models/tsal/loss.py +454 -0
- src/model/michelangelo/models/tsal/sal_perceiver.py +723 -0
- src/model/michelangelo/models/tsal/tsal_base.py +121 -0
- src/model/michelangelo/utils/__init__.py +4 -0
- src/model/michelangelo/utils/eval.py +12 -0
- src/model/michelangelo/utils/misc.py +271 -0
- src/model/parse_encoder.py +28 -0
- src/model/skin_vae/attention_processor.py +283 -0
- src/model/skin_vae/autoencoders/FSQ.py +191 -0
- src/model/skin_vae/autoencoders/SimVQ.py +197 -0
- src/model/skin_vae/autoencoders/__init__.py +1 -0
- src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py +254 -0
- src/model/skin_vae/autoencoders/get_model.py +22 -0
- src/model/skin_vae/autoencoders/miche_transformer_blocks.py +395 -0
- src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py +304 -0
- src/model/skin_vae/autoencoders/vae.py +73 -0
- src/model/skin_vae/embeddings.py +111 -0
- src/model/skin_vae/transformers/__init__.py +41 -0
- src/model/skin_vae/transformers/modeling_outputs.py +8 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# igonore all pychace
|
| 2 |
+
**/__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# ignore tmp & output files
|
| 7 |
+
_data/
|
| 8 |
+
ckpt/
|
| 9 |
+
tmp/
|
| 10 |
+
tmp_fusion/
|
| 11 |
+
tmp_vae/
|
| 12 |
+
tmp_video/
|
| 13 |
+
*.glb
|
| 14 |
+
*.ply
|
| 15 |
+
*.obj
|
| 16 |
+
*.fbx
|
| 17 |
+
*.npz
|
| 18 |
+
*.blend
|
| 19 |
+
*.blend1
|
| 20 |
+
*.blend2
|
| 21 |
+
|
| 22 |
+
# ignore logs
|
| 23 |
+
wandb/
|
| 24 |
+
lightning_logs/
|
| 25 |
+
*.log
|
| 26 |
+
|
| 27 |
+
# ignore experiments
|
| 28 |
+
experiments/
|
| 29 |
+
results/
|
| 30 |
+
dataset_clean/
|
| 31 |
+
logs/
|
| 32 |
+
datalist/
|
| 33 |
+
dataset_inference/
|
| 34 |
+
dataset_inference_clean/
|
| 35 |
+
feature_viz/
|
| 36 |
+
|
| 37 |
+
# Distribution / packaging
|
| 38 |
+
dist/
|
| 39 |
+
build/
|
| 40 |
+
*.egg-info/
|
| 41 |
+
*.egg
|
| 42 |
+
*.whl
|
| 43 |
+
|
| 44 |
+
# Virtual environments
|
| 45 |
+
venv/
|
| 46 |
+
env/
|
| 47 |
+
.env/
|
| 48 |
+
.venv/
|
| 49 |
+
|
| 50 |
+
# IDE specific files
|
| 51 |
+
.idea/
|
| 52 |
+
.vscode/
|
| 53 |
+
*.swp
|
| 54 |
+
*.swo
|
| 55 |
+
.DS_Store
|
| 56 |
+
|
| 57 |
+
# Jupyter Notebook
|
| 58 |
+
.ipynb_checkpoints
|
| 59 |
+
*.ipynb
|
| 60 |
+
|
| 61 |
+
# Unit test / coverage reports
|
| 62 |
+
htmlcov/
|
| 63 |
+
.tox/
|
| 64 |
+
.coverage
|
| 65 |
+
.coverage.*
|
| 66 |
+
coverage.xml
|
| 67 |
+
*.cover
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 VAST-AI-Research
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SkinTokens
|
| 3 |
+
emoji: 🌖
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.12.0
|
| 8 |
+
python_version: 3.12.12
|
| 9 |
+
app_file: demo.py
|
| 10 |
+
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39df5f78dc95d1fae6058a3134a40f645303c6e96540f87e9ee4c0fd436def1d
|
| 3 |
+
size 346159222
|
bpy_server.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.server.bpy_server import run
|
| 2 |
+
|
| 3 |
+
def main():
|
| 4 |
+
run()
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
main()
|
configs/skeleton/mixamo.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
parts_order: [body, hand]
|
| 2 |
+
|
| 3 |
+
parts:
|
| 4 |
+
body: [
|
| 5 |
+
mixamorig:Hips,
|
| 6 |
+
mixamorig:Spine,
|
| 7 |
+
mixamorig:Spine1,
|
| 8 |
+
mixamorig:Spine2,
|
| 9 |
+
mixamorig:Neck,
|
| 10 |
+
mixamorig:Head,
|
| 11 |
+
mixamorig:LeftShoulder,
|
| 12 |
+
mixamorig:LeftArm,
|
| 13 |
+
mixamorig:LeftForeArm,
|
| 14 |
+
mixamorig:LeftHand,
|
| 15 |
+
mixamorig:RightShoulder,
|
| 16 |
+
mixamorig:RightArm,
|
| 17 |
+
mixamorig:RightForeArm,
|
| 18 |
+
mixamorig:RightHand,
|
| 19 |
+
mixamorig:LeftUpLeg,
|
| 20 |
+
mixamorig:LeftLeg,
|
| 21 |
+
mixamorig:LeftFoot,
|
| 22 |
+
mixamorig:LeftToeBase,
|
| 23 |
+
mixamorig:RightUpLeg,
|
| 24 |
+
mixamorig:RightLeg,
|
| 25 |
+
mixamorig:RightFoot,
|
| 26 |
+
mixamorig:RightToeBase,
|
| 27 |
+
]
|
| 28 |
+
hand: [
|
| 29 |
+
mixamorig:LeftHandThumb1,
|
| 30 |
+
mixamorig:LeftHandThumb2,
|
| 31 |
+
mixamorig:LeftHandThumb3,
|
| 32 |
+
mixamorig:LeftHandIndex1,
|
| 33 |
+
mixamorig:LeftHandIndex2,
|
| 34 |
+
mixamorig:LeftHandIndex3,
|
| 35 |
+
mixamorig:LeftHandMiddle1,
|
| 36 |
+
mixamorig:LeftHandMiddle2,
|
| 37 |
+
mixamorig:LeftHandMiddle3,
|
| 38 |
+
mixamorig:LeftHandRing1,
|
| 39 |
+
mixamorig:LeftHandRing2,
|
| 40 |
+
mixamorig:LeftHandRing3,
|
| 41 |
+
mixamorig:LeftHandPinky1,
|
| 42 |
+
mixamorig:LeftHandPinky2,
|
| 43 |
+
mixamorig:LeftHandPinky3,
|
| 44 |
+
mixamorig:RightHandIndex1,
|
| 45 |
+
mixamorig:RightHandIndex2,
|
| 46 |
+
mixamorig:RightHandIndex3,
|
| 47 |
+
mixamorig:RightHandThumb1,
|
| 48 |
+
mixamorig:RightHandThumb2,
|
| 49 |
+
mixamorig:RightHandThumb3,
|
| 50 |
+
mixamorig:RightHandMiddle1,
|
| 51 |
+
mixamorig:RightHandMiddle2,
|
| 52 |
+
mixamorig:RightHandMiddle3,
|
| 53 |
+
mixamorig:RightHandRing1,
|
| 54 |
+
mixamorig:RightHandRing2,
|
| 55 |
+
mixamorig:RightHandRing3,
|
| 56 |
+
mixamorig:RightHandPinky1,
|
| 57 |
+
mixamorig:RightHandPinky2,
|
| 58 |
+
mixamorig:RightHandPinky3,
|
| 59 |
+
]
|
configs/skeleton/vroid.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
parts_order: [body, hand]
|
| 2 |
+
|
| 3 |
+
parts:
|
| 4 |
+
body: [
|
| 5 |
+
J_Bip_C_Hips,
|
| 6 |
+
J_Bip_C_Spine,
|
| 7 |
+
J_Bip_C_Chest,
|
| 8 |
+
J_Bip_C_UpperChest,
|
| 9 |
+
J_Bip_C_Neck,
|
| 10 |
+
J_Bip_C_Head,
|
| 11 |
+
J_Bip_L_Shoulder,
|
| 12 |
+
J_Bip_L_UpperArm,
|
| 13 |
+
J_Bip_L_LowerArm,
|
| 14 |
+
J_Bip_L_Hand,
|
| 15 |
+
J_Bip_R_Shoulder,
|
| 16 |
+
J_Bip_R_UpperArm,
|
| 17 |
+
J_Bip_R_LowerArm,
|
| 18 |
+
J_Bip_R_Hand,
|
| 19 |
+
J_Bip_L_UpperLeg,
|
| 20 |
+
J_Bip_L_LowerLeg,
|
| 21 |
+
J_Bip_L_Foot,
|
| 22 |
+
J_Bip_L_ToeBase,
|
| 23 |
+
J_Bip_R_UpperLeg,
|
| 24 |
+
J_Bip_R_LowerLeg,
|
| 25 |
+
J_Bip_R_Foot,
|
| 26 |
+
J_Bip_R_ToeBase,
|
| 27 |
+
]
|
| 28 |
+
hand: [
|
| 29 |
+
J_Bip_L_Thumb1,
|
| 30 |
+
J_Bip_L_Thumb2,
|
| 31 |
+
J_Bip_L_Thumb3,
|
| 32 |
+
J_Bip_L_Index1,
|
| 33 |
+
J_Bip_L_Index2,
|
| 34 |
+
J_Bip_L_Index3,
|
| 35 |
+
J_Bip_L_Middle1,
|
| 36 |
+
J_Bip_L_Middle2,
|
| 37 |
+
J_Bip_L_Middle3,
|
| 38 |
+
J_Bip_L_Ring1,
|
| 39 |
+
J_Bip_L_Ring2,
|
| 40 |
+
J_Bip_L_Ring3,
|
| 41 |
+
J_Bip_L_Little1,
|
| 42 |
+
J_Bip_L_Little2,
|
| 43 |
+
J_Bip_L_Little3,
|
| 44 |
+
J_Bip_R_Index1,
|
| 45 |
+
J_Bip_R_Index2,
|
| 46 |
+
J_Bip_R_Index3,
|
| 47 |
+
J_Bip_R_Thumb1,
|
| 48 |
+
J_Bip_R_Thumb2,
|
| 49 |
+
J_Bip_R_Thumb3,
|
| 50 |
+
J_Bip_R_Middle1,
|
| 51 |
+
J_Bip_R_Middle2,
|
| 52 |
+
J_Bip_R_Middle3,
|
| 53 |
+
J_Bip_R_Ring1,
|
| 54 |
+
J_Bip_R_Ring2,
|
| 55 |
+
J_Bip_R_Ring3,
|
| 56 |
+
J_Bip_R_Little1,
|
| 57 |
+
J_Bip_R_Little2,
|
| 58 |
+
J_Bip_R_Little3,
|
| 59 |
+
]
|
demo.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import atexit
|
| 3 |
+
import importlib
|
| 4 |
+
import os
|
| 5 |
+
import signal
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import tempfile
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import requests
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# ZeroGPU compatibility shim. The hosted HF Space provides the `spaces`
|
| 20 |
+
# package; running locally we substitute a no-op.
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
try:
|
| 23 |
+
spaces = importlib.import_module("spaces")
|
| 24 |
+
except Exception:
|
| 25 |
+
class _SpacesCompat:
|
| 26 |
+
@staticmethod
|
| 27 |
+
def GPU(*args, **kwargs):
|
| 28 |
+
if len(args) == 1 and callable(args[0]) and not kwargs:
|
| 29 |
+
return args[0]
|
| 30 |
+
|
| 31 |
+
def _decorator(fn):
|
| 32 |
+
return fn
|
| 33 |
+
|
| 34 |
+
return _decorator
|
| 35 |
+
|
| 36 |
+
spaces = _SpacesCompat()
|
| 37 |
+
|
| 38 |
+
os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1")
|
| 39 |
+
gr.TEMP_DIR = "tmp_gradio"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Install the bundled `bpy` wheel at runtime if it isn't already importable.
|
| 44 |
+
#
|
| 45 |
+
# Why this is non-trivial:
|
| 46 |
+
# - Putting the wheel in requirements.txt fails: HF Spaces' Docker build
|
| 47 |
+
# mounts only requirements.txt BEFORE the repo COPY, so the wheel path
|
| 48 |
+
# doesn't exist at pip-install time.
|
| 49 |
+
# - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 /
|
| 50 |
+
# manylinux_2_39).
|
| 51 |
+
# - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS
|
| 52 |
+
# layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't
|
| 53 |
+
# list `*.whl`). The container's COPY-from-repo only carries the LFS
|
| 54 |
+
# *pointer* file — a ~150-byte text stub — not the actual wheel binary.
|
| 55 |
+
# So `pip install <wheel>` and `zipfile.ZipFile(<wheel>)` both fail with
|
| 56 |
+
# "is not a zip file" / "Wheel is invalid".
|
| 57 |
+
#
|
| 58 |
+
# So: we detect the LFS-pointer case and re-fetch the real wheel from the
|
| 59 |
+
# HF Hub at runtime (where the API resolves LFS server-side), then extract
|
| 60 |
+
# it directly into site-packages.
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
def _ensure_bpy_installed():
|
| 63 |
+
try:
|
| 64 |
+
import bpy # noqa: F401
|
| 65 |
+
return
|
| 66 |
+
except Exception:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
import glob
|
| 70 |
+
import sysconfig
|
| 71 |
+
import zipfile
|
| 72 |
+
|
| 73 |
+
here = os.path.dirname(os.path.abspath(__file__))
|
| 74 |
+
wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl")))
|
| 75 |
+
if not wheels:
|
| 76 |
+
print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True)
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
wheel = wheels[-1]
|
| 80 |
+
wheel_name = os.path.basename(wheel)
|
| 81 |
+
|
| 82 |
+
# Detect LFS pointer (text stub starting with "version https://git-lfs...").
|
| 83 |
+
is_real_zip = False
|
| 84 |
+
try:
|
| 85 |
+
with open(wheel, "rb") as f:
|
| 86 |
+
is_real_zip = f.read(4).startswith(b"PK")
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
if not is_real_zip:
|
| 91 |
+
print(
|
| 92 |
+
f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); "
|
| 93 |
+
f"fetching real wheel from HF Hub...",
|
| 94 |
+
flush=True,
|
| 95 |
+
)
|
| 96 |
+
from huggingface_hub import hf_hub_download
|
| 97 |
+
|
| 98 |
+
space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens")
|
| 99 |
+
token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos
|
| 100 |
+
wheel = hf_hub_download(
|
| 101 |
+
repo_id=space_id,
|
| 102 |
+
repo_type="space",
|
| 103 |
+
filename=wheel_name,
|
| 104 |
+
token=token,
|
| 105 |
+
)
|
| 106 |
+
print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True)
|
| 107 |
+
|
| 108 |
+
site = sysconfig.get_paths()["purelib"]
|
| 109 |
+
print(f"[demo] Extracting {wheel_name} into {site}", flush=True)
|
| 110 |
+
with zipfile.ZipFile(wheel) as z:
|
| 111 |
+
z.extractall(site)
|
| 112 |
+
print("[demo] bpy wheel extracted.", flush=True)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
_ensure_bpy_installed()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
# Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3
|
| 120 |
+
# tokenizer/config on first cold-start.
|
| 121 |
+
#
|
| 122 |
+
# These live in the *model* repo `VAST-AI/SkinTokens` (private), separate
|
| 123 |
+
# from this Space repo, so they aren't COPYed into the container. Re-uses
|
| 124 |
+
# `HF_TOKEN` from the Space secrets.
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
def _ensure_models_downloaded():
|
| 127 |
+
here = os.path.dirname(os.path.abspath(__file__))
|
| 128 |
+
needed_ckpts = [
|
| 129 |
+
"experiments/skin_vae_2_10_32768/last.ckpt",
|
| 130 |
+
"experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
|
| 131 |
+
]
|
| 132 |
+
qwen_dir = os.path.join(here, "models", "Qwen3-0.6B")
|
| 133 |
+
|
| 134 |
+
all_present = (
|
| 135 |
+
all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts)
|
| 136 |
+
and os.path.exists(os.path.join(qwen_dir, "tokenizer.json"))
|
| 137 |
+
)
|
| 138 |
+
if all_present:
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 142 |
+
|
| 143 |
+
token = os.environ.get("HF_TOKEN")
|
| 144 |
+
|
| 145 |
+
for rel in needed_ckpts:
|
| 146 |
+
target = os.path.join(here, rel)
|
| 147 |
+
if os.path.exists(target):
|
| 148 |
+
continue
|
| 149 |
+
print(f"[demo] Downloading checkpoint: {rel}", flush=True)
|
| 150 |
+
hf_hub_download(
|
| 151 |
+
repo_id="VAST-AI/SkinTokens",
|
| 152 |
+
filename=rel,
|
| 153 |
+
local_dir=here,
|
| 154 |
+
token=token,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")):
|
| 158 |
+
print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True)
|
| 159 |
+
snapshot_download(
|
| 160 |
+
repo_id="Qwen/Qwen3-0.6B",
|
| 161 |
+
local_dir=qwen_dir,
|
| 162 |
+
ignore_patterns=["*.bin", "*.safetensors"],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
print("[demo] All checkpoints ready.", flush=True)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
_ensure_models_downloaded()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
from src.data.dataset import DatasetConfig, RigDatasetModule
|
| 172 |
+
from src.data.transform import Transform
|
| 173 |
+
from src.model.tokenrig import TokenRigResult
|
| 174 |
+
from src.tokenizer.parse import get_tokenizer
|
| 175 |
+
from src.server.spec import (
|
| 176 |
+
BPY_SERVER,
|
| 177 |
+
get_model,
|
| 178 |
+
object_to_bytes,
|
| 179 |
+
bytes_to_object,
|
| 180 |
+
)
|
| 181 |
+
from src.data.vertex_group import voxel_skin
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
# Pre-warm `bpy_server` in the main (Gradio) process at module load.
|
| 186 |
+
#
|
| 187 |
+
# Why this is necessary on ZeroGPU: each user request runs inside a fresh
|
| 188 |
+
# `@spaces.GPU` worker process with a hard time budget (≈60 s on free tier).
|
| 189 |
+
# Importing the Blender shared object inside that budget burns 30–60 s, so
|
| 190 |
+
# the worker is killed *during* bpy import — manifesting as
|
| 191 |
+
# "GPU task aborted" before any model code runs.
|
| 192 |
+
#
|
| 193 |
+
# We start `bpy_server.py` here, in the always-running main process, so the
|
| 194 |
+
# slow bpy import happens exactly once at Space boot. Workers then just hit
|
| 195 |
+
# `localhost:59876` over HTTP — sub-millisecond, no startup cost.
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
MODEL_CKPTS = [
|
| 199 |
+
"experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
HF_PATHS = [
|
| 203 |
+
"None",
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_dataloader_workers() -> int:
|
| 208 |
+
if os.getenv("SPACE_ID"):
|
| 209 |
+
return 0
|
| 210 |
+
return 1
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
# bpy_server lifecycle — lazy start so the heavy import doesn't fight ZeroGPU
|
| 215 |
+
# during module load.
|
| 216 |
+
# ---------------------------------------------------------------------------
|
| 217 |
+
_BPY_SERVER_PROC = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def is_bpy_server_alive(timeout: float = 1.0) -> bool:
|
| 221 |
+
try:
|
| 222 |
+
resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout)
|
| 223 |
+
return resp.status_code == 200
|
| 224 |
+
except Exception:
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def start_bpy_server():
|
| 229 |
+
proc = subprocess.Popen(
|
| 230 |
+
[sys.executable, "bpy_server.py"],
|
| 231 |
+
stdout=None,
|
| 232 |
+
stderr=None,
|
| 233 |
+
preexec_fn=os.setsid,
|
| 234 |
+
)
|
| 235 |
+
print(f"[Main] bpy_server.py started (pid={proc.pid})")
|
| 236 |
+
|
| 237 |
+
def cleanup():
|
| 238 |
+
print(f"[Main] Terminating bpy_server.py (pid={proc.pid})")
|
| 239 |
+
try:
|
| 240 |
+
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
| 241 |
+
except ProcessLookupError:
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
atexit.register(cleanup)
|
| 245 |
+
return proc
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def wait_for_bpy_server(timeout: float = 120):
|
| 249 |
+
"""Wait for bpy_server.py to come up. The first start of bpy_server is
|
| 250 |
+
slow because importing the Blender `.so` (~200 MB shared object) takes
|
| 251 |
+
30–60 s on a cold container. We allow up to 120 s."""
|
| 252 |
+
t0 = time.time()
|
| 253 |
+
last_log = 0.0
|
| 254 |
+
while True:
|
| 255 |
+
try:
|
| 256 |
+
requests.get(f"{BPY_SERVER}/ping", timeout=1)
|
| 257 |
+
print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)")
|
| 258 |
+
return
|
| 259 |
+
except Exception:
|
| 260 |
+
now = time.time()
|
| 261 |
+
if now - t0 > timeout:
|
| 262 |
+
raise RuntimeError(
|
| 263 |
+
f"bpy_server failed to start after {timeout:.0f}s"
|
| 264 |
+
)
|
| 265 |
+
if now - last_log > 10: # progress every 10s
|
| 266 |
+
print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)")
|
| 267 |
+
last_log = now
|
| 268 |
+
time.sleep(0.5)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def ensure_bpy_server_started():
|
| 272 |
+
global _BPY_SERVER_PROC
|
| 273 |
+
if is_bpy_server_alive():
|
| 274 |
+
return
|
| 275 |
+
if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None:
|
| 276 |
+
return
|
| 277 |
+
_BPY_SERVER_PROC = start_bpy_server()
|
| 278 |
+
wait_for_bpy_server()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ---------------------------------------------------------------------------
|
| 282 |
+
# Lazy model loading.
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
model = None
|
| 285 |
+
tokenizer = None
|
| 286 |
+
transform = None
|
| 287 |
+
CURRENT_MODEL_CKPT: Optional[str] = None
|
| 288 |
+
CURRENT_HF_PATH: Optional[str] = None
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]:
|
| 292 |
+
global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH
|
| 293 |
+
if hf_path == "None":
|
| 294 |
+
hf_path = None
|
| 295 |
+
if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH:
|
| 296 |
+
return ("Model already loaded.", model_ckpt)
|
| 297 |
+
|
| 298 |
+
if not model_ckpt:
|
| 299 |
+
raise RuntimeError("model_ckpt is empty. Please select a checkpoint.")
|
| 300 |
+
|
| 301 |
+
print(f"Loading model: {model_ckpt}, hf_path={hf_path}")
|
| 302 |
+
model = get_model(model_ckpt, hf_path=hf_path)
|
| 303 |
+
assert model.tokenizer_config is not None
|
| 304 |
+
tokenizer = get_tokenizer(**model.tokenizer_config)
|
| 305 |
+
transform = Transform.parse(**model.transform_config["predict_transform"])
|
| 306 |
+
CURRENT_MODEL_CKPT = model_ckpt
|
| 307 |
+
CURRENT_HF_PATH = hf_path
|
| 308 |
+
return ("Model loaded.", model_ckpt)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ---------------------------------------------------------------------------
|
| 312 |
+
# File utilities (CLI-side).
|
| 313 |
+
# ---------------------------------------------------------------------------
|
| 314 |
+
SUPPORTED_EXT = {".obj", ".fbx", ".glb"}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def collect_files(input_path: Path) -> List[Path]:
|
| 318 |
+
if input_path.is_file():
|
| 319 |
+
return [input_path]
|
| 320 |
+
|
| 321 |
+
files = []
|
| 322 |
+
for p in input_path.rglob("*"):
|
| 323 |
+
if p.suffix.lower() in SUPPORTED_EXT:
|
| 324 |
+
files.append(p)
|
| 325 |
+
return files
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path:
|
| 329 |
+
rel = in_path.relative_to(input_root)
|
| 330 |
+
return (output_root / rel).with_suffix(".glb")
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# ---------------------------------------------------------------------------
|
| 334 |
+
# Core inference (shared by CLI and Gradio).
|
| 335 |
+
# ---------------------------------------------------------------------------
|
| 336 |
+
def run_rig(
|
| 337 |
+
filepaths: List[Path],
|
| 338 |
+
top_k: int,
|
| 339 |
+
top_p: float,
|
| 340 |
+
temperature: float,
|
| 341 |
+
repetition_penalty: float,
|
| 342 |
+
num_beams: int,
|
| 343 |
+
use_skeleton: bool,
|
| 344 |
+
use_transfer: bool,
|
| 345 |
+
use_postprocess: bool,
|
| 346 |
+
output_paths: List[Path],
|
| 347 |
+
model_ckpt: str,
|
| 348 |
+
hf_path: Optional[str],
|
| 349 |
+
):
|
| 350 |
+
assert len(filepaths) == len(output_paths)
|
| 351 |
+
ensure_bpy_server_started()
|
| 352 |
+
load_model(model_ckpt, hf_path)
|
| 353 |
+
|
| 354 |
+
datapath = {
|
| 355 |
+
"data_name": None,
|
| 356 |
+
"loader": "bpy_server",
|
| 357 |
+
"filepaths": {"articulation": [str(p) for p in filepaths]},
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
dataset_config = DatasetConfig.parse(
|
| 361 |
+
shuffle=False,
|
| 362 |
+
batch_size=1,
|
| 363 |
+
num_workers=get_dataloader_workers(),
|
| 364 |
+
pin_memory=get_dataloader_workers() > 0,
|
| 365 |
+
persistent_workers=False,
|
| 366 |
+
datapath=datapath,
|
| 367 |
+
).split_by_cls()
|
| 368 |
+
|
| 369 |
+
module = RigDatasetModule(
|
| 370 |
+
predict_dataset_config=dataset_config,
|
| 371 |
+
predict_transform=transform,
|
| 372 |
+
tokenizer=tokenizer,
|
| 373 |
+
process_fn=model._process_fn,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
dataloader = module.predict_dataloader()["articulation"]
|
| 377 |
+
|
| 378 |
+
results_out = []
|
| 379 |
+
infer_device = model.device if model is not None else "cuda"
|
| 380 |
+
|
| 381 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
|
| 382 |
+
batch = {
|
| 383 |
+
k: v.to(infer_device) if isinstance(v, Tensor) else v
|
| 384 |
+
for k, v in batch.items()
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
if not use_skeleton:
|
| 388 |
+
batch.pop("skeleton_tokens", None)
|
| 389 |
+
batch.pop("skeleton_mask", None)
|
| 390 |
+
|
| 391 |
+
batch["generate_kwargs"] = dict(
|
| 392 |
+
max_length=2048,
|
| 393 |
+
top_k=int(top_k),
|
| 394 |
+
top_p=float(top_p),
|
| 395 |
+
temperature=float(temperature),
|
| 396 |
+
repetition_penalty=float(repetition_penalty),
|
| 397 |
+
num_return_sequences=1,
|
| 398 |
+
num_beams=int(num_beams),
|
| 399 |
+
do_sample=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
if "skeleton_tokens" in batch and "skeleton_mask" in batch:
|
| 403 |
+
mask = batch["skeleton_mask"][0] == 1
|
| 404 |
+
skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy()
|
| 405 |
+
else:
|
| 406 |
+
skeleton_tokens = None
|
| 407 |
+
|
| 408 |
+
preds: List[TokenRigResult] = model.predict_step(
|
| 409 |
+
batch,
|
| 410 |
+
skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None,
|
| 411 |
+
make_asset=True,
|
| 412 |
+
)["results"]
|
| 413 |
+
|
| 414 |
+
asset = preds[0].asset
|
| 415 |
+
assert asset is not None
|
| 416 |
+
|
| 417 |
+
if use_postprocess:
|
| 418 |
+
voxel = asset.voxel(resolution=196)
|
| 419 |
+
asset.skin *= voxel_skin(
|
| 420 |
+
grid=0,
|
| 421 |
+
grid_coords=voxel.coords,
|
| 422 |
+
joints=asset.joints,
|
| 423 |
+
vertices=asset.vertices,
|
| 424 |
+
faces=asset.faces,
|
| 425 |
+
mode="square",
|
| 426 |
+
voxel_size=voxel.voxel_size,
|
| 427 |
+
)
|
| 428 |
+
asset.normalize_skin()
|
| 429 |
+
|
| 430 |
+
out_path = output_paths[i]
|
| 431 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 432 |
+
|
| 433 |
+
if use_transfer:
|
| 434 |
+
payload = dict(
|
| 435 |
+
source_asset=asset,
|
| 436 |
+
target_path=asset.path,
|
| 437 |
+
export_path=str(out_path),
|
| 438 |
+
group_per_vertex=4,
|
| 439 |
+
)
|
| 440 |
+
res = bytes_to_object(
|
| 441 |
+
requests.post(
|
| 442 |
+
f"{BPY_SERVER}/transfer",
|
| 443 |
+
data=object_to_bytes(payload),
|
| 444 |
+
).content
|
| 445 |
+
)
|
| 446 |
+
else:
|
| 447 |
+
payload = dict(
|
| 448 |
+
asset=asset,
|
| 449 |
+
filepath=str(out_path),
|
| 450 |
+
group_per_vertex=4,
|
| 451 |
+
)
|
| 452 |
+
res = bytes_to_object(
|
| 453 |
+
requests.post(
|
| 454 |
+
f"{BPY_SERVER}/export",
|
| 455 |
+
data=object_to_bytes(payload),
|
| 456 |
+
).content
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
if res != "ok":
|
| 460 |
+
print(f"[Error] {res}")
|
| 461 |
+
else:
|
| 462 |
+
print(f"[OK] Exported: {out_path}")
|
| 463 |
+
|
| 464 |
+
results_out.append(out_path)
|
| 465 |
+
|
| 466 |
+
return results_out
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# ---------------------------------------------------------------------------
|
| 470 |
+
# CLI entry point.
|
| 471 |
+
# ---------------------------------------------------------------------------
|
| 472 |
+
def run_cli(args):
|
| 473 |
+
input_path = Path(args.input).resolve()
|
| 474 |
+
output_path = Path(args.output).resolve()
|
| 475 |
+
|
| 476 |
+
files = collect_files(input_path)
|
| 477 |
+
if not files:
|
| 478 |
+
raise RuntimeError("No valid 3D files found.")
|
| 479 |
+
|
| 480 |
+
if len(files) == 1 and output_path.suffix:
|
| 481 |
+
outputs = [output_path]
|
| 482 |
+
else:
|
| 483 |
+
outputs = [map_output_path(f, input_path, output_path) for f in files]
|
| 484 |
+
|
| 485 |
+
run_rig(
|
| 486 |
+
files,
|
| 487 |
+
args.top_k,
|
| 488 |
+
args.top_p,
|
| 489 |
+
args.temperature,
|
| 490 |
+
args.repetition_penalty,
|
| 491 |
+
args.num_beams,
|
| 492 |
+
args.use_skeleton,
|
| 493 |
+
args.use_transfer,
|
| 494 |
+
args.use_postprocess,
|
| 495 |
+
outputs,
|
| 496 |
+
args.model_ckpt,
|
| 497 |
+
args.hf_path,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# ---------------------------------------------------------------------------
|
| 502 |
+
# Gradio wrapper (with ZeroGPU duration estimator).
|
| 503 |
+
# ---------------------------------------------------------------------------
|
| 504 |
+
TOT = 0
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _gpu_duration(
|
| 508 |
+
files,
|
| 509 |
+
top_k,
|
| 510 |
+
top_p,
|
| 511 |
+
temperature,
|
| 512 |
+
repetition_penalty,
|
| 513 |
+
num_beams,
|
| 514 |
+
use_skeleton,
|
| 515 |
+
use_transfer,
|
| 516 |
+
use_postprocess,
|
| 517 |
+
model_ckpt,
|
| 518 |
+
hf_path,
|
| 519 |
+
):
|
| 520 |
+
# Cold workers spend ~30–60 s importing bpy + loading the model before
|
| 521 |
+
# any GPU work. Give every request a generous 240 s floor.
|
| 522 |
+
file_count = len(files) if files is not None else 1
|
| 523 |
+
return min(900, max(240, 240 + 60 * file_count))
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
@spaces.GPU(duration=_gpu_duration)
|
| 527 |
+
def run_gradio(
|
| 528 |
+
files,
|
| 529 |
+
top_k,
|
| 530 |
+
top_p,
|
| 531 |
+
temperature,
|
| 532 |
+
repetition_penalty,
|
| 533 |
+
num_beams,
|
| 534 |
+
use_skeleton,
|
| 535 |
+
use_transfer,
|
| 536 |
+
use_postprocess,
|
| 537 |
+
model_ckpt,
|
| 538 |
+
hf_path,
|
| 539 |
+
):
|
| 540 |
+
if not files:
|
| 541 |
+
return "Please upload at least one 3D model.", None
|
| 542 |
+
|
| 543 |
+
tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_"))
|
| 544 |
+
filepaths = [Path(f.name) for f in files]
|
| 545 |
+
global TOT
|
| 546 |
+
outputs = []
|
| 547 |
+
for filepath in filepaths:
|
| 548 |
+
TOT += 1
|
| 549 |
+
outputs.append(tmp_out / f"res_{TOT}.glb")
|
| 550 |
+
|
| 551 |
+
run_rig(
|
| 552 |
+
filepaths,
|
| 553 |
+
top_k,
|
| 554 |
+
top_p,
|
| 555 |
+
temperature,
|
| 556 |
+
repetition_penalty,
|
| 557 |
+
num_beams,
|
| 558 |
+
use_skeleton,
|
| 559 |
+
use_transfer,
|
| 560 |
+
use_postprocess,
|
| 561 |
+
outputs,
|
| 562 |
+
model_ckpt,
|
| 563 |
+
hf_path,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
return f"Processed {len(outputs)} models.", [str(p) for p in outputs]
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# ---------------------------------------------------------------------------
|
| 570 |
+
# Gradio UI.
|
| 571 |
+
# ---------------------------------------------------------------------------
|
| 572 |
+
def build_gradio_app():
|
| 573 |
+
model_ckpts = MODEL_CKPTS
|
| 574 |
+
hf_paths = HF_PATHS
|
| 575 |
+
default_ckpt = model_ckpts[0] if model_ckpts else ""
|
| 576 |
+
default_hf = hf_paths[0] if hf_paths else "None"
|
| 577 |
+
|
| 578 |
+
with gr.Blocks(title="SkinTokens · TokenRig Demo") as app:
|
| 579 |
+
gr.Markdown(
|
| 580 |
+
"""
|
| 581 |
+
## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig
|
| 582 |
+
|
| 583 |
+
Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified
|
| 584 |
+
autoregressive model over learned *SkinTokens*. Successor to
|
| 585 |
+
[UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH '25).
|
| 586 |
+
|
| 587 |
+
* Upload one or more meshes → click **Run** → download a rigged `.glb`.
|
| 588 |
+
* **Paper**: [arXiv 2602.04805](https://arxiv.org/abs/2602.04805) ·
|
| 589 |
+
**Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens) ·
|
| 590 |
+
**Weights**: [🤗 VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens)
|
| 591 |
+
* Looking for **image → rigged 3D** instead? Try our sibling Space
|
| 592 |
+
[🤗 VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen).
|
| 593 |
+
* Want a full AI-powered 3D workspace? → [Tripo](https://www.tripo3d.ai)
|
| 594 |
+
"""
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
gr.HTML(
|
| 598 |
+
"""
|
| 599 |
+
<style>
|
| 600 |
+
@keyframes gentle-pulse {
|
| 601 |
+
0%, 100% { opacity: 1; }
|
| 602 |
+
50% { opacity: 0.35; }
|
| 603 |
+
}
|
| 604 |
+
</style>
|
| 605 |
+
<div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin: 4px 0 -4px 0;">
|
| 606 |
+
<span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">💡 <b>Tips</b></span> 
|
| 607 |
+
Defaults work well for most meshes.
|
| 608 |
+
• If your mesh already has a skeleton and you only want skinning, enable
|
| 609 |
+
<b>Use existing skeleton</b> below.
|
| 610 |
+
• To keep your original textures and world scale, enable <b>Preserve original texture & scale</b>.
|
| 611 |
+
</div>
|
| 612 |
+
"""
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
with gr.Row():
|
| 616 |
+
with gr.Column(scale=1):
|
| 617 |
+
files = gr.File(
|
| 618 |
+
label="3D Models ( .obj / .fbx / .glb, up to a few at a time )",
|
| 619 |
+
file_count="multiple",
|
| 620 |
+
file_types=[".obj", ".fbx", ".glb"],
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
with gr.Accordion("⚙️ Generation Settings", open=False):
|
| 624 |
+
model_ckpt = gr.Dropdown(
|
| 625 |
+
choices=model_ckpts,
|
| 626 |
+
value=default_ckpt,
|
| 627 |
+
label="Model checkpoint",
|
| 628 |
+
info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.",
|
| 629 |
+
interactive=True,
|
| 630 |
+
)
|
| 631 |
+
# Keep the hf_path component for callback compatibility, but hide it
|
| 632 |
+
# from the UI since it currently only exposes the default ("None") option.
|
| 633 |
+
hf_path = gr.Dropdown(
|
| 634 |
+
choices=hf_paths,
|
| 635 |
+
value=default_hf,
|
| 636 |
+
label="HF path (advanced)",
|
| 637 |
+
visible=False,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
gr.Markdown("**Sampling parameters** — control autoregressive decoding of the rig.")
|
| 641 |
+
top_k = gr.Slider(
|
| 642 |
+
1, 200, value=5, step=1,
|
| 643 |
+
label="top_k",
|
| 644 |
+
info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.",
|
| 645 |
+
)
|
| 646 |
+
top_p = gr.Slider(
|
| 647 |
+
0.1, 1.0, value=0.95, step=0.01,
|
| 648 |
+
label="top_p (nucleus)",
|
| 649 |
+
info="Sample from the smallest set of tokens whose cumulative probability ≥ p.",
|
| 650 |
+
)
|
| 651 |
+
temperature = gr.Slider(
|
| 652 |
+
0.1, 2.0, value=1.0, step=0.1,
|
| 653 |
+
label="temperature",
|
| 654 |
+
info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).",
|
| 655 |
+
)
|
| 656 |
+
repetition_penalty = gr.Slider(
|
| 657 |
+
0.5, 3.0, value=2.0, step=0.1,
|
| 658 |
+
label="repetition_penalty",
|
| 659 |
+
info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.",
|
| 660 |
+
)
|
| 661 |
+
num_beams = gr.Slider(
|
| 662 |
+
1, 20, value=10, step=1,
|
| 663 |
+
label="num_beams",
|
| 664 |
+
info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.",
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
gr.Markdown("**Pipeline toggles**")
|
| 668 |
+
use_skeleton = gr.Checkbox(
|
| 669 |
+
False,
|
| 670 |
+
label="Use existing skeleton (predict skinning only)",
|
| 671 |
+
info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.",
|
| 672 |
+
)
|
| 673 |
+
use_transfer = gr.Checkbox(
|
| 674 |
+
False,
|
| 675 |
+
label="Preserve original texture & scale",
|
| 676 |
+
info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.",
|
| 677 |
+
)
|
| 678 |
+
use_postprocess = gr.Checkbox(
|
| 679 |
+
False,
|
| 680 |
+
label="Voxel skin post-processing",
|
| 681 |
+
info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.",
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
run_btn = gr.Button("🚀 Run", variant="primary")
|
| 685 |
+
|
| 686 |
+
with gr.Column(scale=1):
|
| 687 |
+
log = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 688 |
+
output = gr.File(label="Rigged GLB output", interactive=False)
|
| 689 |
+
gr.Markdown(
|
| 690 |
+
"""
|
| 691 |
+
**Notes**
|
| 692 |
+
- The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File → Import → glTF 2.0) or any DCC tool that reads glTF.
|
| 693 |
+
- In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it.
|
| 694 |
+
- On busy moments Zero-GPU may queue your request for ~10–30 s before inference starts — the status box will update once the GPU is attached.
|
| 695 |
+
- Please do **not** upload confidential or NSFW content. See the
|
| 696 |
+
[project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the
|
| 697 |
+
[code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference.
|
| 698 |
+
"""
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
run_btn.click(
|
| 702 |
+
run_gradio,
|
| 703 |
+
inputs=[
|
| 704 |
+
files,
|
| 705 |
+
top_k,
|
| 706 |
+
top_p,
|
| 707 |
+
temperature,
|
| 708 |
+
repetition_penalty,
|
| 709 |
+
num_beams,
|
| 710 |
+
use_skeleton,
|
| 711 |
+
use_transfer,
|
| 712 |
+
use_postprocess,
|
| 713 |
+
model_ckpt,
|
| 714 |
+
hf_path,
|
| 715 |
+
],
|
| 716 |
+
outputs=[log, output],
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
return app
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
demo = build_gradio_app()
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py`
|
| 726 |
+
# transitively imports `src.model.michelangelo.utils.misc`, whose
|
| 727 |
+
# module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)`
|
| 728 |
+
# at import time. That call fails ("RuntimeError: No CUDA GPUs are
|
| 729 |
+
# available") in the main Gradio process on ZeroGPU, where the GPU is only
|
| 730 |
+
# attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot
|
| 731 |
+
# happens on first request, inside the worker.
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
# ---------------------------------------------------------------------------
|
| 735 |
+
# Entry point.
|
| 736 |
+
# ---------------------------------------------------------------------------
|
| 737 |
+
if __name__ == "__main__":
|
| 738 |
+
parser = argparse.ArgumentParser("TokenRig Demo")
|
| 739 |
+
parser.add_argument("--input", help="Input file or directory")
|
| 740 |
+
parser.add_argument("--output", help="Output file or directory")
|
| 741 |
+
|
| 742 |
+
parser.add_argument("--top_k", type=int, default=5)
|
| 743 |
+
parser.add_argument("--top_p", type=float, default=0.95)
|
| 744 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 745 |
+
parser.add_argument("--repetition_penalty", type=float, default=2.0)
|
| 746 |
+
parser.add_argument("--num_beams", type=int, default=10)
|
| 747 |
+
|
| 748 |
+
parser.add_argument("--use_skeleton", action="store_true")
|
| 749 |
+
parser.add_argument("--use_transfer", action="store_true")
|
| 750 |
+
parser.add_argument("--use_postprocess", action="store_true")
|
| 751 |
+
|
| 752 |
+
parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "")
|
| 753 |
+
parser.add_argument("--hf_path", default=None)
|
| 754 |
+
|
| 755 |
+
parser.add_argument("--gradio", action="store_true")
|
| 756 |
+
|
| 757 |
+
args = parser.parse_args()
|
| 758 |
+
|
| 759 |
+
if args.gradio or not args.input:
|
| 760 |
+
demo.queue()
|
| 761 |
+
demo.launch(ssr_mode=False)
|
| 762 |
+
else:
|
| 763 |
+
ensure_bpy_server_started()
|
| 764 |
+
run_cli(args)
|
download.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
REPO_ID = "VAST-AI/SkinTokens"
|
| 6 |
+
|
| 7 |
+
MODELS = [
|
| 8 |
+
"experiments/skin_vae_2_10_32768/last.ckpt",
|
| 9 |
+
"experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
DATASETS = [
|
| 13 |
+
"rignet.zip",
|
| 14 |
+
"articulation.zip",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
LLM_REPO = "Qwen/Qwen3-0.6B"
|
| 18 |
+
LLM_LOCAL_DIR = "models/Qwen3-0.6B"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def download_model(name: str):
|
| 22 |
+
local_path = hf_hub_download(
|
| 23 |
+
repo_id=REPO_ID,
|
| 24 |
+
filename=name,
|
| 25 |
+
local_dir=".",
|
| 26 |
+
)
|
| 27 |
+
print(f"[MODEL] {name} downloaded to: {local_path}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def download_llm():
|
| 31 |
+
local_path = snapshot_download(
|
| 32 |
+
repo_id=LLM_REPO,
|
| 33 |
+
local_dir=LLM_LOCAL_DIR,
|
| 34 |
+
ignore_patterns=["*.bin", "*.safetensors"],
|
| 35 |
+
)
|
| 36 |
+
print(f"[LLM] Config downloaded to: {local_path}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def download_data(name: str):
|
| 40 |
+
local_path = hf_hub_download(
|
| 41 |
+
repo_id=REPO_ID,
|
| 42 |
+
filename=f"dataset_clean/{name}",
|
| 43 |
+
local_dir=".",
|
| 44 |
+
)
|
| 45 |
+
name = name.removesuffix(".zip")
|
| 46 |
+
local_path = snapshot_download(
|
| 47 |
+
repo_id=REPO_ID,
|
| 48 |
+
allow_patterns=[f"datalist/{name}/*"],
|
| 49 |
+
local_dir=".",
|
| 50 |
+
)
|
| 51 |
+
print(f"[DATA] {name} downloaded to: {local_path}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
parser = argparse.ArgumentParser()
|
| 56 |
+
parser.add_argument("--model", action="store_true", help="Download model checkpoints")
|
| 57 |
+
parser.add_argument("--data", action="store_true", help="Download datasets")
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
if not args.model and not args.data:
|
| 60 |
+
print("Please specify --model or --data")
|
| 61 |
+
return
|
| 62 |
+
if args.model:
|
| 63 |
+
for model in MODELS:
|
| 64 |
+
download_model(model)
|
| 65 |
+
download_llm()
|
| 66 |
+
if args.data:
|
| 67 |
+
for data in DATASETS:
|
| 68 |
+
download_data(data)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu128
|
| 2 |
+
--extra-index-url https://pypi.org/simple
|
| 3 |
+
|
| 4 |
+
# Pinned to match the flash-attn wheel below (cu12torch2.9cxx11abiTRUE).
|
| 5 |
+
# Don't bump torch without also bumping the flash-attn URL — they must agree on
|
| 6 |
+
# the (cu12 / torch 2.9 / cxx11-abi=TRUE / cp312) tuple.
|
| 7 |
+
torch==2.9.1
|
| 8 |
+
torchvision==0.24.1
|
| 9 |
+
torchaudio==2.9.1
|
| 10 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
|
| 11 |
+
transformers==4.57.0
|
| 12 |
+
diffusers==0.36.0
|
| 13 |
+
python-box
|
| 14 |
+
einops
|
| 15 |
+
omegaconf
|
| 16 |
+
pytorch_lightning
|
| 17 |
+
lightning
|
| 18 |
+
addict
|
| 19 |
+
timm
|
| 20 |
+
fast-simplification
|
| 21 |
+
trimesh
|
| 22 |
+
open3d
|
| 23 |
+
pyrender
|
| 24 |
+
# bpy is NOT listed here. The bpy wheel committed in this repo
|
| 25 |
+
# (`bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl`) is installed at
|
| 26 |
+
# runtime by `demo.py` — see `_ensure_bpy_installed()`. Reason: HF Spaces'
|
| 27 |
+
# Docker build mounts only `requirements.txt` BEFORE the repo COPY, so the
|
| 28 |
+
# wheel path doesn't exist at pip-install time. Public PyPI also has no bpy
|
| 29 |
+
# wheel matching this exact build (`rc0` / manylinux_2_39 / cp312).
|
| 30 |
+
huggingface_hub
|
| 31 |
+
spaces
|
| 32 |
+
wandb
|
| 33 |
+
numpy==2.2.6
|
| 34 |
+
gradio
|
| 35 |
+
bottle
|
| 36 |
+
tornado
|
| 37 |
+
cython
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.11
|
src/__init__.py
ADDED
|
File without changes
|
src/data/augment.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Union, List, Optional, Dict
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from scipy.spatial.transform import Rotation as R
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from .spec import ConfigSpec
|
| 12 |
+
|
| 13 |
+
from ..rig_package.utils import axis_angle_to_matrix
|
| 14 |
+
from ..rig_package.info.asset import Asset
|
| 15 |
+
|
| 16 |
+
@dataclass(frozen=True)
|
| 17 |
+
class Augment(ConfigSpec):
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def parse(cls, **kwags) -> 'Augment':
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def transform(self, asset: Asset, **kwargs):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class AugmentTrim(Augment):
|
| 30 |
+
"""randomly delete joints and vertices"""
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def parse(cls, **kwargs) -> 'AugmentTrim':
|
| 34 |
+
cls.check_keys(kwargs)
|
| 35 |
+
return AugmentTrim()
|
| 36 |
+
|
| 37 |
+
def transform(self, asset: Asset, **kwargs):
|
| 38 |
+
asset.trim_skeleton()
|
| 39 |
+
|
| 40 |
+
@dataclass(frozen=True)
|
| 41 |
+
class AugmentDelete(Augment):
|
| 42 |
+
"""randomly delete joints and vertices"""
|
| 43 |
+
|
| 44 |
+
# probability
|
| 45 |
+
p: float
|
| 46 |
+
|
| 47 |
+
# how much to keep
|
| 48 |
+
rate: float
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def parse(cls, **kwargs) -> 'AugmentDelete':
|
| 52 |
+
cls.check_keys(kwargs)
|
| 53 |
+
return AugmentDelete(
|
| 54 |
+
p=kwargs.get('p', 0.),
|
| 55 |
+
rate=kwargs.get('rate', 0.5),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def transform(self, asset: Asset, **kwargs):
|
| 59 |
+
if asset.skin is None:
|
| 60 |
+
raise ValueError("do not have skin")
|
| 61 |
+
if asset.parents is None:
|
| 62 |
+
raise ValueError("do not have parents")
|
| 63 |
+
asset.normalize_skin()
|
| 64 |
+
def select_k(arr: List, k: int):
|
| 65 |
+
if len(arr) <= k:
|
| 66 |
+
return arr
|
| 67 |
+
else:
|
| 68 |
+
rest_indices = list(range(1, len(arr)))
|
| 69 |
+
selected_indices = sorted(random.sample(rest_indices, k))
|
| 70 |
+
return [arr[i] for i in selected_indices]
|
| 71 |
+
if np.random.rand() >= self.p:
|
| 72 |
+
return
|
| 73 |
+
ids = select_k([i for i in range(asset.J)], max(int(asset.J * (1 - np.random.rand() * self.rate)), 1))
|
| 74 |
+
if len(ids) == 0:
|
| 75 |
+
return
|
| 76 |
+
# keep bones with no skin
|
| 77 |
+
keep = {}
|
| 78 |
+
for id in ids:
|
| 79 |
+
keep[id] = True
|
| 80 |
+
for id in range(asset.J):
|
| 81 |
+
if np.all(asset.skin[:, id] < 0.1):
|
| 82 |
+
keep[id] = True
|
| 83 |
+
keep[asset.root] = True
|
| 84 |
+
|
| 85 |
+
vertices_to_delete = np.zeros(asset.N, dtype=bool)
|
| 86 |
+
for id in range(asset.J):
|
| 87 |
+
if id not in keep:
|
| 88 |
+
dominant = asset.skin.argmax(axis=1) == id
|
| 89 |
+
x = (asset.skin[:, id] > 0.1) & dominant
|
| 90 |
+
if np.all(~x) or x.sum() * asset.J < asset.N: # avoid collapsing
|
| 91 |
+
keep[id] = 1
|
| 92 |
+
continue
|
| 93 |
+
vertices_to_delete[x] = True
|
| 94 |
+
if np.all(vertices_to_delete):
|
| 95 |
+
return
|
| 96 |
+
if asset.faces is not None:
|
| 97 |
+
indices = np.where(~vertices_to_delete)[0]
|
| 98 |
+
face_mask = np.all(np.isin(asset.faces, indices), axis=1)
|
| 99 |
+
if np.all(~face_mask):
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep]
|
| 103 |
+
asset.delete_joints(joints_to_delete)
|
| 104 |
+
asset.delete_vertices(np.arange(asset.N)[vertices_to_delete])
|
| 105 |
+
|
| 106 |
+
@dataclass(frozen=True)
|
| 107 |
+
class AugmentDropPart(Augment):
|
| 108 |
+
"""randomly drop subtrees and their vertices"""
|
| 109 |
+
|
| 110 |
+
# probability
|
| 111 |
+
p: float
|
| 112 |
+
|
| 113 |
+
# drop rate
|
| 114 |
+
rate: float
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def parse(cls, **kwargs) -> 'AugmentDropPart':
|
| 118 |
+
cls.check_keys(kwargs)
|
| 119 |
+
return AugmentDropPart(
|
| 120 |
+
p=kwargs.get('p', 0.),
|
| 121 |
+
rate=kwargs.get('rate', 0.5),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def transform(self, asset: Asset, **kwargs):
|
| 125 |
+
if np.random.rand() >= self.p:
|
| 126 |
+
return
|
| 127 |
+
if asset.parents is None:
|
| 128 |
+
raise ValueError("do not have parents")
|
| 129 |
+
if asset.skin is None:
|
| 130 |
+
raise ValueError("do not have skin")
|
| 131 |
+
keep = []
|
| 132 |
+
for id in range(asset.J):
|
| 133 |
+
if np.random.rand() < self.rate:
|
| 134 |
+
keep.append(id)
|
| 135 |
+
if len(keep) == 0:
|
| 136 |
+
return
|
| 137 |
+
for id in reversed(asset.dfs_order):
|
| 138 |
+
p = asset.parents[id]
|
| 139 |
+
if p == -1:
|
| 140 |
+
continue
|
| 141 |
+
if id in keep and p not in keep:
|
| 142 |
+
keep.append(p)
|
| 143 |
+
|
| 144 |
+
mask = np.zeros(asset.N, dtype=bool)
|
| 145 |
+
for id in keep:
|
| 146 |
+
mask[asset.skin[:, id] > 1e-5] = True
|
| 147 |
+
vertices_to_delete = ~mask
|
| 148 |
+
if np.all(vertices_to_delete):
|
| 149 |
+
return
|
| 150 |
+
if asset.faces is not None:
|
| 151 |
+
indices = np.where(~vertices_to_delete)[0]
|
| 152 |
+
face_mask = np.all(np.isin(asset.faces, indices), axis=1)
|
| 153 |
+
if np.all(~face_mask):
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep]
|
| 157 |
+
asset.delete_joints(joints_to_delete)
|
| 158 |
+
asset.delete_vertices(np.arange(asset.N)[vertices_to_delete])
|
| 159 |
+
|
| 160 |
+
def inverse(self, asset: Asset):
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
@dataclass(frozen=True)
|
| 164 |
+
class AugmentCollapse(Augment):
|
| 165 |
+
"""randomly merge joints"""
|
| 166 |
+
|
| 167 |
+
# collapse the skeleton with probability p
|
| 168 |
+
p: float
|
| 169 |
+
|
| 170 |
+
# probability to merge the bone
|
| 171 |
+
rate: float
|
| 172 |
+
|
| 173 |
+
# max bones
|
| 174 |
+
max_bones: int
|
| 175 |
+
|
| 176 |
+
@classmethod
|
| 177 |
+
def parse(cls, **kwargs) -> 'AugmentCollapse':
|
| 178 |
+
cls.check_keys(kwargs)
|
| 179 |
+
return AugmentCollapse(
|
| 180 |
+
p=kwargs.get('p', 0.),
|
| 181 |
+
rate=kwargs.get('rate', 0.),
|
| 182 |
+
max_bones=kwargs.get('max_bones', 2147483647),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def transform(self, asset: Asset, **kwargs):
|
| 186 |
+
def select_k(arr: List, k: int):
|
| 187 |
+
if len(arr) <= k:
|
| 188 |
+
return arr
|
| 189 |
+
else:
|
| 190 |
+
rest_indices = list(range(1, len(arr)))
|
| 191 |
+
selected_indices = sorted(random.sample(rest_indices, k))
|
| 192 |
+
return [arr[i] for i in selected_indices]
|
| 193 |
+
|
| 194 |
+
root = asset.root
|
| 195 |
+
if np.random.rand() < self.p:
|
| 196 |
+
ids = []
|
| 197 |
+
for id in range(asset.J):
|
| 198 |
+
if np.random.rand() >= self.rate:
|
| 199 |
+
ids.append(id)
|
| 200 |
+
if root not in ids:
|
| 201 |
+
ids.append(root)
|
| 202 |
+
keep: List[int|str] = select_k([i for i in range(asset.J) if i in ids], self.max_bones)
|
| 203 |
+
if root not in keep:
|
| 204 |
+
keep[0] = root
|
| 205 |
+
asset.set_order(new_orders=keep)
|
| 206 |
+
elif asset.J > self.max_bones:
|
| 207 |
+
ids = select_k([i for i in range(asset.J)], k=self.max_bones)
|
| 208 |
+
if root not in ids:
|
| 209 |
+
ids[0] = root
|
| 210 |
+
keep: List[int|str] = [i for i in range(asset.J) if i in ids]
|
| 211 |
+
asset.set_order(new_orders=keep)
|
| 212 |
+
|
| 213 |
+
@dataclass(frozen=True)
|
| 214 |
+
class AugmentJointDiscrete(Augment):
|
| 215 |
+
# perturb the skeleton with probability p
|
| 216 |
+
p: float
|
| 217 |
+
|
| 218 |
+
# num of discretized coord
|
| 219 |
+
discrete: int
|
| 220 |
+
|
| 221 |
+
# continuous range
|
| 222 |
+
continuous_range: Tuple[float, float]
|
| 223 |
+
|
| 224 |
+
@classmethod
|
| 225 |
+
def parse(cls, **kwargs) -> 'AugmentJointDiscrete':
|
| 226 |
+
cls.check_keys(kwargs)
|
| 227 |
+
return AugmentJointDiscrete(
|
| 228 |
+
p=kwargs.get('p', 0.),
|
| 229 |
+
discrete=kwargs.get('discrete', 256),
|
| 230 |
+
continuous_range=kwargs.get('continuous_range', [-1., 1.]),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def _discretize(
|
| 234 |
+
self,
|
| 235 |
+
t: ndarray,
|
| 236 |
+
continuous_range: Tuple[float, float],
|
| 237 |
+
num_discrete: int,
|
| 238 |
+
) -> ndarray:
|
| 239 |
+
lo, hi = continuous_range
|
| 240 |
+
assert hi >= lo
|
| 241 |
+
t = (t - lo) / (hi - lo)
|
| 242 |
+
t *= num_discrete
|
| 243 |
+
return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64)
|
| 244 |
+
|
| 245 |
+
def _undiscretize(
|
| 246 |
+
self,
|
| 247 |
+
t: ndarray,
|
| 248 |
+
continuous_range: Tuple[float, float],
|
| 249 |
+
num_discrete: int,
|
| 250 |
+
) -> ndarray:
|
| 251 |
+
lo, hi = continuous_range
|
| 252 |
+
assert hi >= lo
|
| 253 |
+
t = t.astype(np.float32) + 0.5
|
| 254 |
+
t /= num_discrete
|
| 255 |
+
return t * (hi - lo) + lo
|
| 256 |
+
|
| 257 |
+
def transform(self, asset: Asset, **kwargs):
|
| 258 |
+
if np.random.rand() < self.p:
|
| 259 |
+
joints = asset.joints
|
| 260 |
+
if joints is not None and asset.matrix_local is not None:
|
| 261 |
+
joints = self._undiscretize(self._discretize(
|
| 262 |
+
joints,
|
| 263 |
+
self.continuous_range,
|
| 264 |
+
self.discrete,
|
| 265 |
+
),
|
| 266 |
+
self.continuous_range,
|
| 267 |
+
self.discrete,
|
| 268 |
+
)
|
| 269 |
+
asset.matrix_local[:, :3, 3] = joints
|
| 270 |
+
|
| 271 |
+
@dataclass(frozen=True)
|
| 272 |
+
class AugmentJointPerturb(Augment):
|
| 273 |
+
# perturb the skeleton with probability p
|
| 274 |
+
p: float
|
| 275 |
+
|
| 276 |
+
# jitter sigma on joints
|
| 277 |
+
sigma: float
|
| 278 |
+
|
| 279 |
+
# jitter clip on joints
|
| 280 |
+
clip: float
|
| 281 |
+
|
| 282 |
+
@classmethod
|
| 283 |
+
def parse(cls, **kwargs) -> 'AugmentJointPerturb':
|
| 284 |
+
cls.check_keys(kwargs)
|
| 285 |
+
return AugmentJointPerturb(
|
| 286 |
+
p=kwargs.get('p', 0.),
|
| 287 |
+
sigma=kwargs.get('sigma', 0.),
|
| 288 |
+
clip=kwargs.get('clip', 0.),
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def transform(self, asset: Asset, **kwargs):
|
| 292 |
+
if np.random.rand() < self.p and asset.matrix_local is not None:
|
| 293 |
+
asset.matrix_local[:, :3] += np.clip(
|
| 294 |
+
np.random.normal(0, self.sigma, (asset.J, 3)),
|
| 295 |
+
-self.clip,
|
| 296 |
+
self.clip,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
@dataclass(frozen=True)
|
| 300 |
+
class AugmentLBS(Augment):
|
| 301 |
+
# apply a random pose with probability p
|
| 302 |
+
random_pose_p: float
|
| 303 |
+
|
| 304 |
+
# random pose angle range
|
| 305 |
+
random_pose_angle: float
|
| 306 |
+
|
| 307 |
+
# random scale
|
| 308 |
+
random_scale_range: Tuple[float, float]
|
| 309 |
+
|
| 310 |
+
@classmethod
|
| 311 |
+
def parse(cls, **kwargs) -> 'AugmentLBS':
|
| 312 |
+
cls.check_keys(kwargs)
|
| 313 |
+
return AugmentLBS(
|
| 314 |
+
random_pose_p=kwargs.get('random_pose_p', 0.),
|
| 315 |
+
random_pose_angle=kwargs.get('random_pose_angle', 0.),
|
| 316 |
+
random_scale_range=kwargs.get('random_scale_range', (1., 1.)),
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
|
| 320 |
+
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
|
| 321 |
+
|
| 322 |
+
def transform(self, asset: Asset, **kwargs):
|
| 323 |
+
def get_matrix_basis(angle: float):
|
| 324 |
+
matrix = axis_angle_to_matrix((np.random.rand(asset.J, 3) - 0.5) * angle / 180 * np.pi * 2).astype(np.float32)
|
| 325 |
+
return matrix
|
| 326 |
+
|
| 327 |
+
if np.random.rand() < self.random_pose_p and asset.joints is not None:
|
| 328 |
+
matrix_basis = get_matrix_basis(self.random_pose_angle)
|
| 329 |
+
max_offset = (asset.joints.max(axis=0) - asset.joints.min(axis=0)).max()
|
| 330 |
+
matrix_basis[:, :3, :3] *= np.tile(np.random.uniform(low=self.random_scale_range[0], high=self.random_scale_range[1], size=(asset.J, 1, 1)), (1, 3, 3))
|
| 331 |
+
asset.vertices_with_pose(matrix_basis=matrix_basis, inplace=True)
|
| 332 |
+
|
| 333 |
+
@dataclass(frozen=True)
|
| 334 |
+
class AugmentLinear(Augment):
|
| 335 |
+
# apply random rotation with probability p
|
| 336 |
+
random_rotate_p: float
|
| 337 |
+
|
| 338 |
+
# random rotation angle(degree)
|
| 339 |
+
random_rotate_angle: float
|
| 340 |
+
|
| 341 |
+
# swap x with probability p
|
| 342 |
+
random_flip_x_p: float
|
| 343 |
+
|
| 344 |
+
# swap y with probability p
|
| 345 |
+
random_flip_y_p: float
|
| 346 |
+
|
| 347 |
+
# swap z with probability p
|
| 348 |
+
random_flip_z_p: float
|
| 349 |
+
|
| 350 |
+
# probability to pick an angle in static_rotate_x
|
| 351 |
+
static_rotate_x_p: float
|
| 352 |
+
|
| 353 |
+
# rotate around x axis among given angles(degrees)
|
| 354 |
+
static_rotate_x: List[float]
|
| 355 |
+
|
| 356 |
+
# probability to pick an angle in static_rotate_y
|
| 357 |
+
static_rotate_y_p: float
|
| 358 |
+
|
| 359 |
+
# rotate around y axis among given angles(degrees)
|
| 360 |
+
static_rotate_y: List[float]
|
| 361 |
+
|
| 362 |
+
# probability to pick an angle in static_rotate_z
|
| 363 |
+
static_rotate_z_p: float
|
| 364 |
+
|
| 365 |
+
# rotate around z axis among given angles(degrees)
|
| 366 |
+
static_rotate_z: List[float]
|
| 367 |
+
|
| 368 |
+
# apply random scaling with probability p
|
| 369 |
+
random_scale_p: float
|
| 370 |
+
|
| 371 |
+
# random scaling xyz axis
|
| 372 |
+
random_scale: Tuple[float, float]
|
| 373 |
+
|
| 374 |
+
# randomly change xyz orientation
|
| 375 |
+
random_transpose: float
|
| 376 |
+
|
| 377 |
+
@classmethod
|
| 378 |
+
def parse(cls, **kwargs) -> 'AugmentLinear':
|
| 379 |
+
if kwargs.get('random_flip_x_p', 0) > 0 or kwargs.get('random_flip_y_p', 0) > 0 or kwargs.get('random_flip_z_p', 0) > 0:
|
| 380 |
+
print("\033[31mWARNING: random flip is enabled and is very likely to confuse ar model !\033[0m")
|
| 381 |
+
cls.check_keys(kwargs)
|
| 382 |
+
return AugmentLinear(
|
| 383 |
+
random_rotate_p=kwargs.get('random_rotate_p', 0.),
|
| 384 |
+
random_rotate_angle=kwargs.get('random_rotate_angle', 0.),
|
| 385 |
+
random_flip_x_p=kwargs.get('random_flip_x_p', 0.),
|
| 386 |
+
random_flip_y_p=kwargs.get('random_flip_y_p', 0.),
|
| 387 |
+
random_flip_z_p=kwargs.get('random_flip_z_p', 0.),
|
| 388 |
+
static_rotate_x_p=kwargs.get('static_rotate_x_p', 0.),
|
| 389 |
+
static_rotate_x=kwargs.get('static_rotate_x', []),
|
| 390 |
+
static_rotate_y_p=kwargs.get('static_rotate_y_p', 0.),
|
| 391 |
+
static_rotate_y=kwargs.get('static_rotate_y', []),
|
| 392 |
+
static_rotate_z_p=kwargs.get('static_rotate_z_p', 0.),
|
| 393 |
+
static_rotate_z=kwargs.get('static_rotate_z', []),
|
| 394 |
+
random_scale_p=kwargs.get('random_scale_p', 0.),
|
| 395 |
+
random_scale=kwargs.get('random_scale', [1.0, 1.0]),
|
| 396 |
+
random_transpose=kwargs.get('random_transpose', 0.),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
|
| 400 |
+
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
|
| 401 |
+
|
| 402 |
+
def transform(self, asset: Asset, **kwargs):
|
| 403 |
+
trans_vertex = np.eye(4, dtype=np.float32)
|
| 404 |
+
r = np.eye(4, dtype=np.float32)
|
| 405 |
+
if np.random.rand() < self.random_rotate_p:
|
| 406 |
+
angle = self.random_rotate_angle
|
| 407 |
+
axis_angle = (np.random.rand(3) - 0.5) * angle / 180 * np.pi * 2
|
| 408 |
+
r = R.from_rotvec(axis_angle).as_matrix()
|
| 409 |
+
r = np.pad(r, ((0, 1), (0, 1)), 'constant', constant_values=0.)
|
| 410 |
+
r[3, 3] = 1.
|
| 411 |
+
|
| 412 |
+
if np.random.uniform(0, 1) < self.random_flip_x_p:
|
| 413 |
+
r @= np.array([
|
| 414 |
+
[-1.0, 0.0, 0.0, 0.0],
|
| 415 |
+
[ 0.0, 1.0, 0.0, 0.0],
|
| 416 |
+
[ 0.0, 0.0, 1.0, 0.0],
|
| 417 |
+
[ 0.0, 0.0, 0.0, 1.0],
|
| 418 |
+
])
|
| 419 |
+
|
| 420 |
+
if np.random.uniform(0, 1) < self.random_flip_y_p:
|
| 421 |
+
r @= np.array([
|
| 422 |
+
[1.0, 0.0, 0.0, 0.0],
|
| 423 |
+
[0.0, -1.0, 0.0, 0.0],
|
| 424 |
+
[0.0, 0.0, 1.0, 0.0],
|
| 425 |
+
[0.0, 0.0, 0.0, 1.0],
|
| 426 |
+
])
|
| 427 |
+
|
| 428 |
+
if np.random.uniform(0, 1) < self.random_flip_z_p:
|
| 429 |
+
r @= np.array([
|
| 430 |
+
[1.0, 0.0, 0.0, 0.0],
|
| 431 |
+
[0.0, 1.0, 0.0, 0.0],
|
| 432 |
+
[0.0, 0.0, -1.0, 0.0],
|
| 433 |
+
[0.0, 0.0, 0.0, 1.0],
|
| 434 |
+
])
|
| 435 |
+
|
| 436 |
+
if np.random.uniform(0, 1) < self.static_rotate_x_p:
|
| 437 |
+
assert len(self.static_rotate_x) > 0, "static rotation of x is enabled, but static_rotate_x is empty"
|
| 438 |
+
angle = np.random.choice(self.static_rotate_x) / 180 * np.pi
|
| 439 |
+
c = np.cos(angle)
|
| 440 |
+
s = np.sin(angle)
|
| 441 |
+
r @= np.array([
|
| 442 |
+
[ 1.0, 0.0, 0.0, 0.0],
|
| 443 |
+
[ 0.0, c, s, 0.0],
|
| 444 |
+
[ 0.0, -s, c, 0.0],
|
| 445 |
+
[ 0.0, 0.0, 0.0, 1.0],
|
| 446 |
+
])
|
| 447 |
+
|
| 448 |
+
if np.random.uniform(0, 1) < self.static_rotate_y_p:
|
| 449 |
+
assert len(self.static_rotate_y) > 0, "static rotation of y is enabled, but static_rotate_y is empty"
|
| 450 |
+
angle = np.random.choice(self.static_rotate_y) / 180 * np.pi
|
| 451 |
+
c = np.cos(angle)
|
| 452 |
+
s = np.sin(angle)
|
| 453 |
+
r @= np.array([
|
| 454 |
+
[ c, 0.0, -s, 0.0],
|
| 455 |
+
[ 0.0, 1.0, 0.0, 0.0],
|
| 456 |
+
[ s, 0.0, c, 0.0],
|
| 457 |
+
[ 0.0, 0.0, 0.0, 1.0],
|
| 458 |
+
])
|
| 459 |
+
|
| 460 |
+
if np.random.uniform(0, 1) < self.static_rotate_z_p:
|
| 461 |
+
assert len(self.static_rotate_z) > 0, "static rotation of z is enabled, but static_rotate_z is empty"
|
| 462 |
+
angle = np.random.choice(self.static_rotate_z) / 180 * np.pi
|
| 463 |
+
c = np.cos(angle)
|
| 464 |
+
s = np.sin(angle)
|
| 465 |
+
r @= np.array([
|
| 466 |
+
[ c, s, 0.0, 0.0],
|
| 467 |
+
[ -s, c, 0.0, 0.0],
|
| 468 |
+
[ 0.0, 0.0, 1.0, 0.0],
|
| 469 |
+
[ 0.0, 0.0, 0.0, 1.0],
|
| 470 |
+
])
|
| 471 |
+
|
| 472 |
+
if np.random.uniform(0, 1) < self.random_scale_p:
|
| 473 |
+
scale_x = np.random.uniform(self.random_scale[0], self.random_scale[1])
|
| 474 |
+
scale_y = np.random.uniform(self.random_scale[0], self.random_scale[1])
|
| 475 |
+
scale_z = np.random.uniform(self.random_scale[0], self.random_scale[1])
|
| 476 |
+
r @= np.array([
|
| 477 |
+
[scale_x, 0.0, 0.0, 0.0],
|
| 478 |
+
[0.0, scale_y, 0.0, 0.0],
|
| 479 |
+
[0.0, 0.0, scale_z, 0.0],
|
| 480 |
+
[0.0, 0.0, 0.0, 1.0],
|
| 481 |
+
])
|
| 482 |
+
|
| 483 |
+
if np.random.uniform(0, 1) < self.random_transpose:
|
| 484 |
+
permutations = [
|
| 485 |
+
(0, 1, 2), # x, y, z
|
| 486 |
+
(0, 2, 1), # x, z, y
|
| 487 |
+
(1, 0, 2), # y, x, z
|
| 488 |
+
(1, 2, 0), # y, z, x
|
| 489 |
+
(2, 0, 1), # z, x, y
|
| 490 |
+
(2, 1, 0), # z, y, x
|
| 491 |
+
]
|
| 492 |
+
direction_signs = [
|
| 493 |
+
(1, 1, 1),
|
| 494 |
+
(1, 1, -1),
|
| 495 |
+
(1, -1, 1),
|
| 496 |
+
(1, -1, -1),
|
| 497 |
+
(-1, 1, 1),
|
| 498 |
+
(-1, 1, -1),
|
| 499 |
+
(-1, -1, 1),
|
| 500 |
+
(-1, -1, -1),
|
| 501 |
+
]
|
| 502 |
+
perm = permutations[np.random.randint(0, 6)]
|
| 503 |
+
sign = direction_signs[np.random.randint(0, 8)]
|
| 504 |
+
m = np.zeros((4, 4))
|
| 505 |
+
for i in range(3):
|
| 506 |
+
m[i, perm[i]] = sign[i]
|
| 507 |
+
m[3, 3] = 1.0
|
| 508 |
+
r = m @ r
|
| 509 |
+
|
| 510 |
+
trans_vertex = r @ trans_vertex
|
| 511 |
+
|
| 512 |
+
# apply transform here
|
| 513 |
+
asset.transform(trans=trans_vertex)
|
| 514 |
+
|
| 515 |
+
@dataclass(frozen=True)
|
| 516 |
+
class AugmentAffine(Augment):
|
| 517 |
+
# final normalization cube
|
| 518 |
+
normalize_into: Tuple[float, float]
|
| 519 |
+
|
| 520 |
+
# randomly scale coordinates with probability p
|
| 521 |
+
random_scale_p: float
|
| 522 |
+
|
| 523 |
+
# scale range (lower, upper)
|
| 524 |
+
random_scale: Tuple[float, float]
|
| 525 |
+
|
| 526 |
+
# randomly shift coordinates with probability p
|
| 527 |
+
random_shift_p: float
|
| 528 |
+
|
| 529 |
+
# shift range (lower, upper)
|
| 530 |
+
random_shift: Tuple[float, float]
|
| 531 |
+
|
| 532 |
+
@classmethod
|
| 533 |
+
def parse(cls, **kwargs) -> 'AugmentAffine':
|
| 534 |
+
cls.check_keys(kwargs)
|
| 535 |
+
return AugmentAffine(
|
| 536 |
+
normalize_into=kwargs.get('normalize_into', [-1.0, 1.0]),
|
| 537 |
+
random_scale_p=kwargs.get('random_scale_p', 0.),
|
| 538 |
+
random_scale=kwargs.get('random_scale', [1., 1.]),
|
| 539 |
+
random_shift_p=kwargs.get('random_shift_p', 0.),
|
| 540 |
+
random_shift=kwargs.get('random_shift', [0., 0.]),
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def transform(self, asset: Asset, **kwargs):
|
| 544 |
+
if asset.vertices is None:
|
| 545 |
+
raise ValueError("do not have vertices")
|
| 546 |
+
bound_min = asset.vertices.min(axis=0)
|
| 547 |
+
bound_max = asset.vertices.max(axis=0)
|
| 548 |
+
if asset.joints is not None:
|
| 549 |
+
joints_bound_min = asset.joints.min(axis=0)
|
| 550 |
+
joints_bound_max = asset.joints.max(axis=0)
|
| 551 |
+
bound_min = np.minimum(bound_min, joints_bound_min)
|
| 552 |
+
bound_max = np.maximum(bound_max, joints_bound_max)
|
| 553 |
+
|
| 554 |
+
trans_vertex = np.eye(4, dtype=np.float32)
|
| 555 |
+
|
| 556 |
+
trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
|
| 557 |
+
|
| 558 |
+
if self.normalize_into is not None:
|
| 559 |
+
# scale into the cube
|
| 560 |
+
normalize_into = self.normalize_into
|
| 561 |
+
scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0]))
|
| 562 |
+
trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
|
| 563 |
+
|
| 564 |
+
bias = (normalize_into[0] + normalize_into[1]) / 2
|
| 565 |
+
trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex
|
| 566 |
+
|
| 567 |
+
if np.random.rand() < self.random_scale_p:
|
| 568 |
+
scale = _scale_to_m(np.random.uniform(self.random_scale[0], self.random_scale[1]))
|
| 569 |
+
trans_vertex = scale @ trans_vertex
|
| 570 |
+
|
| 571 |
+
if np.random.rand() < self.random_shift_p:
|
| 572 |
+
l, r = self.random_shift
|
| 573 |
+
shift_vals = np.array([
|
| 574 |
+
np.random.uniform(l, r),
|
| 575 |
+
np.random.uniform(l, r),
|
| 576 |
+
np.random.uniform(l, r),
|
| 577 |
+
], dtype=np.float32)
|
| 578 |
+
if self.normalize_into is not None:
|
| 579 |
+
def _apply(v: ndarray, trans: ndarray) -> ndarray:
|
| 580 |
+
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
|
| 581 |
+
lo, hi = self.normalize_into
|
| 582 |
+
pts_min = _apply(bound_min[None, :], trans_vertex)[0]
|
| 583 |
+
pts_max = _apply(bound_max[None, :], trans_vertex)[0]
|
| 584 |
+
low_allowed = lo - pts_min
|
| 585 |
+
high_allowed = hi - pts_max
|
| 586 |
+
shift_vals = np.array([
|
| 587 |
+
np.random.uniform(low_allowed[0], high_allowed[0]),
|
| 588 |
+
np.random.uniform(low_allowed[1], high_allowed[1]),
|
| 589 |
+
np.random.uniform(low_allowed[2], high_allowed[2]),
|
| 590 |
+
], dtype=np.float32)
|
| 591 |
+
shift = _trans_to_m(shift_vals.astype(np.float32))
|
| 592 |
+
trans_vertex = shift @ trans_vertex
|
| 593 |
+
asset.transform(trans=trans_vertex)
|
| 594 |
+
|
| 595 |
+
@dataclass(frozen=True)
|
| 596 |
+
class AugmentJitter(Augment):
|
| 597 |
+
# probability
|
| 598 |
+
p: float
|
| 599 |
+
|
| 600 |
+
# jitter sigma on vertices
|
| 601 |
+
vertex_sigma: float
|
| 602 |
+
|
| 603 |
+
# jitter clip on vertices
|
| 604 |
+
vertex_clip: float
|
| 605 |
+
|
| 606 |
+
# jitter sigma on normals
|
| 607 |
+
normal_sigma: float
|
| 608 |
+
|
| 609 |
+
# jitter clip on normals
|
| 610 |
+
normal_clip: float
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def parse(cls, **kwargs) -> 'AugmentJitter':
|
| 614 |
+
cls.check_keys(kwargs)
|
| 615 |
+
return AugmentJitter(
|
| 616 |
+
p=kwargs.get('p', 0.5),
|
| 617 |
+
vertex_sigma=kwargs.get('vertex_sigma', 0.),
|
| 618 |
+
vertex_clip=kwargs.get('vertex_clip', 0.),
|
| 619 |
+
normal_sigma=kwargs.get('normal_sigma', 0.),
|
| 620 |
+
normal_clip=kwargs.get('normal_clip', 0.),
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
def transform(self, asset: Asset, **kwargs):
|
| 624 |
+
vertex_sigma = self.vertex_sigma
|
| 625 |
+
vertex_clip = self.vertex_clip
|
| 626 |
+
normal_sigma = self.normal_sigma
|
| 627 |
+
normal_clip = self.normal_clip
|
| 628 |
+
|
| 629 |
+
if np.random.rand() < self.p:
|
| 630 |
+
scale = np.random.rand() + 1e-6
|
| 631 |
+
vertex_sigma *= scale
|
| 632 |
+
vertex_clip *= scale
|
| 633 |
+
scale = np.random.rand() + 1e-6
|
| 634 |
+
normal_sigma *= scale
|
| 635 |
+
normal_clip *= scale
|
| 636 |
+
if vertex_sigma > 0 and asset.vertices is not None:
|
| 637 |
+
noise = np.clip(np.random.randn(*asset.vertices.shape) * vertex_sigma, -vertex_clip, vertex_clip).astype(np.float32)
|
| 638 |
+
asset.vertices += noise
|
| 639 |
+
|
| 640 |
+
if normal_sigma > 0:
|
| 641 |
+
if asset.vertex_normals is not None:
|
| 642 |
+
noise = np.clip(np.random.randn(*asset.vertex_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32)
|
| 643 |
+
asset.vertex_normals += noise
|
| 644 |
+
|
| 645 |
+
if asset.face_normals is not None:
|
| 646 |
+
noise = np.clip(np.random.randn(*asset.face_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32)
|
| 647 |
+
asset.face_normals += noise
|
| 648 |
+
|
| 649 |
+
@dataclass(frozen=True)
|
| 650 |
+
class AugmentNormalize(Augment):
|
| 651 |
+
|
| 652 |
+
@classmethod
|
| 653 |
+
def parse(cls, **kwargs) -> 'AugmentNormalize':
|
| 654 |
+
cls.check_keys(kwargs)
|
| 655 |
+
return AugmentNormalize()
|
| 656 |
+
|
| 657 |
+
def transform(self, asset: Asset, **kwargs):
|
| 658 |
+
epsilon = 1e-10
|
| 659 |
+
if asset.vertex_normals is not None:
|
| 660 |
+
vertex_norms = np.linalg.norm(asset.vertex_normals, axis=1, keepdims=True)
|
| 661 |
+
vertex_norms = np.maximum(vertex_norms, epsilon)
|
| 662 |
+
asset.vertex_normals = asset.vertex_normals / vertex_norms
|
| 663 |
+
asset.vertex_normals = np.nan_to_num(asset.vertex_normals, nan=0., posinf=0., neginf=0.) # type: ignore
|
| 664 |
+
|
| 665 |
+
if asset.face_normals is not None:
|
| 666 |
+
face_norms = np.linalg.norm(asset.face_normals, axis=1, keepdims=True)
|
| 667 |
+
face_norms = np.maximum(face_norms, epsilon)
|
| 668 |
+
asset.face_normals = asset.face_normals / face_norms
|
| 669 |
+
asset.face_normals = np.nan_to_num(asset.face_normals, nan=0., posinf=0., neginf=0.) # type: ignore
|
| 670 |
+
|
| 671 |
+
def _trans_to_m(v: ndarray):
|
| 672 |
+
m = np.eye(4, dtype=np.float32)
|
| 673 |
+
m[0:3, 3] = v
|
| 674 |
+
return m
|
| 675 |
+
|
| 676 |
+
def _scale_to_m(r: ndarray|float):
|
| 677 |
+
m = np.zeros((4, 4), dtype=np.float32)
|
| 678 |
+
m[0, 0] = r
|
| 679 |
+
m[1, 1] = r
|
| 680 |
+
m[2, 2] = r
|
| 681 |
+
m[3, 3] = 1.
|
| 682 |
+
return m
|
| 683 |
+
|
| 684 |
+
def get_augments(*args) -> List[Augment]:
|
| 685 |
+
MAP = {
|
| 686 |
+
'trim': AugmentTrim,
|
| 687 |
+
'delete': AugmentDelete,
|
| 688 |
+
'drop_part': AugmentDropPart,
|
| 689 |
+
'collapse': AugmentCollapse,
|
| 690 |
+
'lbs': AugmentLBS,
|
| 691 |
+
'linear': AugmentLinear,
|
| 692 |
+
'affine': AugmentAffine,
|
| 693 |
+
'jitter': AugmentJitter,
|
| 694 |
+
'joint_perturb': AugmentJointPerturb,
|
| 695 |
+
'joint_discrete': AugmentJointDiscrete,
|
| 696 |
+
'normalize': AugmentNormalize,
|
| 697 |
+
}
|
| 698 |
+
MAP: Dict[str, type[Augment]]
|
| 699 |
+
augments = []
|
| 700 |
+
for (i, config) in enumerate(args):
|
| 701 |
+
__target__ = config.get('__target__')
|
| 702 |
+
assert __target__ is not None, f"do not find `__target__` in augment of position {i}"
|
| 703 |
+
c = deepcopy(config)
|
| 704 |
+
del c['__target__']
|
| 705 |
+
augments.append(MAP[__target__].parse(**c))
|
| 706 |
+
return augments
|
src/data/datapath.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod, ABC
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
from random import shuffle
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import requests
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from ..rig_package.info.asset import Asset
|
| 13 |
+
from ..server.spec import BPY_SERVER, bytes_to_object, object_to_bytes
|
| 14 |
+
from .spec import ConfigSpec
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class LazyAsset(ABC):
|
| 18 |
+
"""store datapath and load upon requiring"""
|
| 19 |
+
path: str
|
| 20 |
+
|
| 21 |
+
cls: Optional[str]=None
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def load(self) -> 'Asset':
|
| 25 |
+
raise NotImplementedError()
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BpyLazyAsset(LazyAsset):
|
| 29 |
+
|
| 30 |
+
def load(self) -> 'Asset':
|
| 31 |
+
from ..rig_package.parser.bpy import BpyParser
|
| 32 |
+
asset = BpyParser.load(filepath=self.path)
|
| 33 |
+
asset.cls = self.cls
|
| 34 |
+
asset.path = self.path
|
| 35 |
+
return asset
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class BpyServerLazyAsset(LazyAsset):
|
| 39 |
+
"""workaround while bpy is working in multiple threads"""
|
| 40 |
+
def load(self) -> 'Asset':
|
| 41 |
+
try:
|
| 42 |
+
asset = bytes_to_object(requests.get(f"{BPY_SERVER}/load", data=object_to_bytes(self.path)).content)
|
| 43 |
+
if isinstance(asset, str):
|
| 44 |
+
raise RuntimeError(f"bpy server failed: {asset}")
|
| 45 |
+
assert isinstance(asset, Asset)
|
| 46 |
+
asset.cls = self.cls
|
| 47 |
+
asset.path = self.path
|
| 48 |
+
return asset
|
| 49 |
+
except Exception as e:
|
| 50 |
+
raise RuntimeError(f"bpy server failed: {str(e)}")
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class NpzLazyAsset(LazyAsset):
|
| 54 |
+
|
| 55 |
+
def load(self) -> 'Asset':
|
| 56 |
+
d = np.load(self.path, allow_pickle=True)
|
| 57 |
+
asset = Asset(
|
| 58 |
+
vertices=d['vertices'],
|
| 59 |
+
faces=d['faces'],
|
| 60 |
+
mesh_names=d.get('mesh_names', None),
|
| 61 |
+
joint_names=d.get('joint_names', None),
|
| 62 |
+
parents=d.get('parents', None),
|
| 63 |
+
lengths=d.get('lengths', None),
|
| 64 |
+
matrix_world=d.get('matrix_world', None),
|
| 65 |
+
matrix_local=d.get('matrix_local', None),
|
| 66 |
+
armature_name=d.get('armature_name', None),
|
| 67 |
+
skin=d.get('skin', None),
|
| 68 |
+
cls=self.cls,
|
| 69 |
+
path=self.path
|
| 70 |
+
)
|
| 71 |
+
asset.cls = self.cls
|
| 72 |
+
asset.path = self.path
|
| 73 |
+
return asset
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class UniRigLazyAsset(LazyAsset):
|
| 77 |
+
"""map unirig's data correctly"""
|
| 78 |
+
|
| 79 |
+
def load(self) -> 'Asset':
|
| 80 |
+
def bn(x):
|
| 81 |
+
if isinstance(x, ndarray) and x.ndim==0:
|
| 82 |
+
return x.item()
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
d = np.load(self.path, allow_pickle=True)
|
| 86 |
+
parents = bn(d.get('parents', None))
|
| 87 |
+
if parents is not None:
|
| 88 |
+
parents = [-1 if x is None else x for x in parents]
|
| 89 |
+
parents = np.array(parents)
|
| 90 |
+
matrix_local = bn(d.get('matrix_local', None))
|
| 91 |
+
joints = bn(d.get('joints', None))
|
| 92 |
+
if matrix_local is not None and matrix_local.ndim != 3 and joints is not None:
|
| 93 |
+
matrix_local = np.zeros((joints.shape[0], 4, 4))
|
| 94 |
+
matrix_local[...] = np.eye(4)
|
| 95 |
+
matrix_local[:, :3, 3] = joints
|
| 96 |
+
asset = Asset(
|
| 97 |
+
vertices=d['vertices'],
|
| 98 |
+
faces=d['faces'],
|
| 99 |
+
joint_names=bn(d.get('names', None)),
|
| 100 |
+
parents=parents, # type: ignore
|
| 101 |
+
lengths=bn(d.get('lengths', None)),
|
| 102 |
+
matrix_world=bn(d.get('matrix_world', None)),
|
| 103 |
+
matrix_local=matrix_local,
|
| 104 |
+
armature_name=bn(d.get('armature_name', None)),
|
| 105 |
+
skin=bn(d.get('skin', None)),
|
| 106 |
+
cls=self.cls,
|
| 107 |
+
path=self.path
|
| 108 |
+
).change_dtype(float_dtype=np.float32, int_dtype=np.int32)
|
| 109 |
+
asset.cls = self.cls
|
| 110 |
+
asset.path = self.path
|
| 111 |
+
return asset
|
| 112 |
+
|
| 113 |
+
@dataclass
|
| 114 |
+
class Datapath(ConfigSpec):
|
| 115 |
+
"""handle input data paths"""
|
| 116 |
+
|
| 117 |
+
# all filepaths
|
| 118 |
+
filepaths: List[str]
|
| 119 |
+
|
| 120 |
+
# root to add to prefix
|
| 121 |
+
input_dataset_dir: str=''
|
| 122 |
+
|
| 123 |
+
# name of class
|
| 124 |
+
cls_name: Optional[List[str]]=None
|
| 125 |
+
|
| 126 |
+
# bias in a single class
|
| 127 |
+
cls_bias: Optional[List[int]]=None
|
| 128 |
+
|
| 129 |
+
# num of files in a single class
|
| 130 |
+
cls_length: Optional[List[int]]=None
|
| 131 |
+
|
| 132 |
+
# how many files to return when using data sampling
|
| 133 |
+
num_files: Optional[int]=None
|
| 134 |
+
|
| 135 |
+
# use proportion data sampling
|
| 136 |
+
use_prob: bool=False
|
| 137 |
+
|
| 138 |
+
# weight
|
| 139 |
+
cls_weight: Optional[List[float]]=None
|
| 140 |
+
|
| 141 |
+
# use bpy loader
|
| 142 |
+
loader: type[LazyAsset]=BpyLazyAsset
|
| 143 |
+
|
| 144 |
+
# data name
|
| 145 |
+
data_name: Optional[str]=None
|
| 146 |
+
|
| 147 |
+
# check if path exists
|
| 148 |
+
ignore_check: bool=False
|
| 149 |
+
|
| 150 |
+
#################################################################
|
| 151 |
+
# other vertex groups
|
| 152 |
+
vertex_groups: Dict[str, ndarray]=field(default_factory=dict)
|
| 153 |
+
|
| 154 |
+
# sampled vertices
|
| 155 |
+
sampled_vertices: Optional[ndarray]=None
|
| 156 |
+
|
| 157 |
+
# sampled normals
|
| 158 |
+
sampled_normals: Optional[ndarray]=None
|
| 159 |
+
|
| 160 |
+
# sampled vertex groups
|
| 161 |
+
sampled_vertex_groups: Optional[Dict[str, ndarray]]=None
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def parse(cls, **kwargs) -> 'Datapath':
|
| 165 |
+
MAP = {
|
| 166 |
+
None: BpyLazyAsset,
|
| 167 |
+
'bpy': BpyLazyAsset,
|
| 168 |
+
'bpy_server': BpyServerLazyAsset,
|
| 169 |
+
'npz': NpzLazyAsset,
|
| 170 |
+
'unirig': UniRigLazyAsset,
|
| 171 |
+
}
|
| 172 |
+
input_dataset_dir = kwargs.get('input_dataset_dir', '')
|
| 173 |
+
num_files = kwargs.get('num_files', None)
|
| 174 |
+
use_prob = kwargs.get('use_prob', False)
|
| 175 |
+
data_name = kwargs.get('data_name', 'raw_data.npz')
|
| 176 |
+
data_path = kwargs.get('data_path', None)
|
| 177 |
+
loader_cls = MAP[kwargs.get('loader', None)]
|
| 178 |
+
ignore_check = kwargs.get('ignore_check', False)
|
| 179 |
+
|
| 180 |
+
if data_path is not None:
|
| 181 |
+
filepaths = []
|
| 182 |
+
if isinstance(data_path, dict):
|
| 183 |
+
cls_name = []
|
| 184 |
+
cls_bias = []
|
| 185 |
+
cls_length = []
|
| 186 |
+
cls_weight = []
|
| 187 |
+
for name, v in data_path.items():
|
| 188 |
+
assert isinstance(v, list), "items in the dict must be a list of data list paths"
|
| 189 |
+
for item in v:
|
| 190 |
+
if isinstance(item, str):
|
| 191 |
+
datalist_path = item
|
| 192 |
+
weight = 1.0
|
| 193 |
+
else:
|
| 194 |
+
datalist_path = item[0]
|
| 195 |
+
weight = item[1]
|
| 196 |
+
cls_name.append(name)
|
| 197 |
+
lines = [x.strip() for x in open(datalist_path, "r").readlines()]
|
| 198 |
+
ok_lines = []
|
| 199 |
+
missing = 0
|
| 200 |
+
for line in lines:
|
| 201 |
+
if ignore_check:
|
| 202 |
+
ok_lines.append(line)
|
| 203 |
+
elif os.path.exists(os.path.join(input_dataset_dir, line, data_name)):
|
| 204 |
+
ok_lines.append(line)
|
| 205 |
+
else:
|
| 206 |
+
missing += 1
|
| 207 |
+
if missing != 0:
|
| 208 |
+
print(f"\033[31m{datalist_path}: {missing} missing files\033[0m")
|
| 209 |
+
cls_bias.append(len(filepaths))
|
| 210 |
+
cls_length.append(len(ok_lines))
|
| 211 |
+
cls_weight.append(weight)
|
| 212 |
+
filepaths.extend(ok_lines)
|
| 213 |
+
else:
|
| 214 |
+
raise NotImplementedError()
|
| 215 |
+
else:
|
| 216 |
+
_filepaths = kwargs['filepaths']
|
| 217 |
+
if isinstance(_filepaths, list):
|
| 218 |
+
filepaths = _filepaths
|
| 219 |
+
cls_name = None
|
| 220 |
+
cls_bias = None
|
| 221 |
+
cls_length = None
|
| 222 |
+
cls_weight = None
|
| 223 |
+
elif isinstance(_filepaths, dict):
|
| 224 |
+
filepaths = []
|
| 225 |
+
cls_name = []
|
| 226 |
+
cls_bias = []
|
| 227 |
+
cls_length = []
|
| 228 |
+
cls_weight = []
|
| 229 |
+
for k, v in _filepaths.items():
|
| 230 |
+
assert isinstance(v, list), "items in the dict must be a list of paths"
|
| 231 |
+
cls_name.append(k)
|
| 232 |
+
cls_bias.append(len(filepaths))
|
| 233 |
+
cls_length.append(len(v))
|
| 234 |
+
cls_weight.append(1.0)
|
| 235 |
+
filepaths.extend(v)
|
| 236 |
+
else:
|
| 237 |
+
raise NotImplementedError()
|
| 238 |
+
if cls_weight is not None:
|
| 239 |
+
total = sum(cls_weight)
|
| 240 |
+
cls_weight = [x/total for x in cls_weight]
|
| 241 |
+
return Datapath(
|
| 242 |
+
filepaths=filepaths,
|
| 243 |
+
input_dataset_dir=input_dataset_dir,
|
| 244 |
+
cls_name=cls_name,
|
| 245 |
+
cls_bias=cls_bias,
|
| 246 |
+
cls_length=cls_length,
|
| 247 |
+
num_files=num_files,
|
| 248 |
+
use_prob=use_prob,
|
| 249 |
+
cls_weight=cls_weight,
|
| 250 |
+
loader=loader_cls,
|
| 251 |
+
data_name=data_name,
|
| 252 |
+
ignore_check=ignore_check,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def make(self, path: str, cls: str|None) -> LazyAsset:
|
| 256 |
+
return self.loader(path=path, cls=cls)
|
| 257 |
+
|
| 258 |
+
def __getitem__(self, index: int) -> LazyAsset:
|
| 259 |
+
if self.use_prob and self.cls_weight is not None:
|
| 260 |
+
if self.cls_bias is None:
|
| 261 |
+
raise ValueError("do not have cls_bias")
|
| 262 |
+
if self.cls_length is None:
|
| 263 |
+
raise ValueError("do not have cls_length")
|
| 264 |
+
if not hasattr(self, "perms"):
|
| 265 |
+
self.perms = []
|
| 266 |
+
self.current_bias = []
|
| 267 |
+
for i in range(len(self.cls_weight)):
|
| 268 |
+
self.perms.append([x for x in range(self.cls_length[i])])
|
| 269 |
+
self.current_bias.append(0)
|
| 270 |
+
idx = np.random.choice(len(self.cls_weight), p=self.cls_weight)
|
| 271 |
+
i = self.perms[idx][self.current_bias[idx]]
|
| 272 |
+
self.current_bias[idx] += 1
|
| 273 |
+
if self.current_bias[idx] >= self.cls_length[idx]:
|
| 274 |
+
shuffle(self.perms[idx])
|
| 275 |
+
self.current_bias[idx] = 0
|
| 276 |
+
if self.cls_name is None:
|
| 277 |
+
name = None
|
| 278 |
+
else:
|
| 279 |
+
name = self.cls_name[idx]
|
| 280 |
+
path = os.path.join(self.input_dataset_dir, self.filepaths[i+self.cls_bias[idx]])
|
| 281 |
+
if self.data_name is not None:
|
| 282 |
+
path = os.path.join(path, self.data_name)
|
| 283 |
+
return self.make(path=path, cls=name)
|
| 284 |
+
else:
|
| 285 |
+
if self.cls_name is None or self.cls_bias is None or self.cls_length is None:
|
| 286 |
+
name = None
|
| 287 |
+
else:
|
| 288 |
+
name = None
|
| 289 |
+
for i in range(len(self.cls_bias)):
|
| 290 |
+
start = self.cls_bias[i]
|
| 291 |
+
end = start + self.cls_length[i]
|
| 292 |
+
if start <= index < end:
|
| 293 |
+
name = self.cls_name[i]
|
| 294 |
+
break
|
| 295 |
+
path = os.path.join(self.input_dataset_dir, self.filepaths[index])
|
| 296 |
+
if self.data_name is not None:
|
| 297 |
+
path = os.path.join(path, self.data_name)
|
| 298 |
+
return self.make(path=path, cls=name)
|
| 299 |
+
|
| 300 |
+
def get_data(self) -> List[LazyAsset]:
|
| 301 |
+
return [self[i] for i in range(len(self))]
|
| 302 |
+
|
| 303 |
+
def split_by_cls(self) -> Dict[str|None, 'Datapath']:
|
| 304 |
+
res: Dict[str|None, Datapath] = {}
|
| 305 |
+
if self.cls_name is None:
|
| 306 |
+
res[None] = self
|
| 307 |
+
return res
|
| 308 |
+
if self.cls_bias is None:
|
| 309 |
+
raise ValueError("do not have cls_bias")
|
| 310 |
+
if self.cls_length is None:
|
| 311 |
+
raise ValueError("do not have cls_length")
|
| 312 |
+
d_filepaths = defaultdict(list)
|
| 313 |
+
d_length = defaultdict(int)
|
| 314 |
+
d_weight = defaultdict(list)
|
| 315 |
+
for (i, cls) in enumerate(self.cls_name):
|
| 316 |
+
s = slice(self.cls_bias[i], self.cls_bias[i]+self.cls_length[i])
|
| 317 |
+
d_filepaths[cls].extend(self.filepaths[s].copy())
|
| 318 |
+
d_length[cls] += self.cls_length[i]
|
| 319 |
+
if self.cls_weight is not None:
|
| 320 |
+
d_weight[cls].append(self.cls_weight[i])
|
| 321 |
+
for cls in d_filepaths:
|
| 322 |
+
cls_weight = None if self.cls_weight is None else d_weight[cls]
|
| 323 |
+
if cls_weight is not None:
|
| 324 |
+
total = sum(cls_weight)
|
| 325 |
+
cls_weight = [x/total for x in cls_weight]
|
| 326 |
+
res[cls] = Datapath(
|
| 327 |
+
filepaths=d_filepaths[cls],
|
| 328 |
+
input_dataset_dir=self.input_dataset_dir,
|
| 329 |
+
cls_name=[cls],
|
| 330 |
+
cls_bias=[0],
|
| 331 |
+
cls_length=[len(d_filepaths[cls])],
|
| 332 |
+
num_files=self.num_files,
|
| 333 |
+
use_prob=self.use_prob,
|
| 334 |
+
cls_weight=cls_weight,
|
| 335 |
+
loader=self.loader,
|
| 336 |
+
data_name=self.data_name,
|
| 337 |
+
)
|
| 338 |
+
return res
|
| 339 |
+
|
| 340 |
+
def __len__(self):
|
| 341 |
+
if self.use_prob:
|
| 342 |
+
assert self.num_files is not None, 'num_files is not specified'
|
| 343 |
+
return self.num_files
|
| 344 |
+
return len(self.filepaths)
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
from torch.utils.data import DataLoader, Dataset
|
| 8 |
+
from typing import Dict, List, Tuple, Callable, Optional
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import lightning.pytorch as pl
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from .datapath import Datapath, LazyAsset
|
| 16 |
+
from .spec import ConfigSpec
|
| 17 |
+
from .transform import Transform
|
| 18 |
+
|
| 19 |
+
from ..model.spec import ModelInput
|
| 20 |
+
from ..rig_package.info.asset import Asset
|
| 21 |
+
from ..tokenizer.spec import Tokenizer, TokenizeInput
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class DatasetConfig(ConfigSpec):
|
| 25 |
+
shuffle: bool
|
| 26 |
+
batch_size: int
|
| 27 |
+
num_workers: int
|
| 28 |
+
datapath: Datapath
|
| 29 |
+
pin_memory: bool=True
|
| 30 |
+
persistent_workers: bool=True
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def parse(cls, **kwargs) -> 'DatasetConfig':
|
| 34 |
+
cls.check_keys(kwargs)
|
| 35 |
+
return DatasetConfig(
|
| 36 |
+
shuffle=kwargs.get('shuffle', False),
|
| 37 |
+
batch_size=kwargs.get('batch_size', 1),
|
| 38 |
+
num_workers=kwargs.get('num_workers', 1),
|
| 39 |
+
pin_memory=kwargs.get('pin_memory', True),
|
| 40 |
+
persistent_workers=kwargs.get('persistent_workers', True),
|
| 41 |
+
datapath=Datapath.parse(**kwargs.get('datapath')),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def split_by_cls(self) -> Dict[str|None, 'DatasetConfig']:
|
| 45 |
+
res: Dict[str|None, DatasetConfig] = {}
|
| 46 |
+
datapath_dict = self.datapath.split_by_cls()
|
| 47 |
+
for cls, v in datapath_dict.items():
|
| 48 |
+
res[cls] = DatasetConfig(
|
| 49 |
+
shuffle=self.shuffle,
|
| 50 |
+
batch_size=self.batch_size,
|
| 51 |
+
num_workers=self.num_workers,
|
| 52 |
+
datapath=v,
|
| 53 |
+
pin_memory=self.pin_memory,
|
| 54 |
+
persistent_workers=self.persistent_workers,
|
| 55 |
+
)
|
| 56 |
+
return res
|
| 57 |
+
|
| 58 |
+
class RigDatasetModule(pl.LightningDataModule):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None,
|
| 62 |
+
train_dataset_config: Optional[DatasetConfig]=None,
|
| 63 |
+
validate_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None,
|
| 64 |
+
predict_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None,
|
| 65 |
+
train_transform: Optional[Transform]=None,
|
| 66 |
+
validate_transform: Optional[Transform]=None,
|
| 67 |
+
predict_transform: Optional[Transform]=None,
|
| 68 |
+
tokenizer: Optional[Tokenizer]=None,
|
| 69 |
+
debug: bool=False,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.process_fn = process_fn
|
| 73 |
+
self.train_dataset_config = train_dataset_config
|
| 74 |
+
self.validate_dataset_config = validate_dataset_config
|
| 75 |
+
self.predict_dataset_config = predict_dataset_config
|
| 76 |
+
self.train_transform = train_transform
|
| 77 |
+
self.validate_transform = validate_transform
|
| 78 |
+
self.predict_transform = predict_transform
|
| 79 |
+
self.tokenizer = tokenizer
|
| 80 |
+
self.debug = debug
|
| 81 |
+
|
| 82 |
+
if debug:
|
| 83 |
+
print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
|
| 84 |
+
|
| 85 |
+
# build train datapath
|
| 86 |
+
if self.train_dataset_config is not None:
|
| 87 |
+
self.train_datapath = self.train_dataset_config.datapath
|
| 88 |
+
else:
|
| 89 |
+
self.train_datapath = None
|
| 90 |
+
|
| 91 |
+
# build validate datapath
|
| 92 |
+
if self.validate_dataset_config is not None:
|
| 93 |
+
self.validate_datapath = {
|
| 94 |
+
cls: self.validate_dataset_config[cls].datapath
|
| 95 |
+
for cls in self.validate_dataset_config
|
| 96 |
+
}
|
| 97 |
+
else:
|
| 98 |
+
self.validate_datapath = None
|
| 99 |
+
|
| 100 |
+
# build predict datapath
|
| 101 |
+
if self.predict_dataset_config is not None:
|
| 102 |
+
self.predict_datapath = {
|
| 103 |
+
cls: self.predict_dataset_config[cls].datapath
|
| 104 |
+
for cls in self.predict_dataset_config
|
| 105 |
+
}
|
| 106 |
+
else:
|
| 107 |
+
self.predict_datapath = None
|
| 108 |
+
|
| 109 |
+
self.tokenizer = tokenizer
|
| 110 |
+
|
| 111 |
+
def prepare_data(self):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
| 115 |
+
if self.train_dataset_config is None:
|
| 116 |
+
raise ValueError("do not have train_dataset_config")
|
| 117 |
+
if self.train_transform is None:
|
| 118 |
+
raise ValueError("do not have train_transform")
|
| 119 |
+
if self.train_datapath is not None:
|
| 120 |
+
self._train_ds = RigDataset(
|
| 121 |
+
process_fn=self.process_fn,
|
| 122 |
+
data=self.train_datapath.get_data(),
|
| 123 |
+
name="train",
|
| 124 |
+
tokenizer=self.tokenizer,
|
| 125 |
+
transform=self.train_transform,
|
| 126 |
+
debug=self.debug,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
return None
|
| 130 |
+
return self._create_dataloader(
|
| 131 |
+
dataset=self._train_ds,
|
| 132 |
+
config=self.train_dataset_config,
|
| 133 |
+
is_train=True,
|
| 134 |
+
drop_last=False,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
| 138 |
+
if self.validate_dataset_config is None:
|
| 139 |
+
raise ValueError("do not have validate_dataset_config")
|
| 140 |
+
if self.validate_transform is None:
|
| 141 |
+
raise ValueError("do not have validate_transform")
|
| 142 |
+
if self.validate_datapath is not None:
|
| 143 |
+
self._validation_ds = {}
|
| 144 |
+
for cls in self.validate_datapath:
|
| 145 |
+
self._validation_ds[cls] = RigDataset(
|
| 146 |
+
process_fn=self.process_fn,
|
| 147 |
+
data=self.validate_datapath[cls].get_data(),
|
| 148 |
+
name=f"validate-{cls}",
|
| 149 |
+
tokenizer=self.tokenizer,
|
| 150 |
+
transform=self.validate_transform,
|
| 151 |
+
debug=self.debug,
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
return None
|
| 155 |
+
return self._create_dataloader(
|
| 156 |
+
dataset=self._validation_ds,
|
| 157 |
+
config=self.validate_dataset_config,
|
| 158 |
+
is_train=False,
|
| 159 |
+
drop_last=False,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def predict_dataloader(self):
|
| 163 |
+
if self.predict_dataset_config is None:
|
| 164 |
+
raise ValueError("do not have predict_dataset_config")
|
| 165 |
+
if self.predict_transform is None:
|
| 166 |
+
raise ValueError("do not have predict_transform")
|
| 167 |
+
if self.predict_datapath is not None:
|
| 168 |
+
self._predict_ds = {}
|
| 169 |
+
for cls in self.predict_datapath:
|
| 170 |
+
self._predict_ds[cls] = RigDataset(
|
| 171 |
+
process_fn=self.process_fn,
|
| 172 |
+
data=self.predict_datapath[cls].get_data(),
|
| 173 |
+
name=f"predict-{cls}",
|
| 174 |
+
tokenizer=self.tokenizer,
|
| 175 |
+
transform=self.predict_transform,
|
| 176 |
+
debug=self.debug,
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
return None
|
| 180 |
+
return self._create_dataloader(
|
| 181 |
+
dataset=self._predict_ds,
|
| 182 |
+
config=self.predict_dataset_config,
|
| 183 |
+
is_train=False,
|
| 184 |
+
drop_last=False,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def _create_dataloader(
|
| 188 |
+
self,
|
| 189 |
+
dataset: Dataset|Dict[str, Dataset],
|
| 190 |
+
config: DatasetConfig|Dict[str|None, DatasetConfig],
|
| 191 |
+
is_train: bool,
|
| 192 |
+
**kwargs,
|
| 193 |
+
) -> DataLoader|Dict[str, DataLoader]:
|
| 194 |
+
def create_single_dataloader(dataset, config: DatasetConfig, **kwargs):
|
| 195 |
+
return DataLoader(
|
| 196 |
+
dataset,
|
| 197 |
+
batch_size=config.batch_size,
|
| 198 |
+
shuffle=config.shuffle,
|
| 199 |
+
num_workers=config.num_workers,
|
| 200 |
+
pin_memory=config.pin_memory,
|
| 201 |
+
persistent_workers=config.persistent_workers,
|
| 202 |
+
collate_fn=dataset.collate_fn,
|
| 203 |
+
**kwargs,
|
| 204 |
+
)
|
| 205 |
+
if isinstance(dataset, Dict):
|
| 206 |
+
assert isinstance(config, dict)
|
| 207 |
+
return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
|
| 208 |
+
else:
|
| 209 |
+
assert isinstance(config, DatasetConfig)
|
| 210 |
+
return create_single_dataloader(dataset, config, **kwargs)
|
| 211 |
+
|
| 212 |
+
class RigDataset(Dataset):
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
data: List[LazyAsset],
|
| 216 |
+
transform: Transform,
|
| 217 |
+
name: Optional[str]=None,
|
| 218 |
+
process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None,
|
| 219 |
+
tokenizer: Optional[Tokenizer]=None,
|
| 220 |
+
debug: bool=False,
|
| 221 |
+
) -> None:
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.data = data
|
| 225 |
+
self.name = name
|
| 226 |
+
self.process_fn = process_fn
|
| 227 |
+
self.tokenizer = tokenizer
|
| 228 |
+
self.transform = transform
|
| 229 |
+
self.debug = debug
|
| 230 |
+
|
| 231 |
+
if not debug:
|
| 232 |
+
assert self.process_fn is not None, 'missing data processing function'
|
| 233 |
+
|
| 234 |
+
def __len__(self) -> int:
|
| 235 |
+
return len(self.data)
|
| 236 |
+
|
| 237 |
+
def __getitem__(self, idx) -> ModelInput:
|
| 238 |
+
lazy_asset = self.data[idx]
|
| 239 |
+
asset = lazy_asset.load()
|
| 240 |
+
self.transform.apply(asset=asset)
|
| 241 |
+
if self.tokenizer is not None and asset.parents is not None:
|
| 242 |
+
x = TokenizeInput(
|
| 243 |
+
joints=asset.joints,
|
| 244 |
+
parents=asset.parents,
|
| 245 |
+
cls=asset.cls,
|
| 246 |
+
joint_names=asset.joint_names,
|
| 247 |
+
)
|
| 248 |
+
tokens = self.tokenizer.tokenize(input=x)
|
| 249 |
+
else:
|
| 250 |
+
tokens = None
|
| 251 |
+
return ModelInput(asset=asset, tokens=tokens)
|
| 252 |
+
|
| 253 |
+
def _collate_fn_debug(self, batch):
|
| 254 |
+
return batch
|
| 255 |
+
|
| 256 |
+
def _collate_fn(self, batch):
|
| 257 |
+
processed_batch = self.process_fn(batch) # type: ignore
|
| 258 |
+
processed_batch: List[Dict]
|
| 259 |
+
|
| 260 |
+
tensors_stack = {}
|
| 261 |
+
tensors_cat = {}
|
| 262 |
+
non_tensors = {}
|
| 263 |
+
vis = {}
|
| 264 |
+
def check(x):
|
| 265 |
+
assert x not in vis, f"multiple keys found: {x}"
|
| 266 |
+
vis[x] = True
|
| 267 |
+
|
| 268 |
+
for k, v in processed_batch[0].items():
|
| 269 |
+
if k == "cat":
|
| 270 |
+
assert isinstance(v, dict)
|
| 271 |
+
for k1 in v.keys():
|
| 272 |
+
check(k1)
|
| 273 |
+
tensors_cat[k1] = []
|
| 274 |
+
for i in range(len(processed_batch)):
|
| 275 |
+
v1 = processed_batch[i]['cat'][k1]
|
| 276 |
+
if isinstance(v1, ndarray):
|
| 277 |
+
v1 = torch.from_numpy(v1)
|
| 278 |
+
elif isinstance(v1, Tensor):
|
| 279 |
+
v1 = v1
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"cannot concatenate non-tensor type of key {k1}, type: {type(v1)}")
|
| 282 |
+
tensors_cat[k1].append(v1)
|
| 283 |
+
elif k == "non":
|
| 284 |
+
assert isinstance(v, dict)
|
| 285 |
+
for k1 in v.keys():
|
| 286 |
+
check(k1)
|
| 287 |
+
non_tensors[k1] = []
|
| 288 |
+
for i in range(len(processed_batch)):
|
| 289 |
+
v1 = processed_batch[i]['non'][k1]
|
| 290 |
+
if isinstance(v1, ndarray):
|
| 291 |
+
v1 = torch.from_numpy(v1)
|
| 292 |
+
non_tensors[k1].append(v1)
|
| 293 |
+
else:
|
| 294 |
+
check(k)
|
| 295 |
+
tensors_stack[k] = []
|
| 296 |
+
for i in range(len(processed_batch)):
|
| 297 |
+
v1 = processed_batch[i][k]
|
| 298 |
+
if isinstance(v1, ndarray):
|
| 299 |
+
v1 = torch.from_numpy(v1)
|
| 300 |
+
elif isinstance(v1, Tensor):
|
| 301 |
+
v1 = v1
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(f"cannot stack type of key {k}, type: {type(v1)}")
|
| 304 |
+
tensors_stack[k].append(v1)
|
| 305 |
+
|
| 306 |
+
collated_stack = {k: torch.stack(v) for k, v in tensors_stack.items()}
|
| 307 |
+
collated_cat = {k: torch.concat(v, dim=1) for k, v in tensors_cat.items()}
|
| 308 |
+
|
| 309 |
+
collated_batch = {
|
| 310 |
+
**collated_stack,
|
| 311 |
+
**collated_cat,
|
| 312 |
+
**non_tensors,
|
| 313 |
+
}
|
| 314 |
+
return collated_batch
|
| 315 |
+
|
| 316 |
+
def collate_fn(self, batch):
|
| 317 |
+
if self.debug:
|
| 318 |
+
return self._collate_fn_debug(batch)
|
| 319 |
+
return self._collate_fn(batch)
|
src/data/order.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from typing import Dict, List, Tuple, Optional
|
| 7 |
+
|
| 8 |
+
from .spec import ConfigSpec
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Order(ConfigSpec):
|
| 12 |
+
|
| 13 |
+
# {part_name: [bone_name_1, bone_name_2, ...]}
|
| 14 |
+
parts: Dict[str, Dict[str, List[str]]]
|
| 15 |
+
|
| 16 |
+
# parts of bones to be arranged in [part_name_1, part_name_2, ...]
|
| 17 |
+
parts_order: Dict[str, List[str]]
|
| 18 |
+
|
| 19 |
+
# {skeleton_name: path}
|
| 20 |
+
skeleton_path: Optional[Dict[str, str]]=None
|
| 21 |
+
|
| 22 |
+
sort_by_xyz: bool=False
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def parse(cls, **kwargs) -> 'Order':
|
| 26 |
+
cls.check_keys(kwargs)
|
| 27 |
+
skeleton_path = kwargs.get('skeleton_path', None)
|
| 28 |
+
if skeleton_path is not None:
|
| 29 |
+
parts = {}
|
| 30 |
+
parts_order = {}
|
| 31 |
+
for (cls, path) in skeleton_path.items():
|
| 32 |
+
assert cls not in parts, 'cls conflicts'
|
| 33 |
+
d = OmegaConf.load(path)
|
| 34 |
+
parts[cls] = d.parts
|
| 35 |
+
parts_order[cls] = d.parts_order
|
| 36 |
+
else:
|
| 37 |
+
parts = kwargs.get('parts')
|
| 38 |
+
parts_order = kwargs.get('parts_order')
|
| 39 |
+
assert parts is not None
|
| 40 |
+
assert parts_order is not None
|
| 41 |
+
return Order(
|
| 42 |
+
skeleton_path=skeleton_path,
|
| 43 |
+
parts=parts,
|
| 44 |
+
parts_order=parts_order,
|
| 45 |
+
sort_by_xyz=kwargs.get('sort_by_xyz', False),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
|
| 49 |
+
'''
|
| 50 |
+
Check if part exists.
|
| 51 |
+
'''
|
| 52 |
+
if part not in self.parts[cls]:
|
| 53 |
+
return False
|
| 54 |
+
for name in self.parts[cls][part]:
|
| 55 |
+
if name not in names:
|
| 56 |
+
return False
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]:
|
| 60 |
+
'''
|
| 61 |
+
Get names for specified cls.
|
| 62 |
+
'''
|
| 63 |
+
names = []
|
| 64 |
+
for part in parts:
|
| 65 |
+
if part is None: # spring
|
| 66 |
+
continue
|
| 67 |
+
if cls in self.parts and part in self.parts[cls]:
|
| 68 |
+
names.extend(self.parts[cls][part])
|
| 69 |
+
assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
|
| 70 |
+
for i in range(len(names), num_bones):
|
| 71 |
+
names.append(f"bone_{i}")
|
| 72 |
+
return names
|
| 73 |
+
|
| 74 |
+
def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]:
|
| 75 |
+
'''
|
| 76 |
+
Arrange names according to required parts order.
|
| 77 |
+
'''
|
| 78 |
+
def sort_by_xyz(joints):
|
| 79 |
+
return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1]))
|
| 80 |
+
|
| 81 |
+
if self.sort_by_xyz:
|
| 82 |
+
assert joints is not None
|
| 83 |
+
new_names = []
|
| 84 |
+
root = -1
|
| 85 |
+
son = defaultdict(list)
|
| 86 |
+
not_root = {}
|
| 87 |
+
for (i, p) in enumerate(parents):
|
| 88 |
+
if p != -1:
|
| 89 |
+
son[p].append(i)
|
| 90 |
+
not_root[i] = True
|
| 91 |
+
for i in range(len(parents)):
|
| 92 |
+
if not_root.get(i, False) == False:
|
| 93 |
+
root = i
|
| 94 |
+
break
|
| 95 |
+
Q = [root]
|
| 96 |
+
while Q:
|
| 97 |
+
u = Q.pop(0)
|
| 98 |
+
new_names.append(names[u])
|
| 99 |
+
wait = []
|
| 100 |
+
for v in son[u]:
|
| 101 |
+
wait.append((v, joints[v]))
|
| 102 |
+
wait_sorted = sort_by_xyz(wait)
|
| 103 |
+
new_wait = [v for v, _ in wait_sorted]
|
| 104 |
+
Q = new_wait + Q
|
| 105 |
+
return new_names, {}
|
| 106 |
+
if cls not in self.parts_order:
|
| 107 |
+
return names, {0: None} # add a spring token
|
| 108 |
+
vis = defaultdict(bool)
|
| 109 |
+
name_to_id = {name: i for (i, name) in enumerate(names)}
|
| 110 |
+
new_names = []
|
| 111 |
+
parts_bias = {}
|
| 112 |
+
for part in self.parts_order[cls]:
|
| 113 |
+
if self.part_exists(cls=cls, part=part, names=names):
|
| 114 |
+
for name in self.parts[cls][part]:
|
| 115 |
+
vis[name] = True
|
| 116 |
+
flag = False
|
| 117 |
+
for name in self.parts[cls][part]:
|
| 118 |
+
pid = parents[name_to_id[name]]
|
| 119 |
+
if pid==-1:
|
| 120 |
+
continue
|
| 121 |
+
if not vis[names[pid]]:
|
| 122 |
+
flag = True
|
| 123 |
+
break
|
| 124 |
+
if flag: # incorrect parts order and should immediately add a spring token
|
| 125 |
+
break
|
| 126 |
+
parts_bias[len(new_names)] = part
|
| 127 |
+
new_names.extend(self.parts[cls][part])
|
| 128 |
+
parts_bias[len(new_names)] = None # add a spring token
|
| 129 |
+
for name in names:
|
| 130 |
+
if name not in new_names:
|
| 131 |
+
new_names.append(name)
|
| 132 |
+
return new_names, parts_bias
|
src/data/sampler.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from numpy import ndarray
|
| 4 |
+
from scipy.spatial import cKDTree # type: ignore
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from ..rig_package.info.asset import Asset
|
| 11 |
+
from ..rig_package.utils import sample_vertex_groups
|
| 12 |
+
from .spec import ConfigSpec
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SamplerResult():
|
| 16 |
+
sampled_vertices: Optional[ndarray]=None
|
| 17 |
+
sampled_normals: Optional[ndarray]=None
|
| 18 |
+
sampled_vertex_groups: Optional[Dict[str, ndarray]]=None
|
| 19 |
+
|
| 20 |
+
# number of sampled skin
|
| 21 |
+
skin_samples: Optional[int]=None
|
| 22 |
+
|
| 23 |
+
class Sampler(ABC):
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def sample(
|
| 26 |
+
self,
|
| 27 |
+
asset: Asset,
|
| 28 |
+
) -> SamplerResult:
|
| 29 |
+
'''
|
| 30 |
+
Return sampled vertices, sampled normals and vertex groups.
|
| 31 |
+
'''
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def parse(cls, **kwargs) -> 'Sampler':
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class SamplerMix(Sampler, ConfigSpec):
|
| 41 |
+
num_samples: int
|
| 42 |
+
num_vertex_samples: int
|
| 43 |
+
num_skin_samples: Optional[int]=None
|
| 44 |
+
replace: bool=True
|
| 45 |
+
all_skeleton: Optional[bool]=None
|
| 46 |
+
max_distance: float=0.1
|
| 47 |
+
rate_distance: float=0.1
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def parse(cls, **kwargs) -> 'SamplerMix':
|
| 51 |
+
cls.check_keys(kwargs)
|
| 52 |
+
return SamplerMix(
|
| 53 |
+
num_samples=kwargs.get('num_samples', 0),
|
| 54 |
+
num_vertex_samples=kwargs.get('num_vertex_samples', 0),
|
| 55 |
+
num_skin_samples=kwargs.get('num_skin_samples', None),
|
| 56 |
+
replace=kwargs.get('replace', True),
|
| 57 |
+
all_skeleton=kwargs.get('all_skeleton', None),
|
| 58 |
+
max_distance=kwargs.get('max_distance', 0.1),
|
| 59 |
+
rate_distance=kwargs.get('rate_distance', 0.1),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def sample_on_skin(
|
| 63 |
+
self,
|
| 64 |
+
skin: ndarray,
|
| 65 |
+
vertices: ndarray,
|
| 66 |
+
faces: ndarray,
|
| 67 |
+
):
|
| 68 |
+
face_has_skin = np.any(skin[faces] > 0, axis=-1)
|
| 69 |
+
if face_has_skin.sum() == 0:
|
| 70 |
+
face_has_skin = np.ones_like(face_has_skin)
|
| 71 |
+
elif self.max_distance < 1e-5:
|
| 72 |
+
return face_has_skin
|
| 73 |
+
else:
|
| 74 |
+
# sample near points
|
| 75 |
+
p = np.unique(faces[face_has_skin].reshape(-1))
|
| 76 |
+
tree = cKDTree(vertices[p])
|
| 77 |
+
dis, _ = tree.query(vertices, k=1)
|
| 78 |
+
dis_skin = np.sqrt(((np.max(vertices[p], axis=0) - np.min(vertices[p], axis=0))**2).sum())
|
| 79 |
+
mask_face_near = np.any(dis[faces] < min(self.max_distance, dis_skin * self.rate_distance), axis=-1)
|
| 80 |
+
face_has_skin |= mask_face_near
|
| 81 |
+
return face_has_skin
|
| 82 |
+
|
| 83 |
+
def sample(
|
| 84 |
+
self,
|
| 85 |
+
asset: Asset,
|
| 86 |
+
) -> SamplerResult:
|
| 87 |
+
if asset.vertices is None:
|
| 88 |
+
raise ValueError("do not have vertices")
|
| 89 |
+
if asset.faces is None:
|
| 90 |
+
raise ValueError("do not have faces")
|
| 91 |
+
vertex_groups = []
|
| 92 |
+
mapping = {}
|
| 93 |
+
tot = 0
|
| 94 |
+
for k, v in asset.vertex_groups.items():
|
| 95 |
+
if v.ndim == 1:
|
| 96 |
+
v = v[:, None]
|
| 97 |
+
elif v.ndim != 2:
|
| 98 |
+
raise ValueError(f"ndim of key {k} is {v.ndim}")
|
| 99 |
+
s = tot
|
| 100 |
+
e = tot + v.shape[1]
|
| 101 |
+
mapping[k] = slice(s,e)
|
| 102 |
+
vertex_groups.append(v)
|
| 103 |
+
if len(vertex_groups) > 0:
|
| 104 |
+
vertex_groups = np.concatenate(vertex_groups, axis=1)
|
| 105 |
+
else:
|
| 106 |
+
vertex_groups = None
|
| 107 |
+
final_sampled_vertices, final_sampled_normals, sampled_vertex_groups = sample_vertex_groups(
|
| 108 |
+
vertices=asset.vertices,
|
| 109 |
+
faces=asset.faces,
|
| 110 |
+
num_samples=self.num_samples,
|
| 111 |
+
vertex_normals=asset.vertex_normals,
|
| 112 |
+
face_normals=asset.face_normals,
|
| 113 |
+
vertex_groups=vertex_groups,
|
| 114 |
+
face_mask=None,
|
| 115 |
+
shuffle=True,
|
| 116 |
+
same=True,
|
| 117 |
+
)
|
| 118 |
+
if vertex_groups is not None:
|
| 119 |
+
final_sampled_vertices = final_sampled_vertices[:, 0]
|
| 120 |
+
if final_sampled_normals is not None:
|
| 121 |
+
final_sampled_normals = final_sampled_normals[:, 0]
|
| 122 |
+
final_sampled_vertex_groups = {}
|
| 123 |
+
if sampled_vertex_groups is not None:
|
| 124 |
+
for k, s in mapping.items():
|
| 125 |
+
final_sampled_vertex_groups[k] = sampled_vertex_groups[:, s] # (N, k)
|
| 126 |
+
if vertex_groups is not None and self.num_skin_samples is not None:
|
| 127 |
+
dense_vertices = []
|
| 128 |
+
dense_normals = []
|
| 129 |
+
dense_skin = []
|
| 130 |
+
if 'skin' not in mapping:
|
| 131 |
+
raise ValueError("do not have skin")
|
| 132 |
+
if self.all_skeleton:
|
| 133 |
+
dense_indices = [i for i in range(asset.J)]
|
| 134 |
+
else:
|
| 135 |
+
dense_indices = [random.randint(0, asset.J-1)]
|
| 136 |
+
for indice in dense_indices:
|
| 137 |
+
_s = asset.vertex_groups['skin'][:, indice]
|
| 138 |
+
face_has_skin = self.sample_on_skin(
|
| 139 |
+
skin=_s,
|
| 140 |
+
vertices=asset.vertices,
|
| 141 |
+
faces=asset.faces,
|
| 142 |
+
)
|
| 143 |
+
sampled_vertices, sampled_normals, sampled_skin = sample_vertex_groups(
|
| 144 |
+
vertices=asset.vertices,
|
| 145 |
+
faces=asset.faces,
|
| 146 |
+
vertex_normals=asset.vertex_normals,
|
| 147 |
+
face_normals=asset.face_normals,
|
| 148 |
+
vertex_groups=_s,
|
| 149 |
+
num_samples=self.num_skin_samples,
|
| 150 |
+
num_vertex_samples=self.num_vertex_samples,
|
| 151 |
+
face_mask=face_has_skin,
|
| 152 |
+
shuffle=True,
|
| 153 |
+
same=True,
|
| 154 |
+
)
|
| 155 |
+
assert sampled_skin is not None
|
| 156 |
+
assert sampled_skin.ndim == 2
|
| 157 |
+
dense_vertices.append(sampled_vertices[:, 0])
|
| 158 |
+
if sampled_normals is not None:
|
| 159 |
+
dense_normals.append(sampled_normals[:, 0])
|
| 160 |
+
dense_skin.append(sampled_skin[:, 0])
|
| 161 |
+
dense_vertices = np.stack(dense_vertices, axis=0) # (J, m, 3)
|
| 162 |
+
if len(dense_normals) > 0:
|
| 163 |
+
dense_normals = np.stack(dense_normals, axis=0) # (J, m, 3)
|
| 164 |
+
else:
|
| 165 |
+
dense_normals = None
|
| 166 |
+
dense_skin = np.stack(dense_skin, axis=0) # (J, m, 1)
|
| 167 |
+
final_sampled_vertex_groups['skin'] = final_sampled_vertex_groups['skin'][:, dense_indices]
|
| 168 |
+
if asset.meta is None:
|
| 169 |
+
asset.meta = {}
|
| 170 |
+
asset.meta['dense_vertices'] = dense_vertices
|
| 171 |
+
asset.meta['dense_normals'] = dense_normals
|
| 172 |
+
asset.meta['dense_skin'] = dense_skin
|
| 173 |
+
asset.meta['dense_indices'] = dense_indices
|
| 174 |
+
return SamplerResult(
|
| 175 |
+
sampled_vertices=final_sampled_vertices,
|
| 176 |
+
sampled_normals=final_sampled_normals if final_sampled_normals is not None else None,
|
| 177 |
+
sampled_vertex_groups=final_sampled_vertex_groups,
|
| 178 |
+
skin_samples=self.num_skin_samples,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def get_sampler(**kwargs) -> Sampler:
|
| 182 |
+
__target__ = kwargs.get('__target__')
|
| 183 |
+
assert __target__ is not None
|
| 184 |
+
del kwargs['__target__']
|
| 185 |
+
if __target__ == 'mix':
|
| 186 |
+
sampler = SamplerMix.parse(**kwargs)
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"sampler method {__target__} not supported")
|
| 189 |
+
return sampler
|
src/data/spec.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import fields
|
| 3 |
+
|
| 4 |
+
class ConfigSpec(ABC):
|
| 5 |
+
@classmethod
|
| 6 |
+
def check_keys(cls, config, expect=None):
|
| 7 |
+
if expect is None:
|
| 8 |
+
expect = [field.name for field in fields(cls)] # type: ignore
|
| 9 |
+
for key in config.keys():
|
| 10 |
+
if key not in expect:
|
| 11 |
+
raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}")
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def parse(cls, **kwargs) -> 'ConfigSpec':
|
| 16 |
+
raise NotImplementedError()
|
src/data/transform.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from ..rig_package.info.asset import Asset
|
| 5 |
+
from .augment import Augment, get_augments
|
| 6 |
+
from .order import Order
|
| 7 |
+
from .sampler import Sampler, get_sampler
|
| 8 |
+
from .spec import ConfigSpec
|
| 9 |
+
from .vertex_group import VertexGroup, get_vertex_groups
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Transform(ConfigSpec):
|
| 13 |
+
|
| 14 |
+
order: Optional[Order]=None
|
| 15 |
+
|
| 16 |
+
vertex_groups: Optional[List[VertexGroup]]=None
|
| 17 |
+
|
| 18 |
+
augments: Optional[List[Augment]]=None
|
| 19 |
+
|
| 20 |
+
sampler: Optional[Sampler]=None
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def parse(cls, **kwargs) -> 'Transform':
|
| 24 |
+
cls.check_keys(kwargs)
|
| 25 |
+
order_config = kwargs.get('order')
|
| 26 |
+
vertex_groups_config = kwargs.get('vertex_groups')
|
| 27 |
+
augments_config = kwargs.get('augments')
|
| 28 |
+
sampler_config = kwargs.get('sampler')
|
| 29 |
+
|
| 30 |
+
d = {}
|
| 31 |
+
if order_config is not None:
|
| 32 |
+
d['order'] = Order.parse(**order_config)
|
| 33 |
+
if vertex_groups_config is not None:
|
| 34 |
+
d['vertex_groups'] = get_vertex_groups(*vertex_groups_config)
|
| 35 |
+
if augments_config is not None:
|
| 36 |
+
d['augments'] = get_augments(*augments_config)
|
| 37 |
+
if sampler_config is not None:
|
| 38 |
+
d['sampler'] = get_sampler(**sampler_config)
|
| 39 |
+
return Transform(**d)
|
| 40 |
+
|
| 41 |
+
def apply(self, asset: Asset, **kwargs):
|
| 42 |
+
|
| 43 |
+
# 1. arrange bones
|
| 44 |
+
if self.order is not None:
|
| 45 |
+
if asset.joint_names is not None and asset.parents is not None:
|
| 46 |
+
new_names, _ = self.order.arrange_names(cls=asset.cls, names=asset.joint_names, parents=asset.parents.tolist())
|
| 47 |
+
asset.set_order(new_orders=new_names) # type: ignore
|
| 48 |
+
|
| 49 |
+
# 2. collapse must perform first
|
| 50 |
+
if self.augments is not None:
|
| 51 |
+
kwargs = {}
|
| 52 |
+
for augment in self.augments:
|
| 53 |
+
augment.transform(asset=asset, **kwargs)
|
| 54 |
+
|
| 55 |
+
# 3. get vertex groups
|
| 56 |
+
if self.vertex_groups is not None:
|
| 57 |
+
d = {}
|
| 58 |
+
for v in self.vertex_groups:
|
| 59 |
+
d.update(v.get_vertex_group(asset=asset))
|
| 60 |
+
asset.vertex_groups = d
|
| 61 |
+
else:
|
| 62 |
+
asset.vertex_groups = {}
|
| 63 |
+
|
| 64 |
+
# 4. sample
|
| 65 |
+
if self.sampler is not None:
|
| 66 |
+
res = self.sampler.sample(asset=asset)
|
| 67 |
+
asset.sampled_vertices = res.sampled_vertices
|
| 68 |
+
asset.sampled_normals = res.sampled_normals
|
| 69 |
+
asset.sampled_vertex_groups = res.sampled_vertex_groups
|
| 70 |
+
asset.skin_samples = res.skin_samples
|
src/data/vertex_group.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from numpy import ndarray
|
| 6 |
+
from scipy.spatial import cKDTree # type: ignore
|
| 7 |
+
from scipy.sparse import csr_matrix
|
| 8 |
+
from scipy.sparse.csgraph import shortest_path, connected_components
|
| 9 |
+
from typing import Dict, List, Optional, Literal
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from ..rig_package.info.asset import Asset
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class VertexGroup(ABC):
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def parse(cls, **kwargs) -> 'VertexGroup':
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class VertexGroupSkin(VertexGroup):
|
| 29 |
+
"""capture skin"""
|
| 30 |
+
|
| 31 |
+
normalize: bool=True
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def parse(cls, **kwargs) -> 'VertexGroupSkin':
|
| 35 |
+
return VertexGroupSkin(normalize=kwargs.get('normalize', True))
|
| 36 |
+
|
| 37 |
+
def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
|
| 38 |
+
if asset.skin is None:
|
| 39 |
+
raise ValueError("do not have skin")
|
| 40 |
+
if self.normalize:
|
| 41 |
+
asset.normalize_skin()
|
| 42 |
+
return {'skin': asset.skin.copy()}
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class VertexGroupVoxelSkin(VertexGroup):
|
| 46 |
+
"""capture voxel skin"""
|
| 47 |
+
|
| 48 |
+
grid: int
|
| 49 |
+
alpha: float
|
| 50 |
+
link_dis: float
|
| 51 |
+
grid_query: int
|
| 52 |
+
vertex_query: int
|
| 53 |
+
grid_weight: float
|
| 54 |
+
mode: Literal['square', 'exp']
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def parse(cls, **kwargs) -> 'VertexGroupVoxelSkin':
|
| 58 |
+
return VertexGroupVoxelSkin(
|
| 59 |
+
grid=kwargs.get('grid', 64),
|
| 60 |
+
alpha=kwargs.get('alpha', 0.5),
|
| 61 |
+
link_dis=kwargs.get('link_dis', 0.00001),
|
| 62 |
+
grid_query=kwargs.get('grid_query', 27),
|
| 63 |
+
vertex_query=kwargs.get('vertex_query', 27),
|
| 64 |
+
grid_weight=kwargs.get('grid_weight', 3.0),
|
| 65 |
+
mode=kwargs.get('mode', 'square'),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
|
| 69 |
+
if asset.vertices is None:
|
| 70 |
+
raise ValueError("do not have vertices")
|
| 71 |
+
if asset.faces is None:
|
| 72 |
+
raise ValueError("do not have faces")
|
| 73 |
+
if asset.joints is None:
|
| 74 |
+
raise ValueError("do not have joints")
|
| 75 |
+
# normalize into [-1, 1] first
|
| 76 |
+
min_vals = np.min(asset.vertices, axis=0)
|
| 77 |
+
max_vals = np.max(asset.vertices, axis=0)
|
| 78 |
+
|
| 79 |
+
center = (min_vals + max_vals) / 2
|
| 80 |
+
|
| 81 |
+
scale = np.max(max_vals - min_vals) / 2
|
| 82 |
+
|
| 83 |
+
normalized_vertices = (asset.vertices - center) / scale
|
| 84 |
+
normalized_joints = (asset.joints - center) / scale
|
| 85 |
+
|
| 86 |
+
grid_coords = asset.voxel().coords
|
| 87 |
+
skin = voxel_skin(
|
| 88 |
+
grid=self.grid,
|
| 89 |
+
grid_coords=grid_coords,
|
| 90 |
+
joints=normalized_joints,
|
| 91 |
+
vertices=normalized_vertices,
|
| 92 |
+
faces=asset.faces,
|
| 93 |
+
alpha=self.alpha,
|
| 94 |
+
link_dis=self.link_dis,
|
| 95 |
+
grid_query=self.grid_query,
|
| 96 |
+
vertex_query=self.vertex_query,
|
| 97 |
+
grid_weight=self.grid_weight,
|
| 98 |
+
mode=self.mode,
|
| 99 |
+
)
|
| 100 |
+
skin = np.nan_to_num(skin, nan=0., posinf=0., neginf=0.)
|
| 101 |
+
return {'voxel_skin': skin,}
|
| 102 |
+
|
| 103 |
+
def voxel_skin(
|
| 104 |
+
grid: int,
|
| 105 |
+
grid_coords: ndarray, # (M, 3)
|
| 106 |
+
joints: ndarray, # (J, 3)
|
| 107 |
+
vertices: ndarray, # (N, 3)
|
| 108 |
+
faces: ndarray, # (F, 3)
|
| 109 |
+
alpha: float=0.5,
|
| 110 |
+
link_dis: float=0.00001,
|
| 111 |
+
grid_query: int=27,
|
| 112 |
+
vertex_query: int=27,
|
| 113 |
+
grid_weight: float=3.0,
|
| 114 |
+
voxel_size: Optional[float]=None,
|
| 115 |
+
mode: str='square',
|
| 116 |
+
parents: Optional[ndarray]=None,
|
| 117 |
+
):
|
| 118 |
+
# modified from https://dl.acm.org/doi/pdf/10.1145/2485895.2485919
|
| 119 |
+
assert mode in ['square', 'exp']
|
| 120 |
+
J = joints.shape[0]
|
| 121 |
+
M = grid_coords.shape[0]
|
| 122 |
+
N = vertices.shape[0]
|
| 123 |
+
|
| 124 |
+
if voxel_size is None:
|
| 125 |
+
_range = 2/grid*1.74
|
| 126 |
+
else:
|
| 127 |
+
_range = voxel_size*1.74
|
| 128 |
+
|
| 129 |
+
grid_tree = cKDTree(grid_coords)
|
| 130 |
+
vertex_tree = cKDTree(vertices)
|
| 131 |
+
if parents is not None:
|
| 132 |
+
son = defaultdict(list)
|
| 133 |
+
for i, p in enumerate(parents):
|
| 134 |
+
if i == -1:
|
| 135 |
+
continue
|
| 136 |
+
son[p].append(i)
|
| 137 |
+
divide_joints = []
|
| 138 |
+
joints_map = []
|
| 139 |
+
for u in range(len(parents)):
|
| 140 |
+
if len(son[u]) != 1:
|
| 141 |
+
divide_joints.append(joints[u])
|
| 142 |
+
joints_map.append(u)
|
| 143 |
+
else:
|
| 144 |
+
pu = joints[u]
|
| 145 |
+
pv = joints[son[u][0]]
|
| 146 |
+
seg = 10
|
| 147 |
+
for i in range(seg+1):
|
| 148 |
+
p = (pu*i + pv*(seg-i)) / seg
|
| 149 |
+
divide_joints.append(p)
|
| 150 |
+
joints_map.append(u)
|
| 151 |
+
divide_joints = np.stack(divide_joints)
|
| 152 |
+
joints_map = np.array(joints_map)
|
| 153 |
+
else:
|
| 154 |
+
divide_joints = joints
|
| 155 |
+
joints_map = np.arange(joints.shape[0])
|
| 156 |
+
joint_tree = cKDTree(divide_joints)
|
| 157 |
+
|
| 158 |
+
# make combined vertices
|
| 159 |
+
# 0 ~ N-1: mesh vertices
|
| 160 |
+
# N ~ N+M-1: grid vertices
|
| 161 |
+
combined_vertices = np.concatenate([vertices, grid_coords], axis=0)
|
| 162 |
+
|
| 163 |
+
# link adjacent grids
|
| 164 |
+
dist, idx = grid_tree.query(grid_coords, grid_query) # 3*3*3
|
| 165 |
+
dist = dist[:, 1:]
|
| 166 |
+
idx = idx[:, 1:]
|
| 167 |
+
mask = (0 < dist) & (dist < _range)
|
| 168 |
+
source_grid2grid = np.repeat(np.arange(M), grid_query-1)[mask.ravel()] + N
|
| 169 |
+
to_grid2grid = idx[mask] + N
|
| 170 |
+
weight_grid2grid = dist[mask] * grid_weight
|
| 171 |
+
|
| 172 |
+
# link very close vertices
|
| 173 |
+
dist, idx = vertex_tree.query(vertices, 4)
|
| 174 |
+
dist = dist[:, 1:]
|
| 175 |
+
idx = idx[:, 1:]
|
| 176 |
+
mask = (0 < dist) & (dist < link_dis)
|
| 177 |
+
source_close = np.repeat(np.arange(N), 3)[mask.ravel()]
|
| 178 |
+
to_close = idx[mask]
|
| 179 |
+
weight_close = dist[mask]
|
| 180 |
+
|
| 181 |
+
# link grids to mesh vertices
|
| 182 |
+
dist, idx = vertex_tree.query(grid_coords, vertex_query)
|
| 183 |
+
mask = (0 < dist) & (dist < _range) # sqrt(3)
|
| 184 |
+
source_grid2vertex = np.repeat(np.arange(M), vertex_query)[mask.ravel()] + N
|
| 185 |
+
to_grid2vertex = idx[mask]
|
| 186 |
+
weight_grid2vertex = dist[mask]
|
| 187 |
+
|
| 188 |
+
# build combined vertices tree
|
| 189 |
+
combined_tree = cKDTree(combined_vertices)
|
| 190 |
+
# link bones to the neartest vertices
|
| 191 |
+
_, joint_indices = combined_tree.query(divide_joints)
|
| 192 |
+
|
| 193 |
+
# build graph
|
| 194 |
+
source_vertex2vertex = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0)
|
| 195 |
+
to_vertex2vertex = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0)
|
| 196 |
+
weight_vertex2vertex = np.sqrt(((vertices[source_vertex2vertex] - vertices[to_vertex2vertex])**2).sum(axis=-1))
|
| 197 |
+
graph = csr_matrix(
|
| 198 |
+
(np.concatenate([weight_close, weight_vertex2vertex, weight_grid2grid, weight_grid2vertex]),
|
| 199 |
+
(
|
| 200 |
+
np.concatenate([source_close, source_vertex2vertex, source_grid2grid, source_grid2vertex], axis=0),
|
| 201 |
+
np.concatenate([to_close, to_vertex2vertex, to_grid2grid, to_grid2vertex], axis=0)),
|
| 202 |
+
),
|
| 203 |
+
shape=(N+M, N+M),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# get shortest path (J, N+M)
|
| 207 |
+
dist_matrix = shortest_path(graph, method='D', directed=False, indices=joint_indices)
|
| 208 |
+
|
| 209 |
+
# (sum_J, N)
|
| 210 |
+
dis_vertex2bone = dist_matrix[:, :N]
|
| 211 |
+
unreachable = np.isinf(dis_vertex2bone).all(axis=0)
|
| 212 |
+
k = min(J, 3)
|
| 213 |
+
dist, idx = joint_tree.query(vertices[unreachable], k)
|
| 214 |
+
|
| 215 |
+
# make sure at least one value in dis is not inf
|
| 216 |
+
unreachable_indices = np.where(unreachable)[0]
|
| 217 |
+
row_indices = idx
|
| 218 |
+
col_indices = np.repeat(unreachable_indices, k).reshape(-1, k)
|
| 219 |
+
dis_vertex2bone[row_indices, col_indices] = dist
|
| 220 |
+
|
| 221 |
+
finite_vals = dis_vertex2bone[np.isfinite(dis_vertex2bone)]
|
| 222 |
+
max_dis = np.max(finite_vals)
|
| 223 |
+
dis_vertex2bone = np.nan_to_num(dis_vertex2bone, nan=max_dis, posinf=max_dis, neginf=max_dis)
|
| 224 |
+
dis_vertex2bone = np.maximum(dis_vertex2bone, 1e-6)
|
| 225 |
+
|
| 226 |
+
# turn dis2bone to dis2vertex
|
| 227 |
+
dis_vertex2joint = np.full((joints.shape[0], vertices.shape[0]), max_dis)
|
| 228 |
+
for i in range(len(dis_vertex2bone)):
|
| 229 |
+
dis_vertex2joint[joints_map[i]] = np.minimum(dis_vertex2bone[i], dis_vertex2joint[joints_map[i]])
|
| 230 |
+
|
| 231 |
+
# (J, N)
|
| 232 |
+
if mode == 'exp':
|
| 233 |
+
skin = np.exp(-dis_vertex2joint / max_dis * 20.0)
|
| 234 |
+
elif mode == 'square':
|
| 235 |
+
skin = (1./((1-alpha)*dis_vertex2joint + alpha*dis_vertex2joint**2))**2
|
| 236 |
+
else:
|
| 237 |
+
assert False, f'invalid mode: {mode}'
|
| 238 |
+
skin = skin / skin.sum(axis=0)
|
| 239 |
+
# (N, J)
|
| 240 |
+
skin = skin.transpose()
|
| 241 |
+
return skin
|
| 242 |
+
|
| 243 |
+
def get_vertex_groups(*args) -> List[VertexGroup]:
|
| 244 |
+
vertex_groups = []
|
| 245 |
+
MAP = {
|
| 246 |
+
'skin': VertexGroupSkin,
|
| 247 |
+
'voxel_skin': VertexGroupVoxelSkin,
|
| 248 |
+
}
|
| 249 |
+
MAP: Dict[str, type[VertexGroup]]
|
| 250 |
+
for (i, c) in enumerate(args):
|
| 251 |
+
__target__ = c.get('__target__')
|
| 252 |
+
assert __target__ is not None, f"do not find `__target__` in config of vertex_groups of position {i}"
|
| 253 |
+
assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
|
| 254 |
+
c = deepcopy(c)
|
| 255 |
+
del c['__target__']
|
| 256 |
+
vertex_groups.append(MAP[__target__].parse(**c))
|
| 257 |
+
return vertex_groups
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/michelangelo/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
src/model/michelangelo/get_model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder
|
| 4 |
+
|
| 5 |
+
def get_encoder(
|
| 6 |
+
pretrained_path: str=None,
|
| 7 |
+
freeze_decoder: bool=False,
|
| 8 |
+
**kwargs
|
| 9 |
+
) -> AlignedShapeLatentPerceiver:
|
| 10 |
+
model = AlignedShapeLatentPerceiver(**kwargs)
|
| 11 |
+
if pretrained_path is not None:
|
| 12 |
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
| 13 |
+
model.load_state_dict(state_dict)
|
| 14 |
+
if freeze_decoder:
|
| 15 |
+
model.geo_decoder.requires_grad_(False)
|
| 16 |
+
model.encoder.query.requires_grad_(False)
|
| 17 |
+
model.pre_kl.requires_grad_(False)
|
| 18 |
+
model.post_kl.requires_grad_(False)
|
| 19 |
+
model.transformer.requires_grad_(False)
|
| 20 |
+
return model
|
| 21 |
+
|
| 22 |
+
def get_encoder_simplified(
|
| 23 |
+
pretrained_path: str=None,
|
| 24 |
+
**kwargs
|
| 25 |
+
) -> ShapeAsLatentPerceiverEncoder:
|
| 26 |
+
model = ShapeAsLatentPerceiverEncoder(**kwargs)
|
| 27 |
+
if pretrained_path is not None:
|
| 28 |
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
| 29 |
+
model.load_state_dict(state_dict)
|
| 30 |
+
return model
|
src/model/michelangelo/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
src/model/michelangelo/models/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from .checkpoint import checkpoint
|
src/model/michelangelo/models/modules/checkpoint.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Callable, Iterable, Sequence, Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def checkpoint(
|
| 11 |
+
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
|
| 12 |
+
inputs: Sequence[torch.Tensor],
|
| 13 |
+
params: Iterable[torch.Tensor],
|
| 14 |
+
flag: bool,
|
| 15 |
+
use_deepspeed: bool = False
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 19 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 20 |
+
:param func: the function to evaluate.
|
| 21 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 22 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 23 |
+
explicitly take as arguments.
|
| 24 |
+
:param flag: if False, disable gradient checkpointing.
|
| 25 |
+
:param use_deepspeed: if True, use deepspeed
|
| 26 |
+
"""
|
| 27 |
+
if flag:
|
| 28 |
+
if use_deepspeed:
|
| 29 |
+
import deepspeed
|
| 30 |
+
return deepspeed.checkpointing.checkpoint(func, *inputs)
|
| 31 |
+
|
| 32 |
+
args = tuple(inputs) + tuple(params)
|
| 33 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 34 |
+
else:
|
| 35 |
+
return func(*inputs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 39 |
+
@staticmethod
|
| 40 |
+
@torch.amp.custom_fwd(device_type='cuda')
|
| 41 |
+
def forward(ctx, run_function, length, *args):
|
| 42 |
+
ctx.run_function = run_function
|
| 43 |
+
ctx.input_tensors = list(args[:length])
|
| 44 |
+
ctx.input_params = list(args[length:])
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 48 |
+
return output_tensors
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
@torch.amp.custom_bwd(device_type='cuda')
|
| 52 |
+
def backward(ctx, *output_grads):
|
| 53 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 54 |
+
with torch.enable_grad():
|
| 55 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 56 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 57 |
+
# Tensors.
|
| 58 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 59 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 60 |
+
input_grads = torch.autograd.grad(
|
| 61 |
+
output_tensors,
|
| 62 |
+
ctx.input_tensors + ctx.input_params,
|
| 63 |
+
output_grads,
|
| 64 |
+
allow_unused=True,
|
| 65 |
+
)
|
| 66 |
+
del ctx.input_tensors
|
| 67 |
+
del ctx.input_params
|
| 68 |
+
del output_tensors
|
| 69 |
+
return (None, None) + input_grads
|
src/model/michelangelo/models/modules/distributions.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AbstractDistribution(object):
|
| 7 |
+
def sample(self):
|
| 8 |
+
raise NotImplementedError()
|
| 9 |
+
|
| 10 |
+
def mode(self):
|
| 11 |
+
raise NotImplementedError()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DiracDistribution(AbstractDistribution):
|
| 15 |
+
def __init__(self, value):
|
| 16 |
+
self.value = value
|
| 17 |
+
|
| 18 |
+
def sample(self):
|
| 19 |
+
return self.value
|
| 20 |
+
|
| 21 |
+
def mode(self):
|
| 22 |
+
return self.value
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DiagonalGaussianDistribution(object):
|
| 26 |
+
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
|
| 27 |
+
self.feat_dim = feat_dim
|
| 28 |
+
self.parameters = parameters
|
| 29 |
+
|
| 30 |
+
if isinstance(parameters, list):
|
| 31 |
+
self.mean = parameters[0]
|
| 32 |
+
self.logvar = parameters[1]
|
| 33 |
+
else:
|
| 34 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
| 35 |
+
|
| 36 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 37 |
+
self.deterministic = deterministic
|
| 38 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 39 |
+
self.var = torch.exp(self.logvar)
|
| 40 |
+
if self.deterministic:
|
| 41 |
+
self.var = self.std = torch.zeros_like(self.mean)
|
| 42 |
+
|
| 43 |
+
def sample(self):
|
| 44 |
+
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
def kl(self, other=None, dims=(1, 2, 3)):
|
| 48 |
+
if self.deterministic:
|
| 49 |
+
return torch.Tensor([0.])
|
| 50 |
+
else:
|
| 51 |
+
if other is None:
|
| 52 |
+
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
| 53 |
+
+ self.var - 1.0 - self.logvar,
|
| 54 |
+
dim=dims)
|
| 55 |
+
else:
|
| 56 |
+
return 0.5 * torch.mean(
|
| 57 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 58 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 59 |
+
dim=dims)
|
| 60 |
+
|
| 61 |
+
def nll(self, sample, dims=(1, 2, 3)):
|
| 62 |
+
if self.deterministic:
|
| 63 |
+
return torch.Tensor([0.])
|
| 64 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 65 |
+
return 0.5 * torch.sum(
|
| 66 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 67 |
+
dim=dims)
|
| 68 |
+
|
| 69 |
+
def mode(self):
|
| 70 |
+
return self.mean
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
| 74 |
+
"""
|
| 75 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
| 76 |
+
Compute the KL divergence between two gaussians.
|
| 77 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
| 78 |
+
scalars, among other use cases.
|
| 79 |
+
"""
|
| 80 |
+
tensor = None
|
| 81 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
| 82 |
+
if isinstance(obj, torch.Tensor):
|
| 83 |
+
tensor = obj
|
| 84 |
+
break
|
| 85 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
| 86 |
+
|
| 87 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
| 88 |
+
# Tensors, but it does not work for torch.exp().
|
| 89 |
+
logvar1, logvar2 = [
|
| 90 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
| 91 |
+
for x in (logvar1, logvar2)
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
return 0.5 * (
|
| 95 |
+
-1.0
|
| 96 |
+
+ logvar2
|
| 97 |
+
- logvar1
|
| 98 |
+
+ torch.exp(logvar1 - logvar2)
|
| 99 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
| 100 |
+
)
|
src/model/michelangelo/models/modules/embedder.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FourierEmbedder(nn.Module):
|
| 12 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
| 13 |
+
each feature dimension of `x[..., i]` into:
|
| 14 |
+
[
|
| 15 |
+
sin(x[..., i]),
|
| 16 |
+
sin(f_1*x[..., i]),
|
| 17 |
+
sin(f_2*x[..., i]),
|
| 18 |
+
...
|
| 19 |
+
sin(f_N * x[..., i]),
|
| 20 |
+
cos(x[..., i]),
|
| 21 |
+
cos(f_1*x[..., i]),
|
| 22 |
+
cos(f_2*x[..., i]),
|
| 23 |
+
...
|
| 24 |
+
cos(f_N * x[..., i]),
|
| 25 |
+
x[..., i] # only present if include_input is True.
|
| 26 |
+
], here f_i is the frequency.
|
| 27 |
+
|
| 28 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
| 29 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
| 30 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
num_freqs (int): the number of frequencies, default is 6;
|
| 34 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 35 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
| 36 |
+
input_dim (int): the input dimension, default is 3;
|
| 37 |
+
include_input (bool): include the input tensor or not, default is True.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 41 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
| 42 |
+
|
| 43 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
| 44 |
+
otherwise, it is input_dim * num_freqs * 2.
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self,
|
| 49 |
+
num_freqs: int = 6,
|
| 50 |
+
logspace: bool = True,
|
| 51 |
+
input_dim: int = 3,
|
| 52 |
+
include_input: bool = True,
|
| 53 |
+
include_pi: bool = True) -> None:
|
| 54 |
+
|
| 55 |
+
"""The initialization"""
|
| 56 |
+
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
if logspace:
|
| 60 |
+
frequencies = 2.0 ** torch.arange(
|
| 61 |
+
num_freqs,
|
| 62 |
+
dtype=torch.float32
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
frequencies = torch.linspace(
|
| 66 |
+
1.0,
|
| 67 |
+
2.0 ** (num_freqs - 1),
|
| 68 |
+
num_freqs,
|
| 69 |
+
dtype=torch.float32
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if include_pi:
|
| 73 |
+
frequencies *= torch.pi
|
| 74 |
+
|
| 75 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
| 76 |
+
self.include_input = include_input
|
| 77 |
+
self.num_freqs = num_freqs
|
| 78 |
+
|
| 79 |
+
self.out_dim = self.get_dims(input_dim)
|
| 80 |
+
|
| 81 |
+
def get_dims(self, input_dim):
|
| 82 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
| 83 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
| 84 |
+
|
| 85 |
+
return out_dim
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
""" Forward process.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
x: tensor of shape [..., dim]
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
| 95 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
if self.num_freqs > 0:
|
| 99 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
| 100 |
+
if self.include_input:
|
| 101 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
| 102 |
+
else:
|
| 103 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
| 104 |
+
else:
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class LearnedFourierEmbedder(nn.Module):
|
| 109 |
+
""" following @crowsonkb "s lead with learned sinusoidal pos emb """
|
| 110 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
| 111 |
+
|
| 112 |
+
def __init__(self, in_channels, dim):
|
| 113 |
+
super().__init__()
|
| 114 |
+
assert (dim % 2) == 0
|
| 115 |
+
half_dim = dim // 2
|
| 116 |
+
per_channel_dim = half_dim // in_channels
|
| 117 |
+
self.weights = nn.Parameter(torch.randn(per_channel_dim))
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x (torch.FloatTensor): [..., c]
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
x (torch.FloatTensor): [..., d]
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
|
| 130 |
+
freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
|
| 131 |
+
fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
|
| 132 |
+
return fouriered
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TriplaneLearnedFourierEmbedder(nn.Module):
|
| 136 |
+
def __init__(self, in_channels, dim):
|
| 137 |
+
super().__init__()
|
| 138 |
+
|
| 139 |
+
self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
| 140 |
+
self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
| 141 |
+
self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
| 142 |
+
|
| 143 |
+
self.out_dim = in_channels + dim
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
|
| 147 |
+
yz_embed = self.yz_plane_embedder(x)
|
| 148 |
+
xz_embed = self.xz_plane_embedder(x)
|
| 149 |
+
xy_embed = self.xy_plane_embedder(x)
|
| 150 |
+
|
| 151 |
+
embed = yz_embed + xz_embed + xy_embed
|
| 152 |
+
|
| 153 |
+
return embed
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def sequential_pos_embed(num_len, embed_dim):
|
| 157 |
+
assert embed_dim % 2 == 0
|
| 158 |
+
|
| 159 |
+
pos = torch.arange(num_len, dtype=torch.float32)
|
| 160 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32)
|
| 161 |
+
omega /= embed_dim / 2.
|
| 162 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
| 163 |
+
|
| 164 |
+
pos = pos.reshape(-1) # (M,)
|
| 165 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 166 |
+
|
| 167 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 168 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 169 |
+
|
| 170 |
+
embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 171 |
+
|
| 172 |
+
return embeddings
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 176 |
+
"""
|
| 177 |
+
Create sinusoidal timestep embeddings.
|
| 178 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 179 |
+
These may be fractional.
|
| 180 |
+
:param dim: the dimension of the output.
|
| 181 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 182 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 183 |
+
"""
|
| 184 |
+
half = dim // 2
|
| 185 |
+
freqs = torch.exp(
|
| 186 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 187 |
+
).to(device=timesteps.device)
|
| 188 |
+
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
|
| 189 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 190 |
+
if dim % 2:
|
| 191 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 192 |
+
return embedding
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
|
| 196 |
+
num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
|
| 197 |
+
log2_hashmap_size=19, desired_resolution=None):
|
| 198 |
+
if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
|
| 199 |
+
return nn.Identity(), input_dim
|
| 200 |
+
|
| 201 |
+
elif embed_type == "fourier":
|
| 202 |
+
embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
|
| 203 |
+
logspace=True, include_input=True)
|
| 204 |
+
return embedder_obj, embedder_obj.out_dim
|
| 205 |
+
|
| 206 |
+
elif embed_type == "hashgrid":
|
| 207 |
+
raise NotImplementedError
|
| 208 |
+
|
| 209 |
+
elif embed_type == "sphere_harmonic":
|
| 210 |
+
raise NotImplementedError
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
|
src/model/michelangelo/models/modules/transformer_blocks.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from .checkpoint import checkpoint
|
| 11 |
+
from ...utils.misc import use_flash3
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if use_flash3.is_use:
|
| 15 |
+
from flash_attn_interface import flash_attn_func
|
| 16 |
+
print("use flash attention 3.")
|
| 17 |
+
else:
|
| 18 |
+
print("use flash attention 2.")
|
| 19 |
+
|
| 20 |
+
def init_linear(l, stddev):
|
| 21 |
+
nn.init.normal_(l.weight, std=stddev)
|
| 22 |
+
if l.bias is not None:
|
| 23 |
+
nn.init.constant_(l.bias, 0.0)
|
| 24 |
+
|
| 25 |
+
def flash_attention(q, k, v):
|
| 26 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
| 27 |
+
if use_flash3.is_use:
|
| 28 |
+
out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous())
|
| 29 |
+
# out = flash_attn_func(q, k, v)
|
| 30 |
+
|
| 31 |
+
# q_ = q.transpose(1, 2)
|
| 32 |
+
# k_ = k.transpose(1, 2)
|
| 33 |
+
# v_ = v.transpose(1, 2)
|
| 34 |
+
|
| 35 |
+
# # print(q.shape, k.shape, v.shape)
|
| 36 |
+
# out_ = F.scaled_dot_product_attention(q_, k_, v_)
|
| 37 |
+
# out_ = out_.transpose(1, 2)
|
| 38 |
+
|
| 39 |
+
# # print(torch.abs(out - out_).mean())
|
| 40 |
+
# assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large"
|
| 41 |
+
|
| 42 |
+
# out = out_
|
| 43 |
+
|
| 44 |
+
# print("use flash_atten 3")
|
| 45 |
+
else:
|
| 46 |
+
q = q.transpose(1, 2)
|
| 47 |
+
k = k.transpose(1, 2)
|
| 48 |
+
v = v.transpose(1, 2)
|
| 49 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 50 |
+
out = out.transpose(1, 2)
|
| 51 |
+
# print("use flash atten 2")
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
class MultiheadAttention(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
*,
|
| 59 |
+
device: torch.device,
|
| 60 |
+
dtype: torch.dtype,
|
| 61 |
+
n_ctx: int,
|
| 62 |
+
width: int,
|
| 63 |
+
heads: int,
|
| 64 |
+
init_scale: float,
|
| 65 |
+
qkv_bias: bool,
|
| 66 |
+
flash: bool = False
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.n_ctx = n_ctx
|
| 70 |
+
self.width = width
|
| 71 |
+
self.heads = heads
|
| 72 |
+
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
|
| 73 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
| 74 |
+
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
|
| 75 |
+
init_linear(self.c_qkv, init_scale)
|
| 76 |
+
init_linear(self.c_proj, init_scale)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = self.c_qkv(x)
|
| 80 |
+
x = checkpoint(self.attention, (x,), (), False)
|
| 81 |
+
x = self.c_proj(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class QKVMultiheadAttention(nn.Module):
|
| 86 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.device = device
|
| 89 |
+
self.dtype = dtype
|
| 90 |
+
self.heads = heads
|
| 91 |
+
self.n_ctx = n_ctx
|
| 92 |
+
self.flash = flash
|
| 93 |
+
|
| 94 |
+
def forward(self, qkv):
|
| 95 |
+
bs, n_ctx, width = qkv.shape
|
| 96 |
+
attn_ch = width // self.heads // 3
|
| 97 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
| 98 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
| 99 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
| 100 |
+
|
| 101 |
+
if self.flash:
|
| 102 |
+
out = flash_attention(q, k, v)
|
| 103 |
+
out = out.reshape(out.shape[0], out.shape[1], -1)
|
| 104 |
+
else:
|
| 105 |
+
weight = torch.einsum(
|
| 106 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
| 107 |
+
) # More stable with f16 than dividing afterwards
|
| 108 |
+
wdtype = weight.dtype
|
| 109 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
| 110 |
+
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
| 111 |
+
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResidualAttentionBlock(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
*,
|
| 119 |
+
device: torch.device,
|
| 120 |
+
dtype: torch.dtype,
|
| 121 |
+
n_ctx: int,
|
| 122 |
+
width: int,
|
| 123 |
+
heads: int,
|
| 124 |
+
init_scale: float = 1.0,
|
| 125 |
+
qkv_bias: bool = True,
|
| 126 |
+
flash: bool = False,
|
| 127 |
+
use_checkpoint: bool = False
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.use_checkpoint = use_checkpoint
|
| 132 |
+
|
| 133 |
+
self.attn = MultiheadAttention(
|
| 134 |
+
device=device,
|
| 135 |
+
dtype=dtype,
|
| 136 |
+
n_ctx=n_ctx,
|
| 137 |
+
width=width,
|
| 138 |
+
heads=heads,
|
| 139 |
+
init_scale=init_scale,
|
| 140 |
+
qkv_bias=qkv_bias,
|
| 141 |
+
flash=flash
|
| 142 |
+
)
|
| 143 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 144 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
| 145 |
+
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 146 |
+
|
| 147 |
+
def _forward(self, x: torch.Tensor):
|
| 148 |
+
x = x + self.attn(self.ln_1(x))
|
| 149 |
+
x = x + self.mlp(self.ln_2(x))
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
def forward(self, x: torch.Tensor):
|
| 153 |
+
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MultiheadCrossAttention(nn.Module):
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
*,
|
| 160 |
+
device: torch.device,
|
| 161 |
+
dtype: torch.dtype,
|
| 162 |
+
width: int,
|
| 163 |
+
heads: int,
|
| 164 |
+
init_scale: float,
|
| 165 |
+
qkv_bias: bool = True,
|
| 166 |
+
flash: bool = False,
|
| 167 |
+
n_data: Optional[int] = None,
|
| 168 |
+
data_width: Optional[int] = None,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.n_data = n_data
|
| 172 |
+
self.width = width
|
| 173 |
+
self.heads = heads
|
| 174 |
+
self.data_width = width if data_width is None else data_width
|
| 175 |
+
self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
|
| 176 |
+
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
|
| 177 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
| 178 |
+
self.attention = QKVMultiheadCrossAttention(
|
| 179 |
+
device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
|
| 180 |
+
)
|
| 181 |
+
init_linear(self.c_q, init_scale)
|
| 182 |
+
init_linear(self.c_kv, init_scale)
|
| 183 |
+
init_linear(self.c_proj, init_scale)
|
| 184 |
+
|
| 185 |
+
def forward(self, x, data):
|
| 186 |
+
x = self.c_q(x)
|
| 187 |
+
data = self.c_kv(data)
|
| 188 |
+
x = checkpoint(self.attention, (x, data), (), False)
|
| 189 |
+
x = self.c_proj(x)
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
| 194 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
|
| 195 |
+
flash: bool = False, n_data: Optional[int] = None):
|
| 196 |
+
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.device = device
|
| 199 |
+
self.dtype = dtype
|
| 200 |
+
self.heads = heads
|
| 201 |
+
self.n_data = n_data
|
| 202 |
+
self.flash = flash
|
| 203 |
+
|
| 204 |
+
def forward(self, q, kv):
|
| 205 |
+
_, n_ctx, _ = q.shape
|
| 206 |
+
bs, n_data, width = kv.shape
|
| 207 |
+
attn_ch = width // self.heads // 2
|
| 208 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
| 209 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
| 210 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
| 211 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
| 212 |
+
|
| 213 |
+
if self.flash:
|
| 214 |
+
out = flash_attention(q, k, v)
|
| 215 |
+
out = out.reshape(out.shape[0], out.shape[1], -1)
|
| 216 |
+
else:
|
| 217 |
+
weight = torch.einsum(
|
| 218 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
| 219 |
+
) # More stable with f16 than dividing afterwards
|
| 220 |
+
wdtype = weight.dtype
|
| 221 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
| 222 |
+
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
| 223 |
+
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
*,
|
| 231 |
+
device: Optional[torch.device],
|
| 232 |
+
dtype: Optional[torch.dtype],
|
| 233 |
+
n_data: Optional[int] = None,
|
| 234 |
+
width: int,
|
| 235 |
+
heads: int,
|
| 236 |
+
data_width: Optional[int] = None,
|
| 237 |
+
mlp_width_scale: int = 4,
|
| 238 |
+
init_scale: float = 0.25,
|
| 239 |
+
qkv_bias: bool = True,
|
| 240 |
+
flash: bool = False
|
| 241 |
+
):
|
| 242 |
+
super().__init__()
|
| 243 |
+
|
| 244 |
+
if data_width is None:
|
| 245 |
+
data_width = width
|
| 246 |
+
|
| 247 |
+
self.attn = MultiheadCrossAttention(
|
| 248 |
+
device=device,
|
| 249 |
+
dtype=dtype,
|
| 250 |
+
n_data=n_data,
|
| 251 |
+
width=width,
|
| 252 |
+
heads=heads,
|
| 253 |
+
data_width=data_width,
|
| 254 |
+
init_scale=init_scale,
|
| 255 |
+
qkv_bias=qkv_bias,
|
| 256 |
+
flash=flash,
|
| 257 |
+
)
|
| 258 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 259 |
+
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
|
| 260 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale)
|
| 261 |
+
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 262 |
+
|
| 263 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
| 264 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
| 265 |
+
x = x + self.mlp(self.ln_3(x))
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class MLP(nn.Module):
|
| 270 |
+
def __init__(self, *,
|
| 271 |
+
device: Optional[torch.device],
|
| 272 |
+
dtype: Optional[torch.dtype],
|
| 273 |
+
width: int,
|
| 274 |
+
hidden_width_scale: int = 4,
|
| 275 |
+
init_scale: float):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.width = width
|
| 278 |
+
self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype)
|
| 279 |
+
self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype)
|
| 280 |
+
self.gelu = nn.GELU()
|
| 281 |
+
init_linear(self.c_fc, init_scale)
|
| 282 |
+
init_linear(self.c_proj, init_scale)
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class Transformer(nn.Module):
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
*,
|
| 292 |
+
device: Optional[torch.device],
|
| 293 |
+
dtype: Optional[torch.dtype],
|
| 294 |
+
n_ctx: int,
|
| 295 |
+
width: int,
|
| 296 |
+
layers: int,
|
| 297 |
+
heads: int,
|
| 298 |
+
init_scale: float = 0.25,
|
| 299 |
+
qkv_bias: bool = True,
|
| 300 |
+
flash: bool = False,
|
| 301 |
+
use_checkpoint: bool = False
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.n_ctx = n_ctx
|
| 305 |
+
self.width = width
|
| 306 |
+
self.layers = layers
|
| 307 |
+
self.resblocks = nn.ModuleList(
|
| 308 |
+
[
|
| 309 |
+
ResidualAttentionBlock(
|
| 310 |
+
device=device,
|
| 311 |
+
dtype=dtype,
|
| 312 |
+
n_ctx=n_ctx,
|
| 313 |
+
width=width,
|
| 314 |
+
heads=heads,
|
| 315 |
+
init_scale=init_scale,
|
| 316 |
+
qkv_bias=qkv_bias,
|
| 317 |
+
flash=flash,
|
| 318 |
+
use_checkpoint=use_checkpoint
|
| 319 |
+
)
|
| 320 |
+
for _ in range(layers)
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor):
|
| 325 |
+
for block in self.resblocks:
|
| 326 |
+
x = block(x)
|
| 327 |
+
return x
|
src/model/michelangelo/models/tsal/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
src/model/michelangelo/models/tsal/loss.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple, Dict
|
| 7 |
+
|
| 8 |
+
from ..modules.distributions import DiagonalGaussianDistribution
|
| 9 |
+
from ...utils.eval import compute_psnr
|
| 10 |
+
from ...utils import misc
|
| 11 |
+
import numpy as np
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def logits_to_sdf(logits):
|
| 16 |
+
return torch.sigmoid(logits) * 2 - 1
|
| 17 |
+
|
| 18 |
+
class KLNearFar(nn.Module):
|
| 19 |
+
def __init__(self,
|
| 20 |
+
near_weight: float = 0.1,
|
| 21 |
+
kl_weight: float = 1.0,
|
| 22 |
+
num_near_samples: Optional[int] = None):
|
| 23 |
+
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.near_weight = near_weight
|
| 27 |
+
self.kl_weight = kl_weight
|
| 28 |
+
self.num_near_samples = num_near_samples
|
| 29 |
+
self.geo_criterion = nn.BCEWithLogitsLoss()
|
| 30 |
+
|
| 31 |
+
def forward(self,
|
| 32 |
+
posteriors: Optional[DiagonalGaussianDistribution],
|
| 33 |
+
logits: torch.FloatTensor,
|
| 34 |
+
labels: torch.FloatTensor,
|
| 35 |
+
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
|
| 41 |
+
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
|
| 42 |
+
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
|
| 43 |
+
split (str):
|
| 44 |
+
**kwargs:
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
loss (torch.Tensor): (,)
|
| 48 |
+
log (dict):
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
if self.num_near_samples is None:
|
| 53 |
+
num_vol = logits.shape[1] // 2
|
| 54 |
+
else:
|
| 55 |
+
num_vol = logits.shape[1] - self.num_near_samples
|
| 56 |
+
|
| 57 |
+
vol_logits = logits[:, 0:num_vol]
|
| 58 |
+
vol_labels = labels[:, 0:num_vol]
|
| 59 |
+
|
| 60 |
+
near_logits = logits[:, num_vol:]
|
| 61 |
+
near_labels = labels[:, num_vol:]
|
| 62 |
+
|
| 63 |
+
# occupancy loss
|
| 64 |
+
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
|
| 65 |
+
# near_bce = self.geo_criterion(near_logits, near_labels)
|
| 66 |
+
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
| 67 |
+
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
|
| 68 |
+
|
| 69 |
+
if posteriors is None:
|
| 70 |
+
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
|
| 71 |
+
else:
|
| 72 |
+
kl_loss = posteriors.kl(dims=(1, 2))
|
| 73 |
+
kl_loss = torch.mean(kl_loss)
|
| 74 |
+
|
| 75 |
+
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
preds = logits >= 0
|
| 79 |
+
accuracy = (preds == labels).float()
|
| 80 |
+
accuracy = accuracy.mean()
|
| 81 |
+
pos_ratio = torch.mean(labels)
|
| 82 |
+
|
| 83 |
+
log = {
|
| 84 |
+
"{}/total_loss".format(split): loss.clone().detach(),
|
| 85 |
+
"{}/near".format(split): near_bce.detach(),
|
| 86 |
+
"{}/far".format(split): vol_bce.detach(),
|
| 87 |
+
"{}/kl".format(split): kl_loss.detach(),
|
| 88 |
+
"{}/accuracy".format(split): accuracy,
|
| 89 |
+
"{}/pos_ratio".format(split): pos_ratio
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if posteriors is not None:
|
| 93 |
+
log[f"{split}/mean"] = posteriors.mean.mean().detach()
|
| 94 |
+
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
|
| 95 |
+
log[f"{split}/std_max"] = posteriors.std.max().detach()
|
| 96 |
+
|
| 97 |
+
return loss, log
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class KLNearFarColor(nn.Module):
|
| 101 |
+
def __init__(self,
|
| 102 |
+
near_weight: float = 0.1,
|
| 103 |
+
kl_weight: float = 1.0,
|
| 104 |
+
color_weight: float = 1.0,
|
| 105 |
+
color_criterion: str = "mse",
|
| 106 |
+
num_near_samples: Optional[int] = None):
|
| 107 |
+
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self.color_weight = color_weight
|
| 111 |
+
self.near_weight = near_weight
|
| 112 |
+
self.kl_weight = kl_weight
|
| 113 |
+
self.num_near_samples = num_near_samples
|
| 114 |
+
|
| 115 |
+
if color_criterion == "mse":
|
| 116 |
+
self.color_criterion = nn.MSELoss()
|
| 117 |
+
|
| 118 |
+
elif color_criterion == "l1":
|
| 119 |
+
self.color_criterion = nn.L1Loss()
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
|
| 123 |
+
|
| 124 |
+
self.geo_criterion = nn.BCEWithLogitsLoss()
|
| 125 |
+
|
| 126 |
+
def forward(self,
|
| 127 |
+
posteriors: Optional[DiagonalGaussianDistribution],
|
| 128 |
+
logits: torch.FloatTensor,
|
| 129 |
+
labels: torch.FloatTensor,
|
| 130 |
+
pred_colors: torch.FloatTensor,
|
| 131 |
+
gt_colors: torch.FloatTensor,
|
| 132 |
+
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
|
| 138 |
+
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
|
| 139 |
+
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
|
| 140 |
+
pred_colors (torch.FloatTensor): [B, M, 3]
|
| 141 |
+
gt_colors (torch.FloatTensor): [B, M, 3]
|
| 142 |
+
split (str):
|
| 143 |
+
**kwargs:
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
loss (torch.Tensor): (,)
|
| 147 |
+
log (dict):
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
if self.num_near_samples is None:
|
| 152 |
+
num_vol = logits.shape[1] // 2
|
| 153 |
+
else:
|
| 154 |
+
num_vol = logits.shape[1] - self.num_near_samples
|
| 155 |
+
|
| 156 |
+
vol_logits = logits[:, 0:num_vol]
|
| 157 |
+
vol_labels = labels[:, 0:num_vol]
|
| 158 |
+
|
| 159 |
+
near_logits = logits[:, num_vol:]
|
| 160 |
+
near_labels = labels[:, num_vol:]
|
| 161 |
+
|
| 162 |
+
# occupancy loss
|
| 163 |
+
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
|
| 164 |
+
# near_bce = self.geo_criterion(near_logits, near_labels)
|
| 165 |
+
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
| 166 |
+
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
|
| 167 |
+
|
| 168 |
+
# surface color loss
|
| 169 |
+
color = self.color_criterion(pred_colors, gt_colors)
|
| 170 |
+
|
| 171 |
+
if posteriors is None:
|
| 172 |
+
kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
|
| 173 |
+
else:
|
| 174 |
+
kl_loss = posteriors.kl(dims=(1, 2))
|
| 175 |
+
kl_loss = torch.mean(kl_loss)
|
| 176 |
+
|
| 177 |
+
loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
preds = logits >= 0
|
| 181 |
+
accuracy = (preds == labels).float()
|
| 182 |
+
accuracy = accuracy.mean()
|
| 183 |
+
psnr = compute_psnr(pred_colors, gt_colors)
|
| 184 |
+
|
| 185 |
+
log = {
|
| 186 |
+
"{}/total_loss".format(split): loss.clone().detach(),
|
| 187 |
+
"{}/near".format(split): near_bce.detach(),
|
| 188 |
+
"{}/far".format(split): vol_bce.detach(),
|
| 189 |
+
"{}/color".format(split): color.detach(),
|
| 190 |
+
"{}/kl".format(split): kl_loss.detach(),
|
| 191 |
+
"{}/psnr".format(split): psnr.detach(),
|
| 192 |
+
"{}/accuracy".format(split): accuracy
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return loss, log
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ContrastKLNearFar(nn.Module):
|
| 199 |
+
def __init__(self,
|
| 200 |
+
contrast_weight: float = 1.0,
|
| 201 |
+
near_weight: float = 0.1,
|
| 202 |
+
kl_weight: float = 1.0,
|
| 203 |
+
normal_weight: float = 0.0,
|
| 204 |
+
surface_weight: float = 0.0,
|
| 205 |
+
eikonal_weight: float = 0.0,
|
| 206 |
+
sdf_bce_weight: float = 0.0,
|
| 207 |
+
sdf_l1l2_weight: float = 1.0,
|
| 208 |
+
num_near_samples: Optional[int] = None,
|
| 209 |
+
sdf_trunc_val: float = 0.05,
|
| 210 |
+
gt_sdf_soft: bool = False,
|
| 211 |
+
normal_supervision_type: str = "cosine",
|
| 212 |
+
supervision_type: str = 'occupancy'):
|
| 213 |
+
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
self.labels = None
|
| 217 |
+
self.last_local_batch_size = None
|
| 218 |
+
self.supervision_type = supervision_type
|
| 219 |
+
|
| 220 |
+
assert normal_supervision_type in ["l1", "l2", "cosine", "l1_cosine", "l2_cosine", "von_mises"]
|
| 221 |
+
self.normal_supervision_type = normal_supervision_type
|
| 222 |
+
|
| 223 |
+
self.contrast_weight = contrast_weight
|
| 224 |
+
self.near_weight = near_weight
|
| 225 |
+
self.kl_weight = kl_weight
|
| 226 |
+
self.normal_weight = normal_weight
|
| 227 |
+
self.surface_weight = surface_weight
|
| 228 |
+
self.eikonal_weight = eikonal_weight
|
| 229 |
+
self.sdf_bce_weight = sdf_bce_weight # only used in sigmoid-sdf
|
| 230 |
+
self.sdf_l1l2_weight = sdf_l1l2_weight # only used in sigmoid-sdf
|
| 231 |
+
self.sdf_trunc_val = sdf_trunc_val
|
| 232 |
+
self.gt_sdf_soft = gt_sdf_soft
|
| 233 |
+
self.num_near_samples = num_near_samples
|
| 234 |
+
self.geo_criterion = nn.BCEWithLogitsLoss()
|
| 235 |
+
self.geo_criterion_sdf = nn.MSELoss()
|
| 236 |
+
|
| 237 |
+
def sdf_loss(self, pred_sdf, gt_sdf):
|
| 238 |
+
scaled_sdf = gt_sdf / self.sdf_trunc_val
|
| 239 |
+
greater_mask = scaled_sdf > 1.
|
| 240 |
+
smaller_mask = scaled_sdf < -1.
|
| 241 |
+
inside_mask = 1. - greater_mask - smaller_mask
|
| 242 |
+
greater_loss = F.smooth_l1_loss(F.relu(1. - pred_sdf), torch.zeros_like(pred_sdf), reduction="none") * greater_mask
|
| 243 |
+
smaller_loss = F.smooth_l1_loss(F.relu(pred_sdf + 1.), torch.zeros_like(pred_sdf), reduction="none") * smaller_mask
|
| 244 |
+
inside_loss = F.smooth_l1_loss(pred_sdf, gt_sdf, beta=1e-2, reduction="none") * inside_mask
|
| 245 |
+
loss = (greater_loss + smaller_loss + inside_loss).mean()
|
| 246 |
+
return loss
|
| 247 |
+
|
| 248 |
+
def von_mises(self, x, y, k=1):
|
| 249 |
+
cos = F.cosine_similarity(x, y, dim=-1)
|
| 250 |
+
exp = torch.exp(k * (cos - 1))
|
| 251 |
+
return 1 - exp
|
| 252 |
+
|
| 253 |
+
def forward(self,
|
| 254 |
+
shape_embed: torch.FloatTensor,
|
| 255 |
+
text_embed: torch.FloatTensor,
|
| 256 |
+
image_embed: torch.FloatTensor,
|
| 257 |
+
logit_scale: torch.FloatTensor,
|
| 258 |
+
posteriors: Optional[DiagonalGaussianDistribution],
|
| 259 |
+
latents: torch.FloatTensor,
|
| 260 |
+
shape_logits: torch.FloatTensor,
|
| 261 |
+
shape_labels: torch.FloatTensor,
|
| 262 |
+
surface_logits: Optional[torch.FloatTensor],
|
| 263 |
+
surface_normals: Optional[torch.FloatTensor],
|
| 264 |
+
gt_surface_normals: Optional[torch.FloatTensor],
|
| 265 |
+
split: Optional[str] = "train", **kwargs):
|
| 266 |
+
if self.supervision_type == 'occupancy':
|
| 267 |
+
shape_logits = shape_logits.squeeze(-1)
|
| 268 |
+
shape_labels[shape_labels>=0] = 1
|
| 269 |
+
shape_labels[shape_labels<0] = 0
|
| 270 |
+
|
| 271 |
+
elif self.supervision_type == 'occupancy-shapenet':
|
| 272 |
+
shape_logits = shape_logits.squeeze(-1)
|
| 273 |
+
|
| 274 |
+
elif self.supervision_type == 'occupancy-w-surface':
|
| 275 |
+
shape_logits = shape_logits.squeeze(-1)
|
| 276 |
+
shape_labels[shape_labels==10] = 0
|
| 277 |
+
shape_labels[shape_labels>0] = 1
|
| 278 |
+
shape_labels[shape_labels<0] = 0
|
| 279 |
+
|
| 280 |
+
elif 'sdf' in self.supervision_type:
|
| 281 |
+
shape_logits = shape_logits.squeeze(-1)
|
| 282 |
+
if self.gt_sdf_soft:
|
| 283 |
+
shape_labels_sdf = torch.tanh(shape_labels / self.sdf_trunc_val)# * self.sdf_trunc_val
|
| 284 |
+
else:
|
| 285 |
+
shape_labels_sdf = torch.clamp(shape_labels, min=-self.sdf_trunc_val, max=self.sdf_trunc_val) / self.sdf_trunc_val
|
| 286 |
+
else:
|
| 287 |
+
raise ValueError(f"Invalid supervision_type {self.supervision_type}")
|
| 288 |
+
|
| 289 |
+
local_batch_size = shape_embed.size(0)
|
| 290 |
+
|
| 291 |
+
if local_batch_size != self.last_local_batch_size:
|
| 292 |
+
self.labels = local_batch_size * misc.get_rank() + torch.arange(
|
| 293 |
+
local_batch_size, device=shape_embed.device
|
| 294 |
+
).long()
|
| 295 |
+
self.last_local_batch_size = local_batch_size
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if text_embed is not None and image_embed is not None:
|
| 299 |
+
# normalized features
|
| 300 |
+
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
|
| 301 |
+
text_embed = F.normalize(text_embed, dim=-1, p=2)
|
| 302 |
+
image_embed = F.normalize(image_embed, dim=-1, p=2)
|
| 303 |
+
|
| 304 |
+
# gather features from all GPUs
|
| 305 |
+
shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
|
| 306 |
+
[shape_embed, text_embed, image_embed]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# cosine similarity as logits
|
| 310 |
+
logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
|
| 311 |
+
logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
|
| 312 |
+
logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
|
| 313 |
+
logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
|
| 314 |
+
contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
|
| 315 |
+
F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
|
| 316 |
+
(F.cross_entropy(logits_per_shape_image, self.labels) +
|
| 317 |
+
F.cross_entropy(logits_per_image_shape, self.labels)) / 2
|
| 318 |
+
else:
|
| 319 |
+
contrast_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 320 |
+
|
| 321 |
+
# shape reconstruction
|
| 322 |
+
if self.num_near_samples is None:
|
| 323 |
+
num_vol = shape_logits.shape[1] // 2
|
| 324 |
+
else:
|
| 325 |
+
num_vol = shape_logits.shape[1] - self.num_near_samples
|
| 326 |
+
|
| 327 |
+
# occupancy/sdf loss
|
| 328 |
+
if self.supervision_type == 'occupancy' or self.supervision_type == 'occupancy-shapenet':
|
| 329 |
+
vol_logits = shape_logits[:, 0:num_vol]
|
| 330 |
+
vol_labels = shape_labels[:, 0:num_vol]
|
| 331 |
+
|
| 332 |
+
near_logits = shape_logits[:, num_vol:]
|
| 333 |
+
near_labels = shape_labels[:, num_vol:]
|
| 334 |
+
|
| 335 |
+
vol_loss = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
| 336 |
+
near_loss = self.geo_criterion(near_logits.float(), near_labels.float())
|
| 337 |
+
|
| 338 |
+
elif 'sdf' in self.supervision_type:
|
| 339 |
+
if self.supervision_type == "sigmoid-sdf":
|
| 340 |
+
shape_sdfs = logits_to_sdf(shape_logits)
|
| 341 |
+
else:
|
| 342 |
+
shape_sdfs = shape_logits
|
| 343 |
+
|
| 344 |
+
vol_logits = shape_logits[:, 0:num_vol]
|
| 345 |
+
vol_sdfs = shape_sdfs[:, 0:num_vol]
|
| 346 |
+
vol_labels_sdf = shape_labels_sdf[:, 0:num_vol]
|
| 347 |
+
|
| 348 |
+
near_logits= shape_logits[:, num_vol:]
|
| 349 |
+
near_sdfs = shape_sdfs[:, num_vol:]
|
| 350 |
+
near_labels_sdf = shape_labels_sdf[:, num_vol:]
|
| 351 |
+
|
| 352 |
+
# use both sdf loss and occupancy loss
|
| 353 |
+
vol_loss = torch.mean(torch.abs(vol_sdfs - vol_labels_sdf)) + torch.mean((vol_sdfs - vol_labels_sdf) ** 2) #+ self.geo_criterion(vol_logits_sdf, vol_labels)
|
| 354 |
+
near_loss = torch.mean(torch.abs(near_sdfs - near_labels_sdf)) + torch.mean((near_sdfs - near_labels_sdf) ** 2) #+ self.geo_criterion(near_logits_sdf, near_labels)
|
| 355 |
+
|
| 356 |
+
if self.supervision_type == "sigmoid-sdf":
|
| 357 |
+
vol_labels = (vol_labels_sdf + 1) / 2
|
| 358 |
+
near_labels = (near_labels_sdf + 1) / 2
|
| 359 |
+
vol_loss = self.sdf_l1l2_weight * vol_loss + self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels)
|
| 360 |
+
near_loss = self.sdf_l1l2_weight * near_loss + self.sdf_bce_weight * self.geo_criterion(near_logits, near_labels)
|
| 361 |
+
# print(vol_loss, self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels))
|
| 362 |
+
|
| 363 |
+
# surface loss
|
| 364 |
+
if "sdf" in self.supervision_type and surface_logits is not None:
|
| 365 |
+
if self.supervision_type == "sigmoid-sdf":
|
| 366 |
+
surface_sdfs = logits_to_sdf(surface_logits)
|
| 367 |
+
else:
|
| 368 |
+
surface_sdfs = surface_logits
|
| 369 |
+
surface_loss = torch.mean(surface_sdfs ** 2)
|
| 370 |
+
else:
|
| 371 |
+
surface_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 372 |
+
|
| 373 |
+
if surface_normals is not None and gt_surface_normals is not None and "sdf" in self.supervision_type:
|
| 374 |
+
|
| 375 |
+
valid_mask = surface_sdfs.squeeze(-1) < (self.sdf_trunc_val * 0.8)
|
| 376 |
+
|
| 377 |
+
if valid_mask is not None:
|
| 378 |
+
surface_normals = surface_normals[valid_mask]
|
| 379 |
+
gt_surface_normals = gt_surface_normals[valid_mask]
|
| 380 |
+
|
| 381 |
+
# eikonal loss
|
| 382 |
+
surface_normals_norm = torch.norm(surface_normals, dim=-1)
|
| 383 |
+
eikonal_loss = F.mse_loss(surface_normals_norm * self.sdf_trunc_val, surface_normals_norm.new_ones(surface_normals_norm.shape), reduction="mean")
|
| 384 |
+
|
| 385 |
+
# surface normal loss
|
| 386 |
+
# surface_normals = F.normalize(surface_normals, dim=-1)
|
| 387 |
+
surface_normals = surface_normals * self.sdf_trunc_val
|
| 388 |
+
gt_surface_normals = F.normalize(gt_surface_normals, dim=-1)
|
| 389 |
+
|
| 390 |
+
if self.normal_supervision_type == "cosine":
|
| 391 |
+
# use cosine similarity loss
|
| 392 |
+
normal_loss = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
|
| 393 |
+
elif self.normal_supervision_type == "l1":
|
| 394 |
+
# use l1 loss
|
| 395 |
+
normal_loss = F.l1_loss(surface_normals, gt_surface_normals)
|
| 396 |
+
elif self.normal_supervision_type == "l2":
|
| 397 |
+
normal_loss = F.mse_loss(surface_normals, gt_surface_normals)
|
| 398 |
+
elif self.normal_supervision_type == "von_mises":
|
| 399 |
+
normal_loss = self.von_mises(surface_normals, gt_surface_normals).mean()
|
| 400 |
+
elif self.normal_supervision_type == "l1_cosine":
|
| 401 |
+
normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
|
| 402 |
+
normal_loss_l1 = F.l1_loss(surface_normals, gt_surface_normals)
|
| 403 |
+
normal_loss = normal_loss_cos + normal_loss_l1
|
| 404 |
+
elif self.normal_supervision_type == "l2_cosine":
|
| 405 |
+
normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
|
| 406 |
+
normal_loss_l2 = F.mse_loss(surface_normals, gt_surface_normals)
|
| 407 |
+
normal_loss = normal_loss_cos + normal_loss_l2
|
| 408 |
+
else:
|
| 409 |
+
raise NotImplementedError
|
| 410 |
+
else:
|
| 411 |
+
normal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 412 |
+
eikonal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 413 |
+
surface_normals_norm = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 414 |
+
|
| 415 |
+
if posteriors is None:
|
| 416 |
+
kl_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
|
| 417 |
+
else:
|
| 418 |
+
kl_loss = posteriors.kl(dims=(1, 2))
|
| 419 |
+
kl_loss = torch.mean(kl_loss)
|
| 420 |
+
|
| 421 |
+
loss = vol_loss + near_loss * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + normal_loss * self.normal_weight + self.eikonal_weight * eikonal_loss + self.surface_weight * surface_loss
|
| 422 |
+
|
| 423 |
+
# compute accuracy
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
if "sdf" in self.supervision_type:
|
| 426 |
+
preds = shape_sdfs >= 0
|
| 427 |
+
sdf_labels = shape_labels_sdf >= 0
|
| 428 |
+
accuracy = (preds == sdf_labels).float()
|
| 429 |
+
else:
|
| 430 |
+
preds = shape_logits >= 0
|
| 431 |
+
accuracy = (preds == shape_labels).float()
|
| 432 |
+
accuracy = accuracy.mean()
|
| 433 |
+
|
| 434 |
+
log = {
|
| 435 |
+
# "{}/contrast".format(split): contrast_loss.clone().detach(),
|
| 436 |
+
"{}/near".format(split): near_loss.detach(),
|
| 437 |
+
"{}/far".format(split): vol_loss.detach(),
|
| 438 |
+
"{}/normal".format(split): normal_loss.detach(),
|
| 439 |
+
"{}/surface".format(split): surface_loss.detach(),
|
| 440 |
+
"{}/eikonal".format(split): eikonal_loss.detach(),
|
| 441 |
+
"{}/kl".format(split): kl_loss.detach(),
|
| 442 |
+
"{}/surface_grad_norm".format(split): surface_normals_norm.mean().detach(),
|
| 443 |
+
# "{}/shape_text_acc".format(split): shape_text_acc,
|
| 444 |
+
# "{}/shape_image_acc".format(split): shape_image_acc,
|
| 445 |
+
"{}/total_loss".format(split): loss.clone().detach(),
|
| 446 |
+
"{}/accuracy".format(split): accuracy,
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
if posteriors is not None:
|
| 450 |
+
log[f"{split}/posteriors_mean"] = posteriors.mean.mean().detach()
|
| 451 |
+
log[f"{split}/posteriors_std_mean"] = posteriors.std.mean().detach()
|
| 452 |
+
log[f"{split}/posteriors_std_max"] = posteriors.std.max().detach()
|
| 453 |
+
|
| 454 |
+
return loss, log, near_loss
|
src/model/michelangelo/models/tsal/sal_perceiver.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
from einops import repeat
|
| 7 |
+
import math
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from ..modules import checkpoint
|
| 13 |
+
from ..modules.embedder import FourierEmbedder
|
| 14 |
+
from ..modules.distributions import DiagonalGaussianDistribution
|
| 15 |
+
from ..modules.transformer_blocks import (
|
| 16 |
+
ResidualCrossAttentionBlock,
|
| 17 |
+
Transformer
|
| 18 |
+
)
|
| 19 |
+
from ...utils.misc import use_flash3
|
| 20 |
+
|
| 21 |
+
from .tsal_base import ShapeAsLatentModule
|
| 22 |
+
from .loss import logits_to_sdf
|
| 23 |
+
|
| 24 |
+
from ....utils import fps
|
| 25 |
+
|
| 26 |
+
class CrossAttentionEncoder(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(self, *,
|
| 29 |
+
device: Optional[torch.device],
|
| 30 |
+
dtype: Optional[torch.dtype],
|
| 31 |
+
num_latents: int,
|
| 32 |
+
fourier_embedder: FourierEmbedder,
|
| 33 |
+
point_feats: int,
|
| 34 |
+
width: int,
|
| 35 |
+
heads: int,
|
| 36 |
+
layers: int,
|
| 37 |
+
init_scale: float = 0.25,
|
| 38 |
+
qkv_bias: bool = True,
|
| 39 |
+
flash: bool = False,
|
| 40 |
+
use_ln_post: bool = False,
|
| 41 |
+
use_checkpoint: bool = False,
|
| 42 |
+
query_method: bool = False,
|
| 43 |
+
use_full_input: bool = True,
|
| 44 |
+
token_num: int = 256,
|
| 45 |
+
no_query: bool=False):
|
| 46 |
+
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.query_method = query_method
|
| 50 |
+
self.token_num = token_num
|
| 51 |
+
self.use_full_input = use_full_input
|
| 52 |
+
|
| 53 |
+
self.use_checkpoint = use_checkpoint
|
| 54 |
+
self.num_latents = num_latents
|
| 55 |
+
|
| 56 |
+
if no_query:
|
| 57 |
+
self.query = None
|
| 58 |
+
else:
|
| 59 |
+
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
| 60 |
+
|
| 61 |
+
self.fourier_embedder = fourier_embedder
|
| 62 |
+
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
|
| 63 |
+
self.cross_attn = ResidualCrossAttentionBlock(
|
| 64 |
+
device=device,
|
| 65 |
+
dtype=dtype,
|
| 66 |
+
width=width,
|
| 67 |
+
heads=heads,
|
| 68 |
+
init_scale=init_scale,
|
| 69 |
+
qkv_bias=qkv_bias,
|
| 70 |
+
flash=flash,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.self_attn = Transformer(
|
| 74 |
+
device=device,
|
| 75 |
+
dtype=dtype,
|
| 76 |
+
n_ctx=num_latents,
|
| 77 |
+
width=width,
|
| 78 |
+
layers=layers,
|
| 79 |
+
heads=heads,
|
| 80 |
+
init_scale=init_scale,
|
| 81 |
+
qkv_bias=qkv_bias,
|
| 82 |
+
flash=flash,
|
| 83 |
+
use_checkpoint=False
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if use_ln_post:
|
| 87 |
+
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
|
| 88 |
+
else:
|
| 89 |
+
self.ln_post = None
|
| 90 |
+
|
| 91 |
+
def _forward(self, pc, feats):
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 96 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
|
| 100 |
+
"""
|
| 101 |
+
if self.query_method:
|
| 102 |
+
token_num = self.num_latents
|
| 103 |
+
bs = pc.shape[0] #pc [10, 204800, 3]
|
| 104 |
+
data = self.fourier_embedder(pc) #[10, 204800, 51]
|
| 105 |
+
if feats is not None: #[10, 204800, 3]
|
| 106 |
+
data = torch.cat([data, feats], dim=-1)
|
| 107 |
+
data = self.input_proj(data) #[10, 204800, 768]
|
| 108 |
+
|
| 109 |
+
query = repeat(self.query, "m c -> b m c", b=bs) #[10, 257, 768]
|
| 110 |
+
|
| 111 |
+
latents = self.cross_attn(query, data)
|
| 112 |
+
latents = self.self_attn(latents)
|
| 113 |
+
|
| 114 |
+
if self.ln_post is not None:
|
| 115 |
+
latents = self.ln_post(latents)
|
| 116 |
+
|
| 117 |
+
pre_pc = None
|
| 118 |
+
else:
|
| 119 |
+
|
| 120 |
+
if isinstance(self.token_num, int):
|
| 121 |
+
token_num = self.token_num
|
| 122 |
+
else:
|
| 123 |
+
token_num = random.choice(self.token_num)
|
| 124 |
+
# print(token_num,'-----------------------', flush=True)
|
| 125 |
+
|
| 126 |
+
if self.training:
|
| 127 |
+
rng = np.random.default_rng()
|
| 128 |
+
else:
|
| 129 |
+
rng = np.random.default_rng(seed=0)
|
| 130 |
+
ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1])
|
| 131 |
+
|
| 132 |
+
pre_pc = pc[:,ind,:]
|
| 133 |
+
pre_feats = feats[:,ind,:]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
B, N, D = pre_pc.shape #[10, 204800, 3]
|
| 137 |
+
C = pre_feats.shape[-1]
|
| 138 |
+
###### fps
|
| 139 |
+
pos = pre_pc.view(B*N, D)
|
| 140 |
+
pos_feats = pre_feats.view(B*N, C)
|
| 141 |
+
batch = torch.arange(B).to(pc.device)
|
| 142 |
+
batch = torch.repeat_interleave(batch, N)
|
| 143 |
+
|
| 144 |
+
# ratio = 1.0 * token_num / N
|
| 145 |
+
idx = fps(pos, batch, ratio=1. / 4, random_start=self.training)
|
| 146 |
+
|
| 147 |
+
sampled_pc = pos[idx]
|
| 148 |
+
sampled_pc = sampled_pc.view(B, -1, 3)
|
| 149 |
+
|
| 150 |
+
sampled_feats = pos_feats[idx]
|
| 151 |
+
sampled_feats = sampled_feats.view(B, -1, C)
|
| 152 |
+
|
| 153 |
+
######
|
| 154 |
+
if self.use_full_input:
|
| 155 |
+
data = self.fourier_embedder(pc) #[B, 20480, 51]
|
| 156 |
+
else:
|
| 157 |
+
data = self.fourier_embedder(pre_pc) # [B, 4 * token_num, 51]
|
| 158 |
+
|
| 159 |
+
if feats is not None: #[10, 204800, 3]
|
| 160 |
+
if not self.use_full_input:
|
| 161 |
+
feats = pre_feats
|
| 162 |
+
data = torch.cat([data, feats], dim=-1) #[10, 204800, 54]
|
| 163 |
+
data = self.input_proj(data) #[10, 204800, 768]
|
| 164 |
+
|
| 165 |
+
# print(data.shape)
|
| 166 |
+
|
| 167 |
+
sampled_data = self.fourier_embedder(sampled_pc) #[10, 256, 51]
|
| 168 |
+
if feats is not None: #[10, 256, 3]
|
| 169 |
+
sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) #[10, 256, 54]
|
| 170 |
+
sampled_data = self.input_proj(sampled_data) #[10, 256, 768]
|
| 171 |
+
|
| 172 |
+
latents = self.cross_attn(sampled_data, data) #[10, 256, 768]
|
| 173 |
+
latents = self.self_attn(latents)
|
| 174 |
+
|
| 175 |
+
if self.ln_post is not None:
|
| 176 |
+
latents = self.ln_post(latents)
|
| 177 |
+
|
| 178 |
+
pre_pc = torch.cat([pre_pc, pre_feats], dim=-1)
|
| 179 |
+
|
| 180 |
+
return latents, pc, token_num, pre_pc
|
| 181 |
+
|
| 182 |
+
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 187 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
dict
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class CrossAttentionDecoder(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(self, *,
|
| 199 |
+
device: Optional[torch.device],
|
| 200 |
+
dtype: Optional[torch.dtype],
|
| 201 |
+
num_latents: int,
|
| 202 |
+
out_channels: int,
|
| 203 |
+
fourier_embedder: FourierEmbedder,
|
| 204 |
+
width: int,
|
| 205 |
+
heads: int,
|
| 206 |
+
init_scale: float = 0.25,
|
| 207 |
+
qkv_bias: bool = True,
|
| 208 |
+
flash: bool = False,
|
| 209 |
+
use_checkpoint: bool = False,
|
| 210 |
+
mlp_width_scale: int = 4,
|
| 211 |
+
supervision_type: str = 'occupancy'):
|
| 212 |
+
|
| 213 |
+
super().__init__()
|
| 214 |
+
|
| 215 |
+
self.use_checkpoint = use_checkpoint
|
| 216 |
+
self.fourier_embedder = fourier_embedder
|
| 217 |
+
self.supervision_type = supervision_type
|
| 218 |
+
|
| 219 |
+
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
| 220 |
+
|
| 221 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
| 222 |
+
device=device,
|
| 223 |
+
dtype=dtype,
|
| 224 |
+
n_data=num_latents,
|
| 225 |
+
width=width,
|
| 226 |
+
heads=heads,
|
| 227 |
+
init_scale=init_scale,
|
| 228 |
+
qkv_bias=qkv_bias,
|
| 229 |
+
flash=flash,
|
| 230 |
+
mlp_width_scale=mlp_width_scale,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 234 |
+
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
| 235 |
+
if self.supervision_type == 'occupancy-sdf':
|
| 236 |
+
self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
| 241 |
+
if next(self.query_proj.parameters()).dtype == torch.float16:
|
| 242 |
+
queries = queries.half()
|
| 243 |
+
latents = latents.half()
|
| 244 |
+
# print(f"queries: {queries.dtype}, {queries.device}")
|
| 245 |
+
# print(f"latents: {latents.dtype}, {latents.device}"z)
|
| 246 |
+
queries = self.query_proj(self.fourier_embedder(queries))
|
| 247 |
+
x = self.cross_attn_decoder(queries, latents)
|
| 248 |
+
x = self.ln_post(x)
|
| 249 |
+
x_1 = self.output_proj(x)
|
| 250 |
+
if self.supervision_type == 'occupancy-sdf':
|
| 251 |
+
x_2 = self.output_proj_sdf(x)
|
| 252 |
+
return x_1, x_2
|
| 253 |
+
else:
|
| 254 |
+
return x_1
|
| 255 |
+
|
| 256 |
+
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
| 257 |
+
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class ShapeAsLatentPerceiver(ShapeAsLatentModule):
|
| 261 |
+
def __init__(self, *,
|
| 262 |
+
device: Optional[torch.device],
|
| 263 |
+
dtype: Optional[torch.dtype],
|
| 264 |
+
num_latents: int,
|
| 265 |
+
point_feats: int = 0,
|
| 266 |
+
embed_dim: int = 0,
|
| 267 |
+
num_freqs: int = 8,
|
| 268 |
+
include_pi: bool = True,
|
| 269 |
+
width: int,
|
| 270 |
+
heads: int,
|
| 271 |
+
num_encoder_layers: int,
|
| 272 |
+
num_decoder_layers: int,
|
| 273 |
+
decoder_width: Optional[int] = None,
|
| 274 |
+
init_scale: float = 0.25,
|
| 275 |
+
qkv_bias: bool = True,
|
| 276 |
+
flash: bool = False,
|
| 277 |
+
use_ln_post: bool = False,
|
| 278 |
+
use_checkpoint: bool = False,
|
| 279 |
+
supervision_type: str = 'occupancy',
|
| 280 |
+
query_method: bool = False,
|
| 281 |
+
token_num: int = 256,
|
| 282 |
+
grad_type: str = "numerical",
|
| 283 |
+
grad_interval: float = 0.005,
|
| 284 |
+
use_full_input: bool = True,
|
| 285 |
+
freeze_encoder: bool = False,
|
| 286 |
+
decoder_mlp_width_scale: int = 4,
|
| 287 |
+
residual_kl: bool = False,
|
| 288 |
+
):
|
| 289 |
+
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.use_checkpoint = use_checkpoint
|
| 293 |
+
|
| 294 |
+
self.num_latents = num_latents
|
| 295 |
+
assert grad_type in ["numerical", "analytical"]
|
| 296 |
+
self.grad_type = grad_type
|
| 297 |
+
self.grad_interval = grad_interval
|
| 298 |
+
self.supervision_type = supervision_type
|
| 299 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
| 300 |
+
|
| 301 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
| 302 |
+
self.encoder = CrossAttentionEncoder(
|
| 303 |
+
device=device,
|
| 304 |
+
dtype=dtype,
|
| 305 |
+
fourier_embedder=self.fourier_embedder,
|
| 306 |
+
num_latents=num_latents,
|
| 307 |
+
point_feats=point_feats,
|
| 308 |
+
width=width,
|
| 309 |
+
heads=heads,
|
| 310 |
+
layers=num_encoder_layers,
|
| 311 |
+
init_scale=init_scale,
|
| 312 |
+
qkv_bias=qkv_bias,
|
| 313 |
+
flash=flash,
|
| 314 |
+
use_ln_post=use_ln_post,
|
| 315 |
+
use_checkpoint=use_checkpoint,
|
| 316 |
+
query_method=query_method,
|
| 317 |
+
use_full_input=use_full_input,
|
| 318 |
+
token_num=token_num
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.embed_dim = embed_dim
|
| 322 |
+
self.residual_kl = residual_kl
|
| 323 |
+
if decoder_width is None:
|
| 324 |
+
decoder_width = width
|
| 325 |
+
if embed_dim > 0:
|
| 326 |
+
# VAE embed
|
| 327 |
+
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
|
| 328 |
+
self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype)
|
| 329 |
+
self.latent_shape = (num_latents, embed_dim)
|
| 330 |
+
if self.residual_kl:
|
| 331 |
+
assert self.post_kl.out_features % self.post_kl.in_features == 0
|
| 332 |
+
assert self.pre_kl.in_features % self.pre_kl.out_features == 0
|
| 333 |
+
else:
|
| 334 |
+
self.latent_shape = (num_latents, width)
|
| 335 |
+
|
| 336 |
+
print("decoder width = ", decoder_width)
|
| 337 |
+
|
| 338 |
+
self.transformer = Transformer(
|
| 339 |
+
device=device,
|
| 340 |
+
dtype=dtype,
|
| 341 |
+
n_ctx=num_latents,
|
| 342 |
+
width=decoder_width,
|
| 343 |
+
layers=num_decoder_layers,
|
| 344 |
+
heads=heads,
|
| 345 |
+
init_scale=init_scale,
|
| 346 |
+
qkv_bias=qkv_bias,
|
| 347 |
+
flash=flash,
|
| 348 |
+
use_checkpoint=use_checkpoint
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# geometry decoder
|
| 352 |
+
self.geo_decoder = CrossAttentionDecoder(
|
| 353 |
+
device=device,
|
| 354 |
+
dtype=dtype,
|
| 355 |
+
fourier_embedder=self.fourier_embedder,
|
| 356 |
+
out_channels=1,
|
| 357 |
+
num_latents=num_latents,
|
| 358 |
+
width=decoder_width,
|
| 359 |
+
heads=heads,
|
| 360 |
+
init_scale=init_scale,
|
| 361 |
+
qkv_bias=qkv_bias,
|
| 362 |
+
flash=flash,
|
| 363 |
+
use_checkpoint=use_checkpoint,
|
| 364 |
+
supervision_type=supervision_type,
|
| 365 |
+
mlp_width_scale=decoder_mlp_width_scale
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if freeze_encoder:
|
| 369 |
+
for p in self.encoder.parameters():
|
| 370 |
+
p.requires_grad = False
|
| 371 |
+
for p in self.pre_kl.parameters():
|
| 372 |
+
p.requires_grad = False
|
| 373 |
+
print("freeze encoder and pre kl")
|
| 374 |
+
|
| 375 |
+
def encode(self,
|
| 376 |
+
pc: torch.FloatTensor,
|
| 377 |
+
feats: Optional[torch.FloatTensor] = None,
|
| 378 |
+
sample_posterior: bool = True):
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 383 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 384 |
+
sample_posterior (bool):
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
latents (torch.FloatTensor)
|
| 388 |
+
center_pos (torch.FloatTensor or None):
|
| 389 |
+
posterior (DiagonalGaussianDistribution or None):
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
latents, center_pos = self.encoder(pc, feats)
|
| 393 |
+
|
| 394 |
+
posterior = None
|
| 395 |
+
if self.embed_dim > 0:
|
| 396 |
+
moments = self.pre_kl(latents)
|
| 397 |
+
if self.residual_kl:
|
| 398 |
+
B, N = latents.shape[:2]
|
| 399 |
+
moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2)
|
| 400 |
+
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
| 401 |
+
|
| 402 |
+
if sample_posterior:
|
| 403 |
+
latents = posterior.sample()
|
| 404 |
+
else:
|
| 405 |
+
latents = posterior.mode()
|
| 406 |
+
|
| 407 |
+
return latents, center_pos, posterior
|
| 408 |
+
|
| 409 |
+
def decode(self, latents: torch.FloatTensor):
|
| 410 |
+
if self.residual_kl:
|
| 411 |
+
latents = latents.repeat_interleave(self.post_kl.out_features // self.post_kl.in_features, dim=-1) + self.post_kl(latents)
|
| 412 |
+
else:
|
| 413 |
+
latents = self.post_kl(latents)
|
| 414 |
+
|
| 415 |
+
return self.transformer(latents)
|
| 416 |
+
|
| 417 |
+
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor, grad: bool = False):
|
| 418 |
+
# logits = self.geo_decoder(queries, latents).squeeze(-1)
|
| 419 |
+
if grad:
|
| 420 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float32):
|
| 421 |
+
if self.grad_type == "numerical":
|
| 422 |
+
raise NotImplementedError
|
| 423 |
+
interval = self.grad_interval
|
| 424 |
+
# print('grad interval = ', interval)
|
| 425 |
+
grad_value = []
|
| 426 |
+
for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]:
|
| 427 |
+
offset_tensor = torch.tensor(offset, device=queries.device)[None, :]
|
| 428 |
+
res_p = self.geo_decoder(queries + offset_tensor, latents)[..., 0]
|
| 429 |
+
res_n = self.geo_decoder(queries - offset_tensor, latents)[..., 0]
|
| 430 |
+
grad_value.append((res_p - res_n) / (2 * interval))
|
| 431 |
+
grad_value = torch.stack(grad_value, dim=-1)
|
| 432 |
+
else:
|
| 433 |
+
# print("auto grad")
|
| 434 |
+
queries_d = torch.clone(queries)
|
| 435 |
+
queries_d.requires_grad = True
|
| 436 |
+
with torch.enable_grad():
|
| 437 |
+
with use_flash3.disable_flash3():
|
| 438 |
+
logits = self.geo_decoder(queries_d, latents)
|
| 439 |
+
if self.supervision_type == "sigmoid-sdf":
|
| 440 |
+
sdfs = logits_to_sdf(logits)
|
| 441 |
+
grad_value = torch.autograd.grad(sdfs, [queries_d],
|
| 442 |
+
grad_outputs=torch.ones_like(sdfs),
|
| 443 |
+
create_graph=self.geo_decoder.training)[0]
|
| 444 |
+
else:
|
| 445 |
+
logits = self.geo_decoder(queries, latents)
|
| 446 |
+
grad_value = None
|
| 447 |
+
|
| 448 |
+
return logits, grad_value
|
| 449 |
+
|
| 450 |
+
def forward(self,
|
| 451 |
+
pc: torch.FloatTensor,
|
| 452 |
+
feats: torch.FloatTensor,
|
| 453 |
+
volume_queries: torch.FloatTensor,
|
| 454 |
+
sample_posterior: bool = True):
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 459 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 460 |
+
volume_queries (torch.FloatTensor): [B, P, 3]
|
| 461 |
+
sample_posterior (bool):
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
logits (torch.FloatTensor): [B, P]
|
| 465 |
+
center_pos (torch.FloatTensor): [B, M, 3]
|
| 466 |
+
posterior (DiagonalGaussianDistribution or None).
|
| 467 |
+
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
|
| 471 |
+
|
| 472 |
+
latents = self.decode(latents)
|
| 473 |
+
logits = self.query_geometry(volume_queries, latents)
|
| 474 |
+
|
| 475 |
+
return logits, center_pos, posterior
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
|
| 479 |
+
|
| 480 |
+
def __init__(self, *,
|
| 481 |
+
device: Optional[torch.device],
|
| 482 |
+
dtype: Optional[str],
|
| 483 |
+
num_latents: int,
|
| 484 |
+
point_feats: int = 0,
|
| 485 |
+
embed_dim: int = 0,
|
| 486 |
+
num_freqs: int = 8,
|
| 487 |
+
include_pi: bool = True,
|
| 488 |
+
width: int,
|
| 489 |
+
heads: int,
|
| 490 |
+
num_encoder_layers: int,
|
| 491 |
+
num_decoder_layers: int,
|
| 492 |
+
decoder_width: Optional[int] = None,
|
| 493 |
+
init_scale: float = 0.25,
|
| 494 |
+
qkv_bias: bool = True,
|
| 495 |
+
flash: bool = False,
|
| 496 |
+
use_ln_post: bool = False,
|
| 497 |
+
use_checkpoint: bool = False,
|
| 498 |
+
supervision_type: str = 'occupancy',
|
| 499 |
+
grad_type: str = "numerical",
|
| 500 |
+
grad_interval: float = 0.005,
|
| 501 |
+
query_method: bool = False,
|
| 502 |
+
use_full_input: bool = True,
|
| 503 |
+
token_num: int = 256,
|
| 504 |
+
freeze_encoder: bool = False,
|
| 505 |
+
decoder_mlp_width_scale: int = 4,
|
| 506 |
+
residual_kl: bool = False,
|
| 507 |
+
):
|
| 508 |
+
|
| 509 |
+
MAP_DTYPE = {
|
| 510 |
+
'float32': torch.float32,
|
| 511 |
+
'float16': torch.float16,
|
| 512 |
+
'bfloat16': torch.bfloat16,
|
| 513 |
+
}
|
| 514 |
+
if dtype is not None:
|
| 515 |
+
dtype = MAP_DTYPE[dtype]
|
| 516 |
+
super().__init__(
|
| 517 |
+
device=device,
|
| 518 |
+
dtype=dtype,
|
| 519 |
+
num_latents=1 + num_latents,
|
| 520 |
+
point_feats=point_feats,
|
| 521 |
+
embed_dim=embed_dim,
|
| 522 |
+
num_freqs=num_freqs,
|
| 523 |
+
include_pi=include_pi,
|
| 524 |
+
width=width,
|
| 525 |
+
decoder_width=decoder_width,
|
| 526 |
+
heads=heads,
|
| 527 |
+
num_encoder_layers=num_encoder_layers,
|
| 528 |
+
num_decoder_layers=num_decoder_layers,
|
| 529 |
+
init_scale=init_scale,
|
| 530 |
+
qkv_bias=qkv_bias,
|
| 531 |
+
flash=flash,
|
| 532 |
+
use_ln_post=use_ln_post,
|
| 533 |
+
use_checkpoint=use_checkpoint,
|
| 534 |
+
supervision_type=supervision_type,
|
| 535 |
+
grad_type=grad_type,
|
| 536 |
+
grad_interval=grad_interval,
|
| 537 |
+
query_method=query_method,
|
| 538 |
+
token_num=token_num,
|
| 539 |
+
use_full_input=use_full_input,
|
| 540 |
+
freeze_encoder=freeze_encoder,
|
| 541 |
+
decoder_mlp_width_scale=decoder_mlp_width_scale,
|
| 542 |
+
residual_kl=residual_kl,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self.width = width
|
| 546 |
+
|
| 547 |
+
def encode(self,
|
| 548 |
+
pc: torch.FloatTensor,
|
| 549 |
+
feats: Optional[torch.FloatTensor] = None,
|
| 550 |
+
sample_posterior: bool = True,
|
| 551 |
+
only_shape: bool=False):
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 556 |
+
feats (torch.FloatTensor or None): [B, N, c]
|
| 557 |
+
sample_posterior (bool):
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
shape_embed (torch.FloatTensor)
|
| 561 |
+
kl_embed (torch.FloatTensor):
|
| 562 |
+
posterior (DiagonalGaussianDistribution or None):
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats)
|
| 566 |
+
if only_shape:
|
| 567 |
+
return shape_embed
|
| 568 |
+
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
|
| 569 |
+
|
| 570 |
+
return shape_embed, kl_embed, posterior, token_num, pre_pc
|
| 571 |
+
|
| 572 |
+
def encode_latents(self,
|
| 573 |
+
pc: torch.FloatTensor,
|
| 574 |
+
feats: Optional[torch.FloatTensor] = None):
|
| 575 |
+
|
| 576 |
+
x, _, token_num, pre_pc = self.encoder(pc, feats)
|
| 577 |
+
|
| 578 |
+
shape_embed = x[:, 0]
|
| 579 |
+
# latents = x[:, 1:]
|
| 580 |
+
# use all tokens
|
| 581 |
+
latents = x
|
| 582 |
+
|
| 583 |
+
return shape_embed, latents, token_num, pre_pc
|
| 584 |
+
|
| 585 |
+
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
|
| 586 |
+
posterior = None
|
| 587 |
+
if self.embed_dim > 0:
|
| 588 |
+
moments = self.pre_kl(latents)
|
| 589 |
+
if self.residual_kl:
|
| 590 |
+
B, N = latents.shape[:2]
|
| 591 |
+
moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2)
|
| 592 |
+
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
| 593 |
+
|
| 594 |
+
if sample_posterior:
|
| 595 |
+
kl_embed = posterior.sample()
|
| 596 |
+
else:
|
| 597 |
+
kl_embed = posterior.mode()
|
| 598 |
+
else:
|
| 599 |
+
kl_embed = latents
|
| 600 |
+
|
| 601 |
+
return kl_embed, posterior
|
| 602 |
+
|
| 603 |
+
def forward(self,
|
| 604 |
+
pc: torch.FloatTensor,
|
| 605 |
+
feats: torch.FloatTensor,
|
| 606 |
+
volume_queries: torch.FloatTensor,
|
| 607 |
+
sample_posterior: bool = True):
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 612 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 613 |
+
volume_queries (torch.FloatTensor): [B, P, 3]
|
| 614 |
+
sample_posterior (bool):
|
| 615 |
+
|
| 616 |
+
Returns:
|
| 617 |
+
shape_embed (torch.FloatTensor): [B, projection_dim]
|
| 618 |
+
logits (torch.FloatTensor): [B, M]
|
| 619 |
+
posterior (DiagonalGaussianDistribution or None).
|
| 620 |
+
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
shape_embed, kl_embed, posterior, token_num, pre_pc = self.encode(pc, feats, sample_posterior=sample_posterior)
|
| 624 |
+
|
| 625 |
+
latents = self.decode(kl_embed)
|
| 626 |
+
logits, grad = self.query_geometry(volume_queries, latents)
|
| 627 |
+
|
| 628 |
+
return shape_embed, logits, posterior, token_num, pre_pc, grad
|
| 629 |
+
|
| 630 |
+
#####################################################
|
| 631 |
+
# a simplified verstion of perceiver encoder
|
| 632 |
+
#####################################################
|
| 633 |
+
|
| 634 |
+
class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule):
|
| 635 |
+
def __init__(self, *,
|
| 636 |
+
device: Optional[torch.device],
|
| 637 |
+
dtype: Optional[Union[torch.dtype, str]],
|
| 638 |
+
num_latents: int,
|
| 639 |
+
point_feats: int = 0,
|
| 640 |
+
embed_dim: int = 0,
|
| 641 |
+
num_freqs: int = 8,
|
| 642 |
+
include_pi: bool = True,
|
| 643 |
+
width: int,
|
| 644 |
+
heads: int,
|
| 645 |
+
num_encoder_layers: int,
|
| 646 |
+
init_scale: float = 0.25,
|
| 647 |
+
qkv_bias: bool = True,
|
| 648 |
+
flash: bool = False,
|
| 649 |
+
use_ln_post: bool = False,
|
| 650 |
+
use_checkpoint: bool = False,
|
| 651 |
+
supervision_type: str = 'occupancy',
|
| 652 |
+
query_method: bool = False,
|
| 653 |
+
token_num: int = 256,
|
| 654 |
+
grad_type: str = "numerical",
|
| 655 |
+
grad_interval: float = 0.005,
|
| 656 |
+
use_full_input: bool = True,
|
| 657 |
+
freeze_encoder: bool = False,
|
| 658 |
+
residual_kl: bool = False,
|
| 659 |
+
):
|
| 660 |
+
|
| 661 |
+
super().__init__()
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
MAP_DTYPE = {
|
| 665 |
+
'float32': torch.float32,
|
| 666 |
+
'float16': torch.float16,
|
| 667 |
+
'bfloat16': torch.bfloat16,
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
if dtype is not None and isinstance(dtype, str):
|
| 671 |
+
dtype = MAP_DTYPE[dtype]
|
| 672 |
+
|
| 673 |
+
self.use_checkpoint = use_checkpoint
|
| 674 |
+
|
| 675 |
+
self.num_latents = num_latents
|
| 676 |
+
assert grad_type in ["numerical", "analytical"]
|
| 677 |
+
self.grad_type = grad_type
|
| 678 |
+
self.grad_interval = grad_interval
|
| 679 |
+
self.supervision_type = supervision_type
|
| 680 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
| 681 |
+
|
| 682 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
| 683 |
+
self.encoder = CrossAttentionEncoder(
|
| 684 |
+
device=device,
|
| 685 |
+
dtype=dtype,
|
| 686 |
+
fourier_embedder=self.fourier_embedder,
|
| 687 |
+
num_latents=num_latents,
|
| 688 |
+
point_feats=point_feats,
|
| 689 |
+
width=width,
|
| 690 |
+
heads=heads,
|
| 691 |
+
layers=num_encoder_layers,
|
| 692 |
+
init_scale=init_scale,
|
| 693 |
+
qkv_bias=qkv_bias,
|
| 694 |
+
flash=flash,
|
| 695 |
+
use_ln_post=use_ln_post,
|
| 696 |
+
use_checkpoint=use_checkpoint,
|
| 697 |
+
query_method=query_method,
|
| 698 |
+
use_full_input=use_full_input,
|
| 699 |
+
token_num=token_num,
|
| 700 |
+
no_query=True,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
self.embed_dim = embed_dim
|
| 704 |
+
self.residual_kl = residual_kl
|
| 705 |
+
if freeze_encoder:
|
| 706 |
+
for p in self.encoder.parameters():
|
| 707 |
+
p.requires_grad = False
|
| 708 |
+
print("freeze encoder")
|
| 709 |
+
self.width = width
|
| 710 |
+
|
| 711 |
+
def encode_latents(self,
|
| 712 |
+
pc: torch.FloatTensor,
|
| 713 |
+
feats: Optional[torch.FloatTensor] = None):
|
| 714 |
+
|
| 715 |
+
x, _, token_num, pre_pc = self.encoder(pc, feats)
|
| 716 |
+
|
| 717 |
+
shape_embed = x[:, 0]
|
| 718 |
+
latents = x
|
| 719 |
+
|
| 720 |
+
return shape_embed, latents, token_num, pre_pc
|
| 721 |
+
|
| 722 |
+
def forward(self):
|
| 723 |
+
raise NotImplementedError()
|
src/model/michelangelo/models/tsal/tsal_base.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Tuple, List, Optional
|
| 5 |
+
import lightning.pytorch as pl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Point2MeshOutput(object):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.mesh_v = None
|
| 11 |
+
self.mesh_f = None
|
| 12 |
+
self.center = None
|
| 13 |
+
self.pc = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Latent2MeshOutput(object):
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.mesh_v = None
|
| 20 |
+
self.mesh_f = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AlignedMeshOutput(object):
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.mesh_v = None
|
| 27 |
+
self.mesh_f = None
|
| 28 |
+
self.surface = None
|
| 29 |
+
self.image = None
|
| 30 |
+
self.text: Optional[str] = None
|
| 31 |
+
self.shape_text_similarity: Optional[float] = None
|
| 32 |
+
self.shape_image_similarity: Optional[float] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ShapeAsLatentPLModule(pl.LightningModule):
|
| 36 |
+
latent_shape: Tuple[int]
|
| 37 |
+
|
| 38 |
+
def encode(self, surface, *args, **kwargs):
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
def decode(self, z_q, *args, **kwargs):
|
| 42 |
+
raise NotImplementedError
|
| 43 |
+
|
| 44 |
+
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
|
| 48 |
+
raise NotImplementedError
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ShapeAsLatentModule(nn.Module):
|
| 52 |
+
latent_shape: Tuple[int, int]
|
| 53 |
+
|
| 54 |
+
def __init__(self, *args, **kwargs):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
def encode(self, *args, **kwargs):
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
def decode(self, *args, **kwargs):
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
def query_geometry(self, *args, **kwargs):
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AlignedShapeAsLatentPLModule(pl.LightningModule):
|
| 68 |
+
latent_shape: Tuple[int]
|
| 69 |
+
|
| 70 |
+
def set_shape_model_only(self):
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
def encode(self, surface, *args, **kwargs):
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
def decode(self, z_q, *args, **kwargs):
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AlignedShapeAsLatentModule(nn.Module):
|
| 87 |
+
shape_model: ShapeAsLatentModule
|
| 88 |
+
latent_shape: Tuple[int, int]
|
| 89 |
+
|
| 90 |
+
def __init__(self, *args, **kwargs):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
def set_shape_model_only(self):
|
| 94 |
+
raise NotImplementedError
|
| 95 |
+
|
| 96 |
+
def encode_image_embed(self, *args, **kwargs):
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
def encode_text_embed(self, *args, **kwargs):
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
def encode_shape_embed(self, *args, **kwargs):
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TexturedShapeAsLatentModule(nn.Module):
|
| 107 |
+
|
| 108 |
+
def __init__(self, *args, **kwargs):
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
def encode(self, *args, **kwargs):
|
| 112 |
+
raise NotImplementedError
|
| 113 |
+
|
| 114 |
+
def decode(self, *args, **kwargs):
|
| 115 |
+
raise NotImplementedError
|
| 116 |
+
|
| 117 |
+
def query_geometry(self, *args, **kwargs):
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
|
| 120 |
+
def query_color(self, *args, **kwargs):
|
| 121 |
+
raise NotImplementedError
|
src/model/michelangelo/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from .misc import get_config_from_file
|
| 4 |
+
from .misc import instantiate_from_config
|
src/model/michelangelo/utils/eval.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
|
| 7 |
+
|
| 8 |
+
mse = torch.mean((x - y) ** 2)
|
| 9 |
+
psnr = 10 * torch.log10(data_range / (mse + eps))
|
| 10 |
+
|
| 11 |
+
return psnr
|
| 12 |
+
|
src/model/michelangelo/utils/misc.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
from omegaconf import OmegaConf, DictConfig, ListConfig
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from typing import Union, Any, Optional
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from torch.optim import lr_scheduler
|
| 11 |
+
import os
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
|
| 21 |
+
return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
|
| 25 |
+
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ExperimentConfig:
|
| 29 |
+
task: str = "vae"
|
| 30 |
+
output_dir: str = "outputs"
|
| 31 |
+
resume: Optional[str] = None
|
| 32 |
+
|
| 33 |
+
data: dict = field(default_factory=dict)
|
| 34 |
+
model: dict = field(default_factory=dict)
|
| 35 |
+
|
| 36 |
+
trainer: dict = field(default_factory=dict)
|
| 37 |
+
checkpoint: dict = field(default_factory=dict)
|
| 38 |
+
|
| 39 |
+
wandb: dict = field(default_factory=dict)
|
| 40 |
+
|
| 41 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
| 42 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
|
| 43 |
+
return scfg
|
| 44 |
+
|
| 45 |
+
def get_config_from_file(config_file: str, cli_args: list = [], **kwargs) -> Union[DictConfig, ListConfig]:
|
| 46 |
+
config_file = OmegaConf.load(config_file)
|
| 47 |
+
cli_conf = OmegaConf.from_cli(cli_args)
|
| 48 |
+
|
| 49 |
+
if 'base_config' in config_file.keys():
|
| 50 |
+
if config_file['base_config'] == "default_base":
|
| 51 |
+
base_config = OmegaConf.create()
|
| 52 |
+
# base_config = get_default_config()
|
| 53 |
+
elif config_file['base_config'].endswith(".yaml"):
|
| 54 |
+
base_config = get_config_from_file(config_file['base_config'])
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
|
| 57 |
+
|
| 58 |
+
config_file = {key: value for key, value in config_file.items() if key != "base_config"}
|
| 59 |
+
|
| 60 |
+
cfg = OmegaConf.merge(base_config, config_file, cli_conf, kwargs)
|
| 61 |
+
else:
|
| 62 |
+
cfg = OmegaConf.merge(config_file, cli_conf, kwargs)
|
| 63 |
+
|
| 64 |
+
scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
|
| 65 |
+
|
| 66 |
+
return scfg
|
| 67 |
+
|
| 68 |
+
def get_obj_from_str(string, reload=False):
|
| 69 |
+
module, cls = string.rsplit(".", 1)
|
| 70 |
+
if reload:
|
| 71 |
+
module_imp = importlib.import_module(module)
|
| 72 |
+
importlib.reload(module_imp)
|
| 73 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_obj_from_config(config):
|
| 77 |
+
if "target" not in config:
|
| 78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 79 |
+
|
| 80 |
+
return get_obj_from_str(config["target"])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def instantiate_from_config(config, **kwargs):
|
| 84 |
+
if "target" not in config:
|
| 85 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 86 |
+
|
| 87 |
+
cls = get_obj_from_str(config["target"])
|
| 88 |
+
|
| 89 |
+
params = config.get("params", dict())
|
| 90 |
+
# params.update(kwargs)
|
| 91 |
+
# instance = cls(**params)
|
| 92 |
+
kwargs.update(params)
|
| 93 |
+
instance = cls(**kwargs)
|
| 94 |
+
|
| 95 |
+
return instance
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def is_dist_avail_and_initialized():
|
| 99 |
+
if not dist.is_available():
|
| 100 |
+
return False
|
| 101 |
+
if not dist.is_initialized():
|
| 102 |
+
return False
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_rank():
|
| 107 |
+
if not is_dist_avail_and_initialized():
|
| 108 |
+
return 0
|
| 109 |
+
return dist.get_rank()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_world_size():
|
| 113 |
+
if not is_dist_avail_and_initialized():
|
| 114 |
+
return 1
|
| 115 |
+
return dist.get_world_size()
|
| 116 |
+
|
| 117 |
+
def get_free_space(path):
|
| 118 |
+
fs_stats = os.statvfs(path)
|
| 119 |
+
free_space = fs_stats.f_bsize * fs_stats.f_bfree
|
| 120 |
+
return free_space
|
| 121 |
+
|
| 122 |
+
def get_device_type():
|
| 123 |
+
# Returns an empty string when no CUDA device is available so that
|
| 124 |
+
# callers like `FLASH3.__init__` (which only check `"H100" in ...`) can
|
| 125 |
+
# be imported safely on CPU-only / ZeroGPU-main processes without
|
| 126 |
+
# raising "No CUDA GPUs are available".
|
| 127 |
+
try:
|
| 128 |
+
if not torch.cuda.is_available():
|
| 129 |
+
return ""
|
| 130 |
+
return torch.cuda.get_device_name(0)
|
| 131 |
+
except (RuntimeError, AssertionError):
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
def get_hostname():
|
| 135 |
+
import socket
|
| 136 |
+
return socket.gethostname()
|
| 137 |
+
|
| 138 |
+
def all_gather_batch(tensors):
|
| 139 |
+
"""
|
| 140 |
+
Performs all_gather operation on the provided tensors.
|
| 141 |
+
"""
|
| 142 |
+
# Queue the gathered tensors
|
| 143 |
+
world_size = get_world_size()
|
| 144 |
+
# There is no need for reduction in the single-proc case
|
| 145 |
+
if world_size == 1:
|
| 146 |
+
return tensors
|
| 147 |
+
tensor_list = []
|
| 148 |
+
output_tensor = []
|
| 149 |
+
for tensor in tensors:
|
| 150 |
+
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
| 151 |
+
dist.all_gather(
|
| 152 |
+
tensor_all,
|
| 153 |
+
tensor,
|
| 154 |
+
async_op=False # performance opt
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
tensor_list.append(tensor_all)
|
| 158 |
+
|
| 159 |
+
for tensor_all in tensor_list:
|
| 160 |
+
output_tensor.append(torch.cat(tensor_all, dim=0))
|
| 161 |
+
return output_tensor
|
| 162 |
+
|
| 163 |
+
def get_scheduler(name):
|
| 164 |
+
if hasattr(lr_scheduler, name):
|
| 165 |
+
return getattr(lr_scheduler, name)
|
| 166 |
+
else:
|
| 167 |
+
raise NotImplementedError
|
| 168 |
+
|
| 169 |
+
def parse_scheduler(config, optimizer):
|
| 170 |
+
interval = config.get("interval", "epoch")
|
| 171 |
+
assert interval in ["epoch", "step"]
|
| 172 |
+
if config.name == "SequentialLR":
|
| 173 |
+
scheduler = {
|
| 174 |
+
"scheduler": lr_scheduler.SequentialLR(
|
| 175 |
+
optimizer,
|
| 176 |
+
[
|
| 177 |
+
parse_scheduler(conf, optimizer)["scheduler"]
|
| 178 |
+
for conf in config.schedulers
|
| 179 |
+
],
|
| 180 |
+
milestones=config.milestones,
|
| 181 |
+
),
|
| 182 |
+
"interval": interval,
|
| 183 |
+
}
|
| 184 |
+
elif config.name == "ChainedScheduler":
|
| 185 |
+
scheduler = {
|
| 186 |
+
"scheduler": lr_scheduler.ChainedScheduler(
|
| 187 |
+
[
|
| 188 |
+
parse_scheduler(conf, optimizer)["scheduler"]
|
| 189 |
+
for conf in config.schedulers
|
| 190 |
+
]
|
| 191 |
+
),
|
| 192 |
+
"interval": interval,
|
| 193 |
+
}
|
| 194 |
+
else:
|
| 195 |
+
scheduler = {
|
| 196 |
+
"scheduler": get_scheduler(config.name)(optimizer, **config.args),
|
| 197 |
+
"interval": interval,
|
| 198 |
+
}
|
| 199 |
+
return scheduler
|
| 200 |
+
|
| 201 |
+
class TimeRecorder:
|
| 202 |
+
_instance = None
|
| 203 |
+
|
| 204 |
+
def __init__(self):
|
| 205 |
+
self.items = {}
|
| 206 |
+
self.accumulations = defaultdict(list)
|
| 207 |
+
self.time_scale = 1000.0 # ms
|
| 208 |
+
self.time_unit = "ms"
|
| 209 |
+
self.enabled = False
|
| 210 |
+
|
| 211 |
+
def __new__(cls):
|
| 212 |
+
# singleton
|
| 213 |
+
if cls._instance is None:
|
| 214 |
+
cls._instance = super(TimeRecorder, cls).__new__(cls)
|
| 215 |
+
return cls._instance
|
| 216 |
+
|
| 217 |
+
def enable(self, enabled: bool) -> None:
|
| 218 |
+
self.enabled = enabled
|
| 219 |
+
|
| 220 |
+
def start(self, name: str) -> None:
|
| 221 |
+
if not self.enabled:
|
| 222 |
+
return
|
| 223 |
+
torch.cuda.synchronize()
|
| 224 |
+
self.items[name] = time.time()
|
| 225 |
+
|
| 226 |
+
def end(self, name: str, accumulate: bool = False) -> float:
|
| 227 |
+
if not self.enabled or name not in self.items:
|
| 228 |
+
return
|
| 229 |
+
torch.cuda.synchronize()
|
| 230 |
+
start_time = self.items.pop(name)
|
| 231 |
+
delta = time.time() - start_time
|
| 232 |
+
if accumulate:
|
| 233 |
+
self.accumulations[name].append(delta)
|
| 234 |
+
t = delta * self.time_scale
|
| 235 |
+
logger.info(f"{name}: {t:.2f}{self.time_unit}")
|
| 236 |
+
|
| 237 |
+
def get_accumulation(self, name: str, average: bool = False) -> float:
|
| 238 |
+
if not self.enabled or name not in self.accumulations:
|
| 239 |
+
return
|
| 240 |
+
acc = self.accumulations.pop(name)
|
| 241 |
+
total = sum(acc)
|
| 242 |
+
if average:
|
| 243 |
+
t = total / len(acc) * self.time_scale
|
| 244 |
+
else:
|
| 245 |
+
t = total * self.time_scale
|
| 246 |
+
logger.info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
### global time recorder
|
| 250 |
+
time_recorder = TimeRecorder()
|
| 251 |
+
|
| 252 |
+
class FLASH3:
|
| 253 |
+
def __init__(self) -> None:
|
| 254 |
+
self.available = "H100" in get_device_type()
|
| 255 |
+
self.use = os.environ.get("USE_FLASH3", False)
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def is_use(self):
|
| 259 |
+
return self.available and self.use
|
| 260 |
+
|
| 261 |
+
@contextmanager
|
| 262 |
+
def disable_flash3(self):
|
| 263 |
+
use = self.use
|
| 264 |
+
self.set_use(False)
|
| 265 |
+
yield
|
| 266 |
+
self.set_use(use)
|
| 267 |
+
|
| 268 |
+
def set_use(self, use=True):
|
| 269 |
+
self.use = use
|
| 270 |
+
|
| 271 |
+
use_flash3 = FLASH3()
|
src/model/parse_encoder.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from .michelangelo.get_model import get_encoder as get_encoder_michelangelo
|
| 5 |
+
from .michelangelo.get_model import AlignedShapeLatentPerceiver
|
| 6 |
+
from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder
|
| 7 |
+
from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder
|
| 8 |
+
from .skin_vae.autoencoders.autoencoder_kl_tripo2 import Tripo2Encoder
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class _MAP_MESH_ENCODER:
|
| 12 |
+
michelangelo = AlignedShapeLatentPerceiver
|
| 13 |
+
michelangelo_encoder = ShapeAsLatentPerceiverEncoder
|
| 14 |
+
tripo = Tripo2Encoder
|
| 15 |
+
|
| 16 |
+
MAP_MESH_ENCODER = _MAP_MESH_ENCODER()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_mesh_encoder(**kwargs):
|
| 20 |
+
MAP = {
|
| 21 |
+
'michelangelo': get_encoder_michelangelo,
|
| 22 |
+
'michelangelo_encoder': get_encoder_michelangelo_encoder,
|
| 23 |
+
'tripo': Tripo2Encoder,
|
| 24 |
+
}
|
| 25 |
+
__target__ = kwargs['__target__']
|
| 26 |
+
del kwargs['__target__']
|
| 27 |
+
assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
|
| 28 |
+
return MAP[__target__](**deepcopy(kwargs))
|
src/model/skin_vae/attention_processor.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.attention_processor import Attention
|
| 8 |
+
from diffusers.utils import logging
|
| 9 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
| 10 |
+
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from flash_attn_interface import flash_attn_func
|
| 15 |
+
except Exception as e:
|
| 16 |
+
def flash_attn_func(q, k, v):
|
| 17 |
+
q = q.permute(0, 2, 1, 3) # (B, H, L, D)
|
| 18 |
+
k = k.permute(0, 2, 1, 3)
|
| 19 |
+
v = v.permute(0, 2, 1, 3)
|
| 20 |
+
|
| 21 |
+
if q.shape[1] != k.shape[1]:
|
| 22 |
+
repeat_factor = q.shape[1] // k.shape[1]
|
| 23 |
+
k = k.repeat_interleave(repeat_factor, dim=1)
|
| 24 |
+
v = v.repeat_interleave(repeat_factor, dim=1)
|
| 25 |
+
|
| 26 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 27 |
+
return out.permute(0, 2, 1, 3), None # (B, L, H, D)
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Tripo2AttnProcessor2_0:
|
| 33 |
+
r"""
|
| 34 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 35 |
+
used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 40 |
+
raise ImportError(
|
| 41 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def __call__(
|
| 45 |
+
self,
|
| 46 |
+
attn: Attention,
|
| 47 |
+
hidden_states: torch.Tensor,
|
| 48 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 49 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 50 |
+
temb: Optional[torch.Tensor] = None,
|
| 51 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 54 |
+
|
| 55 |
+
residual = hidden_states
|
| 56 |
+
if attn.spatial_norm is not None:
|
| 57 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 58 |
+
|
| 59 |
+
input_ndim = hidden_states.ndim
|
| 60 |
+
|
| 61 |
+
if input_ndim == 4:
|
| 62 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 63 |
+
hidden_states = hidden_states.view(
|
| 64 |
+
batch_size, channel, height * width
|
| 65 |
+
).transpose(1, 2)
|
| 66 |
+
|
| 67 |
+
batch_size, sequence_length, _ = (
|
| 68 |
+
hidden_states.shape
|
| 69 |
+
if encoder_hidden_states is None
|
| 70 |
+
else encoder_hidden_states.shape
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if attention_mask is not None:
|
| 74 |
+
attention_mask = attn.prepare_attention_mask(
|
| 75 |
+
attention_mask, sequence_length, batch_size
|
| 76 |
+
)
|
| 77 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 78 |
+
# (batch, heads, source_length, target_length)
|
| 79 |
+
attention_mask = attention_mask.view(
|
| 80 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if attn.group_norm is not None:
|
| 84 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 85 |
+
1, 2
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
query = attn.to_q(hidden_states)
|
| 89 |
+
|
| 90 |
+
if encoder_hidden_states is None:
|
| 91 |
+
encoder_hidden_states = hidden_states
|
| 92 |
+
elif attn.norm_cross:
|
| 93 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 94 |
+
encoder_hidden_states
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
key = attn.to_k(encoder_hidden_states)
|
| 98 |
+
value = attn.to_v(encoder_hidden_states)
|
| 99 |
+
|
| 100 |
+
# NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
| 101 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
| 102 |
+
if not attn.is_cross_attention:
|
| 103 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
| 104 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
| 105 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
| 106 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
| 107 |
+
else:
|
| 108 |
+
kv = torch.cat((key, value), dim=-1)
|
| 109 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
| 110 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
| 111 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
| 112 |
+
|
| 113 |
+
head_dim = key.shape[-1]
|
| 114 |
+
|
| 115 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 116 |
+
|
| 117 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 118 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 119 |
+
|
| 120 |
+
if attn.norm_q is not None:
|
| 121 |
+
query = attn.norm_q(query)
|
| 122 |
+
if attn.norm_k is not None:
|
| 123 |
+
key = attn.norm_k(key)
|
| 124 |
+
|
| 125 |
+
# Apply RoPE if needed
|
| 126 |
+
if image_rotary_emb is not None:
|
| 127 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 128 |
+
if not attn.is_cross_attention:
|
| 129 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 130 |
+
|
| 131 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 132 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 133 |
+
|
| 134 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 135 |
+
hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2))
|
| 136 |
+
if type(hidden_states) == tuple:
|
| 137 |
+
hidden_states = hidden_states[0]
|
| 138 |
+
# hidden_states = F.scaled_dot_product_attention(
|
| 139 |
+
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 140 |
+
# )
|
| 141 |
+
#hidden_states =
|
| 142 |
+
hidden_states = hidden_states.reshape(
|
| 143 |
+
batch_size, -1, attn.heads * head_dim
|
| 144 |
+
)
|
| 145 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 146 |
+
|
| 147 |
+
# linear proj
|
| 148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 149 |
+
# dropout
|
| 150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 151 |
+
|
| 152 |
+
if input_ndim == 4:
|
| 153 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 154 |
+
batch_size, channel, height, width
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if attn.residual_connection:
|
| 158 |
+
hidden_states = hidden_states + residual
|
| 159 |
+
|
| 160 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 161 |
+
|
| 162 |
+
return hidden_states
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class FusedTripo2AttnProcessor2_0:
|
| 166 |
+
r"""
|
| 167 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
| 168 |
+
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
|
| 169 |
+
query and key vector.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self):
|
| 173 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 174 |
+
raise ImportError(
|
| 175 |
+
"FusedTripo2AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def __call__(
|
| 179 |
+
self,
|
| 180 |
+
attn: Attention,
|
| 181 |
+
hidden_states: torch.Tensor,
|
| 182 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 183 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 184 |
+
temb: Optional[torch.Tensor] = None,
|
| 185 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 188 |
+
|
| 189 |
+
residual = hidden_states
|
| 190 |
+
if attn.spatial_norm is not None:
|
| 191 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 192 |
+
|
| 193 |
+
input_ndim = hidden_states.ndim
|
| 194 |
+
|
| 195 |
+
if input_ndim == 4:
|
| 196 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 197 |
+
hidden_states = hidden_states.view(
|
| 198 |
+
batch_size, channel, height * width
|
| 199 |
+
).transpose(1, 2)
|
| 200 |
+
|
| 201 |
+
batch_size, sequence_length, _ = (
|
| 202 |
+
hidden_states.shape
|
| 203 |
+
if encoder_hidden_states is None
|
| 204 |
+
else encoder_hidden_states.shape
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if attention_mask is not None:
|
| 208 |
+
attention_mask = attn.prepare_attention_mask(
|
| 209 |
+
attention_mask, sequence_length, batch_size
|
| 210 |
+
)
|
| 211 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 212 |
+
# (batch, heads, source_length, target_length)
|
| 213 |
+
attention_mask = attention_mask.view(
|
| 214 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if attn.group_norm is not None:
|
| 218 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 219 |
+
1, 2
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# NOTE that tripo2 split heads first, then split qkv
|
| 223 |
+
if encoder_hidden_states is None:
|
| 224 |
+
qkv = attn.to_qkv(hidden_states)
|
| 225 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
| 226 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
| 227 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
| 228 |
+
else:
|
| 229 |
+
if attn.norm_cross:
|
| 230 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 231 |
+
encoder_hidden_states
|
| 232 |
+
)
|
| 233 |
+
query = attn.to_q(hidden_states)
|
| 234 |
+
|
| 235 |
+
kv = attn.to_kv(encoder_hidden_states)
|
| 236 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
| 237 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
| 238 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
| 239 |
+
|
| 240 |
+
head_dim = key.shape[-1]
|
| 241 |
+
|
| 242 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 243 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 244 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 245 |
+
|
| 246 |
+
if attn.norm_q is not None:
|
| 247 |
+
query = attn.norm_q(query)
|
| 248 |
+
if attn.norm_k is not None:
|
| 249 |
+
key = attn.norm_k(key)
|
| 250 |
+
|
| 251 |
+
# Apply RoPE if needed
|
| 252 |
+
if image_rotary_emb is not None:
|
| 253 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 254 |
+
if not attn.is_cross_attention:
|
| 255 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 256 |
+
|
| 257 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 258 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 259 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 260 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 264 |
+
batch_size, -1, attn.heads * head_dim
|
| 265 |
+
)
|
| 266 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 267 |
+
|
| 268 |
+
# linear proj
|
| 269 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 270 |
+
# dropout
|
| 271 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 272 |
+
|
| 273 |
+
if input_ndim == 4:
|
| 274 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 275 |
+
batch_size, channel, height, width
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if attn.residual_connection:
|
| 279 |
+
hidden_states = hidden_states + residual
|
| 280 |
+
|
| 281 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 282 |
+
|
| 283 |
+
return hidden_states
|
src/model/skin_vae/autoencoders/FSQ.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from functools import wraps, partial
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
from torch import Tensor, int32
|
| 10 |
+
from torch.amp import autocast
|
| 11 |
+
|
| 12 |
+
from einops import rearrange, pack, unpack
|
| 13 |
+
|
| 14 |
+
# helper functions
|
| 15 |
+
|
| 16 |
+
def exists(v):
|
| 17 |
+
return v is not None
|
| 18 |
+
|
| 19 |
+
def default(*args):
|
| 20 |
+
for arg in args:
|
| 21 |
+
if exists(arg):
|
| 22 |
+
return arg
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
def maybe(fn):
|
| 26 |
+
@wraps(fn)
|
| 27 |
+
def inner(x, *args, **kwargs):
|
| 28 |
+
if not exists(x):
|
| 29 |
+
return x
|
| 30 |
+
return fn(x, *args, **kwargs)
|
| 31 |
+
return inner
|
| 32 |
+
|
| 33 |
+
def pack_one(t, pattern):
|
| 34 |
+
return pack([t], pattern)
|
| 35 |
+
|
| 36 |
+
def unpack_one(t, ps, pattern):
|
| 37 |
+
return unpack(t, ps, pattern)[0]
|
| 38 |
+
|
| 39 |
+
# tensor helpers
|
| 40 |
+
|
| 41 |
+
def round_ste(z: Tensor) -> Tensor:
|
| 42 |
+
"""Round with straight through gradients."""
|
| 43 |
+
zhat = z.round()
|
| 44 |
+
return z + (zhat - z).detach()
|
| 45 |
+
|
| 46 |
+
# main class
|
| 47 |
+
|
| 48 |
+
class FSQ(Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
levels: List[int],
|
| 52 |
+
dim: int | None = None,
|
| 53 |
+
num_codebooks = 1,
|
| 54 |
+
keep_num_codebooks_dim: bool | None = None,
|
| 55 |
+
scale: float | None = None,
|
| 56 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
| 57 |
+
channel_first: bool = False,
|
| 58 |
+
projection_has_bias: bool = True,
|
| 59 |
+
return_indices = True,
|
| 60 |
+
force_quantization_f32 = True
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
_levels = torch.tensor(levels, dtype=int32)
|
| 64 |
+
self.register_buffer("_levels", _levels, persistent = False)
|
| 65 |
+
|
| 66 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
| 67 |
+
self.register_buffer("_basis", _basis, persistent = False)
|
| 68 |
+
|
| 69 |
+
self.scale = scale
|
| 70 |
+
|
| 71 |
+
codebook_dim = len(levels)
|
| 72 |
+
self.codebook_dim = codebook_dim
|
| 73 |
+
|
| 74 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
| 75 |
+
self.num_codebooks = num_codebooks
|
| 76 |
+
self.effective_codebook_dim = effective_codebook_dim
|
| 77 |
+
|
| 78 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
| 79 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
| 80 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
| 81 |
+
|
| 82 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
| 83 |
+
|
| 84 |
+
self.channel_first = channel_first
|
| 85 |
+
|
| 86 |
+
has_projections = self.dim != effective_codebook_dim
|
| 87 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
| 88 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
| 89 |
+
|
| 90 |
+
self.has_projections = has_projections
|
| 91 |
+
|
| 92 |
+
self.return_indices = return_indices
|
| 93 |
+
if return_indices:
|
| 94 |
+
self.codebook_size: int = self._levels.prod().item()
|
| 95 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
| 96 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
| 97 |
+
|
| 98 |
+
self.allowed_dtypes = allowed_dtypes
|
| 99 |
+
self.force_quantization_f32 = force_quantization_f32
|
| 100 |
+
|
| 101 |
+
def bound(self, z, eps: float = 1e-3):
|
| 102 |
+
""" Bound `z`, an array of shape (..., d). """
|
| 103 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
| 104 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
| 105 |
+
shift = (offset / half_l).atanh()
|
| 106 |
+
return (z + shift).tanh() * half_l - offset
|
| 107 |
+
|
| 108 |
+
def quantize(self, z):
|
| 109 |
+
""" Quantizes z, returns quantized zhat, same shape as z. """
|
| 110 |
+
quantized = round_ste(self.bound(z))
|
| 111 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
| 112 |
+
return quantized / half_width
|
| 113 |
+
|
| 114 |
+
def _scale_and_shift(self, zhat_normalized):
|
| 115 |
+
half_width = self._levels // 2
|
| 116 |
+
return (zhat_normalized * half_width) + half_width
|
| 117 |
+
|
| 118 |
+
def _scale_and_shift_inverse(self, zhat):
|
| 119 |
+
half_width = self._levels // 2
|
| 120 |
+
return (zhat - half_width) / half_width
|
| 121 |
+
|
| 122 |
+
def _indices_to_codes(self, indices):
|
| 123 |
+
level_indices = self.indices_to_level_indices(indices)
|
| 124 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
| 125 |
+
return codes
|
| 126 |
+
|
| 127 |
+
def codes_to_indices(self, zhat):
|
| 128 |
+
""" Converts a `code` to an index in the codebook. """
|
| 129 |
+
assert zhat.shape[-1] == self.codebook_dim
|
| 130 |
+
zhat = self._scale_and_shift(zhat)
|
| 131 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
| 132 |
+
|
| 133 |
+
def indices_to_level_indices(self, indices):
|
| 134 |
+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
|
| 135 |
+
indices = rearrange(indices, '... -> ... 1')
|
| 136 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
| 137 |
+
return codes_non_centered
|
| 138 |
+
|
| 139 |
+
def indices_to_codes(self, indices):
|
| 140 |
+
""" Inverse of `codes_to_indices`. """
|
| 141 |
+
assert exists(indices)
|
| 142 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
| 143 |
+
codes = self._indices_to_codes(indices)
|
| 144 |
+
if self.keep_num_codebooks_dim:
|
| 145 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
| 146 |
+
codes = self.project_out(codes)
|
| 147 |
+
if is_img_or_video or self.channel_first:
|
| 148 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
| 149 |
+
return codes
|
| 150 |
+
|
| 151 |
+
def dequantize(self, indices):
|
| 152 |
+
codes = self._indices_to_codes(indices)
|
| 153 |
+
out = self.project_out(codes)
|
| 154 |
+
return out
|
| 155 |
+
|
| 156 |
+
def forward(self, z):
|
| 157 |
+
"""
|
| 158 |
+
einstein notation
|
| 159 |
+
b - batch
|
| 160 |
+
n - sequence (or flattened spatial dimensions)
|
| 161 |
+
d - feature dimension
|
| 162 |
+
c - number of codebook dim
|
| 163 |
+
"""
|
| 164 |
+
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
| 165 |
+
|
| 166 |
+
z = self.project_in(z)
|
| 167 |
+
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
| 168 |
+
|
| 169 |
+
# whether to force quantization step to be full precision or not
|
| 170 |
+
force_f32 = self.force_quantization_f32
|
| 171 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
| 172 |
+
|
| 173 |
+
with quantization_context():
|
| 174 |
+
orig_dtype = z.dtype
|
| 175 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
| 176 |
+
z = z.float()
|
| 177 |
+
codes = self.quantize(z)
|
| 178 |
+
# returning indices could be optional
|
| 179 |
+
indices = None
|
| 180 |
+
if self.return_indices:
|
| 181 |
+
indices = self.codes_to_indices(codes)
|
| 182 |
+
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
| 183 |
+
codes = codes.type(orig_dtype)
|
| 184 |
+
|
| 185 |
+
# project out
|
| 186 |
+
out = self.project_out(codes)
|
| 187 |
+
|
| 188 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
| 189 |
+
indices = maybe(rearrange)(indices, '... 1 -> ...')
|
| 190 |
+
# return quantized output and indices
|
| 191 |
+
return out, indices, None
|
src/model/skin_vae/autoencoders/SimVQ.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
class SimVQ(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 10 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 11 |
+
"""
|
| 12 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 13 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 14 |
+
# specify legacy=False to fix it.
|
| 15 |
+
def __init__(self, n_e, e_dim, beta=0.25, remap=None, unknown_index="random",
|
| 16 |
+
same_index_shape=False, legacy=True):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.n_e = n_e
|
| 19 |
+
self.e_dim = e_dim
|
| 20 |
+
self.beta = beta
|
| 21 |
+
self.legacy = legacy
|
| 22 |
+
|
| 23 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 24 |
+
self.codebook_size = self.n_e
|
| 25 |
+
nn.init.normal_(self.embedding.weight, mean=0, std=self.e_dim**-0.5)
|
| 26 |
+
for p in self.embedding.parameters():
|
| 27 |
+
p.requires_grad = False
|
| 28 |
+
|
| 29 |
+
self.embedding_proj = nn.Linear(self.e_dim, self.e_dim)
|
| 30 |
+
|
| 31 |
+
self.remap = remap
|
| 32 |
+
if self.remap is not None:
|
| 33 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 34 |
+
self.re_embed = self.used.shape[0]
|
| 35 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 36 |
+
if self.unknown_index == "extra":
|
| 37 |
+
self.unknown_index = self.re_embed
|
| 38 |
+
self.re_embed = self.re_embed+1
|
| 39 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
| 40 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 41 |
+
else:
|
| 42 |
+
self.re_embed = n_e
|
| 43 |
+
|
| 44 |
+
self.same_index_shape = same_index_shape
|
| 45 |
+
|
| 46 |
+
def remap_to_used(self, inds):
|
| 47 |
+
ishape = inds.shape
|
| 48 |
+
assert len(ishape)>1
|
| 49 |
+
inds = inds.reshape(ishape[0],-1)
|
| 50 |
+
used = self.used.to(inds)
|
| 51 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
| 52 |
+
new = match.argmax(-1)
|
| 53 |
+
unknown = match.sum(2)<1
|
| 54 |
+
if self.unknown_index == "random":
|
| 55 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
| 56 |
+
else:
|
| 57 |
+
new[unknown] = self.unknown_index
|
| 58 |
+
return new.reshape(ishape)
|
| 59 |
+
|
| 60 |
+
def unmap_to_all(self, inds):
|
| 61 |
+
ishape = inds.shape
|
| 62 |
+
assert len(ishape)>1
|
| 63 |
+
inds = inds.reshape(ishape[0],-1)
|
| 64 |
+
used = self.used.to(inds)
|
| 65 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 66 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
| 67 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
| 68 |
+
return back.reshape(ishape)
|
| 69 |
+
|
| 70 |
+
def forward(self, z):
|
| 71 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 72 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 73 |
+
assert z.shape[-1] == self.e_dim
|
| 74 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 75 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 76 |
+
|
| 77 |
+
quant_codebook = self.embedding_proj(self.embedding.weight)
|
| 78 |
+
|
| 79 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 80 |
+
torch.sum(quant_codebook**2, dim=1) - 2 * \
|
| 81 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n'))
|
| 82 |
+
|
| 83 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 84 |
+
z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape)
|
| 85 |
+
|
| 86 |
+
# compute loss for embedding
|
| 87 |
+
if not self.legacy:
|
| 88 |
+
quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 89 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 90 |
+
else:
|
| 91 |
+
quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 92 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 93 |
+
|
| 94 |
+
# preserve gradients
|
| 95 |
+
z_q = z + (z_q - z).detach()
|
| 96 |
+
|
| 97 |
+
# reshape back to match original input shape
|
| 98 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
| 99 |
+
|
| 100 |
+
if self.remap is not None:
|
| 101 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
| 102 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
| 103 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
| 104 |
+
|
| 105 |
+
if self.same_index_shape:
|
| 106 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 107 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
| 108 |
+
|
| 109 |
+
return z_q, min_encoding_indices, quantization_loss
|
| 110 |
+
|
| 111 |
+
def get_codebook_entry(self, indices, shape):
|
| 112 |
+
# shape specifying (batch, height, width, channel)
|
| 113 |
+
if self.remap is not None:
|
| 114 |
+
indices = indices.reshape(shape[0],-1) # add batch axis
|
| 115 |
+
indices = self.unmap_to_all(indices)
|
| 116 |
+
indices = indices.reshape(-1) # flatten again
|
| 117 |
+
|
| 118 |
+
# get quantized latent vectors
|
| 119 |
+
z_q = self.embedding(indices)
|
| 120 |
+
|
| 121 |
+
if shape is not None:
|
| 122 |
+
z_q = z_q.view(shape)
|
| 123 |
+
# reshape back to match original input shape
|
| 124 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 125 |
+
|
| 126 |
+
return z_q
|
| 127 |
+
|
| 128 |
+
def indices_to_codes(self, indices):
|
| 129 |
+
return self.get_codebook_entry(indices, None)
|
| 130 |
+
|
| 131 |
+
def entropy(prob):
|
| 132 |
+
return (-prob * log(prob)).sum(dim=-1)
|
| 133 |
+
|
| 134 |
+
class SimVQ1D(SimVQ):
|
| 135 |
+
|
| 136 |
+
def __init__(self, n_e, e_dim, dim, beta=0.25, remap=None, unknown_index="random", same_index_shape=True, legacy=True):
|
| 137 |
+
super().__init__(n_e, e_dim, beta, remap, unknown_index, same_index_shape, legacy)
|
| 138 |
+
|
| 139 |
+
self.project_in = nn.Linear(dim, e_dim)
|
| 140 |
+
self.project_out = nn.Linear(e_dim, dim)
|
| 141 |
+
|
| 142 |
+
def forward(self, z):
|
| 143 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 144 |
+
#assert z.shape[-1] == self.e_dim
|
| 145 |
+
z = self.project_in(z)
|
| 146 |
+
|
| 147 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 148 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 149 |
+
|
| 150 |
+
quant_codebook = self.embedding_proj(self.embedding.weight)
|
| 151 |
+
|
| 152 |
+
# # Use IBQ
|
| 153 |
+
# logits = torch.matmul(z_flattened, quant_codebook.t())
|
| 154 |
+
# Ind_soft = torch.softmax(logits, dim=1)
|
| 155 |
+
# indices = torch.argmax(Ind_soft, dim=1)
|
| 156 |
+
# Ind_hard = F.one_hot(indices, num_classes=Ind_soft.shape[1])
|
| 157 |
+
# Ind = Ind_hard - Ind_soft.detach() + Ind_soft
|
| 158 |
+
# z_q = torch.matmul(Ind, quant_codebook).view(z.shape)
|
| 159 |
+
|
| 160 |
+
# if not self.legacy:
|
| 161 |
+
# quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 162 |
+
# torch.mean((z_q - z.detach()) ** 2)
|
| 163 |
+
# else:
|
| 164 |
+
# quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 165 |
+
# torch.mean((z_q - z.detach()) ** 2)
|
| 166 |
+
|
| 167 |
+
# return z_q, indices, quantization_loss
|
| 168 |
+
|
| 169 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 170 |
+
torch.sum(quant_codebook**2, dim=1) - 2 * \
|
| 171 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n'))
|
| 172 |
+
|
| 173 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 174 |
+
z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape)
|
| 175 |
+
|
| 176 |
+
# compute loss for embedding
|
| 177 |
+
if not self.legacy:
|
| 178 |
+
quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 179 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 180 |
+
else:
|
| 181 |
+
quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 182 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 183 |
+
|
| 184 |
+
# preserve gradients
|
| 185 |
+
z_q = z + (z_q - z).detach()
|
| 186 |
+
|
| 187 |
+
if self.remap is not None:
|
| 188 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
| 189 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
| 190 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
| 191 |
+
|
| 192 |
+
if self.same_index_shape:
|
| 193 |
+
min_encoding_indices = min_encoding_indices.view(z.shape[0], z.shape[1])
|
| 194 |
+
z_q = self.project_out(z_q.view(z.shape))
|
| 195 |
+
|
| 196 |
+
return z_q, min_encoding_indices, quantization_loss
|
| 197 |
+
|
src/model/skin_vae/autoencoders/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .skin_fsq_cvae_model import SkinFSQCVAEModel
|
src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from diffusers.models.normalization import LayerNorm
|
| 7 |
+
from diffusers.utils import logging
|
| 8 |
+
from einops import repeat
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
from ..embeddings import FrequencyPositionalEmbedding
|
| 12 |
+
from ..transformers.tripo2_transformer import DiTBlock
|
| 13 |
+
from ...utils import fps
|
| 14 |
+
|
| 15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 16 |
+
|
| 17 |
+
def init_linear(l, stddev):
|
| 18 |
+
nn.init.normal_(l.weight, std=stddev)
|
| 19 |
+
if l.bias is not None:
|
| 20 |
+
nn.init.constant_(l.bias, 0.0)
|
| 21 |
+
|
| 22 |
+
class Tripo2Encoder(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
in_channels: int = 3,
|
| 26 |
+
dim: int = 512,
|
| 27 |
+
num_attention_heads: int = 8,
|
| 28 |
+
num_layers: int = 8,
|
| 29 |
+
is_learned_queries: bool = False,
|
| 30 |
+
sample_tokens: int = 32,
|
| 31 |
+
embed_frequency: int = 8,
|
| 32 |
+
embed_include_pi: bool = False,
|
| 33 |
+
fps: bool = False,
|
| 34 |
+
is_miche: bool = False,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.fps = fps
|
| 39 |
+
if fps and not is_learned_queries:
|
| 40 |
+
self.embedder = FrequencyPositionalEmbedding(
|
| 41 |
+
num_freqs=embed_frequency,
|
| 42 |
+
logspace=True,
|
| 43 |
+
input_dim=3,
|
| 44 |
+
include_pi=embed_include_pi,
|
| 45 |
+
)
|
| 46 |
+
self.proj_k = nn.Linear(3+self.embedder.out_dim, dim, bias=True)
|
| 47 |
+
self.proj_in = nn.Linear(in_channels-3+self.embedder.out_dim, dim, bias=True)
|
| 48 |
+
else:
|
| 49 |
+
self.proj_in = nn.Linear(in_channels, dim, bias=True)
|
| 50 |
+
self.output_channels = dim
|
| 51 |
+
self.is_miche = is_miche
|
| 52 |
+
init_scale = 0.25 * math.sqrt(1.0 / dim)
|
| 53 |
+
init_linear(self.proj_in, init_scale)
|
| 54 |
+
|
| 55 |
+
self.blocks = nn.ModuleList(
|
| 56 |
+
[
|
| 57 |
+
DiTBlock(
|
| 58 |
+
dim=dim,
|
| 59 |
+
num_attention_heads=num_attention_heads,
|
| 60 |
+
use_self_attention=False,
|
| 61 |
+
use_cross_attention=True,
|
| 62 |
+
cross_attention_dim=dim,
|
| 63 |
+
cross_attention_norm_type="layer_norm",
|
| 64 |
+
activation_fn="gelu",
|
| 65 |
+
norm_type="fp32_layer_norm",
|
| 66 |
+
norm_eps=1e-5,
|
| 67 |
+
qk_norm=False,
|
| 68 |
+
qkv_bias=False,
|
| 69 |
+
) # cross attention
|
| 70 |
+
]
|
| 71 |
+
+ [
|
| 72 |
+
DiTBlock(
|
| 73 |
+
dim=dim,
|
| 74 |
+
num_attention_heads=num_attention_heads,
|
| 75 |
+
use_self_attention=True,
|
| 76 |
+
self_attention_norm_type="fp32_layer_norm",
|
| 77 |
+
use_cross_attention=False,
|
| 78 |
+
use_cross_attention_2=False,
|
| 79 |
+
activation_fn="gelu",
|
| 80 |
+
norm_type="fp32_layer_norm",
|
| 81 |
+
norm_eps=1e-5,
|
| 82 |
+
qk_norm=False,
|
| 83 |
+
qkv_bias=False,
|
| 84 |
+
)
|
| 85 |
+
for _ in range(num_layers) # self attention
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
self.norm_out = LayerNorm(dim)
|
| 89 |
+
self.is_learned_queries = is_learned_queries
|
| 90 |
+
if is_learned_queries:
|
| 91 |
+
self.learned_queries = nn.Parameter(torch.randn(sample_tokens, dim) * 0.02)
|
| 92 |
+
|
| 93 |
+
def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor, num_tokens: int=1024):
|
| 94 |
+
if self.is_learned_queries or not self.fps:
|
| 95 |
+
hidden_states = self.proj_in(sample_1) if not self.is_learned_queries else repeat(self.learned_queries[:sample_1.shape[1], :], 'n d -> b n d', b=sample_1.shape[0])
|
| 96 |
+
encoder_hidden_states = self.proj_in(sample_2)
|
| 97 |
+
else:
|
| 98 |
+
x_q, x_kv = self.get_qkv(x=sample_1, num_tokens=num_tokens)
|
| 99 |
+
hidden_states = self.proj_k(x_q)
|
| 100 |
+
encoder_hidden_states = self.proj_in(x_kv)
|
| 101 |
+
|
| 102 |
+
if not self.is_miche:
|
| 103 |
+
for layer, block in enumerate(self.blocks):
|
| 104 |
+
if layer == 0:
|
| 105 |
+
hidden_states = block(
|
| 106 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
hidden_states = block(hidden_states)
|
| 110 |
+
else:
|
| 111 |
+
for layer, block in enumerate(self.blocks):
|
| 112 |
+
if layer == 0:
|
| 113 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
| 114 |
+
else:
|
| 115 |
+
hidden_states = block(hidden_states)
|
| 116 |
+
|
| 117 |
+
hidden_states = self.norm_out(hidden_states)
|
| 118 |
+
|
| 119 |
+
return hidden_states
|
| 120 |
+
|
| 121 |
+
def _sample_features(
|
| 122 |
+
self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Sample points from features of the input point cloud.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
x (torch.Tensor): The input point cloud. shape: (B, N, C)
|
| 129 |
+
num_tokens (int, optional): The number of points to sample. Defaults to 1024.
|
| 130 |
+
seed (Optional[int], optional): The random seed. Defaults to None.
|
| 131 |
+
"""
|
| 132 |
+
rng = np.random.default_rng(seed)
|
| 133 |
+
indices = rng.choice(
|
| 134 |
+
x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
|
| 135 |
+
)
|
| 136 |
+
selected_points = x[:, indices]
|
| 137 |
+
|
| 138 |
+
batch_size, num_points, num_channels = selected_points.shape
|
| 139 |
+
flattened_points = selected_points.view(batch_size * num_points, num_channels)
|
| 140 |
+
batch_indices = (
|
| 141 |
+
torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# fps sampling
|
| 145 |
+
sampling_ratio = 1.0 / 4
|
| 146 |
+
sampled_indices = fps(
|
| 147 |
+
flattened_points[:, :3],
|
| 148 |
+
batch_indices,
|
| 149 |
+
ratio=sampling_ratio,
|
| 150 |
+
random_start=self.training,
|
| 151 |
+
)
|
| 152 |
+
sampled_points = flattened_points[sampled_indices].view(
|
| 153 |
+
batch_size, -1, num_channels
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return sampled_points
|
| 157 |
+
|
| 158 |
+
def get_qkv(self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None):
|
| 159 |
+
positions, features = x[..., :3], x[..., 3:]
|
| 160 |
+
x_kv = torch.cat([self.embedder(positions), features], dim=-1)
|
| 161 |
+
|
| 162 |
+
sampled_x = self._sample_features(x, num_tokens, seed)
|
| 163 |
+
positions, features = (
|
| 164 |
+
sampled_x[..., :3],
|
| 165 |
+
sampled_x[..., 3:],
|
| 166 |
+
)
|
| 167 |
+
x_q = torch.cat([self.embedder(positions), features], dim=-1)
|
| 168 |
+
return x_q, x_kv
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Tripo2Decoder(nn.Module):
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
in_channels: int = 3,
|
| 175 |
+
out_channels: int = 1,
|
| 176 |
+
dim: int = 512,
|
| 177 |
+
num_attention_heads: int = 8,
|
| 178 |
+
num_layers: int = 16,
|
| 179 |
+
grad_type: str = "analytical",
|
| 180 |
+
grad_interval: float = 0.001,
|
| 181 |
+
is_miche: bool = False,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
if grad_type not in ["numerical", "analytical"]:
|
| 186 |
+
raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
|
| 187 |
+
self.grad_type = grad_type
|
| 188 |
+
self.grad_interval = grad_interval
|
| 189 |
+
self.is_miche = is_miche
|
| 190 |
+
|
| 191 |
+
self.blocks = nn.ModuleList(
|
| 192 |
+
[
|
| 193 |
+
DiTBlock(
|
| 194 |
+
dim=dim,
|
| 195 |
+
num_attention_heads=num_attention_heads,
|
| 196 |
+
use_self_attention=True,
|
| 197 |
+
self_attention_norm_type="fp32_layer_norm",
|
| 198 |
+
use_cross_attention=False,
|
| 199 |
+
use_cross_attention_2=False,
|
| 200 |
+
activation_fn="gelu",
|
| 201 |
+
norm_type="fp32_layer_norm",
|
| 202 |
+
norm_eps=1e-5,
|
| 203 |
+
qk_norm=False,
|
| 204 |
+
qkv_bias=False,
|
| 205 |
+
)
|
| 206 |
+
for _ in range(num_layers) # self attention
|
| 207 |
+
]
|
| 208 |
+
+ [
|
| 209 |
+
DiTBlock(
|
| 210 |
+
dim=dim,
|
| 211 |
+
num_attention_heads=num_attention_heads,
|
| 212 |
+
use_self_attention=False,
|
| 213 |
+
use_cross_attention=True,
|
| 214 |
+
cross_attention_dim=dim,
|
| 215 |
+
cross_attention_norm_type="layer_norm",
|
| 216 |
+
activation_fn="gelu",
|
| 217 |
+
norm_type="fp32_layer_norm",
|
| 218 |
+
norm_eps=1e-5,
|
| 219 |
+
qk_norm=False,
|
| 220 |
+
qkv_bias=False,
|
| 221 |
+
) # cross attention
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
self.proj_query = nn.Linear(in_channels, dim, bias=True)
|
| 225 |
+
|
| 226 |
+
self.norm_out = LayerNorm(dim)
|
| 227 |
+
self.proj_out = nn.Linear(dim, out_channels, bias=True)
|
| 228 |
+
self.sigmoid = nn.Sigmoid()
|
| 229 |
+
init_scale = 0.25 * math.sqrt(1.0 / dim)
|
| 230 |
+
init_linear(self.proj_query, init_scale)
|
| 231 |
+
init_linear(self.proj_out, init_scale)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self,
|
| 235 |
+
sample: torch.Tensor,
|
| 236 |
+
queries: torch.Tensor,
|
| 237 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 238 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 239 |
+
if kv_cache is None:
|
| 240 |
+
hidden_states = sample
|
| 241 |
+
for _, block in enumerate(self.blocks[:-1]):
|
| 242 |
+
hidden_states = block(hidden_states)
|
| 243 |
+
kv_cache = hidden_states
|
| 244 |
+
# query grid logits by cross attention
|
| 245 |
+
q = self.proj_query(queries)
|
| 246 |
+
if self.is_miche:
|
| 247 |
+
l = self.blocks[-1](q, kv_cache)
|
| 248 |
+
else:
|
| 249 |
+
l = self.blocks[-1](q, encoder_hidden_states=kv_cache)
|
| 250 |
+
logits = self.proj_out(self.norm_out(l))
|
| 251 |
+
|
| 252 |
+
logits = self.sigmoid(logits)
|
| 253 |
+
assert kv_cache is not None
|
| 254 |
+
return logits, kv_cache
|
src/model/skin_vae/autoencoders/get_model.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .skin_cvae_model import SkinCVAEModel
|
| 2 |
+
from .skin_fsq_cvae_model import SkinFSQCVAEModel
|
| 3 |
+
|
| 4 |
+
def get_model_cvae(
|
| 5 |
+
pretrained_path: str=None,
|
| 6 |
+
**kwargs
|
| 7 |
+
) -> SkinCVAEModel:
|
| 8 |
+
model = SkinCVAEModel(**kwargs)
|
| 9 |
+
if pretrained_path is not None:
|
| 10 |
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
| 11 |
+
model.load_state_dict(state_dict)
|
| 12 |
+
return model
|
| 13 |
+
|
| 14 |
+
def get_model_fsq_cvae(
|
| 15 |
+
pretrained_path: str=None,
|
| 16 |
+
**kwargs
|
| 17 |
+
) -> SkinFSQCVAEModel:
|
| 18 |
+
model = SkinFSQCVAEModel(**kwargs)
|
| 19 |
+
if pretrained_path is not None:
|
| 20 |
+
state_dict = torch.load(pretrained_path, weights_only=True)
|
| 21 |
+
model.load_state_dict(state_dict)
|
| 22 |
+
return model
|
src/model/skin_vae/autoencoders/miche_transformer_blocks.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# -*- coding: utf-8 -*-
|
| 11 |
+
"""
|
| 12 |
+
Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from typing import Callable, Iterable, Sequence, Union
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def checkpoint(
|
| 20 |
+
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
|
| 21 |
+
inputs: Sequence[torch.Tensor],
|
| 22 |
+
params: Iterable[torch.Tensor],
|
| 23 |
+
flag: bool,
|
| 24 |
+
use_deepspeed: bool = False
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 28 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 29 |
+
:param func: the function to evaluate.
|
| 30 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 31 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 32 |
+
explicitly take as arguments.
|
| 33 |
+
:param flag: if False, disable gradient checkpointing.
|
| 34 |
+
:param use_deepspeed: if True, use deepspeed
|
| 35 |
+
"""
|
| 36 |
+
if flag:
|
| 37 |
+
if use_deepspeed:
|
| 38 |
+
import deepspeed
|
| 39 |
+
return deepspeed.checkpointing.checkpoint(func, *inputs)
|
| 40 |
+
|
| 41 |
+
args = tuple(inputs) + tuple(params)
|
| 42 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 43 |
+
else:
|
| 44 |
+
return func(*inputs)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 48 |
+
@staticmethod
|
| 49 |
+
@torch.amp.custom_fwd(device_type='cuda')
|
| 50 |
+
def forward(ctx, run_function, length, *args):
|
| 51 |
+
ctx.run_function = run_function
|
| 52 |
+
ctx.input_tensors = list(args[:length])
|
| 53 |
+
ctx.input_params = list(args[length:])
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 57 |
+
return output_tensors
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
@torch.amp.custom_bwd(device_type='cuda')
|
| 61 |
+
def backward(ctx, *output_grads):
|
| 62 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 63 |
+
with torch.enable_grad():
|
| 64 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 65 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 66 |
+
# Tensors.
|
| 67 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 68 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 69 |
+
input_grads = torch.autograd.grad(
|
| 70 |
+
output_tensors,
|
| 71 |
+
ctx.input_tensors + ctx.input_params,
|
| 72 |
+
output_grads,
|
| 73 |
+
allow_unused=True,
|
| 74 |
+
)
|
| 75 |
+
del ctx.input_tensors
|
| 76 |
+
del ctx.input_params
|
| 77 |
+
del output_tensors
|
| 78 |
+
return (None, None) + input_grads
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
from flash_attn_interface import flash_attn_func
|
| 82 |
+
print("use flash attention 3.")
|
| 83 |
+
_use_flash3 = True
|
| 84 |
+
except:
|
| 85 |
+
print("use flash attention 2.")
|
| 86 |
+
_use_flash3 = False
|
| 87 |
+
|
| 88 |
+
def init_linear(l, stddev):
|
| 89 |
+
nn.init.normal_(l.weight, std=stddev)
|
| 90 |
+
if l.bias is not None:
|
| 91 |
+
nn.init.constant_(l.bias, 0.0)
|
| 92 |
+
|
| 93 |
+
def flash_attention(q, k, v):
|
| 94 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
| 95 |
+
if _use_flash3:
|
| 96 |
+
out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous())
|
| 97 |
+
# out = flash_attn_func(q, k, v)
|
| 98 |
+
|
| 99 |
+
# q_ = q.transpose(1, 2)
|
| 100 |
+
# k_ = k.transpose(1, 2)
|
| 101 |
+
# v_ = v.transpose(1, 2)
|
| 102 |
+
|
| 103 |
+
# # print(q.shape, k.shape, v.shape)
|
| 104 |
+
# out_ = F.scaled_dot_product_attention(q_, k_, v_)
|
| 105 |
+
# out_ = out_.transpose(1, 2)
|
| 106 |
+
|
| 107 |
+
# # print(torch.abs(out - out_).mean())
|
| 108 |
+
# assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large"
|
| 109 |
+
|
| 110 |
+
# out = out_
|
| 111 |
+
|
| 112 |
+
# print("use flash_atten 3")
|
| 113 |
+
else:
|
| 114 |
+
q = q.transpose(1, 2)
|
| 115 |
+
k = k.transpose(1, 2)
|
| 116 |
+
v = v.transpose(1, 2)
|
| 117 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 118 |
+
out = out.transpose(1, 2)
|
| 119 |
+
# print("use flash atten 2")
|
| 120 |
+
|
| 121 |
+
return out
|
| 122 |
+
|
| 123 |
+
class MultiheadAttention(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
*,
|
| 127 |
+
device: torch.device,
|
| 128 |
+
dtype: torch.dtype,
|
| 129 |
+
n_ctx: int,
|
| 130 |
+
width: int,
|
| 131 |
+
heads: int,
|
| 132 |
+
init_scale: float,
|
| 133 |
+
qkv_bias: bool,
|
| 134 |
+
flash: bool = False
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.n_ctx = n_ctx
|
| 138 |
+
self.width = width
|
| 139 |
+
self.heads = heads
|
| 140 |
+
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
|
| 141 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
| 142 |
+
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
|
| 143 |
+
init_linear(self.c_qkv, init_scale)
|
| 144 |
+
init_linear(self.c_proj, init_scale)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
x = self.c_qkv(x)
|
| 148 |
+
x = checkpoint(self.attention, (x,), (), False)
|
| 149 |
+
x = self.c_proj(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class QKVMultiheadAttention(nn.Module):
|
| 154 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.device = device
|
| 157 |
+
self.dtype = dtype
|
| 158 |
+
self.heads = heads
|
| 159 |
+
self.n_ctx = n_ctx
|
| 160 |
+
self.flash = flash
|
| 161 |
+
|
| 162 |
+
def forward(self, qkv):
|
| 163 |
+
bs, n_ctx, width = qkv.shape
|
| 164 |
+
attn_ch = width // self.heads // 3
|
| 165 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
| 166 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
| 167 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
| 168 |
+
|
| 169 |
+
if self.flash:
|
| 170 |
+
out = flash_attention(q, k, v)
|
| 171 |
+
out = out.reshape(out.shape[0], out.shape[1], -1)
|
| 172 |
+
else:
|
| 173 |
+
weight = torch.einsum(
|
| 174 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
| 175 |
+
) # More stable with f16 than dividing afterwards
|
| 176 |
+
wdtype = weight.dtype
|
| 177 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
| 178 |
+
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
| 179 |
+
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ResidualAttentionBlock(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
*,
|
| 187 |
+
device: torch.device,
|
| 188 |
+
dtype: torch.dtype,
|
| 189 |
+
n_ctx: int,
|
| 190 |
+
width: int,
|
| 191 |
+
heads: int,
|
| 192 |
+
init_scale: float = 1.0,
|
| 193 |
+
qkv_bias: bool = True,
|
| 194 |
+
flash: bool = False,
|
| 195 |
+
use_checkpoint: bool = False
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
|
| 199 |
+
self.use_checkpoint = use_checkpoint
|
| 200 |
+
|
| 201 |
+
self.attn = MultiheadAttention(
|
| 202 |
+
device=device,
|
| 203 |
+
dtype=dtype,
|
| 204 |
+
n_ctx=n_ctx,
|
| 205 |
+
width=width,
|
| 206 |
+
heads=heads,
|
| 207 |
+
init_scale=init_scale,
|
| 208 |
+
qkv_bias=qkv_bias,
|
| 209 |
+
flash=flash
|
| 210 |
+
)
|
| 211 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 212 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
| 213 |
+
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 214 |
+
|
| 215 |
+
def _forward(self, x: torch.Tensor):
|
| 216 |
+
x = x + self.attn(self.ln_1(x))
|
| 217 |
+
x = x + self.mlp(self.ln_2(x))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
def forward(self, x: torch.Tensor):
|
| 221 |
+
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MultiheadCrossAttention(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
*,
|
| 228 |
+
device: torch.device,
|
| 229 |
+
dtype: torch.dtype,
|
| 230 |
+
width: int,
|
| 231 |
+
heads: int,
|
| 232 |
+
init_scale: float,
|
| 233 |
+
qkv_bias: bool = True,
|
| 234 |
+
flash: bool = False,
|
| 235 |
+
n_data: Optional[int] = None,
|
| 236 |
+
data_width: Optional[int] = None,
|
| 237 |
+
):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.n_data = n_data
|
| 240 |
+
self.width = width
|
| 241 |
+
self.heads = heads
|
| 242 |
+
self.data_width = width if data_width is None else data_width
|
| 243 |
+
self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
|
| 244 |
+
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
|
| 245 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
| 246 |
+
self.attention = QKVMultiheadCrossAttention(
|
| 247 |
+
device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
|
| 248 |
+
)
|
| 249 |
+
init_linear(self.c_q, init_scale)
|
| 250 |
+
init_linear(self.c_kv, init_scale)
|
| 251 |
+
init_linear(self.c_proj, init_scale)
|
| 252 |
+
|
| 253 |
+
def forward(self, x, data):
|
| 254 |
+
x = self.c_q(x)
|
| 255 |
+
data = self.c_kv(data)
|
| 256 |
+
x = checkpoint(self.attention, (x, data), (), False)
|
| 257 |
+
x = self.c_proj(x)
|
| 258 |
+
return x
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
| 262 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
|
| 263 |
+
flash: bool = False, n_data: Optional[int] = None):
|
| 264 |
+
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.device = device
|
| 267 |
+
self.dtype = dtype
|
| 268 |
+
self.heads = heads
|
| 269 |
+
self.n_data = n_data
|
| 270 |
+
self.flash = flash
|
| 271 |
+
|
| 272 |
+
def forward(self, q, kv):
|
| 273 |
+
_, n_ctx, _ = q.shape
|
| 274 |
+
bs, n_data, width = kv.shape
|
| 275 |
+
attn_ch = width // self.heads // 2
|
| 276 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
| 277 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
| 278 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
| 279 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
| 280 |
+
|
| 281 |
+
if self.flash:
|
| 282 |
+
out = flash_attention(q, k, v)
|
| 283 |
+
out = out.reshape(out.shape[0], out.shape[1], -1)
|
| 284 |
+
else:
|
| 285 |
+
weight = torch.einsum(
|
| 286 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
| 287 |
+
) # More stable with f16 than dividing afterwards
|
| 288 |
+
wdtype = weight.dtype
|
| 289 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
| 290 |
+
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
| 291 |
+
|
| 292 |
+
return out
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
*,
|
| 299 |
+
device: Optional[torch.device],
|
| 300 |
+
dtype: Optional[torch.dtype],
|
| 301 |
+
n_data: Optional[int] = None,
|
| 302 |
+
width: int,
|
| 303 |
+
heads: int,
|
| 304 |
+
data_width: Optional[int] = None,
|
| 305 |
+
mlp_width_scale: int = 4,
|
| 306 |
+
init_scale: float = 0.25,
|
| 307 |
+
qkv_bias: bool = True,
|
| 308 |
+
flash: bool = False
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
if data_width is None:
|
| 313 |
+
data_width = width
|
| 314 |
+
|
| 315 |
+
self.attn = MultiheadCrossAttention(
|
| 316 |
+
device=device,
|
| 317 |
+
dtype=dtype,
|
| 318 |
+
n_data=n_data,
|
| 319 |
+
width=width,
|
| 320 |
+
heads=heads,
|
| 321 |
+
data_width=data_width,
|
| 322 |
+
init_scale=init_scale,
|
| 323 |
+
qkv_bias=qkv_bias,
|
| 324 |
+
flash=flash,
|
| 325 |
+
)
|
| 326 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 327 |
+
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
|
| 328 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale)
|
| 329 |
+
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
|
| 330 |
+
|
| 331 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
| 332 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
| 333 |
+
x = x + self.mlp(self.ln_3(x))
|
| 334 |
+
return x
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class MLP(nn.Module):
|
| 338 |
+
def __init__(self, *,
|
| 339 |
+
device: Optional[torch.device],
|
| 340 |
+
dtype: Optional[torch.dtype],
|
| 341 |
+
width: int,
|
| 342 |
+
hidden_width_scale: int = 4,
|
| 343 |
+
init_scale: float):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.width = width
|
| 346 |
+
self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype)
|
| 347 |
+
self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype)
|
| 348 |
+
self.gelu = nn.GELU()
|
| 349 |
+
init_linear(self.c_fc, init_scale)
|
| 350 |
+
init_linear(self.c_proj, init_scale)
|
| 351 |
+
|
| 352 |
+
def forward(self, x):
|
| 353 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Transformer(nn.Module):
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
*,
|
| 360 |
+
device: Optional[torch.device],
|
| 361 |
+
dtype: Optional[torch.dtype],
|
| 362 |
+
n_ctx: int,
|
| 363 |
+
width: int,
|
| 364 |
+
layers: int,
|
| 365 |
+
heads: int,
|
| 366 |
+
init_scale: float = 0.25,
|
| 367 |
+
qkv_bias: bool = True,
|
| 368 |
+
flash: bool = False,
|
| 369 |
+
use_checkpoint: bool = False
|
| 370 |
+
):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.n_ctx = n_ctx
|
| 373 |
+
self.width = width
|
| 374 |
+
self.layers = layers
|
| 375 |
+
self.resblocks = nn.ModuleList(
|
| 376 |
+
[
|
| 377 |
+
ResidualAttentionBlock(
|
| 378 |
+
device=device,
|
| 379 |
+
dtype=dtype,
|
| 380 |
+
n_ctx=n_ctx,
|
| 381 |
+
width=width,
|
| 382 |
+
heads=heads,
|
| 383 |
+
init_scale=init_scale,
|
| 384 |
+
qkv_bias=qkv_bias,
|
| 385 |
+
flash=flash,
|
| 386 |
+
use_checkpoint=use_checkpoint
|
| 387 |
+
)
|
| 388 |
+
for _ in range(layers)
|
| 389 |
+
]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def forward(self, x: torch.Tensor):
|
| 393 |
+
for block in self.resblocks:
|
| 394 |
+
x = block(x)
|
| 395 |
+
return x
|
src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
+
from einops import repeat
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
from ..attention_processor import Tripo2AttnProcessor2_0
|
| 13 |
+
from ..embeddings import FrequencyPositionalEmbedding
|
| 14 |
+
from .autoencoder_kl_tripo2 import Tripo2Encoder, Tripo2Decoder
|
| 15 |
+
from .FSQ import FSQ
|
| 16 |
+
from .SimVQ import SimVQ1D
|
| 17 |
+
|
| 18 |
+
from ...utils import fps
|
| 19 |
+
|
| 20 |
+
def init_linear(l, stddev):
|
| 21 |
+
nn.init.normal_(l.weight, std=stddev)
|
| 22 |
+
if l.bias is not None:
|
| 23 |
+
nn.init.constant_(l.bias, 0.0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SkinFSQCVAEModel(ModelMixin, ConfigMixin):
|
| 27 |
+
@register_to_config
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
in_channels: int = 4,
|
| 31 |
+
cond_channels: int = 3,
|
| 32 |
+
latent_channels: int = 64,
|
| 33 |
+
num_attention_heads: int = 8,
|
| 34 |
+
width_encoder: int = 512,
|
| 35 |
+
width_decoder: int = 1024,
|
| 36 |
+
num_layers_encoder: int = 8,
|
| 37 |
+
num_layers_decoder: int = 16,
|
| 38 |
+
embedding_type: str = "frequency",
|
| 39 |
+
embed_frequency: int = 8,
|
| 40 |
+
embed_include_pi: bool = False,
|
| 41 |
+
sample_tokens: int = 32,
|
| 42 |
+
**kwargs
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.out_channels = 1
|
| 47 |
+
|
| 48 |
+
if embedding_type == "frequency":
|
| 49 |
+
self.embedder = FrequencyPositionalEmbedding(
|
| 50 |
+
num_freqs=embed_frequency,
|
| 51 |
+
logspace=True,
|
| 52 |
+
input_dim=3,
|
| 53 |
+
include_pi=embed_include_pi,
|
| 54 |
+
use_pmpe=kwargs.get('use_pmpe', False),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError(
|
| 58 |
+
f"Embedding type {embedding_type} is not supported."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.is_learned_queries = kwargs['is_learned_queries']
|
| 62 |
+
|
| 63 |
+
is_miche = kwargs.get('is_miche', False)
|
| 64 |
+
self.encoder = Tripo2Encoder(
|
| 65 |
+
in_channels=in_channels + self.embedder.out_dim,
|
| 66 |
+
dim=width_encoder,
|
| 67 |
+
num_attention_heads=num_attention_heads,
|
| 68 |
+
num_layers=num_layers_encoder,
|
| 69 |
+
is_learned_queries=self.is_learned_queries,
|
| 70 |
+
sample_tokens=sample_tokens,
|
| 71 |
+
is_miche=is_miche,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.cond_encoder = Tripo2Encoder(
|
| 75 |
+
in_channels=cond_channels + self.embedder.out_dim,
|
| 76 |
+
dim=width_encoder,
|
| 77 |
+
num_attention_heads=num_attention_heads,
|
| 78 |
+
num_layers=num_layers_encoder,
|
| 79 |
+
is_miche=is_miche,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.decoder = Tripo2Decoder(
|
| 83 |
+
in_channels=self.embedder.out_dim + self.cond_channels,
|
| 84 |
+
out_channels=self.out_channels,
|
| 85 |
+
dim=width_decoder,
|
| 86 |
+
num_attention_heads=num_attention_heads,
|
| 87 |
+
num_layers=num_layers_decoder,
|
| 88 |
+
is_miche=is_miche,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.cond_quant = nn.Linear(width_encoder, latent_channels, bias=True)
|
| 92 |
+
|
| 93 |
+
self.quant = nn.Linear(width_encoder, latent_channels, bias=True)
|
| 94 |
+
self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
|
| 95 |
+
|
| 96 |
+
init_scale = 0.25 * math.sqrt(1.0 / width_encoder)
|
| 97 |
+
init_linear(self.cond_quant, init_scale)
|
| 98 |
+
init_linear(self.quant, init_scale)
|
| 99 |
+
init_scale = 0.25 * math.sqrt(1.0 / latent_channels)
|
| 100 |
+
init_linear(self.post_quant, init_scale)
|
| 101 |
+
self.use_slicing = False
|
| 102 |
+
self.slicing_length = 1
|
| 103 |
+
if kwargs.get('FSQ_dict', None) is not None:
|
| 104 |
+
self.FSQ = FSQ(**kwargs['FSQ_dict'])
|
| 105 |
+
else:
|
| 106 |
+
self.FSQ = SimVQ1D(**kwargs['SimVQ_dict'])
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 110 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 111 |
+
r"""
|
| 112 |
+
Returns:
|
| 113 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 114 |
+
indexed by its weight name.
|
| 115 |
+
"""
|
| 116 |
+
# set recursively
|
| 117 |
+
processors = {}
|
| 118 |
+
|
| 119 |
+
def fn_recursive_add_processors(
|
| 120 |
+
name: str,
|
| 121 |
+
module: torch.nn.Module,
|
| 122 |
+
processors: Dict[str, AttentionProcessor],
|
| 123 |
+
):
|
| 124 |
+
if hasattr(module, "get_processor"):
|
| 125 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 126 |
+
|
| 127 |
+
for sub_name, child in module.named_children():
|
| 128 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 129 |
+
|
| 130 |
+
return processors
|
| 131 |
+
|
| 132 |
+
for name, module in self.named_children():
|
| 133 |
+
fn_recursive_add_processors(name, module, processors)
|
| 134 |
+
|
| 135 |
+
return processors
|
| 136 |
+
|
| 137 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 138 |
+
def set_attn_processor(
|
| 139 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
| 140 |
+
):
|
| 141 |
+
r"""
|
| 142 |
+
Sets the attention processor to use to compute attention.
|
| 143 |
+
|
| 144 |
+
Parameters:
|
| 145 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 146 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 147 |
+
for **all** `Attention` layers.
|
| 148 |
+
|
| 149 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 150 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 151 |
+
|
| 152 |
+
"""
|
| 153 |
+
count = len(self.attn_processors.keys())
|
| 154 |
+
|
| 155 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 158 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 162 |
+
if hasattr(module, "set_processor"):
|
| 163 |
+
if not isinstance(processor, dict):
|
| 164 |
+
module.set_processor(processor)
|
| 165 |
+
else:
|
| 166 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 167 |
+
|
| 168 |
+
for sub_name, child in module.named_children():
|
| 169 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 170 |
+
|
| 171 |
+
for name, module in self.named_children():
|
| 172 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 173 |
+
|
| 174 |
+
def set_default_attn_processor(self):
|
| 175 |
+
"""
|
| 176 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 177 |
+
"""
|
| 178 |
+
self.set_attn_processor(Tripo2AttnProcessor2_0())
|
| 179 |
+
|
| 180 |
+
def enable_slicing(self, slicing_length: int = 1) -> None:
|
| 181 |
+
r"""
|
| 182 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 183 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 184 |
+
"""
|
| 185 |
+
self.use_slicing = True
|
| 186 |
+
self.slicing_length = slicing_length
|
| 187 |
+
|
| 188 |
+
def disable_slicing(self) -> None:
|
| 189 |
+
r"""
|
| 190 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 191 |
+
decoding in one step.
|
| 192 |
+
"""
|
| 193 |
+
self.use_slicing = False
|
| 194 |
+
|
| 195 |
+
def _sample_features(
|
| 196 |
+
self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Sample points from features of the input point cloud.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x (torch.Tensor): The input point cloud. shape: (B, N, C)
|
| 203 |
+
num_tokens (int, optional): The number of points to sample. Defaults to 2048.
|
| 204 |
+
seed (Optional[int], optional): The random seed. Defaults to None.
|
| 205 |
+
"""
|
| 206 |
+
rng = np.random.default_rng(seed)
|
| 207 |
+
indices = rng.choice(
|
| 208 |
+
x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
|
| 209 |
+
)
|
| 210 |
+
selected_points = x[:, indices]
|
| 211 |
+
|
| 212 |
+
batch_size, num_points, num_channels = selected_points.shape
|
| 213 |
+
flattened_points = selected_points.view(batch_size * num_points, num_channels)
|
| 214 |
+
batch_indices = (
|
| 215 |
+
torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# fps sampling
|
| 219 |
+
sampling_ratio = 1.0 / 4
|
| 220 |
+
sampled_indices = fps(
|
| 221 |
+
flattened_points[:, :3],
|
| 222 |
+
batch_indices,
|
| 223 |
+
ratio=sampling_ratio,
|
| 224 |
+
random_start=self.training,
|
| 225 |
+
)
|
| 226 |
+
sampled_points = flattened_points[sampled_indices].view(
|
| 227 |
+
batch_size, -1, num_channels
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return sampled_points
|
| 231 |
+
|
| 232 |
+
def get_qkv(self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None, not_get_q: bool=False):
|
| 233 |
+
positions, features = x[..., :3], x[..., 3:]
|
| 234 |
+
x_kv = torch.cat([self.embedder(positions), features], dim=-1)
|
| 235 |
+
|
| 236 |
+
if not_get_q:
|
| 237 |
+
x_q = torch.zeros((x.shape[0], num_tokens, x.shape[-1]), dtype=x.dtype, device=x.device)
|
| 238 |
+
else:
|
| 239 |
+
sampled_x = self._sample_features(x, num_tokens, seed)
|
| 240 |
+
positions, features = (
|
| 241 |
+
sampled_x[..., :3],
|
| 242 |
+
sampled_x[..., 3:],
|
| 243 |
+
)
|
| 244 |
+
x_q = torch.cat([self.embedder(positions), features], dim=-1)
|
| 245 |
+
return x_q, x_kv
|
| 246 |
+
|
| 247 |
+
def _encode(
|
| 248 |
+
self, x: torch.Tensor|None, cond: torch.Tensor|None, num_tokens: int = 128, cond_tokens: int = 128, seed: Optional[int] = None,
|
| 249 |
+
return_z: bool=True, return_cond: bool=True,
|
| 250 |
+
):
|
| 251 |
+
position_channels = 3
|
| 252 |
+
if return_z:
|
| 253 |
+
assert x is not None
|
| 254 |
+
x_q, x_kv = self.get_qkv(x, num_tokens, seed, not_get_q=self.is_learned_queries)
|
| 255 |
+
x = self.encoder(x_q, x_kv)
|
| 256 |
+
x = self.quant(x)
|
| 257 |
+
else:
|
| 258 |
+
x = None
|
| 259 |
+
|
| 260 |
+
if return_cond:
|
| 261 |
+
assert cond is not None
|
| 262 |
+
cond_q, cond_kv = self.get_qkv(cond, cond_tokens, seed)
|
| 263 |
+
cond_embed = self.cond_encoder(cond_q, cond_kv)
|
| 264 |
+
cond = self.cond_quant(cond_embed)
|
| 265 |
+
else:
|
| 266 |
+
cond = None
|
| 267 |
+
|
| 268 |
+
return x, cond
|
| 269 |
+
|
| 270 |
+
def _decode(
|
| 271 |
+
self, z: torch.Tensor,
|
| 272 |
+
cond: torch.Tensor,
|
| 273 |
+
sampled_points: torch.Tensor,
|
| 274 |
+
num_chunks: Optional[int] = None,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
xyz_samples = sampled_points
|
| 277 |
+
z = self.post_quant(torch.cat([z, cond], dim=1))
|
| 278 |
+
|
| 279 |
+
num_points = xyz_samples.shape[1]
|
| 280 |
+
if num_chunks is None:
|
| 281 |
+
num_chunks = num_points
|
| 282 |
+
|
| 283 |
+
queries = sampled_points.to(z.device, dtype=z.dtype)
|
| 284 |
+
positions, features = (
|
| 285 |
+
queries[..., :3],
|
| 286 |
+
queries[..., 3:],
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
kv_cache = None
|
| 290 |
+
dec = []
|
| 291 |
+
for i in range(0, num_points, num_chunks):
|
| 292 |
+
queries = torch.cat([self.embedder(positions[:, i:i + num_chunks, :]), features[:, i:i + num_chunks, :]], dim=-1)
|
| 293 |
+
z, kv_cache = self.decoder(z, queries, kv_cache)
|
| 294 |
+
dec.append(z)
|
| 295 |
+
|
| 296 |
+
return torch.cat(dec, dim=1)
|
| 297 |
+
|
| 298 |
+
def compile_model(self):
|
| 299 |
+
self.encoder = torch.compile(self.encoder)
|
| 300 |
+
self.cond_encoder = torch.compile(self.cond_encoder)
|
| 301 |
+
self.decoder = torch.compile(self.decoder)
|
| 302 |
+
|
| 303 |
+
def forward(self, x: torch.Tensor):
|
| 304 |
+
pass
|
src/model/skin_vae/autoencoders/vae.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DiagonalGaussianDistribution(object):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
parameters: torch.Tensor,
|
| 12 |
+
deterministic: bool = False,
|
| 13 |
+
feature_dim: int = 1,
|
| 14 |
+
):
|
| 15 |
+
self.parameters = parameters
|
| 16 |
+
self.feature_dim = feature_dim
|
| 17 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
|
| 18 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 19 |
+
self.deterministic = deterministic
|
| 20 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 21 |
+
self.var = torch.exp(self.logvar)
|
| 22 |
+
if self.deterministic:
|
| 23 |
+
self.var = self.std = torch.zeros_like(
|
| 24 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
| 28 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
| 29 |
+
sample = randn_tensor(
|
| 30 |
+
self.mean.shape,
|
| 31 |
+
generator=generator,
|
| 32 |
+
device=self.parameters.device,
|
| 33 |
+
dtype=self.parameters.dtype,
|
| 34 |
+
)
|
| 35 |
+
x = self.mean + self.std * sample
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
| 39 |
+
if self.deterministic:
|
| 40 |
+
return torch.Tensor([0.0])
|
| 41 |
+
else:
|
| 42 |
+
if other is None:
|
| 43 |
+
return 0.5 * torch.mean(
|
| 44 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 45 |
+
)
|
| 46 |
+
elif isinstance(other, DiagonalGaussianDistribution):
|
| 47 |
+
return 0.5 * torch.mean(
|
| 48 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 49 |
+
+ self.var / other.var
|
| 50 |
+
- 1.0
|
| 51 |
+
- self.logvar
|
| 52 |
+
+ other.logvar,
|
| 53 |
+
)
|
| 54 |
+
elif isinstance(other, torch.Tensor):
|
| 55 |
+
return 0.5 * torch.mean(
|
| 56 |
+
torch.pow(self.mean - other, 2) + self.var - 1.0 - self.logvar,
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("Other must be a DiagonalGaussianDistribution or torch.Tensor")
|
| 60 |
+
|
| 61 |
+
def nll(
|
| 62 |
+
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
if self.deterministic:
|
| 65 |
+
return torch.Tensor([0.0])
|
| 66 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 67 |
+
return 0.5 * torch.sum(
|
| 68 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 69 |
+
dim=dims,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def mode(self) -> torch.Tensor:
|
| 73 |
+
return self.mean
|
src/model/skin_vae/embeddings.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FrequencyPositionalEmbedding(nn.Module):
|
| 6 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
| 7 |
+
each feature dimension of `x[..., i]` into:
|
| 8 |
+
[
|
| 9 |
+
sin(x[..., i]),
|
| 10 |
+
sin(f_1*x[..., i]),
|
| 11 |
+
sin(f_2*x[..., i]),
|
| 12 |
+
...
|
| 13 |
+
sin(f_N * x[..., i]),
|
| 14 |
+
cos(x[..., i]),
|
| 15 |
+
cos(f_1*x[..., i]),
|
| 16 |
+
cos(f_2*x[..., i]),
|
| 17 |
+
...
|
| 18 |
+
cos(f_N * x[..., i]),
|
| 19 |
+
x[..., i] # only present if include_input is True.
|
| 20 |
+
], here f_i is the frequency.
|
| 21 |
+
|
| 22 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
| 23 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
| 24 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
num_freqs (int): the number of frequencies, default is 6;
|
| 28 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 29 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
| 30 |
+
input_dim (int): the input dimension, default is 3;
|
| 31 |
+
include_input (bool): include the input tensor or not, default is True.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 35 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
| 36 |
+
|
| 37 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
| 38 |
+
otherwise, it is input_dim * num_freqs * 2.
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
num_freqs: int = 6,
|
| 45 |
+
logspace: bool = True,
|
| 46 |
+
input_dim: int = 3,
|
| 47 |
+
include_input: bool = True,
|
| 48 |
+
include_pi: bool = True,
|
| 49 |
+
use_pmpe: bool = False,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""The initialization"""
|
| 52 |
+
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
if logspace:
|
| 56 |
+
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
|
| 57 |
+
else:
|
| 58 |
+
frequencies = torch.linspace(
|
| 59 |
+
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if include_pi:
|
| 63 |
+
frequencies *= torch.pi
|
| 64 |
+
|
| 65 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
| 66 |
+
self.include_input = include_input
|
| 67 |
+
self.num_freqs = num_freqs
|
| 68 |
+
self.use_pmpe = use_pmpe
|
| 69 |
+
if use_pmpe:
|
| 70 |
+
phase = torch.arange(num_freqs, dtype=torch.float32)
|
| 71 |
+
for i in range(num_freqs):
|
| 72 |
+
phase[i] = torch.pow(torch.tensor(num_freqs), 1.0-(i+1)/num_freqs)+(i+1)/num_freqs
|
| 73 |
+
phase *= torch.pi*2
|
| 74 |
+
self.register_buffer("phase", phase, persistent=False)
|
| 75 |
+
|
| 76 |
+
self.out_dim = self.get_dims(input_dim)
|
| 77 |
+
|
| 78 |
+
def get_dims(self, input_dim):
|
| 79 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
| 80 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
| 81 |
+
|
| 82 |
+
return out_dim
|
| 83 |
+
|
| 84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""Forward process.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
x: tensor of shape [..., dim]
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
| 92 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
if self.num_freqs > 0:
|
| 96 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(
|
| 97 |
+
*x.shape[:-1], -1
|
| 98 |
+
)
|
| 99 |
+
if self.use_pmpe:
|
| 100 |
+
phase = (x[..., None].contiguous()*torch.pi*0.5 + self.phase).view(
|
| 101 |
+
*x.shape[:-1], -1
|
| 102 |
+
)
|
| 103 |
+
res = torch.cat((embed.sin()+phase.sin(), embed.cos()+phase.cos()), dim=-1)
|
| 104 |
+
else:
|
| 105 |
+
res = torch.cat((embed.sin(), embed.cos()), dim=-1)
|
| 106 |
+
if self.include_input:
|
| 107 |
+
return torch.cat((x, res), dim=-1)
|
| 108 |
+
else:
|
| 109 |
+
return res
|
| 110 |
+
else:
|
| 111 |
+
return x
|
src/model/skin_vae/transformers/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional
|
| 2 |
+
|
| 3 |
+
from .tripo2_transformer import Tripo2DiTModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def default_set_attn_proc_func(
|
| 7 |
+
name: str,
|
| 8 |
+
hidden_size: int,
|
| 9 |
+
cross_attention_dim: Optional[int],
|
| 10 |
+
ori_attn_proc: object,
|
| 11 |
+
) -> object:
|
| 12 |
+
return ori_attn_proc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_transformer_attn_processor(
|
| 16 |
+
transformer: Tripo2DiTModel,
|
| 17 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 18 |
+
set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 19 |
+
) -> None:
|
| 20 |
+
attn_procs = {}
|
| 21 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 22 |
+
hidden_size = transformer.config.width
|
| 23 |
+
if name.endswith("attn1.processor"):
|
| 24 |
+
# self attention
|
| 25 |
+
attn_procs[name] = set_self_attn_proc_func(
|
| 26 |
+
name, hidden_size, None, attn_processor
|
| 27 |
+
)
|
| 28 |
+
elif name.endswith("attn2.processor"):
|
| 29 |
+
# cross attention
|
| 30 |
+
cross_attention_dim = transformer.config.cross_attention_dim
|
| 31 |
+
attn_procs[name] = set_cross_attn_proc_func(
|
| 32 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 33 |
+
)
|
| 34 |
+
elif name.endswith("attn2_2.processor"):
|
| 35 |
+
# cross attention 2
|
| 36 |
+
cross_attention_dim = transformer.config.cross_attention_2_dim
|
| 37 |
+
attn_procs[name] = set_cross_attn_proc_func(
|
| 38 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
transformer.set_attn_processor(attn_procs)
|
src/model/skin_vae/transformers/modeling_outputs.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Transformer1DModelOutput:
|
| 8 |
+
sample: torch.FloatTensor
|