pookiefoof commited on
Commit
9d7cf7f
·
0 Parent(s):

Public release: SkinTokens · TokenRig demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +67 -0
  3. LICENSE +21 -0
  4. README.md +14 -0
  5. bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl +3 -0
  6. bpy_server.py +7 -0
  7. configs/skeleton/mixamo.yaml +59 -0
  8. configs/skeleton/vroid.yaml +59 -0
  9. demo.py +764 -0
  10. download.py +72 -0
  11. requirements.txt +37 -0
  12. runtime.txt +1 -0
  13. src/__init__.py +0 -0
  14. src/data/augment.py +706 -0
  15. src/data/datapath.py +344 -0
  16. src/data/dataset.py +319 -0
  17. src/data/order.py +132 -0
  18. src/data/sampler.py +189 -0
  19. src/data/spec.py +16 -0
  20. src/data/transform.py +70 -0
  21. src/data/vertex_group.py +257 -0
  22. src/model/__init__.py +0 -0
  23. src/model/michelangelo/__init__.py +1 -0
  24. src/model/michelangelo/get_model.py +30 -0
  25. src/model/michelangelo/models/__init__.py +1 -0
  26. src/model/michelangelo/models/modules/__init__.py +3 -0
  27. src/model/michelangelo/models/modules/checkpoint.py +69 -0
  28. src/model/michelangelo/models/modules/distributions.py +100 -0
  29. src/model/michelangelo/models/modules/embedder.py +213 -0
  30. src/model/michelangelo/models/modules/transformer_blocks.py +327 -0
  31. src/model/michelangelo/models/tsal/__init__.py +1 -0
  32. src/model/michelangelo/models/tsal/loss.py +454 -0
  33. src/model/michelangelo/models/tsal/sal_perceiver.py +723 -0
  34. src/model/michelangelo/models/tsal/tsal_base.py +121 -0
  35. src/model/michelangelo/utils/__init__.py +4 -0
  36. src/model/michelangelo/utils/eval.py +12 -0
  37. src/model/michelangelo/utils/misc.py +271 -0
  38. src/model/parse_encoder.py +28 -0
  39. src/model/skin_vae/attention_processor.py +283 -0
  40. src/model/skin_vae/autoencoders/FSQ.py +191 -0
  41. src/model/skin_vae/autoencoders/SimVQ.py +197 -0
  42. src/model/skin_vae/autoencoders/__init__.py +1 -0
  43. src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py +254 -0
  44. src/model/skin_vae/autoencoders/get_model.py +22 -0
  45. src/model/skin_vae/autoencoders/miche_transformer_blocks.py +395 -0
  46. src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py +304 -0
  47. src/model/skin_vae/autoencoders/vae.py +73 -0
  48. src/model/skin_vae/embeddings.py +111 -0
  49. src/model/skin_vae/transformers/__init__.py +41 -0
  50. src/model/skin_vae/transformers/modeling_outputs.py +8 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # igonore all pychace
2
+ **/__pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # ignore tmp & output files
7
+ _data/
8
+ ckpt/
9
+ tmp/
10
+ tmp_fusion/
11
+ tmp_vae/
12
+ tmp_video/
13
+ *.glb
14
+ *.ply
15
+ *.obj
16
+ *.fbx
17
+ *.npz
18
+ *.blend
19
+ *.blend1
20
+ *.blend2
21
+
22
+ # ignore logs
23
+ wandb/
24
+ lightning_logs/
25
+ *.log
26
+
27
+ # ignore experiments
28
+ experiments/
29
+ results/
30
+ dataset_clean/
31
+ logs/
32
+ datalist/
33
+ dataset_inference/
34
+ dataset_inference_clean/
35
+ feature_viz/
36
+
37
+ # Distribution / packaging
38
+ dist/
39
+ build/
40
+ *.egg-info/
41
+ *.egg
42
+ *.whl
43
+
44
+ # Virtual environments
45
+ venv/
46
+ env/
47
+ .env/
48
+ .venv/
49
+
50
+ # IDE specific files
51
+ .idea/
52
+ .vscode/
53
+ *.swp
54
+ *.swo
55
+ .DS_Store
56
+
57
+ # Jupyter Notebook
58
+ .ipynb_checkpoints
59
+ *.ipynb
60
+
61
+ # Unit test / coverage reports
62
+ htmlcov/
63
+ .tox/
64
+ .coverage
65
+ .coverage.*
66
+ coverage.xml
67
+ *.cover
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 VAST-AI-Research
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SkinTokens
3
+ emoji: 🌖
4
+ colorFrom: green
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 6.12.0
8
+ python_version: 3.12.12
9
+ app_file: demo.py
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39df5f78dc95d1fae6058a3134a40f645303c6e96540f87e9ee4c0fd436def1d
3
+ size 346159222
bpy_server.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from src.server.bpy_server import run
2
+
3
+ def main():
4
+ run()
5
+
6
+ if __name__ == "__main__":
7
+ main()
configs/skeleton/mixamo.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parts_order: [body, hand]
2
+
3
+ parts:
4
+ body: [
5
+ mixamorig:Hips,
6
+ mixamorig:Spine,
7
+ mixamorig:Spine1,
8
+ mixamorig:Spine2,
9
+ mixamorig:Neck,
10
+ mixamorig:Head,
11
+ mixamorig:LeftShoulder,
12
+ mixamorig:LeftArm,
13
+ mixamorig:LeftForeArm,
14
+ mixamorig:LeftHand,
15
+ mixamorig:RightShoulder,
16
+ mixamorig:RightArm,
17
+ mixamorig:RightForeArm,
18
+ mixamorig:RightHand,
19
+ mixamorig:LeftUpLeg,
20
+ mixamorig:LeftLeg,
21
+ mixamorig:LeftFoot,
22
+ mixamorig:LeftToeBase,
23
+ mixamorig:RightUpLeg,
24
+ mixamorig:RightLeg,
25
+ mixamorig:RightFoot,
26
+ mixamorig:RightToeBase,
27
+ ]
28
+ hand: [
29
+ mixamorig:LeftHandThumb1,
30
+ mixamorig:LeftHandThumb2,
31
+ mixamorig:LeftHandThumb3,
32
+ mixamorig:LeftHandIndex1,
33
+ mixamorig:LeftHandIndex2,
34
+ mixamorig:LeftHandIndex3,
35
+ mixamorig:LeftHandMiddle1,
36
+ mixamorig:LeftHandMiddle2,
37
+ mixamorig:LeftHandMiddle3,
38
+ mixamorig:LeftHandRing1,
39
+ mixamorig:LeftHandRing2,
40
+ mixamorig:LeftHandRing3,
41
+ mixamorig:LeftHandPinky1,
42
+ mixamorig:LeftHandPinky2,
43
+ mixamorig:LeftHandPinky3,
44
+ mixamorig:RightHandIndex1,
45
+ mixamorig:RightHandIndex2,
46
+ mixamorig:RightHandIndex3,
47
+ mixamorig:RightHandThumb1,
48
+ mixamorig:RightHandThumb2,
49
+ mixamorig:RightHandThumb3,
50
+ mixamorig:RightHandMiddle1,
51
+ mixamorig:RightHandMiddle2,
52
+ mixamorig:RightHandMiddle3,
53
+ mixamorig:RightHandRing1,
54
+ mixamorig:RightHandRing2,
55
+ mixamorig:RightHandRing3,
56
+ mixamorig:RightHandPinky1,
57
+ mixamorig:RightHandPinky2,
58
+ mixamorig:RightHandPinky3,
59
+ ]
configs/skeleton/vroid.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parts_order: [body, hand]
2
+
3
+ parts:
4
+ body: [
5
+ J_Bip_C_Hips,
6
+ J_Bip_C_Spine,
7
+ J_Bip_C_Chest,
8
+ J_Bip_C_UpperChest,
9
+ J_Bip_C_Neck,
10
+ J_Bip_C_Head,
11
+ J_Bip_L_Shoulder,
12
+ J_Bip_L_UpperArm,
13
+ J_Bip_L_LowerArm,
14
+ J_Bip_L_Hand,
15
+ J_Bip_R_Shoulder,
16
+ J_Bip_R_UpperArm,
17
+ J_Bip_R_LowerArm,
18
+ J_Bip_R_Hand,
19
+ J_Bip_L_UpperLeg,
20
+ J_Bip_L_LowerLeg,
21
+ J_Bip_L_Foot,
22
+ J_Bip_L_ToeBase,
23
+ J_Bip_R_UpperLeg,
24
+ J_Bip_R_LowerLeg,
25
+ J_Bip_R_Foot,
26
+ J_Bip_R_ToeBase,
27
+ ]
28
+ hand: [
29
+ J_Bip_L_Thumb1,
30
+ J_Bip_L_Thumb2,
31
+ J_Bip_L_Thumb3,
32
+ J_Bip_L_Index1,
33
+ J_Bip_L_Index2,
34
+ J_Bip_L_Index3,
35
+ J_Bip_L_Middle1,
36
+ J_Bip_L_Middle2,
37
+ J_Bip_L_Middle3,
38
+ J_Bip_L_Ring1,
39
+ J_Bip_L_Ring2,
40
+ J_Bip_L_Ring3,
41
+ J_Bip_L_Little1,
42
+ J_Bip_L_Little2,
43
+ J_Bip_L_Little3,
44
+ J_Bip_R_Index1,
45
+ J_Bip_R_Index2,
46
+ J_Bip_R_Index3,
47
+ J_Bip_R_Thumb1,
48
+ J_Bip_R_Thumb2,
49
+ J_Bip_R_Thumb3,
50
+ J_Bip_R_Middle1,
51
+ J_Bip_R_Middle2,
52
+ J_Bip_R_Middle3,
53
+ J_Bip_R_Ring1,
54
+ J_Bip_R_Ring2,
55
+ J_Bip_R_Ring3,
56
+ J_Bip_R_Little1,
57
+ J_Bip_R_Little2,
58
+ J_Bip_R_Little3,
59
+ ]
demo.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import atexit
3
+ import importlib
4
+ import os
5
+ import signal
6
+ import subprocess
7
+ import sys
8
+ import tempfile
9
+ import time
10
+ from pathlib import Path
11
+ from typing import List, Optional, Tuple
12
+
13
+ import gradio as gr
14
+ import requests
15
+ from torch import Tensor
16
+ from tqdm import tqdm
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # ZeroGPU compatibility shim. The hosted HF Space provides the `spaces`
20
+ # package; running locally we substitute a no-op.
21
+ # ---------------------------------------------------------------------------
22
+ try:
23
+ spaces = importlib.import_module("spaces")
24
+ except Exception:
25
+ class _SpacesCompat:
26
+ @staticmethod
27
+ def GPU(*args, **kwargs):
28
+ if len(args) == 1 and callable(args[0]) and not kwargs:
29
+ return args[0]
30
+
31
+ def _decorator(fn):
32
+ return fn
33
+
34
+ return _decorator
35
+
36
+ spaces = _SpacesCompat()
37
+
38
+ os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1")
39
+ gr.TEMP_DIR = "tmp_gradio"
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Install the bundled `bpy` wheel at runtime if it isn't already importable.
44
+ #
45
+ # Why this is non-trivial:
46
+ # - Putting the wheel in requirements.txt fails: HF Spaces' Docker build
47
+ # mounts only requirements.txt BEFORE the repo COPY, so the wheel path
48
+ # doesn't exist at pip-install time.
49
+ # - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 /
50
+ # manylinux_2_39).
51
+ # - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS
52
+ # layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't
53
+ # list `*.whl`). The container's COPY-from-repo only carries the LFS
54
+ # *pointer* file — a ~150-byte text stub — not the actual wheel binary.
55
+ # So `pip install <wheel>` and `zipfile.ZipFile(<wheel>)` both fail with
56
+ # "is not a zip file" / "Wheel is invalid".
57
+ #
58
+ # So: we detect the LFS-pointer case and re-fetch the real wheel from the
59
+ # HF Hub at runtime (where the API resolves LFS server-side), then extract
60
+ # it directly into site-packages.
61
+ # ---------------------------------------------------------------------------
62
+ def _ensure_bpy_installed():
63
+ try:
64
+ import bpy # noqa: F401
65
+ return
66
+ except Exception:
67
+ pass
68
+
69
+ import glob
70
+ import sysconfig
71
+ import zipfile
72
+
73
+ here = os.path.dirname(os.path.abspath(__file__))
74
+ wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl")))
75
+ if not wheels:
76
+ print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True)
77
+ return
78
+
79
+ wheel = wheels[-1]
80
+ wheel_name = os.path.basename(wheel)
81
+
82
+ # Detect LFS pointer (text stub starting with "version https://git-lfs...").
83
+ is_real_zip = False
84
+ try:
85
+ with open(wheel, "rb") as f:
86
+ is_real_zip = f.read(4).startswith(b"PK")
87
+ except Exception:
88
+ pass
89
+
90
+ if not is_real_zip:
91
+ print(
92
+ f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); "
93
+ f"fetching real wheel from HF Hub...",
94
+ flush=True,
95
+ )
96
+ from huggingface_hub import hf_hub_download
97
+
98
+ space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens")
99
+ token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos
100
+ wheel = hf_hub_download(
101
+ repo_id=space_id,
102
+ repo_type="space",
103
+ filename=wheel_name,
104
+ token=token,
105
+ )
106
+ print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True)
107
+
108
+ site = sysconfig.get_paths()["purelib"]
109
+ print(f"[demo] Extracting {wheel_name} into {site}", flush=True)
110
+ with zipfile.ZipFile(wheel) as z:
111
+ z.extractall(site)
112
+ print("[demo] bpy wheel extracted.", flush=True)
113
+
114
+
115
+ _ensure_bpy_installed()
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3
120
+ # tokenizer/config on first cold-start.
121
+ #
122
+ # These live in the *model* repo `VAST-AI/SkinTokens` (private), separate
123
+ # from this Space repo, so they aren't COPYed into the container. Re-uses
124
+ # `HF_TOKEN` from the Space secrets.
125
+ # ---------------------------------------------------------------------------
126
+ def _ensure_models_downloaded():
127
+ here = os.path.dirname(os.path.abspath(__file__))
128
+ needed_ckpts = [
129
+ "experiments/skin_vae_2_10_32768/last.ckpt",
130
+ "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
131
+ ]
132
+ qwen_dir = os.path.join(here, "models", "Qwen3-0.6B")
133
+
134
+ all_present = (
135
+ all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts)
136
+ and os.path.exists(os.path.join(qwen_dir, "tokenizer.json"))
137
+ )
138
+ if all_present:
139
+ return
140
+
141
+ from huggingface_hub import hf_hub_download, snapshot_download
142
+
143
+ token = os.environ.get("HF_TOKEN")
144
+
145
+ for rel in needed_ckpts:
146
+ target = os.path.join(here, rel)
147
+ if os.path.exists(target):
148
+ continue
149
+ print(f"[demo] Downloading checkpoint: {rel}", flush=True)
150
+ hf_hub_download(
151
+ repo_id="VAST-AI/SkinTokens",
152
+ filename=rel,
153
+ local_dir=here,
154
+ token=token,
155
+ )
156
+
157
+ if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")):
158
+ print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True)
159
+ snapshot_download(
160
+ repo_id="Qwen/Qwen3-0.6B",
161
+ local_dir=qwen_dir,
162
+ ignore_patterns=["*.bin", "*.safetensors"],
163
+ )
164
+
165
+ print("[demo] All checkpoints ready.", flush=True)
166
+
167
+
168
+ _ensure_models_downloaded()
169
+
170
+
171
+ from src.data.dataset import DatasetConfig, RigDatasetModule
172
+ from src.data.transform import Transform
173
+ from src.model.tokenrig import TokenRigResult
174
+ from src.tokenizer.parse import get_tokenizer
175
+ from src.server.spec import (
176
+ BPY_SERVER,
177
+ get_model,
178
+ object_to_bytes,
179
+ bytes_to_object,
180
+ )
181
+ from src.data.vertex_group import voxel_skin
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Pre-warm `bpy_server` in the main (Gradio) process at module load.
186
+ #
187
+ # Why this is necessary on ZeroGPU: each user request runs inside a fresh
188
+ # `@spaces.GPU` worker process with a hard time budget (≈60 s on free tier).
189
+ # Importing the Blender shared object inside that budget burns 30–60 s, so
190
+ # the worker is killed *during* bpy import — manifesting as
191
+ # "GPU task aborted" before any model code runs.
192
+ #
193
+ # We start `bpy_server.py` here, in the always-running main process, so the
194
+ # slow bpy import happens exactly once at Space boot. Workers then just hit
195
+ # `localhost:59876` over HTTP — sub-millisecond, no startup cost.
196
+ # ---------------------------------------------------------------------------
197
+
198
+ MODEL_CKPTS = [
199
+ "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
200
+ ]
201
+
202
+ HF_PATHS = [
203
+ "None",
204
+ ]
205
+
206
+
207
+ def get_dataloader_workers() -> int:
208
+ if os.getenv("SPACE_ID"):
209
+ return 0
210
+ return 1
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # bpy_server lifecycle — lazy start so the heavy import doesn't fight ZeroGPU
215
+ # during module load.
216
+ # ---------------------------------------------------------------------------
217
+ _BPY_SERVER_PROC = None
218
+
219
+
220
+ def is_bpy_server_alive(timeout: float = 1.0) -> bool:
221
+ try:
222
+ resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout)
223
+ return resp.status_code == 200
224
+ except Exception:
225
+ return False
226
+
227
+
228
+ def start_bpy_server():
229
+ proc = subprocess.Popen(
230
+ [sys.executable, "bpy_server.py"],
231
+ stdout=None,
232
+ stderr=None,
233
+ preexec_fn=os.setsid,
234
+ )
235
+ print(f"[Main] bpy_server.py started (pid={proc.pid})")
236
+
237
+ def cleanup():
238
+ print(f"[Main] Terminating bpy_server.py (pid={proc.pid})")
239
+ try:
240
+ os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
241
+ except ProcessLookupError:
242
+ pass
243
+
244
+ atexit.register(cleanup)
245
+ return proc
246
+
247
+
248
+ def wait_for_bpy_server(timeout: float = 120):
249
+ """Wait for bpy_server.py to come up. The first start of bpy_server is
250
+ slow because importing the Blender `.so` (~200 MB shared object) takes
251
+ 30–60 s on a cold container. We allow up to 120 s."""
252
+ t0 = time.time()
253
+ last_log = 0.0
254
+ while True:
255
+ try:
256
+ requests.get(f"{BPY_SERVER}/ping", timeout=1)
257
+ print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)")
258
+ return
259
+ except Exception:
260
+ now = time.time()
261
+ if now - t0 > timeout:
262
+ raise RuntimeError(
263
+ f"bpy_server failed to start after {timeout:.0f}s"
264
+ )
265
+ if now - last_log > 10: # progress every 10s
266
+ print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)")
267
+ last_log = now
268
+ time.sleep(0.5)
269
+
270
+
271
+ def ensure_bpy_server_started():
272
+ global _BPY_SERVER_PROC
273
+ if is_bpy_server_alive():
274
+ return
275
+ if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None:
276
+ return
277
+ _BPY_SERVER_PROC = start_bpy_server()
278
+ wait_for_bpy_server()
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Lazy model loading.
283
+ # ---------------------------------------------------------------------------
284
+ model = None
285
+ tokenizer = None
286
+ transform = None
287
+ CURRENT_MODEL_CKPT: Optional[str] = None
288
+ CURRENT_HF_PATH: Optional[str] = None
289
+
290
+
291
+ def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]:
292
+ global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH
293
+ if hf_path == "None":
294
+ hf_path = None
295
+ if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH:
296
+ return ("Model already loaded.", model_ckpt)
297
+
298
+ if not model_ckpt:
299
+ raise RuntimeError("model_ckpt is empty. Please select a checkpoint.")
300
+
301
+ print(f"Loading model: {model_ckpt}, hf_path={hf_path}")
302
+ model = get_model(model_ckpt, hf_path=hf_path)
303
+ assert model.tokenizer_config is not None
304
+ tokenizer = get_tokenizer(**model.tokenizer_config)
305
+ transform = Transform.parse(**model.transform_config["predict_transform"])
306
+ CURRENT_MODEL_CKPT = model_ckpt
307
+ CURRENT_HF_PATH = hf_path
308
+ return ("Model loaded.", model_ckpt)
309
+
310
+
311
+ # ---------------------------------------------------------------------------
312
+ # File utilities (CLI-side).
313
+ # ---------------------------------------------------------------------------
314
+ SUPPORTED_EXT = {".obj", ".fbx", ".glb"}
315
+
316
+
317
+ def collect_files(input_path: Path) -> List[Path]:
318
+ if input_path.is_file():
319
+ return [input_path]
320
+
321
+ files = []
322
+ for p in input_path.rglob("*"):
323
+ if p.suffix.lower() in SUPPORTED_EXT:
324
+ files.append(p)
325
+ return files
326
+
327
+
328
+ def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path:
329
+ rel = in_path.relative_to(input_root)
330
+ return (output_root / rel).with_suffix(".glb")
331
+
332
+
333
+ # ---------------------------------------------------------------------------
334
+ # Core inference (shared by CLI and Gradio).
335
+ # ---------------------------------------------------------------------------
336
+ def run_rig(
337
+ filepaths: List[Path],
338
+ top_k: int,
339
+ top_p: float,
340
+ temperature: float,
341
+ repetition_penalty: float,
342
+ num_beams: int,
343
+ use_skeleton: bool,
344
+ use_transfer: bool,
345
+ use_postprocess: bool,
346
+ output_paths: List[Path],
347
+ model_ckpt: str,
348
+ hf_path: Optional[str],
349
+ ):
350
+ assert len(filepaths) == len(output_paths)
351
+ ensure_bpy_server_started()
352
+ load_model(model_ckpt, hf_path)
353
+
354
+ datapath = {
355
+ "data_name": None,
356
+ "loader": "bpy_server",
357
+ "filepaths": {"articulation": [str(p) for p in filepaths]},
358
+ }
359
+
360
+ dataset_config = DatasetConfig.parse(
361
+ shuffle=False,
362
+ batch_size=1,
363
+ num_workers=get_dataloader_workers(),
364
+ pin_memory=get_dataloader_workers() > 0,
365
+ persistent_workers=False,
366
+ datapath=datapath,
367
+ ).split_by_cls()
368
+
369
+ module = RigDatasetModule(
370
+ predict_dataset_config=dataset_config,
371
+ predict_transform=transform,
372
+ tokenizer=tokenizer,
373
+ process_fn=model._process_fn,
374
+ )
375
+
376
+ dataloader = module.predict_dataloader()["articulation"]
377
+
378
+ results_out = []
379
+ infer_device = model.device if model is not None else "cuda"
380
+
381
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
382
+ batch = {
383
+ k: v.to(infer_device) if isinstance(v, Tensor) else v
384
+ for k, v in batch.items()
385
+ }
386
+
387
+ if not use_skeleton:
388
+ batch.pop("skeleton_tokens", None)
389
+ batch.pop("skeleton_mask", None)
390
+
391
+ batch["generate_kwargs"] = dict(
392
+ max_length=2048,
393
+ top_k=int(top_k),
394
+ top_p=float(top_p),
395
+ temperature=float(temperature),
396
+ repetition_penalty=float(repetition_penalty),
397
+ num_return_sequences=1,
398
+ num_beams=int(num_beams),
399
+ do_sample=True,
400
+ )
401
+
402
+ if "skeleton_tokens" in batch and "skeleton_mask" in batch:
403
+ mask = batch["skeleton_mask"][0] == 1
404
+ skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy()
405
+ else:
406
+ skeleton_tokens = None
407
+
408
+ preds: List[TokenRigResult] = model.predict_step(
409
+ batch,
410
+ skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None,
411
+ make_asset=True,
412
+ )["results"]
413
+
414
+ asset = preds[0].asset
415
+ assert asset is not None
416
+
417
+ if use_postprocess:
418
+ voxel = asset.voxel(resolution=196)
419
+ asset.skin *= voxel_skin(
420
+ grid=0,
421
+ grid_coords=voxel.coords,
422
+ joints=asset.joints,
423
+ vertices=asset.vertices,
424
+ faces=asset.faces,
425
+ mode="square",
426
+ voxel_size=voxel.voxel_size,
427
+ )
428
+ asset.normalize_skin()
429
+
430
+ out_path = output_paths[i]
431
+ out_path.parent.mkdir(parents=True, exist_ok=True)
432
+
433
+ if use_transfer:
434
+ payload = dict(
435
+ source_asset=asset,
436
+ target_path=asset.path,
437
+ export_path=str(out_path),
438
+ group_per_vertex=4,
439
+ )
440
+ res = bytes_to_object(
441
+ requests.post(
442
+ f"{BPY_SERVER}/transfer",
443
+ data=object_to_bytes(payload),
444
+ ).content
445
+ )
446
+ else:
447
+ payload = dict(
448
+ asset=asset,
449
+ filepath=str(out_path),
450
+ group_per_vertex=4,
451
+ )
452
+ res = bytes_to_object(
453
+ requests.post(
454
+ f"{BPY_SERVER}/export",
455
+ data=object_to_bytes(payload),
456
+ ).content
457
+ )
458
+
459
+ if res != "ok":
460
+ print(f"[Error] {res}")
461
+ else:
462
+ print(f"[OK] Exported: {out_path}")
463
+
464
+ results_out.append(out_path)
465
+
466
+ return results_out
467
+
468
+
469
+ # ---------------------------------------------------------------------------
470
+ # CLI entry point.
471
+ # ---------------------------------------------------------------------------
472
+ def run_cli(args):
473
+ input_path = Path(args.input).resolve()
474
+ output_path = Path(args.output).resolve()
475
+
476
+ files = collect_files(input_path)
477
+ if not files:
478
+ raise RuntimeError("No valid 3D files found.")
479
+
480
+ if len(files) == 1 and output_path.suffix:
481
+ outputs = [output_path]
482
+ else:
483
+ outputs = [map_output_path(f, input_path, output_path) for f in files]
484
+
485
+ run_rig(
486
+ files,
487
+ args.top_k,
488
+ args.top_p,
489
+ args.temperature,
490
+ args.repetition_penalty,
491
+ args.num_beams,
492
+ args.use_skeleton,
493
+ args.use_transfer,
494
+ args.use_postprocess,
495
+ outputs,
496
+ args.model_ckpt,
497
+ args.hf_path,
498
+ )
499
+
500
+
501
+ # ---------------------------------------------------------------------------
502
+ # Gradio wrapper (with ZeroGPU duration estimator).
503
+ # ---------------------------------------------------------------------------
504
+ TOT = 0
505
+
506
+
507
+ def _gpu_duration(
508
+ files,
509
+ top_k,
510
+ top_p,
511
+ temperature,
512
+ repetition_penalty,
513
+ num_beams,
514
+ use_skeleton,
515
+ use_transfer,
516
+ use_postprocess,
517
+ model_ckpt,
518
+ hf_path,
519
+ ):
520
+ # Cold workers spend ~30–60 s importing bpy + loading the model before
521
+ # any GPU work. Give every request a generous 240 s floor.
522
+ file_count = len(files) if files is not None else 1
523
+ return min(900, max(240, 240 + 60 * file_count))
524
+
525
+
526
+ @spaces.GPU(duration=_gpu_duration)
527
+ def run_gradio(
528
+ files,
529
+ top_k,
530
+ top_p,
531
+ temperature,
532
+ repetition_penalty,
533
+ num_beams,
534
+ use_skeleton,
535
+ use_transfer,
536
+ use_postprocess,
537
+ model_ckpt,
538
+ hf_path,
539
+ ):
540
+ if not files:
541
+ return "Please upload at least one 3D model.", None
542
+
543
+ tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_"))
544
+ filepaths = [Path(f.name) for f in files]
545
+ global TOT
546
+ outputs = []
547
+ for filepath in filepaths:
548
+ TOT += 1
549
+ outputs.append(tmp_out / f"res_{TOT}.glb")
550
+
551
+ run_rig(
552
+ filepaths,
553
+ top_k,
554
+ top_p,
555
+ temperature,
556
+ repetition_penalty,
557
+ num_beams,
558
+ use_skeleton,
559
+ use_transfer,
560
+ use_postprocess,
561
+ outputs,
562
+ model_ckpt,
563
+ hf_path,
564
+ )
565
+
566
+ return f"Processed {len(outputs)} models.", [str(p) for p in outputs]
567
+
568
+
569
+ # ---------------------------------------------------------------------------
570
+ # Gradio UI.
571
+ # ---------------------------------------------------------------------------
572
+ def build_gradio_app():
573
+ model_ckpts = MODEL_CKPTS
574
+ hf_paths = HF_PATHS
575
+ default_ckpt = model_ckpts[0] if model_ckpts else ""
576
+ default_hf = hf_paths[0] if hf_paths else "None"
577
+
578
+ with gr.Blocks(title="SkinTokens · TokenRig Demo") as app:
579
+ gr.Markdown(
580
+ """
581
+ ## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig
582
+
583
+ Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified
584
+ autoregressive model over learned *SkinTokens*. Successor to
585
+ [UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH&nbsp;'25).
586
+
587
+ * Upload one or more meshes → click **Run** → download a rigged `.glb`.
588
+ * **Paper**: [arXiv&nbsp;2602.04805](https://arxiv.org/abs/2602.04805) &nbsp;·&nbsp;
589
+ **Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens) &nbsp;·&nbsp;
590
+ **Weights**: [🤗&nbsp;VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens)
591
+ * Looking for **image → rigged 3D** instead? Try our sibling Space
592
+ [🤗&nbsp;VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen).
593
+ * Want a full AI-powered 3D workspace? → [Tripo](https://www.tripo3d.ai)
594
+ """
595
+ )
596
+
597
+ gr.HTML(
598
+ """
599
+ <style>
600
+ @keyframes gentle-pulse {
601
+ 0%, 100% { opacity: 1; }
602
+ 50% { opacity: 0.35; }
603
+ }
604
+ </style>
605
+ <div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin: 4px 0 -4px 0;">
606
+ <span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">&#128161; <b>Tips</b></span>&ensp;
607
+ Defaults work well for most meshes.
608
+ &nbsp;• If your mesh already has a skeleton and you only want skinning, enable
609
+ <b>Use existing skeleton</b> below.
610
+ &nbsp;• To keep your original textures and world scale, enable <b>Preserve original texture &amp; scale</b>.
611
+ </div>
612
+ """
613
+ )
614
+
615
+ with gr.Row():
616
+ with gr.Column(scale=1):
617
+ files = gr.File(
618
+ label="3D Models ( .obj / .fbx / .glb, up to a few at a time )",
619
+ file_count="multiple",
620
+ file_types=[".obj", ".fbx", ".glb"],
621
+ )
622
+
623
+ with gr.Accordion("⚙️ Generation Settings", open=False):
624
+ model_ckpt = gr.Dropdown(
625
+ choices=model_ckpts,
626
+ value=default_ckpt,
627
+ label="Model checkpoint",
628
+ info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.",
629
+ interactive=True,
630
+ )
631
+ # Keep the hf_path component for callback compatibility, but hide it
632
+ # from the UI since it currently only exposes the default ("None") option.
633
+ hf_path = gr.Dropdown(
634
+ choices=hf_paths,
635
+ value=default_hf,
636
+ label="HF path (advanced)",
637
+ visible=False,
638
+ )
639
+
640
+ gr.Markdown("**Sampling parameters** — control autoregressive decoding of the rig.")
641
+ top_k = gr.Slider(
642
+ 1, 200, value=5, step=1,
643
+ label="top_k",
644
+ info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.",
645
+ )
646
+ top_p = gr.Slider(
647
+ 0.1, 1.0, value=0.95, step=0.01,
648
+ label="top_p (nucleus)",
649
+ info="Sample from the smallest set of tokens whose cumulative probability ≥ p.",
650
+ )
651
+ temperature = gr.Slider(
652
+ 0.1, 2.0, value=1.0, step=0.1,
653
+ label="temperature",
654
+ info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).",
655
+ )
656
+ repetition_penalty = gr.Slider(
657
+ 0.5, 3.0, value=2.0, step=0.1,
658
+ label="repetition_penalty",
659
+ info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.",
660
+ )
661
+ num_beams = gr.Slider(
662
+ 1, 20, value=10, step=1,
663
+ label="num_beams",
664
+ info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.",
665
+ )
666
+
667
+ gr.Markdown("**Pipeline toggles**")
668
+ use_skeleton = gr.Checkbox(
669
+ False,
670
+ label="Use existing skeleton (predict skinning only)",
671
+ info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.",
672
+ )
673
+ use_transfer = gr.Checkbox(
674
+ False,
675
+ label="Preserve original texture & scale",
676
+ info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.",
677
+ )
678
+ use_postprocess = gr.Checkbox(
679
+ False,
680
+ label="Voxel skin post-processing",
681
+ info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.",
682
+ )
683
+
684
+ run_btn = gr.Button("🚀 Run", variant="primary")
685
+
686
+ with gr.Column(scale=1):
687
+ log = gr.Textbox(label="Status", lines=2, interactive=False)
688
+ output = gr.File(label="Rigged GLB output", interactive=False)
689
+ gr.Markdown(
690
+ """
691
+ **Notes**
692
+ - The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File → Import → glTF&nbsp;2.0) or any DCC tool that reads glTF.
693
+ - In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it.
694
+ - On busy moments Zero-GPU may queue your request for ~10–30&nbsp;s before inference starts — the status box will update once the GPU is attached.
695
+ - Please do **not** upload confidential or NSFW content. See the
696
+ [project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the
697
+ [code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference.
698
+ """
699
+ )
700
+
701
+ run_btn.click(
702
+ run_gradio,
703
+ inputs=[
704
+ files,
705
+ top_k,
706
+ top_p,
707
+ temperature,
708
+ repetition_penalty,
709
+ num_beams,
710
+ use_skeleton,
711
+ use_transfer,
712
+ use_postprocess,
713
+ model_ckpt,
714
+ hf_path,
715
+ ],
716
+ outputs=[log, output],
717
+ )
718
+
719
+ return app
720
+
721
+
722
+ demo = build_gradio_app()
723
+
724
+
725
+ # Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py`
726
+ # transitively imports `src.model.michelangelo.utils.misc`, whose
727
+ # module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)`
728
+ # at import time. That call fails ("RuntimeError: No CUDA GPUs are
729
+ # available") in the main Gradio process on ZeroGPU, where the GPU is only
730
+ # attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot
731
+ # happens on first request, inside the worker.
732
+
733
+
734
+ # ---------------------------------------------------------------------------
735
+ # Entry point.
736
+ # ---------------------------------------------------------------------------
737
+ if __name__ == "__main__":
738
+ parser = argparse.ArgumentParser("TokenRig Demo")
739
+ parser.add_argument("--input", help="Input file or directory")
740
+ parser.add_argument("--output", help="Output file or directory")
741
+
742
+ parser.add_argument("--top_k", type=int, default=5)
743
+ parser.add_argument("--top_p", type=float, default=0.95)
744
+ parser.add_argument("--temperature", type=float, default=1.0)
745
+ parser.add_argument("--repetition_penalty", type=float, default=2.0)
746
+ parser.add_argument("--num_beams", type=int, default=10)
747
+
748
+ parser.add_argument("--use_skeleton", action="store_true")
749
+ parser.add_argument("--use_transfer", action="store_true")
750
+ parser.add_argument("--use_postprocess", action="store_true")
751
+
752
+ parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "")
753
+ parser.add_argument("--hf_path", default=None)
754
+
755
+ parser.add_argument("--gradio", action="store_true")
756
+
757
+ args = parser.parse_args()
758
+
759
+ if args.gradio or not args.input:
760
+ demo.queue()
761
+ demo.launch(ssr_mode=False)
762
+ else:
763
+ ensure_bpy_server_started()
764
+ run_cli(args)
download.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download, snapshot_download
2
+
3
+ import argparse
4
+
5
+ REPO_ID = "VAST-AI/SkinTokens"
6
+
7
+ MODELS = [
8
+ "experiments/skin_vae_2_10_32768/last.ckpt",
9
+ "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
10
+ ]
11
+
12
+ DATASETS = [
13
+ "rignet.zip",
14
+ "articulation.zip",
15
+ ]
16
+
17
+ LLM_REPO = "Qwen/Qwen3-0.6B"
18
+ LLM_LOCAL_DIR = "models/Qwen3-0.6B"
19
+
20
+
21
+ def download_model(name: str):
22
+ local_path = hf_hub_download(
23
+ repo_id=REPO_ID,
24
+ filename=name,
25
+ local_dir=".",
26
+ )
27
+ print(f"[MODEL] {name} downloaded to: {local_path}")
28
+
29
+
30
+ def download_llm():
31
+ local_path = snapshot_download(
32
+ repo_id=LLM_REPO,
33
+ local_dir=LLM_LOCAL_DIR,
34
+ ignore_patterns=["*.bin", "*.safetensors"],
35
+ )
36
+ print(f"[LLM] Config downloaded to: {local_path}")
37
+
38
+
39
+ def download_data(name: str):
40
+ local_path = hf_hub_download(
41
+ repo_id=REPO_ID,
42
+ filename=f"dataset_clean/{name}",
43
+ local_dir=".",
44
+ )
45
+ name = name.removesuffix(".zip")
46
+ local_path = snapshot_download(
47
+ repo_id=REPO_ID,
48
+ allow_patterns=[f"datalist/{name}/*"],
49
+ local_dir=".",
50
+ )
51
+ print(f"[DATA] {name} downloaded to: {local_path}")
52
+
53
+
54
+ def main():
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--model", action="store_true", help="Download model checkpoints")
57
+ parser.add_argument("--data", action="store_true", help="Download datasets")
58
+ args = parser.parse_args()
59
+ if not args.model and not args.data:
60
+ print("Please specify --model or --data")
61
+ return
62
+ if args.model:
63
+ for model in MODELS:
64
+ download_model(model)
65
+ download_llm()
66
+ if args.data:
67
+ for data in DATASETS:
68
+ download_data(data)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ --extra-index-url https://pypi.org/simple
3
+
4
+ # Pinned to match the flash-attn wheel below (cu12torch2.9cxx11abiTRUE).
5
+ # Don't bump torch without also bumping the flash-attn URL — they must agree on
6
+ # the (cu12 / torch 2.9 / cxx11-abi=TRUE / cp312) tuple.
7
+ torch==2.9.1
8
+ torchvision==0.24.1
9
+ torchaudio==2.9.1
10
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
11
+ transformers==4.57.0
12
+ diffusers==0.36.0
13
+ python-box
14
+ einops
15
+ omegaconf
16
+ pytorch_lightning
17
+ lightning
18
+ addict
19
+ timm
20
+ fast-simplification
21
+ trimesh
22
+ open3d
23
+ pyrender
24
+ # bpy is NOT listed here. The bpy wheel committed in this repo
25
+ # (`bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl`) is installed at
26
+ # runtime by `demo.py` — see `_ensure_bpy_installed()`. Reason: HF Spaces'
27
+ # Docker build mounts only `requirements.txt` BEFORE the repo COPY, so the
28
+ # wheel path doesn't exist at pip-install time. Public PyPI also has no bpy
29
+ # wheel matching this exact build (`rc0` / manylinux_2_39 / cp312).
30
+ huggingface_hub
31
+ spaces
32
+ wandb
33
+ numpy==2.2.6
34
+ gradio
35
+ bottle
36
+ tornado
37
+ cython
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11
src/__init__.py ADDED
File without changes
src/data/augment.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Union, List, Optional, Dict
4
+ from numpy import ndarray
5
+ from abc import ABC, abstractmethod
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+ import numpy as np
9
+ import random
10
+
11
+ from .spec import ConfigSpec
12
+
13
+ from ..rig_package.utils import axis_angle_to_matrix
14
+ from ..rig_package.info.asset import Asset
15
+
16
+ @dataclass(frozen=True)
17
+ class Augment(ConfigSpec):
18
+
19
+ @classmethod
20
+ @abstractmethod
21
+ def parse(cls, **kwags) -> 'Augment':
22
+ pass
23
+
24
+ @abstractmethod
25
+ def transform(self, asset: Asset, **kwargs):
26
+ pass
27
+
28
+ @dataclass(frozen=True)
29
+ class AugmentTrim(Augment):
30
+ """randomly delete joints and vertices"""
31
+
32
+ @classmethod
33
+ def parse(cls, **kwargs) -> 'AugmentTrim':
34
+ cls.check_keys(kwargs)
35
+ return AugmentTrim()
36
+
37
+ def transform(self, asset: Asset, **kwargs):
38
+ asset.trim_skeleton()
39
+
40
+ @dataclass(frozen=True)
41
+ class AugmentDelete(Augment):
42
+ """randomly delete joints and vertices"""
43
+
44
+ # probability
45
+ p: float
46
+
47
+ # how much to keep
48
+ rate: float
49
+
50
+ @classmethod
51
+ def parse(cls, **kwargs) -> 'AugmentDelete':
52
+ cls.check_keys(kwargs)
53
+ return AugmentDelete(
54
+ p=kwargs.get('p', 0.),
55
+ rate=kwargs.get('rate', 0.5),
56
+ )
57
+
58
+ def transform(self, asset: Asset, **kwargs):
59
+ if asset.skin is None:
60
+ raise ValueError("do not have skin")
61
+ if asset.parents is None:
62
+ raise ValueError("do not have parents")
63
+ asset.normalize_skin()
64
+ def select_k(arr: List, k: int):
65
+ if len(arr) <= k:
66
+ return arr
67
+ else:
68
+ rest_indices = list(range(1, len(arr)))
69
+ selected_indices = sorted(random.sample(rest_indices, k))
70
+ return [arr[i] for i in selected_indices]
71
+ if np.random.rand() >= self.p:
72
+ return
73
+ ids = select_k([i for i in range(asset.J)], max(int(asset.J * (1 - np.random.rand() * self.rate)), 1))
74
+ if len(ids) == 0:
75
+ return
76
+ # keep bones with no skin
77
+ keep = {}
78
+ for id in ids:
79
+ keep[id] = True
80
+ for id in range(asset.J):
81
+ if np.all(asset.skin[:, id] < 0.1):
82
+ keep[id] = True
83
+ keep[asset.root] = True
84
+
85
+ vertices_to_delete = np.zeros(asset.N, dtype=bool)
86
+ for id in range(asset.J):
87
+ if id not in keep:
88
+ dominant = asset.skin.argmax(axis=1) == id
89
+ x = (asset.skin[:, id] > 0.1) & dominant
90
+ if np.all(~x) or x.sum() * asset.J < asset.N: # avoid collapsing
91
+ keep[id] = 1
92
+ continue
93
+ vertices_to_delete[x] = True
94
+ if np.all(vertices_to_delete):
95
+ return
96
+ if asset.faces is not None:
97
+ indices = np.where(~vertices_to_delete)[0]
98
+ face_mask = np.all(np.isin(asset.faces, indices), axis=1)
99
+ if np.all(~face_mask):
100
+ return
101
+
102
+ joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep]
103
+ asset.delete_joints(joints_to_delete)
104
+ asset.delete_vertices(np.arange(asset.N)[vertices_to_delete])
105
+
106
+ @dataclass(frozen=True)
107
+ class AugmentDropPart(Augment):
108
+ """randomly drop subtrees and their vertices"""
109
+
110
+ # probability
111
+ p: float
112
+
113
+ # drop rate
114
+ rate: float
115
+
116
+ @classmethod
117
+ def parse(cls, **kwargs) -> 'AugmentDropPart':
118
+ cls.check_keys(kwargs)
119
+ return AugmentDropPart(
120
+ p=kwargs.get('p', 0.),
121
+ rate=kwargs.get('rate', 0.5),
122
+ )
123
+
124
+ def transform(self, asset: Asset, **kwargs):
125
+ if np.random.rand() >= self.p:
126
+ return
127
+ if asset.parents is None:
128
+ raise ValueError("do not have parents")
129
+ if asset.skin is None:
130
+ raise ValueError("do not have skin")
131
+ keep = []
132
+ for id in range(asset.J):
133
+ if np.random.rand() < self.rate:
134
+ keep.append(id)
135
+ if len(keep) == 0:
136
+ return
137
+ for id in reversed(asset.dfs_order):
138
+ p = asset.parents[id]
139
+ if p == -1:
140
+ continue
141
+ if id in keep and p not in keep:
142
+ keep.append(p)
143
+
144
+ mask = np.zeros(asset.N, dtype=bool)
145
+ for id in keep:
146
+ mask[asset.skin[:, id] > 1e-5] = True
147
+ vertices_to_delete = ~mask
148
+ if np.all(vertices_to_delete):
149
+ return
150
+ if asset.faces is not None:
151
+ indices = np.where(~vertices_to_delete)[0]
152
+ face_mask = np.all(np.isin(asset.faces, indices), axis=1)
153
+ if np.all(~face_mask):
154
+ return
155
+
156
+ joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep]
157
+ asset.delete_joints(joints_to_delete)
158
+ asset.delete_vertices(np.arange(asset.N)[vertices_to_delete])
159
+
160
+ def inverse(self, asset: Asset):
161
+ pass
162
+
163
+ @dataclass(frozen=True)
164
+ class AugmentCollapse(Augment):
165
+ """randomly merge joints"""
166
+
167
+ # collapse the skeleton with probability p
168
+ p: float
169
+
170
+ # probability to merge the bone
171
+ rate: float
172
+
173
+ # max bones
174
+ max_bones: int
175
+
176
+ @classmethod
177
+ def parse(cls, **kwargs) -> 'AugmentCollapse':
178
+ cls.check_keys(kwargs)
179
+ return AugmentCollapse(
180
+ p=kwargs.get('p', 0.),
181
+ rate=kwargs.get('rate', 0.),
182
+ max_bones=kwargs.get('max_bones', 2147483647),
183
+ )
184
+
185
+ def transform(self, asset: Asset, **kwargs):
186
+ def select_k(arr: List, k: int):
187
+ if len(arr) <= k:
188
+ return arr
189
+ else:
190
+ rest_indices = list(range(1, len(arr)))
191
+ selected_indices = sorted(random.sample(rest_indices, k))
192
+ return [arr[i] for i in selected_indices]
193
+
194
+ root = asset.root
195
+ if np.random.rand() < self.p:
196
+ ids = []
197
+ for id in range(asset.J):
198
+ if np.random.rand() >= self.rate:
199
+ ids.append(id)
200
+ if root not in ids:
201
+ ids.append(root)
202
+ keep: List[int|str] = select_k([i for i in range(asset.J) if i in ids], self.max_bones)
203
+ if root not in keep:
204
+ keep[0] = root
205
+ asset.set_order(new_orders=keep)
206
+ elif asset.J > self.max_bones:
207
+ ids = select_k([i for i in range(asset.J)], k=self.max_bones)
208
+ if root not in ids:
209
+ ids[0] = root
210
+ keep: List[int|str] = [i for i in range(asset.J) if i in ids]
211
+ asset.set_order(new_orders=keep)
212
+
213
+ @dataclass(frozen=True)
214
+ class AugmentJointDiscrete(Augment):
215
+ # perturb the skeleton with probability p
216
+ p: float
217
+
218
+ # num of discretized coord
219
+ discrete: int
220
+
221
+ # continuous range
222
+ continuous_range: Tuple[float, float]
223
+
224
+ @classmethod
225
+ def parse(cls, **kwargs) -> 'AugmentJointDiscrete':
226
+ cls.check_keys(kwargs)
227
+ return AugmentJointDiscrete(
228
+ p=kwargs.get('p', 0.),
229
+ discrete=kwargs.get('discrete', 256),
230
+ continuous_range=kwargs.get('continuous_range', [-1., 1.]),
231
+ )
232
+
233
+ def _discretize(
234
+ self,
235
+ t: ndarray,
236
+ continuous_range: Tuple[float, float],
237
+ num_discrete: int,
238
+ ) -> ndarray:
239
+ lo, hi = continuous_range
240
+ assert hi >= lo
241
+ t = (t - lo) / (hi - lo)
242
+ t *= num_discrete
243
+ return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64)
244
+
245
+ def _undiscretize(
246
+ self,
247
+ t: ndarray,
248
+ continuous_range: Tuple[float, float],
249
+ num_discrete: int,
250
+ ) -> ndarray:
251
+ lo, hi = continuous_range
252
+ assert hi >= lo
253
+ t = t.astype(np.float32) + 0.5
254
+ t /= num_discrete
255
+ return t * (hi - lo) + lo
256
+
257
+ def transform(self, asset: Asset, **kwargs):
258
+ if np.random.rand() < self.p:
259
+ joints = asset.joints
260
+ if joints is not None and asset.matrix_local is not None:
261
+ joints = self._undiscretize(self._discretize(
262
+ joints,
263
+ self.continuous_range,
264
+ self.discrete,
265
+ ),
266
+ self.continuous_range,
267
+ self.discrete,
268
+ )
269
+ asset.matrix_local[:, :3, 3] = joints
270
+
271
+ @dataclass(frozen=True)
272
+ class AugmentJointPerturb(Augment):
273
+ # perturb the skeleton with probability p
274
+ p: float
275
+
276
+ # jitter sigma on joints
277
+ sigma: float
278
+
279
+ # jitter clip on joints
280
+ clip: float
281
+
282
+ @classmethod
283
+ def parse(cls, **kwargs) -> 'AugmentJointPerturb':
284
+ cls.check_keys(kwargs)
285
+ return AugmentJointPerturb(
286
+ p=kwargs.get('p', 0.),
287
+ sigma=kwargs.get('sigma', 0.),
288
+ clip=kwargs.get('clip', 0.),
289
+ )
290
+
291
+ def transform(self, asset: Asset, **kwargs):
292
+ if np.random.rand() < self.p and asset.matrix_local is not None:
293
+ asset.matrix_local[:, :3] += np.clip(
294
+ np.random.normal(0, self.sigma, (asset.J, 3)),
295
+ -self.clip,
296
+ self.clip,
297
+ )
298
+
299
+ @dataclass(frozen=True)
300
+ class AugmentLBS(Augment):
301
+ # apply a random pose with probability p
302
+ random_pose_p: float
303
+
304
+ # random pose angle range
305
+ random_pose_angle: float
306
+
307
+ # random scale
308
+ random_scale_range: Tuple[float, float]
309
+
310
+ @classmethod
311
+ def parse(cls, **kwargs) -> 'AugmentLBS':
312
+ cls.check_keys(kwargs)
313
+ return AugmentLBS(
314
+ random_pose_p=kwargs.get('random_pose_p', 0.),
315
+ random_pose_angle=kwargs.get('random_pose_angle', 0.),
316
+ random_scale_range=kwargs.get('random_scale_range', (1., 1.)),
317
+ )
318
+
319
+ def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
320
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
321
+
322
+ def transform(self, asset: Asset, **kwargs):
323
+ def get_matrix_basis(angle: float):
324
+ matrix = axis_angle_to_matrix((np.random.rand(asset.J, 3) - 0.5) * angle / 180 * np.pi * 2).astype(np.float32)
325
+ return matrix
326
+
327
+ if np.random.rand() < self.random_pose_p and asset.joints is not None:
328
+ matrix_basis = get_matrix_basis(self.random_pose_angle)
329
+ max_offset = (asset.joints.max(axis=0) - asset.joints.min(axis=0)).max()
330
+ matrix_basis[:, :3, :3] *= np.tile(np.random.uniform(low=self.random_scale_range[0], high=self.random_scale_range[1], size=(asset.J, 1, 1)), (1, 3, 3))
331
+ asset.vertices_with_pose(matrix_basis=matrix_basis, inplace=True)
332
+
333
+ @dataclass(frozen=True)
334
+ class AugmentLinear(Augment):
335
+ # apply random rotation with probability p
336
+ random_rotate_p: float
337
+
338
+ # random rotation angle(degree)
339
+ random_rotate_angle: float
340
+
341
+ # swap x with probability p
342
+ random_flip_x_p: float
343
+
344
+ # swap y with probability p
345
+ random_flip_y_p: float
346
+
347
+ # swap z with probability p
348
+ random_flip_z_p: float
349
+
350
+ # probability to pick an angle in static_rotate_x
351
+ static_rotate_x_p: float
352
+
353
+ # rotate around x axis among given angles(degrees)
354
+ static_rotate_x: List[float]
355
+
356
+ # probability to pick an angle in static_rotate_y
357
+ static_rotate_y_p: float
358
+
359
+ # rotate around y axis among given angles(degrees)
360
+ static_rotate_y: List[float]
361
+
362
+ # probability to pick an angle in static_rotate_z
363
+ static_rotate_z_p: float
364
+
365
+ # rotate around z axis among given angles(degrees)
366
+ static_rotate_z: List[float]
367
+
368
+ # apply random scaling with probability p
369
+ random_scale_p: float
370
+
371
+ # random scaling xyz axis
372
+ random_scale: Tuple[float, float]
373
+
374
+ # randomly change xyz orientation
375
+ random_transpose: float
376
+
377
+ @classmethod
378
+ def parse(cls, **kwargs) -> 'AugmentLinear':
379
+ if kwargs.get('random_flip_x_p', 0) > 0 or kwargs.get('random_flip_y_p', 0) > 0 or kwargs.get('random_flip_z_p', 0) > 0:
380
+ print("\033[31mWARNING: random flip is enabled and is very likely to confuse ar model !\033[0m")
381
+ cls.check_keys(kwargs)
382
+ return AugmentLinear(
383
+ random_rotate_p=kwargs.get('random_rotate_p', 0.),
384
+ random_rotate_angle=kwargs.get('random_rotate_angle', 0.),
385
+ random_flip_x_p=kwargs.get('random_flip_x_p', 0.),
386
+ random_flip_y_p=kwargs.get('random_flip_y_p', 0.),
387
+ random_flip_z_p=kwargs.get('random_flip_z_p', 0.),
388
+ static_rotate_x_p=kwargs.get('static_rotate_x_p', 0.),
389
+ static_rotate_x=kwargs.get('static_rotate_x', []),
390
+ static_rotate_y_p=kwargs.get('static_rotate_y_p', 0.),
391
+ static_rotate_y=kwargs.get('static_rotate_y', []),
392
+ static_rotate_z_p=kwargs.get('static_rotate_z_p', 0.),
393
+ static_rotate_z=kwargs.get('static_rotate_z', []),
394
+ random_scale_p=kwargs.get('random_scale_p', 0.),
395
+ random_scale=kwargs.get('random_scale', [1.0, 1.0]),
396
+ random_transpose=kwargs.get('random_transpose', 0.),
397
+ )
398
+
399
+ def _apply(self, v: ndarray, trans: ndarray) -> ndarray:
400
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
401
+
402
+ def transform(self, asset: Asset, **kwargs):
403
+ trans_vertex = np.eye(4, dtype=np.float32)
404
+ r = np.eye(4, dtype=np.float32)
405
+ if np.random.rand() < self.random_rotate_p:
406
+ angle = self.random_rotate_angle
407
+ axis_angle = (np.random.rand(3) - 0.5) * angle / 180 * np.pi * 2
408
+ r = R.from_rotvec(axis_angle).as_matrix()
409
+ r = np.pad(r, ((0, 1), (0, 1)), 'constant', constant_values=0.)
410
+ r[3, 3] = 1.
411
+
412
+ if np.random.uniform(0, 1) < self.random_flip_x_p:
413
+ r @= np.array([
414
+ [-1.0, 0.0, 0.0, 0.0],
415
+ [ 0.0, 1.0, 0.0, 0.0],
416
+ [ 0.0, 0.0, 1.0, 0.0],
417
+ [ 0.0, 0.0, 0.0, 1.0],
418
+ ])
419
+
420
+ if np.random.uniform(0, 1) < self.random_flip_y_p:
421
+ r @= np.array([
422
+ [1.0, 0.0, 0.0, 0.0],
423
+ [0.0, -1.0, 0.0, 0.0],
424
+ [0.0, 0.0, 1.0, 0.0],
425
+ [0.0, 0.0, 0.0, 1.0],
426
+ ])
427
+
428
+ if np.random.uniform(0, 1) < self.random_flip_z_p:
429
+ r @= np.array([
430
+ [1.0, 0.0, 0.0, 0.0],
431
+ [0.0, 1.0, 0.0, 0.0],
432
+ [0.0, 0.0, -1.0, 0.0],
433
+ [0.0, 0.0, 0.0, 1.0],
434
+ ])
435
+
436
+ if np.random.uniform(0, 1) < self.static_rotate_x_p:
437
+ assert len(self.static_rotate_x) > 0, "static rotation of x is enabled, but static_rotate_x is empty"
438
+ angle = np.random.choice(self.static_rotate_x) / 180 * np.pi
439
+ c = np.cos(angle)
440
+ s = np.sin(angle)
441
+ r @= np.array([
442
+ [ 1.0, 0.0, 0.0, 0.0],
443
+ [ 0.0, c, s, 0.0],
444
+ [ 0.0, -s, c, 0.0],
445
+ [ 0.0, 0.0, 0.0, 1.0],
446
+ ])
447
+
448
+ if np.random.uniform(0, 1) < self.static_rotate_y_p:
449
+ assert len(self.static_rotate_y) > 0, "static rotation of y is enabled, but static_rotate_y is empty"
450
+ angle = np.random.choice(self.static_rotate_y) / 180 * np.pi
451
+ c = np.cos(angle)
452
+ s = np.sin(angle)
453
+ r @= np.array([
454
+ [ c, 0.0, -s, 0.0],
455
+ [ 0.0, 1.0, 0.0, 0.0],
456
+ [ s, 0.0, c, 0.0],
457
+ [ 0.0, 0.0, 0.0, 1.0],
458
+ ])
459
+
460
+ if np.random.uniform(0, 1) < self.static_rotate_z_p:
461
+ assert len(self.static_rotate_z) > 0, "static rotation of z is enabled, but static_rotate_z is empty"
462
+ angle = np.random.choice(self.static_rotate_z) / 180 * np.pi
463
+ c = np.cos(angle)
464
+ s = np.sin(angle)
465
+ r @= np.array([
466
+ [ c, s, 0.0, 0.0],
467
+ [ -s, c, 0.0, 0.0],
468
+ [ 0.0, 0.0, 1.0, 0.0],
469
+ [ 0.0, 0.0, 0.0, 1.0],
470
+ ])
471
+
472
+ if np.random.uniform(0, 1) < self.random_scale_p:
473
+ scale_x = np.random.uniform(self.random_scale[0], self.random_scale[1])
474
+ scale_y = np.random.uniform(self.random_scale[0], self.random_scale[1])
475
+ scale_z = np.random.uniform(self.random_scale[0], self.random_scale[1])
476
+ r @= np.array([
477
+ [scale_x, 0.0, 0.0, 0.0],
478
+ [0.0, scale_y, 0.0, 0.0],
479
+ [0.0, 0.0, scale_z, 0.0],
480
+ [0.0, 0.0, 0.0, 1.0],
481
+ ])
482
+
483
+ if np.random.uniform(0, 1) < self.random_transpose:
484
+ permutations = [
485
+ (0, 1, 2), # x, y, z
486
+ (0, 2, 1), # x, z, y
487
+ (1, 0, 2), # y, x, z
488
+ (1, 2, 0), # y, z, x
489
+ (2, 0, 1), # z, x, y
490
+ (2, 1, 0), # z, y, x
491
+ ]
492
+ direction_signs = [
493
+ (1, 1, 1),
494
+ (1, 1, -1),
495
+ (1, -1, 1),
496
+ (1, -1, -1),
497
+ (-1, 1, 1),
498
+ (-1, 1, -1),
499
+ (-1, -1, 1),
500
+ (-1, -1, -1),
501
+ ]
502
+ perm = permutations[np.random.randint(0, 6)]
503
+ sign = direction_signs[np.random.randint(0, 8)]
504
+ m = np.zeros((4, 4))
505
+ for i in range(3):
506
+ m[i, perm[i]] = sign[i]
507
+ m[3, 3] = 1.0
508
+ r = m @ r
509
+
510
+ trans_vertex = r @ trans_vertex
511
+
512
+ # apply transform here
513
+ asset.transform(trans=trans_vertex)
514
+
515
+ @dataclass(frozen=True)
516
+ class AugmentAffine(Augment):
517
+ # final normalization cube
518
+ normalize_into: Tuple[float, float]
519
+
520
+ # randomly scale coordinates with probability p
521
+ random_scale_p: float
522
+
523
+ # scale range (lower, upper)
524
+ random_scale: Tuple[float, float]
525
+
526
+ # randomly shift coordinates with probability p
527
+ random_shift_p: float
528
+
529
+ # shift range (lower, upper)
530
+ random_shift: Tuple[float, float]
531
+
532
+ @classmethod
533
+ def parse(cls, **kwargs) -> 'AugmentAffine':
534
+ cls.check_keys(kwargs)
535
+ return AugmentAffine(
536
+ normalize_into=kwargs.get('normalize_into', [-1.0, 1.0]),
537
+ random_scale_p=kwargs.get('random_scale_p', 0.),
538
+ random_scale=kwargs.get('random_scale', [1., 1.]),
539
+ random_shift_p=kwargs.get('random_shift_p', 0.),
540
+ random_shift=kwargs.get('random_shift', [0., 0.]),
541
+ )
542
+
543
+ def transform(self, asset: Asset, **kwargs):
544
+ if asset.vertices is None:
545
+ raise ValueError("do not have vertices")
546
+ bound_min = asset.vertices.min(axis=0)
547
+ bound_max = asset.vertices.max(axis=0)
548
+ if asset.joints is not None:
549
+ joints_bound_min = asset.joints.min(axis=0)
550
+ joints_bound_max = asset.joints.max(axis=0)
551
+ bound_min = np.minimum(bound_min, joints_bound_min)
552
+ bound_max = np.maximum(bound_max, joints_bound_max)
553
+
554
+ trans_vertex = np.eye(4, dtype=np.float32)
555
+
556
+ trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex
557
+
558
+ if self.normalize_into is not None:
559
+ # scale into the cube
560
+ normalize_into = self.normalize_into
561
+ scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0]))
562
+ trans_vertex = _scale_to_m(1. / scale) @ trans_vertex
563
+
564
+ bias = (normalize_into[0] + normalize_into[1]) / 2
565
+ trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex
566
+
567
+ if np.random.rand() < self.random_scale_p:
568
+ scale = _scale_to_m(np.random.uniform(self.random_scale[0], self.random_scale[1]))
569
+ trans_vertex = scale @ trans_vertex
570
+
571
+ if np.random.rand() < self.random_shift_p:
572
+ l, r = self.random_shift
573
+ shift_vals = np.array([
574
+ np.random.uniform(l, r),
575
+ np.random.uniform(l, r),
576
+ np.random.uniform(l, r),
577
+ ], dtype=np.float32)
578
+ if self.normalize_into is not None:
579
+ def _apply(v: ndarray, trans: ndarray) -> ndarray:
580
+ return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3]
581
+ lo, hi = self.normalize_into
582
+ pts_min = _apply(bound_min[None, :], trans_vertex)[0]
583
+ pts_max = _apply(bound_max[None, :], trans_vertex)[0]
584
+ low_allowed = lo - pts_min
585
+ high_allowed = hi - pts_max
586
+ shift_vals = np.array([
587
+ np.random.uniform(low_allowed[0], high_allowed[0]),
588
+ np.random.uniform(low_allowed[1], high_allowed[1]),
589
+ np.random.uniform(low_allowed[2], high_allowed[2]),
590
+ ], dtype=np.float32)
591
+ shift = _trans_to_m(shift_vals.astype(np.float32))
592
+ trans_vertex = shift @ trans_vertex
593
+ asset.transform(trans=trans_vertex)
594
+
595
+ @dataclass(frozen=True)
596
+ class AugmentJitter(Augment):
597
+ # probability
598
+ p: float
599
+
600
+ # jitter sigma on vertices
601
+ vertex_sigma: float
602
+
603
+ # jitter clip on vertices
604
+ vertex_clip: float
605
+
606
+ # jitter sigma on normals
607
+ normal_sigma: float
608
+
609
+ # jitter clip on normals
610
+ normal_clip: float
611
+
612
+ @classmethod
613
+ def parse(cls, **kwargs) -> 'AugmentJitter':
614
+ cls.check_keys(kwargs)
615
+ return AugmentJitter(
616
+ p=kwargs.get('p', 0.5),
617
+ vertex_sigma=kwargs.get('vertex_sigma', 0.),
618
+ vertex_clip=kwargs.get('vertex_clip', 0.),
619
+ normal_sigma=kwargs.get('normal_sigma', 0.),
620
+ normal_clip=kwargs.get('normal_clip', 0.),
621
+ )
622
+
623
+ def transform(self, asset: Asset, **kwargs):
624
+ vertex_sigma = self.vertex_sigma
625
+ vertex_clip = self.vertex_clip
626
+ normal_sigma = self.normal_sigma
627
+ normal_clip = self.normal_clip
628
+
629
+ if np.random.rand() < self.p:
630
+ scale = np.random.rand() + 1e-6
631
+ vertex_sigma *= scale
632
+ vertex_clip *= scale
633
+ scale = np.random.rand() + 1e-6
634
+ normal_sigma *= scale
635
+ normal_clip *= scale
636
+ if vertex_sigma > 0 and asset.vertices is not None:
637
+ noise = np.clip(np.random.randn(*asset.vertices.shape) * vertex_sigma, -vertex_clip, vertex_clip).astype(np.float32)
638
+ asset.vertices += noise
639
+
640
+ if normal_sigma > 0:
641
+ if asset.vertex_normals is not None:
642
+ noise = np.clip(np.random.randn(*asset.vertex_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32)
643
+ asset.vertex_normals += noise
644
+
645
+ if asset.face_normals is not None:
646
+ noise = np.clip(np.random.randn(*asset.face_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32)
647
+ asset.face_normals += noise
648
+
649
+ @dataclass(frozen=True)
650
+ class AugmentNormalize(Augment):
651
+
652
+ @classmethod
653
+ def parse(cls, **kwargs) -> 'AugmentNormalize':
654
+ cls.check_keys(kwargs)
655
+ return AugmentNormalize()
656
+
657
+ def transform(self, asset: Asset, **kwargs):
658
+ epsilon = 1e-10
659
+ if asset.vertex_normals is not None:
660
+ vertex_norms = np.linalg.norm(asset.vertex_normals, axis=1, keepdims=True)
661
+ vertex_norms = np.maximum(vertex_norms, epsilon)
662
+ asset.vertex_normals = asset.vertex_normals / vertex_norms
663
+ asset.vertex_normals = np.nan_to_num(asset.vertex_normals, nan=0., posinf=0., neginf=0.) # type: ignore
664
+
665
+ if asset.face_normals is not None:
666
+ face_norms = np.linalg.norm(asset.face_normals, axis=1, keepdims=True)
667
+ face_norms = np.maximum(face_norms, epsilon)
668
+ asset.face_normals = asset.face_normals / face_norms
669
+ asset.face_normals = np.nan_to_num(asset.face_normals, nan=0., posinf=0., neginf=0.) # type: ignore
670
+
671
+ def _trans_to_m(v: ndarray):
672
+ m = np.eye(4, dtype=np.float32)
673
+ m[0:3, 3] = v
674
+ return m
675
+
676
+ def _scale_to_m(r: ndarray|float):
677
+ m = np.zeros((4, 4), dtype=np.float32)
678
+ m[0, 0] = r
679
+ m[1, 1] = r
680
+ m[2, 2] = r
681
+ m[3, 3] = 1.
682
+ return m
683
+
684
+ def get_augments(*args) -> List[Augment]:
685
+ MAP = {
686
+ 'trim': AugmentTrim,
687
+ 'delete': AugmentDelete,
688
+ 'drop_part': AugmentDropPart,
689
+ 'collapse': AugmentCollapse,
690
+ 'lbs': AugmentLBS,
691
+ 'linear': AugmentLinear,
692
+ 'affine': AugmentAffine,
693
+ 'jitter': AugmentJitter,
694
+ 'joint_perturb': AugmentJointPerturb,
695
+ 'joint_discrete': AugmentJointDiscrete,
696
+ 'normalize': AugmentNormalize,
697
+ }
698
+ MAP: Dict[str, type[Augment]]
699
+ augments = []
700
+ for (i, config) in enumerate(args):
701
+ __target__ = config.get('__target__')
702
+ assert __target__ is not None, f"do not find `__target__` in augment of position {i}"
703
+ c = deepcopy(config)
704
+ del c['__target__']
705
+ augments.append(MAP[__target__].parse(**c))
706
+ return augments
src/data/datapath.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from numpy import ndarray
5
+ from random import shuffle
6
+ from typing import Dict, List, Optional
7
+
8
+ import numpy as np
9
+ import requests
10
+ import os
11
+
12
+ from ..rig_package.info.asset import Asset
13
+ from ..server.spec import BPY_SERVER, bytes_to_object, object_to_bytes
14
+ from .spec import ConfigSpec
15
+
16
+ @dataclass
17
+ class LazyAsset(ABC):
18
+ """store datapath and load upon requiring"""
19
+ path: str
20
+
21
+ cls: Optional[str]=None
22
+
23
+ @abstractmethod
24
+ def load(self) -> 'Asset':
25
+ raise NotImplementedError()
26
+
27
+ @dataclass
28
+ class BpyLazyAsset(LazyAsset):
29
+
30
+ def load(self) -> 'Asset':
31
+ from ..rig_package.parser.bpy import BpyParser
32
+ asset = BpyParser.load(filepath=self.path)
33
+ asset.cls = self.cls
34
+ asset.path = self.path
35
+ return asset
36
+
37
+ @dataclass
38
+ class BpyServerLazyAsset(LazyAsset):
39
+ """workaround while bpy is working in multiple threads"""
40
+ def load(self) -> 'Asset':
41
+ try:
42
+ asset = bytes_to_object(requests.get(f"{BPY_SERVER}/load", data=object_to_bytes(self.path)).content)
43
+ if isinstance(asset, str):
44
+ raise RuntimeError(f"bpy server failed: {asset}")
45
+ assert isinstance(asset, Asset)
46
+ asset.cls = self.cls
47
+ asset.path = self.path
48
+ return asset
49
+ except Exception as e:
50
+ raise RuntimeError(f"bpy server failed: {str(e)}")
51
+
52
+ @dataclass
53
+ class NpzLazyAsset(LazyAsset):
54
+
55
+ def load(self) -> 'Asset':
56
+ d = np.load(self.path, allow_pickle=True)
57
+ asset = Asset(
58
+ vertices=d['vertices'],
59
+ faces=d['faces'],
60
+ mesh_names=d.get('mesh_names', None),
61
+ joint_names=d.get('joint_names', None),
62
+ parents=d.get('parents', None),
63
+ lengths=d.get('lengths', None),
64
+ matrix_world=d.get('matrix_world', None),
65
+ matrix_local=d.get('matrix_local', None),
66
+ armature_name=d.get('armature_name', None),
67
+ skin=d.get('skin', None),
68
+ cls=self.cls,
69
+ path=self.path
70
+ )
71
+ asset.cls = self.cls
72
+ asset.path = self.path
73
+ return asset
74
+
75
+ @dataclass
76
+ class UniRigLazyAsset(LazyAsset):
77
+ """map unirig's data correctly"""
78
+
79
+ def load(self) -> 'Asset':
80
+ def bn(x):
81
+ if isinstance(x, ndarray) and x.ndim==0:
82
+ return x.item()
83
+ return x
84
+
85
+ d = np.load(self.path, allow_pickle=True)
86
+ parents = bn(d.get('parents', None))
87
+ if parents is not None:
88
+ parents = [-1 if x is None else x for x in parents]
89
+ parents = np.array(parents)
90
+ matrix_local = bn(d.get('matrix_local', None))
91
+ joints = bn(d.get('joints', None))
92
+ if matrix_local is not None and matrix_local.ndim != 3 and joints is not None:
93
+ matrix_local = np.zeros((joints.shape[0], 4, 4))
94
+ matrix_local[...] = np.eye(4)
95
+ matrix_local[:, :3, 3] = joints
96
+ asset = Asset(
97
+ vertices=d['vertices'],
98
+ faces=d['faces'],
99
+ joint_names=bn(d.get('names', None)),
100
+ parents=parents, # type: ignore
101
+ lengths=bn(d.get('lengths', None)),
102
+ matrix_world=bn(d.get('matrix_world', None)),
103
+ matrix_local=matrix_local,
104
+ armature_name=bn(d.get('armature_name', None)),
105
+ skin=bn(d.get('skin', None)),
106
+ cls=self.cls,
107
+ path=self.path
108
+ ).change_dtype(float_dtype=np.float32, int_dtype=np.int32)
109
+ asset.cls = self.cls
110
+ asset.path = self.path
111
+ return asset
112
+
113
+ @dataclass
114
+ class Datapath(ConfigSpec):
115
+ """handle input data paths"""
116
+
117
+ # all filepaths
118
+ filepaths: List[str]
119
+
120
+ # root to add to prefix
121
+ input_dataset_dir: str=''
122
+
123
+ # name of class
124
+ cls_name: Optional[List[str]]=None
125
+
126
+ # bias in a single class
127
+ cls_bias: Optional[List[int]]=None
128
+
129
+ # num of files in a single class
130
+ cls_length: Optional[List[int]]=None
131
+
132
+ # how many files to return when using data sampling
133
+ num_files: Optional[int]=None
134
+
135
+ # use proportion data sampling
136
+ use_prob: bool=False
137
+
138
+ # weight
139
+ cls_weight: Optional[List[float]]=None
140
+
141
+ # use bpy loader
142
+ loader: type[LazyAsset]=BpyLazyAsset
143
+
144
+ # data name
145
+ data_name: Optional[str]=None
146
+
147
+ # check if path exists
148
+ ignore_check: bool=False
149
+
150
+ #################################################################
151
+ # other vertex groups
152
+ vertex_groups: Dict[str, ndarray]=field(default_factory=dict)
153
+
154
+ # sampled vertices
155
+ sampled_vertices: Optional[ndarray]=None
156
+
157
+ # sampled normals
158
+ sampled_normals: Optional[ndarray]=None
159
+
160
+ # sampled vertex groups
161
+ sampled_vertex_groups: Optional[Dict[str, ndarray]]=None
162
+
163
+ @classmethod
164
+ def parse(cls, **kwargs) -> 'Datapath':
165
+ MAP = {
166
+ None: BpyLazyAsset,
167
+ 'bpy': BpyLazyAsset,
168
+ 'bpy_server': BpyServerLazyAsset,
169
+ 'npz': NpzLazyAsset,
170
+ 'unirig': UniRigLazyAsset,
171
+ }
172
+ input_dataset_dir = kwargs.get('input_dataset_dir', '')
173
+ num_files = kwargs.get('num_files', None)
174
+ use_prob = kwargs.get('use_prob', False)
175
+ data_name = kwargs.get('data_name', 'raw_data.npz')
176
+ data_path = kwargs.get('data_path', None)
177
+ loader_cls = MAP[kwargs.get('loader', None)]
178
+ ignore_check = kwargs.get('ignore_check', False)
179
+
180
+ if data_path is not None:
181
+ filepaths = []
182
+ if isinstance(data_path, dict):
183
+ cls_name = []
184
+ cls_bias = []
185
+ cls_length = []
186
+ cls_weight = []
187
+ for name, v in data_path.items():
188
+ assert isinstance(v, list), "items in the dict must be a list of data list paths"
189
+ for item in v:
190
+ if isinstance(item, str):
191
+ datalist_path = item
192
+ weight = 1.0
193
+ else:
194
+ datalist_path = item[0]
195
+ weight = item[1]
196
+ cls_name.append(name)
197
+ lines = [x.strip() for x in open(datalist_path, "r").readlines()]
198
+ ok_lines = []
199
+ missing = 0
200
+ for line in lines:
201
+ if ignore_check:
202
+ ok_lines.append(line)
203
+ elif os.path.exists(os.path.join(input_dataset_dir, line, data_name)):
204
+ ok_lines.append(line)
205
+ else:
206
+ missing += 1
207
+ if missing != 0:
208
+ print(f"\033[31m{datalist_path}: {missing} missing files\033[0m")
209
+ cls_bias.append(len(filepaths))
210
+ cls_length.append(len(ok_lines))
211
+ cls_weight.append(weight)
212
+ filepaths.extend(ok_lines)
213
+ else:
214
+ raise NotImplementedError()
215
+ else:
216
+ _filepaths = kwargs['filepaths']
217
+ if isinstance(_filepaths, list):
218
+ filepaths = _filepaths
219
+ cls_name = None
220
+ cls_bias = None
221
+ cls_length = None
222
+ cls_weight = None
223
+ elif isinstance(_filepaths, dict):
224
+ filepaths = []
225
+ cls_name = []
226
+ cls_bias = []
227
+ cls_length = []
228
+ cls_weight = []
229
+ for k, v in _filepaths.items():
230
+ assert isinstance(v, list), "items in the dict must be a list of paths"
231
+ cls_name.append(k)
232
+ cls_bias.append(len(filepaths))
233
+ cls_length.append(len(v))
234
+ cls_weight.append(1.0)
235
+ filepaths.extend(v)
236
+ else:
237
+ raise NotImplementedError()
238
+ if cls_weight is not None:
239
+ total = sum(cls_weight)
240
+ cls_weight = [x/total for x in cls_weight]
241
+ return Datapath(
242
+ filepaths=filepaths,
243
+ input_dataset_dir=input_dataset_dir,
244
+ cls_name=cls_name,
245
+ cls_bias=cls_bias,
246
+ cls_length=cls_length,
247
+ num_files=num_files,
248
+ use_prob=use_prob,
249
+ cls_weight=cls_weight,
250
+ loader=loader_cls,
251
+ data_name=data_name,
252
+ ignore_check=ignore_check,
253
+ )
254
+
255
+ def make(self, path: str, cls: str|None) -> LazyAsset:
256
+ return self.loader(path=path, cls=cls)
257
+
258
+ def __getitem__(self, index: int) -> LazyAsset:
259
+ if self.use_prob and self.cls_weight is not None:
260
+ if self.cls_bias is None:
261
+ raise ValueError("do not have cls_bias")
262
+ if self.cls_length is None:
263
+ raise ValueError("do not have cls_length")
264
+ if not hasattr(self, "perms"):
265
+ self.perms = []
266
+ self.current_bias = []
267
+ for i in range(len(self.cls_weight)):
268
+ self.perms.append([x for x in range(self.cls_length[i])])
269
+ self.current_bias.append(0)
270
+ idx = np.random.choice(len(self.cls_weight), p=self.cls_weight)
271
+ i = self.perms[idx][self.current_bias[idx]]
272
+ self.current_bias[idx] += 1
273
+ if self.current_bias[idx] >= self.cls_length[idx]:
274
+ shuffle(self.perms[idx])
275
+ self.current_bias[idx] = 0
276
+ if self.cls_name is None:
277
+ name = None
278
+ else:
279
+ name = self.cls_name[idx]
280
+ path = os.path.join(self.input_dataset_dir, self.filepaths[i+self.cls_bias[idx]])
281
+ if self.data_name is not None:
282
+ path = os.path.join(path, self.data_name)
283
+ return self.make(path=path, cls=name)
284
+ else:
285
+ if self.cls_name is None or self.cls_bias is None or self.cls_length is None:
286
+ name = None
287
+ else:
288
+ name = None
289
+ for i in range(len(self.cls_bias)):
290
+ start = self.cls_bias[i]
291
+ end = start + self.cls_length[i]
292
+ if start <= index < end:
293
+ name = self.cls_name[i]
294
+ break
295
+ path = os.path.join(self.input_dataset_dir, self.filepaths[index])
296
+ if self.data_name is not None:
297
+ path = os.path.join(path, self.data_name)
298
+ return self.make(path=path, cls=name)
299
+
300
+ def get_data(self) -> List[LazyAsset]:
301
+ return [self[i] for i in range(len(self))]
302
+
303
+ def split_by_cls(self) -> Dict[str|None, 'Datapath']:
304
+ res: Dict[str|None, Datapath] = {}
305
+ if self.cls_name is None:
306
+ res[None] = self
307
+ return res
308
+ if self.cls_bias is None:
309
+ raise ValueError("do not have cls_bias")
310
+ if self.cls_length is None:
311
+ raise ValueError("do not have cls_length")
312
+ d_filepaths = defaultdict(list)
313
+ d_length = defaultdict(int)
314
+ d_weight = defaultdict(list)
315
+ for (i, cls) in enumerate(self.cls_name):
316
+ s = slice(self.cls_bias[i], self.cls_bias[i]+self.cls_length[i])
317
+ d_filepaths[cls].extend(self.filepaths[s].copy())
318
+ d_length[cls] += self.cls_length[i]
319
+ if self.cls_weight is not None:
320
+ d_weight[cls].append(self.cls_weight[i])
321
+ for cls in d_filepaths:
322
+ cls_weight = None if self.cls_weight is None else d_weight[cls]
323
+ if cls_weight is not None:
324
+ total = sum(cls_weight)
325
+ cls_weight = [x/total for x in cls_weight]
326
+ res[cls] = Datapath(
327
+ filepaths=d_filepaths[cls],
328
+ input_dataset_dir=self.input_dataset_dir,
329
+ cls_name=[cls],
330
+ cls_bias=[0],
331
+ cls_length=[len(d_filepaths[cls])],
332
+ num_files=self.num_files,
333
+ use_prob=self.use_prob,
334
+ cls_weight=cls_weight,
335
+ loader=self.loader,
336
+ data_name=self.data_name,
337
+ )
338
+ return res
339
+
340
+ def __len__(self):
341
+ if self.use_prob:
342
+ assert self.num_files is not None, 'num_files is not specified'
343
+ return self.num_files
344
+ return len(self.filepaths)
src/data/dataset.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
4
+ from numpy import ndarray
5
+ from torch import Tensor
6
+ from torch.utils import data
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from typing import Dict, List, Tuple, Callable, Optional
9
+
10
+ import os
11
+ import lightning.pytorch as pl
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .datapath import Datapath, LazyAsset
16
+ from .spec import ConfigSpec
17
+ from .transform import Transform
18
+
19
+ from ..model.spec import ModelInput
20
+ from ..rig_package.info.asset import Asset
21
+ from ..tokenizer.spec import Tokenizer, TokenizeInput
22
+
23
+ @dataclass
24
+ class DatasetConfig(ConfigSpec):
25
+ shuffle: bool
26
+ batch_size: int
27
+ num_workers: int
28
+ datapath: Datapath
29
+ pin_memory: bool=True
30
+ persistent_workers: bool=True
31
+
32
+ @classmethod
33
+ def parse(cls, **kwargs) -> 'DatasetConfig':
34
+ cls.check_keys(kwargs)
35
+ return DatasetConfig(
36
+ shuffle=kwargs.get('shuffle', False),
37
+ batch_size=kwargs.get('batch_size', 1),
38
+ num_workers=kwargs.get('num_workers', 1),
39
+ pin_memory=kwargs.get('pin_memory', True),
40
+ persistent_workers=kwargs.get('persistent_workers', True),
41
+ datapath=Datapath.parse(**kwargs.get('datapath')),
42
+ )
43
+
44
+ def split_by_cls(self) -> Dict[str|None, 'DatasetConfig']:
45
+ res: Dict[str|None, DatasetConfig] = {}
46
+ datapath_dict = self.datapath.split_by_cls()
47
+ for cls, v in datapath_dict.items():
48
+ res[cls] = DatasetConfig(
49
+ shuffle=self.shuffle,
50
+ batch_size=self.batch_size,
51
+ num_workers=self.num_workers,
52
+ datapath=v,
53
+ pin_memory=self.pin_memory,
54
+ persistent_workers=self.persistent_workers,
55
+ )
56
+ return res
57
+
58
+ class RigDatasetModule(pl.LightningDataModule):
59
+ def __init__(
60
+ self,
61
+ process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None,
62
+ train_dataset_config: Optional[DatasetConfig]=None,
63
+ validate_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None,
64
+ predict_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None,
65
+ train_transform: Optional[Transform]=None,
66
+ validate_transform: Optional[Transform]=None,
67
+ predict_transform: Optional[Transform]=None,
68
+ tokenizer: Optional[Tokenizer]=None,
69
+ debug: bool=False,
70
+ ):
71
+ super().__init__()
72
+ self.process_fn = process_fn
73
+ self.train_dataset_config = train_dataset_config
74
+ self.validate_dataset_config = validate_dataset_config
75
+ self.predict_dataset_config = predict_dataset_config
76
+ self.train_transform = train_transform
77
+ self.validate_transform = validate_transform
78
+ self.predict_transform = predict_transform
79
+ self.tokenizer = tokenizer
80
+ self.debug = debug
81
+
82
+ if debug:
83
+ print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m")
84
+
85
+ # build train datapath
86
+ if self.train_dataset_config is not None:
87
+ self.train_datapath = self.train_dataset_config.datapath
88
+ else:
89
+ self.train_datapath = None
90
+
91
+ # build validate datapath
92
+ if self.validate_dataset_config is not None:
93
+ self.validate_datapath = {
94
+ cls: self.validate_dataset_config[cls].datapath
95
+ for cls in self.validate_dataset_config
96
+ }
97
+ else:
98
+ self.validate_datapath = None
99
+
100
+ # build predict datapath
101
+ if self.predict_dataset_config is not None:
102
+ self.predict_datapath = {
103
+ cls: self.predict_dataset_config[cls].datapath
104
+ for cls in self.predict_dataset_config
105
+ }
106
+ else:
107
+ self.predict_datapath = None
108
+
109
+ self.tokenizer = tokenizer
110
+
111
+ def prepare_data(self):
112
+ pass
113
+
114
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
115
+ if self.train_dataset_config is None:
116
+ raise ValueError("do not have train_dataset_config")
117
+ if self.train_transform is None:
118
+ raise ValueError("do not have train_transform")
119
+ if self.train_datapath is not None:
120
+ self._train_ds = RigDataset(
121
+ process_fn=self.process_fn,
122
+ data=self.train_datapath.get_data(),
123
+ name="train",
124
+ tokenizer=self.tokenizer,
125
+ transform=self.train_transform,
126
+ debug=self.debug,
127
+ )
128
+ else:
129
+ return None
130
+ return self._create_dataloader(
131
+ dataset=self._train_ds,
132
+ config=self.train_dataset_config,
133
+ is_train=True,
134
+ drop_last=False,
135
+ )
136
+
137
+ def val_dataloader(self) -> EVAL_DATALOADERS:
138
+ if self.validate_dataset_config is None:
139
+ raise ValueError("do not have validate_dataset_config")
140
+ if self.validate_transform is None:
141
+ raise ValueError("do not have validate_transform")
142
+ if self.validate_datapath is not None:
143
+ self._validation_ds = {}
144
+ for cls in self.validate_datapath:
145
+ self._validation_ds[cls] = RigDataset(
146
+ process_fn=self.process_fn,
147
+ data=self.validate_datapath[cls].get_data(),
148
+ name=f"validate-{cls}",
149
+ tokenizer=self.tokenizer,
150
+ transform=self.validate_transform,
151
+ debug=self.debug,
152
+ )
153
+ else:
154
+ return None
155
+ return self._create_dataloader(
156
+ dataset=self._validation_ds,
157
+ config=self.validate_dataset_config,
158
+ is_train=False,
159
+ drop_last=False,
160
+ )
161
+
162
+ def predict_dataloader(self):
163
+ if self.predict_dataset_config is None:
164
+ raise ValueError("do not have predict_dataset_config")
165
+ if self.predict_transform is None:
166
+ raise ValueError("do not have predict_transform")
167
+ if self.predict_datapath is not None:
168
+ self._predict_ds = {}
169
+ for cls in self.predict_datapath:
170
+ self._predict_ds[cls] = RigDataset(
171
+ process_fn=self.process_fn,
172
+ data=self.predict_datapath[cls].get_data(),
173
+ name=f"predict-{cls}",
174
+ tokenizer=self.tokenizer,
175
+ transform=self.predict_transform,
176
+ debug=self.debug,
177
+ )
178
+ else:
179
+ return None
180
+ return self._create_dataloader(
181
+ dataset=self._predict_ds,
182
+ config=self.predict_dataset_config,
183
+ is_train=False,
184
+ drop_last=False,
185
+ )
186
+
187
+ def _create_dataloader(
188
+ self,
189
+ dataset: Dataset|Dict[str, Dataset],
190
+ config: DatasetConfig|Dict[str|None, DatasetConfig],
191
+ is_train: bool,
192
+ **kwargs,
193
+ ) -> DataLoader|Dict[str, DataLoader]:
194
+ def create_single_dataloader(dataset, config: DatasetConfig, **kwargs):
195
+ return DataLoader(
196
+ dataset,
197
+ batch_size=config.batch_size,
198
+ shuffle=config.shuffle,
199
+ num_workers=config.num_workers,
200
+ pin_memory=config.pin_memory,
201
+ persistent_workers=config.persistent_workers,
202
+ collate_fn=dataset.collate_fn,
203
+ **kwargs,
204
+ )
205
+ if isinstance(dataset, Dict):
206
+ assert isinstance(config, dict)
207
+ return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()}
208
+ else:
209
+ assert isinstance(config, DatasetConfig)
210
+ return create_single_dataloader(dataset, config, **kwargs)
211
+
212
+ class RigDataset(Dataset):
213
+ def __init__(
214
+ self,
215
+ data: List[LazyAsset],
216
+ transform: Transform,
217
+ name: Optional[str]=None,
218
+ process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None,
219
+ tokenizer: Optional[Tokenizer]=None,
220
+ debug: bool=False,
221
+ ) -> None:
222
+ super().__init__()
223
+
224
+ self.data = data
225
+ self.name = name
226
+ self.process_fn = process_fn
227
+ self.tokenizer = tokenizer
228
+ self.transform = transform
229
+ self.debug = debug
230
+
231
+ if not debug:
232
+ assert self.process_fn is not None, 'missing data processing function'
233
+
234
+ def __len__(self) -> int:
235
+ return len(self.data)
236
+
237
+ def __getitem__(self, idx) -> ModelInput:
238
+ lazy_asset = self.data[idx]
239
+ asset = lazy_asset.load()
240
+ self.transform.apply(asset=asset)
241
+ if self.tokenizer is not None and asset.parents is not None:
242
+ x = TokenizeInput(
243
+ joints=asset.joints,
244
+ parents=asset.parents,
245
+ cls=asset.cls,
246
+ joint_names=asset.joint_names,
247
+ )
248
+ tokens = self.tokenizer.tokenize(input=x)
249
+ else:
250
+ tokens = None
251
+ return ModelInput(asset=asset, tokens=tokens)
252
+
253
+ def _collate_fn_debug(self, batch):
254
+ return batch
255
+
256
+ def _collate_fn(self, batch):
257
+ processed_batch = self.process_fn(batch) # type: ignore
258
+ processed_batch: List[Dict]
259
+
260
+ tensors_stack = {}
261
+ tensors_cat = {}
262
+ non_tensors = {}
263
+ vis = {}
264
+ def check(x):
265
+ assert x not in vis, f"multiple keys found: {x}"
266
+ vis[x] = True
267
+
268
+ for k, v in processed_batch[0].items():
269
+ if k == "cat":
270
+ assert isinstance(v, dict)
271
+ for k1 in v.keys():
272
+ check(k1)
273
+ tensors_cat[k1] = []
274
+ for i in range(len(processed_batch)):
275
+ v1 = processed_batch[i]['cat'][k1]
276
+ if isinstance(v1, ndarray):
277
+ v1 = torch.from_numpy(v1)
278
+ elif isinstance(v1, Tensor):
279
+ v1 = v1
280
+ else:
281
+ raise ValueError(f"cannot concatenate non-tensor type of key {k1}, type: {type(v1)}")
282
+ tensors_cat[k1].append(v1)
283
+ elif k == "non":
284
+ assert isinstance(v, dict)
285
+ for k1 in v.keys():
286
+ check(k1)
287
+ non_tensors[k1] = []
288
+ for i in range(len(processed_batch)):
289
+ v1 = processed_batch[i]['non'][k1]
290
+ if isinstance(v1, ndarray):
291
+ v1 = torch.from_numpy(v1)
292
+ non_tensors[k1].append(v1)
293
+ else:
294
+ check(k)
295
+ tensors_stack[k] = []
296
+ for i in range(len(processed_batch)):
297
+ v1 = processed_batch[i][k]
298
+ if isinstance(v1, ndarray):
299
+ v1 = torch.from_numpy(v1)
300
+ elif isinstance(v1, Tensor):
301
+ v1 = v1
302
+ else:
303
+ raise ValueError(f"cannot stack type of key {k}, type: {type(v1)}")
304
+ tensors_stack[k].append(v1)
305
+
306
+ collated_stack = {k: torch.stack(v) for k, v in tensors_stack.items()}
307
+ collated_cat = {k: torch.concat(v, dim=1) for k, v in tensors_cat.items()}
308
+
309
+ collated_batch = {
310
+ **collated_stack,
311
+ **collated_cat,
312
+ **non_tensors,
313
+ }
314
+ return collated_batch
315
+
316
+ def collate_fn(self, batch):
317
+ if self.debug:
318
+ return self._collate_fn_debug(batch)
319
+ return self._collate_fn(batch)
src/data/order.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from numpy import ndarray
5
+ from omegaconf import OmegaConf
6
+ from typing import Dict, List, Tuple, Optional
7
+
8
+ from .spec import ConfigSpec
9
+
10
+ @dataclass
11
+ class Order(ConfigSpec):
12
+
13
+ # {part_name: [bone_name_1, bone_name_2, ...]}
14
+ parts: Dict[str, Dict[str, List[str]]]
15
+
16
+ # parts of bones to be arranged in [part_name_1, part_name_2, ...]
17
+ parts_order: Dict[str, List[str]]
18
+
19
+ # {skeleton_name: path}
20
+ skeleton_path: Optional[Dict[str, str]]=None
21
+
22
+ sort_by_xyz: bool=False
23
+
24
+ @classmethod
25
+ def parse(cls, **kwargs) -> 'Order':
26
+ cls.check_keys(kwargs)
27
+ skeleton_path = kwargs.get('skeleton_path', None)
28
+ if skeleton_path is not None:
29
+ parts = {}
30
+ parts_order = {}
31
+ for (cls, path) in skeleton_path.items():
32
+ assert cls not in parts, 'cls conflicts'
33
+ d = OmegaConf.load(path)
34
+ parts[cls] = d.parts
35
+ parts_order[cls] = d.parts_order
36
+ else:
37
+ parts = kwargs.get('parts')
38
+ parts_order = kwargs.get('parts_order')
39
+ assert parts is not None
40
+ assert parts_order is not None
41
+ return Order(
42
+ skeleton_path=skeleton_path,
43
+ parts=parts,
44
+ parts_order=parts_order,
45
+ sort_by_xyz=kwargs.get('sort_by_xyz', False),
46
+ )
47
+
48
+ def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
49
+ '''
50
+ Check if part exists.
51
+ '''
52
+ if part not in self.parts[cls]:
53
+ return False
54
+ for name in self.parts[cls][part]:
55
+ if name not in names:
56
+ return False
57
+ return True
58
+
59
+ def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]:
60
+ '''
61
+ Get names for specified cls.
62
+ '''
63
+ names = []
64
+ for part in parts:
65
+ if part is None: # spring
66
+ continue
67
+ if cls in self.parts and part in self.parts[cls]:
68
+ names.extend(self.parts[cls][part])
69
+ assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
70
+ for i in range(len(names), num_bones):
71
+ names.append(f"bone_{i}")
72
+ return names
73
+
74
+ def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]:
75
+ '''
76
+ Arrange names according to required parts order.
77
+ '''
78
+ def sort_by_xyz(joints):
79
+ return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1]))
80
+
81
+ if self.sort_by_xyz:
82
+ assert joints is not None
83
+ new_names = []
84
+ root = -1
85
+ son = defaultdict(list)
86
+ not_root = {}
87
+ for (i, p) in enumerate(parents):
88
+ if p != -1:
89
+ son[p].append(i)
90
+ not_root[i] = True
91
+ for i in range(len(parents)):
92
+ if not_root.get(i, False) == False:
93
+ root = i
94
+ break
95
+ Q = [root]
96
+ while Q:
97
+ u = Q.pop(0)
98
+ new_names.append(names[u])
99
+ wait = []
100
+ for v in son[u]:
101
+ wait.append((v, joints[v]))
102
+ wait_sorted = sort_by_xyz(wait)
103
+ new_wait = [v for v, _ in wait_sorted]
104
+ Q = new_wait + Q
105
+ return new_names, {}
106
+ if cls not in self.parts_order:
107
+ return names, {0: None} # add a spring token
108
+ vis = defaultdict(bool)
109
+ name_to_id = {name: i for (i, name) in enumerate(names)}
110
+ new_names = []
111
+ parts_bias = {}
112
+ for part in self.parts_order[cls]:
113
+ if self.part_exists(cls=cls, part=part, names=names):
114
+ for name in self.parts[cls][part]:
115
+ vis[name] = True
116
+ flag = False
117
+ for name in self.parts[cls][part]:
118
+ pid = parents[name_to_id[name]]
119
+ if pid==-1:
120
+ continue
121
+ if not vis[names[pid]]:
122
+ flag = True
123
+ break
124
+ if flag: # incorrect parts order and should immediately add a spring token
125
+ break
126
+ parts_bias[len(new_names)] = part
127
+ new_names.extend(self.parts[cls][part])
128
+ parts_bias[len(new_names)] = None # add a spring token
129
+ for name in names:
130
+ if name not in new_names:
131
+ new_names.append(name)
132
+ return new_names, parts_bias
src/data/sampler.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from abc import ABC, abstractmethod
3
+ from numpy import ndarray
4
+ from scipy.spatial import cKDTree # type: ignore
5
+ from typing import Dict, Optional
6
+
7
+ import numpy as np
8
+ import random
9
+
10
+ from ..rig_package.info.asset import Asset
11
+ from ..rig_package.utils import sample_vertex_groups
12
+ from .spec import ConfigSpec
13
+
14
+ @dataclass
15
+ class SamplerResult():
16
+ sampled_vertices: Optional[ndarray]=None
17
+ sampled_normals: Optional[ndarray]=None
18
+ sampled_vertex_groups: Optional[Dict[str, ndarray]]=None
19
+
20
+ # number of sampled skin
21
+ skin_samples: Optional[int]=None
22
+
23
+ class Sampler(ABC):
24
+ @abstractmethod
25
+ def sample(
26
+ self,
27
+ asset: Asset,
28
+ ) -> SamplerResult:
29
+ '''
30
+ Return sampled vertices, sampled normals and vertex groups.
31
+ '''
32
+ pass
33
+
34
+ @classmethod
35
+ @abstractmethod
36
+ def parse(cls, **kwargs) -> 'Sampler':
37
+ pass
38
+
39
+ @dataclass
40
+ class SamplerMix(Sampler, ConfigSpec):
41
+ num_samples: int
42
+ num_vertex_samples: int
43
+ num_skin_samples: Optional[int]=None
44
+ replace: bool=True
45
+ all_skeleton: Optional[bool]=None
46
+ max_distance: float=0.1
47
+ rate_distance: float=0.1
48
+
49
+ @classmethod
50
+ def parse(cls, **kwargs) -> 'SamplerMix':
51
+ cls.check_keys(kwargs)
52
+ return SamplerMix(
53
+ num_samples=kwargs.get('num_samples', 0),
54
+ num_vertex_samples=kwargs.get('num_vertex_samples', 0),
55
+ num_skin_samples=kwargs.get('num_skin_samples', None),
56
+ replace=kwargs.get('replace', True),
57
+ all_skeleton=kwargs.get('all_skeleton', None),
58
+ max_distance=kwargs.get('max_distance', 0.1),
59
+ rate_distance=kwargs.get('rate_distance', 0.1),
60
+ )
61
+
62
+ def sample_on_skin(
63
+ self,
64
+ skin: ndarray,
65
+ vertices: ndarray,
66
+ faces: ndarray,
67
+ ):
68
+ face_has_skin = np.any(skin[faces] > 0, axis=-1)
69
+ if face_has_skin.sum() == 0:
70
+ face_has_skin = np.ones_like(face_has_skin)
71
+ elif self.max_distance < 1e-5:
72
+ return face_has_skin
73
+ else:
74
+ # sample near points
75
+ p = np.unique(faces[face_has_skin].reshape(-1))
76
+ tree = cKDTree(vertices[p])
77
+ dis, _ = tree.query(vertices, k=1)
78
+ dis_skin = np.sqrt(((np.max(vertices[p], axis=0) - np.min(vertices[p], axis=0))**2).sum())
79
+ mask_face_near = np.any(dis[faces] < min(self.max_distance, dis_skin * self.rate_distance), axis=-1)
80
+ face_has_skin |= mask_face_near
81
+ return face_has_skin
82
+
83
+ def sample(
84
+ self,
85
+ asset: Asset,
86
+ ) -> SamplerResult:
87
+ if asset.vertices is None:
88
+ raise ValueError("do not have vertices")
89
+ if asset.faces is None:
90
+ raise ValueError("do not have faces")
91
+ vertex_groups = []
92
+ mapping = {}
93
+ tot = 0
94
+ for k, v in asset.vertex_groups.items():
95
+ if v.ndim == 1:
96
+ v = v[:, None]
97
+ elif v.ndim != 2:
98
+ raise ValueError(f"ndim of key {k} is {v.ndim}")
99
+ s = tot
100
+ e = tot + v.shape[1]
101
+ mapping[k] = slice(s,e)
102
+ vertex_groups.append(v)
103
+ if len(vertex_groups) > 0:
104
+ vertex_groups = np.concatenate(vertex_groups, axis=1)
105
+ else:
106
+ vertex_groups = None
107
+ final_sampled_vertices, final_sampled_normals, sampled_vertex_groups = sample_vertex_groups(
108
+ vertices=asset.vertices,
109
+ faces=asset.faces,
110
+ num_samples=self.num_samples,
111
+ vertex_normals=asset.vertex_normals,
112
+ face_normals=asset.face_normals,
113
+ vertex_groups=vertex_groups,
114
+ face_mask=None,
115
+ shuffle=True,
116
+ same=True,
117
+ )
118
+ if vertex_groups is not None:
119
+ final_sampled_vertices = final_sampled_vertices[:, 0]
120
+ if final_sampled_normals is not None:
121
+ final_sampled_normals = final_sampled_normals[:, 0]
122
+ final_sampled_vertex_groups = {}
123
+ if sampled_vertex_groups is not None:
124
+ for k, s in mapping.items():
125
+ final_sampled_vertex_groups[k] = sampled_vertex_groups[:, s] # (N, k)
126
+ if vertex_groups is not None and self.num_skin_samples is not None:
127
+ dense_vertices = []
128
+ dense_normals = []
129
+ dense_skin = []
130
+ if 'skin' not in mapping:
131
+ raise ValueError("do not have skin")
132
+ if self.all_skeleton:
133
+ dense_indices = [i for i in range(asset.J)]
134
+ else:
135
+ dense_indices = [random.randint(0, asset.J-1)]
136
+ for indice in dense_indices:
137
+ _s = asset.vertex_groups['skin'][:, indice]
138
+ face_has_skin = self.sample_on_skin(
139
+ skin=_s,
140
+ vertices=asset.vertices,
141
+ faces=asset.faces,
142
+ )
143
+ sampled_vertices, sampled_normals, sampled_skin = sample_vertex_groups(
144
+ vertices=asset.vertices,
145
+ faces=asset.faces,
146
+ vertex_normals=asset.vertex_normals,
147
+ face_normals=asset.face_normals,
148
+ vertex_groups=_s,
149
+ num_samples=self.num_skin_samples,
150
+ num_vertex_samples=self.num_vertex_samples,
151
+ face_mask=face_has_skin,
152
+ shuffle=True,
153
+ same=True,
154
+ )
155
+ assert sampled_skin is not None
156
+ assert sampled_skin.ndim == 2
157
+ dense_vertices.append(sampled_vertices[:, 0])
158
+ if sampled_normals is not None:
159
+ dense_normals.append(sampled_normals[:, 0])
160
+ dense_skin.append(sampled_skin[:, 0])
161
+ dense_vertices = np.stack(dense_vertices, axis=0) # (J, m, 3)
162
+ if len(dense_normals) > 0:
163
+ dense_normals = np.stack(dense_normals, axis=0) # (J, m, 3)
164
+ else:
165
+ dense_normals = None
166
+ dense_skin = np.stack(dense_skin, axis=0) # (J, m, 1)
167
+ final_sampled_vertex_groups['skin'] = final_sampled_vertex_groups['skin'][:, dense_indices]
168
+ if asset.meta is None:
169
+ asset.meta = {}
170
+ asset.meta['dense_vertices'] = dense_vertices
171
+ asset.meta['dense_normals'] = dense_normals
172
+ asset.meta['dense_skin'] = dense_skin
173
+ asset.meta['dense_indices'] = dense_indices
174
+ return SamplerResult(
175
+ sampled_vertices=final_sampled_vertices,
176
+ sampled_normals=final_sampled_normals if final_sampled_normals is not None else None,
177
+ sampled_vertex_groups=final_sampled_vertex_groups,
178
+ skin_samples=self.num_skin_samples,
179
+ )
180
+
181
+ def get_sampler(**kwargs) -> Sampler:
182
+ __target__ = kwargs.get('__target__')
183
+ assert __target__ is not None
184
+ del kwargs['__target__']
185
+ if __target__ == 'mix':
186
+ sampler = SamplerMix.parse(**kwargs)
187
+ else:
188
+ raise ValueError(f"sampler method {__target__} not supported")
189
+ return sampler
src/data/spec.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import fields
3
+
4
+ class ConfigSpec(ABC):
5
+ @classmethod
6
+ def check_keys(cls, config, expect=None):
7
+ if expect is None:
8
+ expect = [field.name for field in fields(cls)] # type: ignore
9
+ for key in config.keys():
10
+ if key not in expect:
11
+ raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}")
12
+
13
+ @classmethod
14
+ @abstractmethod
15
+ def parse(cls, **kwargs) -> 'ConfigSpec':
16
+ raise NotImplementedError()
src/data/transform.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from ..rig_package.info.asset import Asset
5
+ from .augment import Augment, get_augments
6
+ from .order import Order
7
+ from .sampler import Sampler, get_sampler
8
+ from .spec import ConfigSpec
9
+ from .vertex_group import VertexGroup, get_vertex_groups
10
+
11
+ @dataclass
12
+ class Transform(ConfigSpec):
13
+
14
+ order: Optional[Order]=None
15
+
16
+ vertex_groups: Optional[List[VertexGroup]]=None
17
+
18
+ augments: Optional[List[Augment]]=None
19
+
20
+ sampler: Optional[Sampler]=None
21
+
22
+ @classmethod
23
+ def parse(cls, **kwargs) -> 'Transform':
24
+ cls.check_keys(kwargs)
25
+ order_config = kwargs.get('order')
26
+ vertex_groups_config = kwargs.get('vertex_groups')
27
+ augments_config = kwargs.get('augments')
28
+ sampler_config = kwargs.get('sampler')
29
+
30
+ d = {}
31
+ if order_config is not None:
32
+ d['order'] = Order.parse(**order_config)
33
+ if vertex_groups_config is not None:
34
+ d['vertex_groups'] = get_vertex_groups(*vertex_groups_config)
35
+ if augments_config is not None:
36
+ d['augments'] = get_augments(*augments_config)
37
+ if sampler_config is not None:
38
+ d['sampler'] = get_sampler(**sampler_config)
39
+ return Transform(**d)
40
+
41
+ def apply(self, asset: Asset, **kwargs):
42
+
43
+ # 1. arrange bones
44
+ if self.order is not None:
45
+ if asset.joint_names is not None and asset.parents is not None:
46
+ new_names, _ = self.order.arrange_names(cls=asset.cls, names=asset.joint_names, parents=asset.parents.tolist())
47
+ asset.set_order(new_orders=new_names) # type: ignore
48
+
49
+ # 2. collapse must perform first
50
+ if self.augments is not None:
51
+ kwargs = {}
52
+ for augment in self.augments:
53
+ augment.transform(asset=asset, **kwargs)
54
+
55
+ # 3. get vertex groups
56
+ if self.vertex_groups is not None:
57
+ d = {}
58
+ for v in self.vertex_groups:
59
+ d.update(v.get_vertex_group(asset=asset))
60
+ asset.vertex_groups = d
61
+ else:
62
+ asset.vertex_groups = {}
63
+
64
+ # 4. sample
65
+ if self.sampler is not None:
66
+ res = self.sampler.sample(asset=asset)
67
+ asset.sampled_vertices = res.sampled_vertices
68
+ asset.sampled_normals = res.sampled_normals
69
+ asset.sampled_vertex_groups = res.sampled_vertex_groups
70
+ asset.skin_samples = res.skin_samples
src/data/vertex_group.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from numpy import ndarray
6
+ from scipy.spatial import cKDTree # type: ignore
7
+ from scipy.sparse import csr_matrix
8
+ from scipy.sparse.csgraph import shortest_path, connected_components
9
+ from typing import Dict, List, Optional, Literal
10
+
11
+ import numpy as np
12
+
13
+ from ..rig_package.info.asset import Asset
14
+
15
+ @dataclass(frozen=True)
16
+ class VertexGroup(ABC):
17
+
18
+ @classmethod
19
+ @abstractmethod
20
+ def parse(cls, **kwargs) -> 'VertexGroup':
21
+ pass
22
+
23
+ @abstractmethod
24
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
25
+ pass
26
+
27
+ @dataclass(frozen=True)
28
+ class VertexGroupSkin(VertexGroup):
29
+ """capture skin"""
30
+
31
+ normalize: bool=True
32
+
33
+ @classmethod
34
+ def parse(cls, **kwargs) -> 'VertexGroupSkin':
35
+ return VertexGroupSkin(normalize=kwargs.get('normalize', True))
36
+
37
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
38
+ if asset.skin is None:
39
+ raise ValueError("do not have skin")
40
+ if self.normalize:
41
+ asset.normalize_skin()
42
+ return {'skin': asset.skin.copy()}
43
+
44
+ @dataclass(frozen=True)
45
+ class VertexGroupVoxelSkin(VertexGroup):
46
+ """capture voxel skin"""
47
+
48
+ grid: int
49
+ alpha: float
50
+ link_dis: float
51
+ grid_query: int
52
+ vertex_query: int
53
+ grid_weight: float
54
+ mode: Literal['square', 'exp']
55
+
56
+ @classmethod
57
+ def parse(cls, **kwargs) -> 'VertexGroupVoxelSkin':
58
+ return VertexGroupVoxelSkin(
59
+ grid=kwargs.get('grid', 64),
60
+ alpha=kwargs.get('alpha', 0.5),
61
+ link_dis=kwargs.get('link_dis', 0.00001),
62
+ grid_query=kwargs.get('grid_query', 27),
63
+ vertex_query=kwargs.get('vertex_query', 27),
64
+ grid_weight=kwargs.get('grid_weight', 3.0),
65
+ mode=kwargs.get('mode', 'square'),
66
+ )
67
+
68
+ def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]:
69
+ if asset.vertices is None:
70
+ raise ValueError("do not have vertices")
71
+ if asset.faces is None:
72
+ raise ValueError("do not have faces")
73
+ if asset.joints is None:
74
+ raise ValueError("do not have joints")
75
+ # normalize into [-1, 1] first
76
+ min_vals = np.min(asset.vertices, axis=0)
77
+ max_vals = np.max(asset.vertices, axis=0)
78
+
79
+ center = (min_vals + max_vals) / 2
80
+
81
+ scale = np.max(max_vals - min_vals) / 2
82
+
83
+ normalized_vertices = (asset.vertices - center) / scale
84
+ normalized_joints = (asset.joints - center) / scale
85
+
86
+ grid_coords = asset.voxel().coords
87
+ skin = voxel_skin(
88
+ grid=self.grid,
89
+ grid_coords=grid_coords,
90
+ joints=normalized_joints,
91
+ vertices=normalized_vertices,
92
+ faces=asset.faces,
93
+ alpha=self.alpha,
94
+ link_dis=self.link_dis,
95
+ grid_query=self.grid_query,
96
+ vertex_query=self.vertex_query,
97
+ grid_weight=self.grid_weight,
98
+ mode=self.mode,
99
+ )
100
+ skin = np.nan_to_num(skin, nan=0., posinf=0., neginf=0.)
101
+ return {'voxel_skin': skin,}
102
+
103
+ def voxel_skin(
104
+ grid: int,
105
+ grid_coords: ndarray, # (M, 3)
106
+ joints: ndarray, # (J, 3)
107
+ vertices: ndarray, # (N, 3)
108
+ faces: ndarray, # (F, 3)
109
+ alpha: float=0.5,
110
+ link_dis: float=0.00001,
111
+ grid_query: int=27,
112
+ vertex_query: int=27,
113
+ grid_weight: float=3.0,
114
+ voxel_size: Optional[float]=None,
115
+ mode: str='square',
116
+ parents: Optional[ndarray]=None,
117
+ ):
118
+ # modified from https://dl.acm.org/doi/pdf/10.1145/2485895.2485919
119
+ assert mode in ['square', 'exp']
120
+ J = joints.shape[0]
121
+ M = grid_coords.shape[0]
122
+ N = vertices.shape[0]
123
+
124
+ if voxel_size is None:
125
+ _range = 2/grid*1.74
126
+ else:
127
+ _range = voxel_size*1.74
128
+
129
+ grid_tree = cKDTree(grid_coords)
130
+ vertex_tree = cKDTree(vertices)
131
+ if parents is not None:
132
+ son = defaultdict(list)
133
+ for i, p in enumerate(parents):
134
+ if i == -1:
135
+ continue
136
+ son[p].append(i)
137
+ divide_joints = []
138
+ joints_map = []
139
+ for u in range(len(parents)):
140
+ if len(son[u]) != 1:
141
+ divide_joints.append(joints[u])
142
+ joints_map.append(u)
143
+ else:
144
+ pu = joints[u]
145
+ pv = joints[son[u][0]]
146
+ seg = 10
147
+ for i in range(seg+1):
148
+ p = (pu*i + pv*(seg-i)) / seg
149
+ divide_joints.append(p)
150
+ joints_map.append(u)
151
+ divide_joints = np.stack(divide_joints)
152
+ joints_map = np.array(joints_map)
153
+ else:
154
+ divide_joints = joints
155
+ joints_map = np.arange(joints.shape[0])
156
+ joint_tree = cKDTree(divide_joints)
157
+
158
+ # make combined vertices
159
+ # 0 ~ N-1: mesh vertices
160
+ # N ~ N+M-1: grid vertices
161
+ combined_vertices = np.concatenate([vertices, grid_coords], axis=0)
162
+
163
+ # link adjacent grids
164
+ dist, idx = grid_tree.query(grid_coords, grid_query) # 3*3*3
165
+ dist = dist[:, 1:]
166
+ idx = idx[:, 1:]
167
+ mask = (0 < dist) & (dist < _range)
168
+ source_grid2grid = np.repeat(np.arange(M), grid_query-1)[mask.ravel()] + N
169
+ to_grid2grid = idx[mask] + N
170
+ weight_grid2grid = dist[mask] * grid_weight
171
+
172
+ # link very close vertices
173
+ dist, idx = vertex_tree.query(vertices, 4)
174
+ dist = dist[:, 1:]
175
+ idx = idx[:, 1:]
176
+ mask = (0 < dist) & (dist < link_dis)
177
+ source_close = np.repeat(np.arange(N), 3)[mask.ravel()]
178
+ to_close = idx[mask]
179
+ weight_close = dist[mask]
180
+
181
+ # link grids to mesh vertices
182
+ dist, idx = vertex_tree.query(grid_coords, vertex_query)
183
+ mask = (0 < dist) & (dist < _range) # sqrt(3)
184
+ source_grid2vertex = np.repeat(np.arange(M), vertex_query)[mask.ravel()] + N
185
+ to_grid2vertex = idx[mask]
186
+ weight_grid2vertex = dist[mask]
187
+
188
+ # build combined vertices tree
189
+ combined_tree = cKDTree(combined_vertices)
190
+ # link bones to the neartest vertices
191
+ _, joint_indices = combined_tree.query(divide_joints)
192
+
193
+ # build graph
194
+ source_vertex2vertex = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0)
195
+ to_vertex2vertex = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0)
196
+ weight_vertex2vertex = np.sqrt(((vertices[source_vertex2vertex] - vertices[to_vertex2vertex])**2).sum(axis=-1))
197
+ graph = csr_matrix(
198
+ (np.concatenate([weight_close, weight_vertex2vertex, weight_grid2grid, weight_grid2vertex]),
199
+ (
200
+ np.concatenate([source_close, source_vertex2vertex, source_grid2grid, source_grid2vertex], axis=0),
201
+ np.concatenate([to_close, to_vertex2vertex, to_grid2grid, to_grid2vertex], axis=0)),
202
+ ),
203
+ shape=(N+M, N+M),
204
+ )
205
+
206
+ # get shortest path (J, N+M)
207
+ dist_matrix = shortest_path(graph, method='D', directed=False, indices=joint_indices)
208
+
209
+ # (sum_J, N)
210
+ dis_vertex2bone = dist_matrix[:, :N]
211
+ unreachable = np.isinf(dis_vertex2bone).all(axis=0)
212
+ k = min(J, 3)
213
+ dist, idx = joint_tree.query(vertices[unreachable], k)
214
+
215
+ # make sure at least one value in dis is not inf
216
+ unreachable_indices = np.where(unreachable)[0]
217
+ row_indices = idx
218
+ col_indices = np.repeat(unreachable_indices, k).reshape(-1, k)
219
+ dis_vertex2bone[row_indices, col_indices] = dist
220
+
221
+ finite_vals = dis_vertex2bone[np.isfinite(dis_vertex2bone)]
222
+ max_dis = np.max(finite_vals)
223
+ dis_vertex2bone = np.nan_to_num(dis_vertex2bone, nan=max_dis, posinf=max_dis, neginf=max_dis)
224
+ dis_vertex2bone = np.maximum(dis_vertex2bone, 1e-6)
225
+
226
+ # turn dis2bone to dis2vertex
227
+ dis_vertex2joint = np.full((joints.shape[0], vertices.shape[0]), max_dis)
228
+ for i in range(len(dis_vertex2bone)):
229
+ dis_vertex2joint[joints_map[i]] = np.minimum(dis_vertex2bone[i], dis_vertex2joint[joints_map[i]])
230
+
231
+ # (J, N)
232
+ if mode == 'exp':
233
+ skin = np.exp(-dis_vertex2joint / max_dis * 20.0)
234
+ elif mode == 'square':
235
+ skin = (1./((1-alpha)*dis_vertex2joint + alpha*dis_vertex2joint**2))**2
236
+ else:
237
+ assert False, f'invalid mode: {mode}'
238
+ skin = skin / skin.sum(axis=0)
239
+ # (N, J)
240
+ skin = skin.transpose()
241
+ return skin
242
+
243
+ def get_vertex_groups(*args) -> List[VertexGroup]:
244
+ vertex_groups = []
245
+ MAP = {
246
+ 'skin': VertexGroupSkin,
247
+ 'voxel_skin': VertexGroupVoxelSkin,
248
+ }
249
+ MAP: Dict[str, type[VertexGroup]]
250
+ for (i, c) in enumerate(args):
251
+ __target__ = c.get('__target__')
252
+ assert __target__ is not None, f"do not find `__target__` in config of vertex_groups of position {i}"
253
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
254
+ c = deepcopy(c)
255
+ del c['__target__']
256
+ vertex_groups.append(MAP[__target__].parse(**c))
257
+ return vertex_groups
src/model/__init__.py ADDED
File without changes
src/model/michelangelo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
src/model/michelangelo/get_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder
4
+
5
+ def get_encoder(
6
+ pretrained_path: str=None,
7
+ freeze_decoder: bool=False,
8
+ **kwargs
9
+ ) -> AlignedShapeLatentPerceiver:
10
+ model = AlignedShapeLatentPerceiver(**kwargs)
11
+ if pretrained_path is not None:
12
+ state_dict = torch.load(pretrained_path, weights_only=True)
13
+ model.load_state_dict(state_dict)
14
+ if freeze_decoder:
15
+ model.geo_decoder.requires_grad_(False)
16
+ model.encoder.query.requires_grad_(False)
17
+ model.pre_kl.requires_grad_(False)
18
+ model.post_kl.requires_grad_(False)
19
+ model.transformer.requires_grad_(False)
20
+ return model
21
+
22
+ def get_encoder_simplified(
23
+ pretrained_path: str=None,
24
+ **kwargs
25
+ ) -> ShapeAsLatentPerceiverEncoder:
26
+ model = ShapeAsLatentPerceiverEncoder(**kwargs)
27
+ if pretrained_path is not None:
28
+ state_dict = torch.load(pretrained_path, weights_only=True)
29
+ model.load_state_dict(state_dict)
30
+ return model
src/model/michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
src/model/michelangelo/models/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import checkpoint
src/model/michelangelo/models/modules/checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4
+ """
5
+
6
+ import torch
7
+ from typing import Callable, Iterable, Sequence, Union
8
+
9
+
10
+ def checkpoint(
11
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12
+ inputs: Sequence[torch.Tensor],
13
+ params: Iterable[torch.Tensor],
14
+ flag: bool,
15
+ use_deepspeed: bool = False
16
+ ):
17
+ """
18
+ Evaluate a function without caching intermediate activations, allowing for
19
+ reduced memory at the expense of extra compute in the backward pass.
20
+ :param func: the function to evaluate.
21
+ :param inputs: the argument sequence to pass to `func`.
22
+ :param params: a sequence of parameters `func` depends on but does not
23
+ explicitly take as arguments.
24
+ :param flag: if False, disable gradient checkpointing.
25
+ :param use_deepspeed: if True, use deepspeed
26
+ """
27
+ if flag:
28
+ if use_deepspeed:
29
+ import deepspeed
30
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
31
+
32
+ args = tuple(inputs) + tuple(params)
33
+ return CheckpointFunction.apply(func, len(inputs), *args)
34
+ else:
35
+ return func(*inputs)
36
+
37
+
38
+ class CheckpointFunction(torch.autograd.Function):
39
+ @staticmethod
40
+ @torch.amp.custom_fwd(device_type='cuda')
41
+ def forward(ctx, run_function, length, *args):
42
+ ctx.run_function = run_function
43
+ ctx.input_tensors = list(args[:length])
44
+ ctx.input_params = list(args[length:])
45
+
46
+ with torch.no_grad():
47
+ output_tensors = ctx.run_function(*ctx.input_tensors)
48
+ return output_tensors
49
+
50
+ @staticmethod
51
+ @torch.amp.custom_bwd(device_type='cuda')
52
+ def backward(ctx, *output_grads):
53
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54
+ with torch.enable_grad():
55
+ # Fixes a bug where the first op in run_function modifies the
56
+ # Tensor storage in place, which is not allowed for detach()'d
57
+ # Tensors.
58
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59
+ output_tensors = ctx.run_function(*shallow_copies)
60
+ input_grads = torch.autograd.grad(
61
+ output_tensors,
62
+ ctx.input_tensors + ctx.input_params,
63
+ output_grads,
64
+ allow_unused=True,
65
+ )
66
+ del ctx.input_tensors
67
+ del ctx.input_params
68
+ del output_tensors
69
+ return (None, None) + input_grads
src/model/michelangelo/models/modules/distributions.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union, List
4
+
5
+
6
+ class AbstractDistribution(object):
7
+ def sample(self):
8
+ raise NotImplementedError()
9
+
10
+ def mode(self):
11
+ raise NotImplementedError()
12
+
13
+
14
+ class DiracDistribution(AbstractDistribution):
15
+ def __init__(self, value):
16
+ self.value = value
17
+
18
+ def sample(self):
19
+ return self.value
20
+
21
+ def mode(self):
22
+ return self.value
23
+
24
+
25
+ class DiagonalGaussianDistribution(object):
26
+ def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27
+ self.feat_dim = feat_dim
28
+ self.parameters = parameters
29
+
30
+ if isinstance(parameters, list):
31
+ self.mean = parameters[0]
32
+ self.logvar = parameters[1]
33
+ else:
34
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35
+
36
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37
+ self.deterministic = deterministic
38
+ self.std = torch.exp(0.5 * self.logvar)
39
+ self.var = torch.exp(self.logvar)
40
+ if self.deterministic:
41
+ self.var = self.std = torch.zeros_like(self.mean)
42
+
43
+ def sample(self):
44
+ x = self.mean + self.std * torch.randn_like(self.mean)
45
+ return x
46
+
47
+ def kl(self, other=None, dims=(1, 2, 3)):
48
+ if self.deterministic:
49
+ return torch.Tensor([0.])
50
+ else:
51
+ if other is None:
52
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
53
+ + self.var - 1.0 - self.logvar,
54
+ dim=dims)
55
+ else:
56
+ return 0.5 * torch.mean(
57
+ torch.pow(self.mean - other.mean, 2) / other.var
58
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
59
+ dim=dims)
60
+
61
+ def nll(self, sample, dims=(1, 2, 3)):
62
+ if self.deterministic:
63
+ return torch.Tensor([0.])
64
+ logtwopi = np.log(2.0 * np.pi)
65
+ return 0.5 * torch.sum(
66
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67
+ dim=dims)
68
+
69
+ def mode(self):
70
+ return self.mean
71
+
72
+
73
+ def normal_kl(mean1, logvar1, mean2, logvar2):
74
+ """
75
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76
+ Compute the KL divergence between two gaussians.
77
+ Shapes are automatically broadcasted, so batches can be compared to
78
+ scalars, among other use cases.
79
+ """
80
+ tensor = None
81
+ for obj in (mean1, logvar1, mean2, logvar2):
82
+ if isinstance(obj, torch.Tensor):
83
+ tensor = obj
84
+ break
85
+ assert tensor is not None, "at least one argument must be a Tensor"
86
+
87
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
88
+ # Tensors, but it does not work for torch.exp().
89
+ logvar1, logvar2 = [
90
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91
+ for x in (logvar1, logvar2)
92
+ ]
93
+
94
+ return 0.5 * (
95
+ -1.0
96
+ + logvar2
97
+ - logvar1
98
+ + torch.exp(logvar1 - logvar2)
99
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100
+ )
src/model/michelangelo/models/modules/embedder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+ VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9
+
10
+
11
+ class FourierEmbedder(nn.Module):
12
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13
+ each feature dimension of `x[..., i]` into:
14
+ [
15
+ sin(x[..., i]),
16
+ sin(f_1*x[..., i]),
17
+ sin(f_2*x[..., i]),
18
+ ...
19
+ sin(f_N * x[..., i]),
20
+ cos(x[..., i]),
21
+ cos(f_1*x[..., i]),
22
+ cos(f_2*x[..., i]),
23
+ ...
24
+ cos(f_N * x[..., i]),
25
+ x[..., i] # only present if include_input is True.
26
+ ], here f_i is the frequency.
27
+
28
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31
+
32
+ Args:
33
+ num_freqs (int): the number of frequencies, default is 6;
34
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36
+ input_dim (int): the input dimension, default is 3;
37
+ include_input (bool): include the input tensor or not, default is True.
38
+
39
+ Attributes:
40
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42
+
43
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44
+ otherwise, it is input_dim * num_freqs * 2.
45
+
46
+ """
47
+
48
+ def __init__(self,
49
+ num_freqs: int = 6,
50
+ logspace: bool = True,
51
+ input_dim: int = 3,
52
+ include_input: bool = True,
53
+ include_pi: bool = True) -> None:
54
+
55
+ """The initialization"""
56
+
57
+ super().__init__()
58
+
59
+ if logspace:
60
+ frequencies = 2.0 ** torch.arange(
61
+ num_freqs,
62
+ dtype=torch.float32
63
+ )
64
+ else:
65
+ frequencies = torch.linspace(
66
+ 1.0,
67
+ 2.0 ** (num_freqs - 1),
68
+ num_freqs,
69
+ dtype=torch.float32
70
+ )
71
+
72
+ if include_pi:
73
+ frequencies *= torch.pi
74
+
75
+ self.register_buffer("frequencies", frequencies, persistent=False)
76
+ self.include_input = include_input
77
+ self.num_freqs = num_freqs
78
+
79
+ self.out_dim = self.get_dims(input_dim)
80
+
81
+ def get_dims(self, input_dim):
82
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
83
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
84
+
85
+ return out_dim
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ """ Forward process.
89
+
90
+ Args:
91
+ x: tensor of shape [..., dim]
92
+
93
+ Returns:
94
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95
+ where temp is 1 if include_input is True and 0 otherwise.
96
+ """
97
+
98
+ if self.num_freqs > 0:
99
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100
+ if self.include_input:
101
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102
+ else:
103
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
104
+ else:
105
+ return x
106
+
107
+
108
+ class LearnedFourierEmbedder(nn.Module):
109
+ """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111
+
112
+ def __init__(self, in_channels, dim):
113
+ super().__init__()
114
+ assert (dim % 2) == 0
115
+ half_dim = dim // 2
116
+ per_channel_dim = half_dim // in_channels
117
+ self.weights = nn.Parameter(torch.randn(per_channel_dim))
118
+
119
+ def forward(self, x):
120
+ """
121
+
122
+ Args:
123
+ x (torch.FloatTensor): [..., c]
124
+
125
+ Returns:
126
+ x (torch.FloatTensor): [..., d]
127
+ """
128
+
129
+ # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130
+ freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131
+ fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132
+ return fouriered
133
+
134
+
135
+ class TriplaneLearnedFourierEmbedder(nn.Module):
136
+ def __init__(self, in_channels, dim):
137
+ super().__init__()
138
+
139
+ self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140
+ self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141
+ self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142
+
143
+ self.out_dim = in_channels + dim
144
+
145
+ def forward(self, x):
146
+
147
+ yz_embed = self.yz_plane_embedder(x)
148
+ xz_embed = self.xz_plane_embedder(x)
149
+ xy_embed = self.xy_plane_embedder(x)
150
+
151
+ embed = yz_embed + xz_embed + xy_embed
152
+
153
+ return embed
154
+
155
+
156
+ def sequential_pos_embed(num_len, embed_dim):
157
+ assert embed_dim % 2 == 0
158
+
159
+ pos = torch.arange(num_len, dtype=torch.float32)
160
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161
+ omega /= embed_dim / 2.
162
+ omega = 1. / 10000 ** omega # (D/2,)
163
+
164
+ pos = pos.reshape(-1) # (M,)
165
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166
+
167
+ emb_sin = torch.sin(out) # (M, D/2)
168
+ emb_cos = torch.cos(out) # (M, D/2)
169
+
170
+ embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171
+
172
+ return embeddings
173
+
174
+
175
+ def timestep_embedding(timesteps, dim, max_period=10000):
176
+ """
177
+ Create sinusoidal timestep embeddings.
178
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
179
+ These may be fractional.
180
+ :param dim: the dimension of the output.
181
+ :param max_period: controls the minimum frequency of the embeddings.
182
+ :return: an [N x dim] Tensor of positional embeddings.
183
+ """
184
+ half = dim // 2
185
+ freqs = torch.exp(
186
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187
+ ).to(device=timesteps.device)
188
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190
+ if dim % 2:
191
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192
+ return embedding
193
+
194
+
195
+ def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196
+ num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197
+ log2_hashmap_size=19, desired_resolution=None):
198
+ if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199
+ return nn.Identity(), input_dim
200
+
201
+ elif embed_type == "fourier":
202
+ embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203
+ logspace=True, include_input=True)
204
+ return embedder_obj, embedder_obj.out_dim
205
+
206
+ elif embed_type == "hashgrid":
207
+ raise NotImplementedError
208
+
209
+ elif embed_type == "sphere_harmonic":
210
+ raise NotImplementedError
211
+
212
+ else:
213
+ raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
src/model/michelangelo/models/modules/transformer_blocks.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional
8
+ import os
9
+
10
+ from .checkpoint import checkpoint
11
+ from ...utils.misc import use_flash3
12
+
13
+
14
+ if use_flash3.is_use:
15
+ from flash_attn_interface import flash_attn_func
16
+ print("use flash attention 3.")
17
+ else:
18
+ print("use flash attention 2.")
19
+
20
+ def init_linear(l, stddev):
21
+ nn.init.normal_(l.weight, std=stddev)
22
+ if l.bias is not None:
23
+ nn.init.constant_(l.bias, 0.0)
24
+
25
+ def flash_attention(q, k, v):
26
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
27
+ if use_flash3.is_use:
28
+ out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous())
29
+ # out = flash_attn_func(q, k, v)
30
+
31
+ # q_ = q.transpose(1, 2)
32
+ # k_ = k.transpose(1, 2)
33
+ # v_ = v.transpose(1, 2)
34
+
35
+ # # print(q.shape, k.shape, v.shape)
36
+ # out_ = F.scaled_dot_product_attention(q_, k_, v_)
37
+ # out_ = out_.transpose(1, 2)
38
+
39
+ # # print(torch.abs(out - out_).mean())
40
+ # assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large"
41
+
42
+ # out = out_
43
+
44
+ # print("use flash_atten 3")
45
+ else:
46
+ q = q.transpose(1, 2)
47
+ k = k.transpose(1, 2)
48
+ v = v.transpose(1, 2)
49
+ out = F.scaled_dot_product_attention(q, k, v)
50
+ out = out.transpose(1, 2)
51
+ # print("use flash atten 2")
52
+
53
+ return out
54
+
55
+ class MultiheadAttention(nn.Module):
56
+ def __init__(
57
+ self,
58
+ *,
59
+ device: torch.device,
60
+ dtype: torch.dtype,
61
+ n_ctx: int,
62
+ width: int,
63
+ heads: int,
64
+ init_scale: float,
65
+ qkv_bias: bool,
66
+ flash: bool = False
67
+ ):
68
+ super().__init__()
69
+ self.n_ctx = n_ctx
70
+ self.width = width
71
+ self.heads = heads
72
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
73
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
74
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
75
+ init_linear(self.c_qkv, init_scale)
76
+ init_linear(self.c_proj, init_scale)
77
+
78
+ def forward(self, x):
79
+ x = self.c_qkv(x)
80
+ x = checkpoint(self.attention, (x,), (), False)
81
+ x = self.c_proj(x)
82
+ return x
83
+
84
+
85
+ class QKVMultiheadAttention(nn.Module):
86
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
87
+ super().__init__()
88
+ self.device = device
89
+ self.dtype = dtype
90
+ self.heads = heads
91
+ self.n_ctx = n_ctx
92
+ self.flash = flash
93
+
94
+ def forward(self, qkv):
95
+ bs, n_ctx, width = qkv.shape
96
+ attn_ch = width // self.heads // 3
97
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
98
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
99
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
100
+
101
+ if self.flash:
102
+ out = flash_attention(q, k, v)
103
+ out = out.reshape(out.shape[0], out.shape[1], -1)
104
+ else:
105
+ weight = torch.einsum(
106
+ "bthc,bshc->bhts", q * scale, k * scale
107
+ ) # More stable with f16 than dividing afterwards
108
+ wdtype = weight.dtype
109
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
110
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
111
+
112
+ return out
113
+
114
+
115
+ class ResidualAttentionBlock(nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ device: torch.device,
120
+ dtype: torch.dtype,
121
+ n_ctx: int,
122
+ width: int,
123
+ heads: int,
124
+ init_scale: float = 1.0,
125
+ qkv_bias: bool = True,
126
+ flash: bool = False,
127
+ use_checkpoint: bool = False
128
+ ):
129
+ super().__init__()
130
+
131
+ self.use_checkpoint = use_checkpoint
132
+
133
+ self.attn = MultiheadAttention(
134
+ device=device,
135
+ dtype=dtype,
136
+ n_ctx=n_ctx,
137
+ width=width,
138
+ heads=heads,
139
+ init_scale=init_scale,
140
+ qkv_bias=qkv_bias,
141
+ flash=flash
142
+ )
143
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
144
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
145
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
146
+
147
+ def _forward(self, x: torch.Tensor):
148
+ x = x + self.attn(self.ln_1(x))
149
+ x = x + self.mlp(self.ln_2(x))
150
+ return x
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
154
+
155
+
156
+ class MultiheadCrossAttention(nn.Module):
157
+ def __init__(
158
+ self,
159
+ *,
160
+ device: torch.device,
161
+ dtype: torch.dtype,
162
+ width: int,
163
+ heads: int,
164
+ init_scale: float,
165
+ qkv_bias: bool = True,
166
+ flash: bool = False,
167
+ n_data: Optional[int] = None,
168
+ data_width: Optional[int] = None,
169
+ ):
170
+ super().__init__()
171
+ self.n_data = n_data
172
+ self.width = width
173
+ self.heads = heads
174
+ self.data_width = width if data_width is None else data_width
175
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
176
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
177
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
178
+ self.attention = QKVMultiheadCrossAttention(
179
+ device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
180
+ )
181
+ init_linear(self.c_q, init_scale)
182
+ init_linear(self.c_kv, init_scale)
183
+ init_linear(self.c_proj, init_scale)
184
+
185
+ def forward(self, x, data):
186
+ x = self.c_q(x)
187
+ data = self.c_kv(data)
188
+ x = checkpoint(self.attention, (x, data), (), False)
189
+ x = self.c_proj(x)
190
+ return x
191
+
192
+
193
+ class QKVMultiheadCrossAttention(nn.Module):
194
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
195
+ flash: bool = False, n_data: Optional[int] = None):
196
+
197
+ super().__init__()
198
+ self.device = device
199
+ self.dtype = dtype
200
+ self.heads = heads
201
+ self.n_data = n_data
202
+ self.flash = flash
203
+
204
+ def forward(self, q, kv):
205
+ _, n_ctx, _ = q.shape
206
+ bs, n_data, width = kv.shape
207
+ attn_ch = width // self.heads // 2
208
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
209
+ q = q.view(bs, n_ctx, self.heads, -1)
210
+ kv = kv.view(bs, n_data, self.heads, -1)
211
+ k, v = torch.split(kv, attn_ch, dim=-1)
212
+
213
+ if self.flash:
214
+ out = flash_attention(q, k, v)
215
+ out = out.reshape(out.shape[0], out.shape[1], -1)
216
+ else:
217
+ weight = torch.einsum(
218
+ "bthc,bshc->bhts", q * scale, k * scale
219
+ ) # More stable with f16 than dividing afterwards
220
+ wdtype = weight.dtype
221
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
222
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
223
+
224
+ return out
225
+
226
+
227
+ class ResidualCrossAttentionBlock(nn.Module):
228
+ def __init__(
229
+ self,
230
+ *,
231
+ device: Optional[torch.device],
232
+ dtype: Optional[torch.dtype],
233
+ n_data: Optional[int] = None,
234
+ width: int,
235
+ heads: int,
236
+ data_width: Optional[int] = None,
237
+ mlp_width_scale: int = 4,
238
+ init_scale: float = 0.25,
239
+ qkv_bias: bool = True,
240
+ flash: bool = False
241
+ ):
242
+ super().__init__()
243
+
244
+ if data_width is None:
245
+ data_width = width
246
+
247
+ self.attn = MultiheadCrossAttention(
248
+ device=device,
249
+ dtype=dtype,
250
+ n_data=n_data,
251
+ width=width,
252
+ heads=heads,
253
+ data_width=data_width,
254
+ init_scale=init_scale,
255
+ qkv_bias=qkv_bias,
256
+ flash=flash,
257
+ )
258
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
259
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
260
+ self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale)
261
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
262
+
263
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
264
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
265
+ x = x + self.mlp(self.ln_3(x))
266
+ return x
267
+
268
+
269
+ class MLP(nn.Module):
270
+ def __init__(self, *,
271
+ device: Optional[torch.device],
272
+ dtype: Optional[torch.dtype],
273
+ width: int,
274
+ hidden_width_scale: int = 4,
275
+ init_scale: float):
276
+ super().__init__()
277
+ self.width = width
278
+ self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype)
279
+ self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype)
280
+ self.gelu = nn.GELU()
281
+ init_linear(self.c_fc, init_scale)
282
+ init_linear(self.c_proj, init_scale)
283
+
284
+ def forward(self, x):
285
+ return self.c_proj(self.gelu(self.c_fc(x)))
286
+
287
+
288
+ class Transformer(nn.Module):
289
+ def __init__(
290
+ self,
291
+ *,
292
+ device: Optional[torch.device],
293
+ dtype: Optional[torch.dtype],
294
+ n_ctx: int,
295
+ width: int,
296
+ layers: int,
297
+ heads: int,
298
+ init_scale: float = 0.25,
299
+ qkv_bias: bool = True,
300
+ flash: bool = False,
301
+ use_checkpoint: bool = False
302
+ ):
303
+ super().__init__()
304
+ self.n_ctx = n_ctx
305
+ self.width = width
306
+ self.layers = layers
307
+ self.resblocks = nn.ModuleList(
308
+ [
309
+ ResidualAttentionBlock(
310
+ device=device,
311
+ dtype=dtype,
312
+ n_ctx=n_ctx,
313
+ width=width,
314
+ heads=heads,
315
+ init_scale=init_scale,
316
+ qkv_bias=qkv_bias,
317
+ flash=flash,
318
+ use_checkpoint=use_checkpoint
319
+ )
320
+ for _ in range(layers)
321
+ ]
322
+ )
323
+
324
+ def forward(self, x: torch.Tensor):
325
+ for block in self.resblocks:
326
+ x = block(x)
327
+ return x
src/model/michelangelo/models/tsal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
src/model/michelangelo/models/tsal/loss.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from typing import Optional, Tuple, Dict
7
+
8
+ from ..modules.distributions import DiagonalGaussianDistribution
9
+ from ...utils.eval import compute_psnr
10
+ from ...utils import misc
11
+ import numpy as np
12
+ from copy import deepcopy
13
+
14
+
15
+ def logits_to_sdf(logits):
16
+ return torch.sigmoid(logits) * 2 - 1
17
+
18
+ class KLNearFar(nn.Module):
19
+ def __init__(self,
20
+ near_weight: float = 0.1,
21
+ kl_weight: float = 1.0,
22
+ num_near_samples: Optional[int] = None):
23
+
24
+ super().__init__()
25
+
26
+ self.near_weight = near_weight
27
+ self.kl_weight = kl_weight
28
+ self.num_near_samples = num_near_samples
29
+ self.geo_criterion = nn.BCEWithLogitsLoss()
30
+
31
+ def forward(self,
32
+ posteriors: Optional[DiagonalGaussianDistribution],
33
+ logits: torch.FloatTensor,
34
+ labels: torch.FloatTensor,
35
+ split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
36
+
37
+ """
38
+
39
+ Args:
40
+ posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
41
+ logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
42
+ labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
43
+ split (str):
44
+ **kwargs:
45
+
46
+ Returns:
47
+ loss (torch.Tensor): (,)
48
+ log (dict):
49
+
50
+ """
51
+
52
+ if self.num_near_samples is None:
53
+ num_vol = logits.shape[1] // 2
54
+ else:
55
+ num_vol = logits.shape[1] - self.num_near_samples
56
+
57
+ vol_logits = logits[:, 0:num_vol]
58
+ vol_labels = labels[:, 0:num_vol]
59
+
60
+ near_logits = logits[:, num_vol:]
61
+ near_labels = labels[:, num_vol:]
62
+
63
+ # occupancy loss
64
+ # vol_bce = self.geo_criterion(vol_logits, vol_labels)
65
+ # near_bce = self.geo_criterion(near_logits, near_labels)
66
+ vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
67
+ near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
68
+
69
+ if posteriors is None:
70
+ kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
71
+ else:
72
+ kl_loss = posteriors.kl(dims=(1, 2))
73
+ kl_loss = torch.mean(kl_loss)
74
+
75
+ loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
76
+
77
+ with torch.no_grad():
78
+ preds = logits >= 0
79
+ accuracy = (preds == labels).float()
80
+ accuracy = accuracy.mean()
81
+ pos_ratio = torch.mean(labels)
82
+
83
+ log = {
84
+ "{}/total_loss".format(split): loss.clone().detach(),
85
+ "{}/near".format(split): near_bce.detach(),
86
+ "{}/far".format(split): vol_bce.detach(),
87
+ "{}/kl".format(split): kl_loss.detach(),
88
+ "{}/accuracy".format(split): accuracy,
89
+ "{}/pos_ratio".format(split): pos_ratio
90
+ }
91
+
92
+ if posteriors is not None:
93
+ log[f"{split}/mean"] = posteriors.mean.mean().detach()
94
+ log[f"{split}/std_mean"] = posteriors.std.mean().detach()
95
+ log[f"{split}/std_max"] = posteriors.std.max().detach()
96
+
97
+ return loss, log
98
+
99
+
100
+ class KLNearFarColor(nn.Module):
101
+ def __init__(self,
102
+ near_weight: float = 0.1,
103
+ kl_weight: float = 1.0,
104
+ color_weight: float = 1.0,
105
+ color_criterion: str = "mse",
106
+ num_near_samples: Optional[int] = None):
107
+
108
+ super().__init__()
109
+
110
+ self.color_weight = color_weight
111
+ self.near_weight = near_weight
112
+ self.kl_weight = kl_weight
113
+ self.num_near_samples = num_near_samples
114
+
115
+ if color_criterion == "mse":
116
+ self.color_criterion = nn.MSELoss()
117
+
118
+ elif color_criterion == "l1":
119
+ self.color_criterion = nn.L1Loss()
120
+
121
+ else:
122
+ raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
123
+
124
+ self.geo_criterion = nn.BCEWithLogitsLoss()
125
+
126
+ def forward(self,
127
+ posteriors: Optional[DiagonalGaussianDistribution],
128
+ logits: torch.FloatTensor,
129
+ labels: torch.FloatTensor,
130
+ pred_colors: torch.FloatTensor,
131
+ gt_colors: torch.FloatTensor,
132
+ split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
133
+
134
+ """
135
+
136
+ Args:
137
+ posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
138
+ logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
139
+ labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
140
+ pred_colors (torch.FloatTensor): [B, M, 3]
141
+ gt_colors (torch.FloatTensor): [B, M, 3]
142
+ split (str):
143
+ **kwargs:
144
+
145
+ Returns:
146
+ loss (torch.Tensor): (,)
147
+ log (dict):
148
+
149
+ """
150
+
151
+ if self.num_near_samples is None:
152
+ num_vol = logits.shape[1] // 2
153
+ else:
154
+ num_vol = logits.shape[1] - self.num_near_samples
155
+
156
+ vol_logits = logits[:, 0:num_vol]
157
+ vol_labels = labels[:, 0:num_vol]
158
+
159
+ near_logits = logits[:, num_vol:]
160
+ near_labels = labels[:, num_vol:]
161
+
162
+ # occupancy loss
163
+ # vol_bce = self.geo_criterion(vol_logits, vol_labels)
164
+ # near_bce = self.geo_criterion(near_logits, near_labels)
165
+ vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
166
+ near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
167
+
168
+ # surface color loss
169
+ color = self.color_criterion(pred_colors, gt_colors)
170
+
171
+ if posteriors is None:
172
+ kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
173
+ else:
174
+ kl_loss = posteriors.kl(dims=(1, 2))
175
+ kl_loss = torch.mean(kl_loss)
176
+
177
+ loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
178
+
179
+ with torch.no_grad():
180
+ preds = logits >= 0
181
+ accuracy = (preds == labels).float()
182
+ accuracy = accuracy.mean()
183
+ psnr = compute_psnr(pred_colors, gt_colors)
184
+
185
+ log = {
186
+ "{}/total_loss".format(split): loss.clone().detach(),
187
+ "{}/near".format(split): near_bce.detach(),
188
+ "{}/far".format(split): vol_bce.detach(),
189
+ "{}/color".format(split): color.detach(),
190
+ "{}/kl".format(split): kl_loss.detach(),
191
+ "{}/psnr".format(split): psnr.detach(),
192
+ "{}/accuracy".format(split): accuracy
193
+ }
194
+
195
+ return loss, log
196
+
197
+
198
+ class ContrastKLNearFar(nn.Module):
199
+ def __init__(self,
200
+ contrast_weight: float = 1.0,
201
+ near_weight: float = 0.1,
202
+ kl_weight: float = 1.0,
203
+ normal_weight: float = 0.0,
204
+ surface_weight: float = 0.0,
205
+ eikonal_weight: float = 0.0,
206
+ sdf_bce_weight: float = 0.0,
207
+ sdf_l1l2_weight: float = 1.0,
208
+ num_near_samples: Optional[int] = None,
209
+ sdf_trunc_val: float = 0.05,
210
+ gt_sdf_soft: bool = False,
211
+ normal_supervision_type: str = "cosine",
212
+ supervision_type: str = 'occupancy'):
213
+
214
+ super().__init__()
215
+
216
+ self.labels = None
217
+ self.last_local_batch_size = None
218
+ self.supervision_type = supervision_type
219
+
220
+ assert normal_supervision_type in ["l1", "l2", "cosine", "l1_cosine", "l2_cosine", "von_mises"]
221
+ self.normal_supervision_type = normal_supervision_type
222
+
223
+ self.contrast_weight = contrast_weight
224
+ self.near_weight = near_weight
225
+ self.kl_weight = kl_weight
226
+ self.normal_weight = normal_weight
227
+ self.surface_weight = surface_weight
228
+ self.eikonal_weight = eikonal_weight
229
+ self.sdf_bce_weight = sdf_bce_weight # only used in sigmoid-sdf
230
+ self.sdf_l1l2_weight = sdf_l1l2_weight # only used in sigmoid-sdf
231
+ self.sdf_trunc_val = sdf_trunc_val
232
+ self.gt_sdf_soft = gt_sdf_soft
233
+ self.num_near_samples = num_near_samples
234
+ self.geo_criterion = nn.BCEWithLogitsLoss()
235
+ self.geo_criterion_sdf = nn.MSELoss()
236
+
237
+ def sdf_loss(self, pred_sdf, gt_sdf):
238
+ scaled_sdf = gt_sdf / self.sdf_trunc_val
239
+ greater_mask = scaled_sdf > 1.
240
+ smaller_mask = scaled_sdf < -1.
241
+ inside_mask = 1. - greater_mask - smaller_mask
242
+ greater_loss = F.smooth_l1_loss(F.relu(1. - pred_sdf), torch.zeros_like(pred_sdf), reduction="none") * greater_mask
243
+ smaller_loss = F.smooth_l1_loss(F.relu(pred_sdf + 1.), torch.zeros_like(pred_sdf), reduction="none") * smaller_mask
244
+ inside_loss = F.smooth_l1_loss(pred_sdf, gt_sdf, beta=1e-2, reduction="none") * inside_mask
245
+ loss = (greater_loss + smaller_loss + inside_loss).mean()
246
+ return loss
247
+
248
+ def von_mises(self, x, y, k=1):
249
+ cos = F.cosine_similarity(x, y, dim=-1)
250
+ exp = torch.exp(k * (cos - 1))
251
+ return 1 - exp
252
+
253
+ def forward(self,
254
+ shape_embed: torch.FloatTensor,
255
+ text_embed: torch.FloatTensor,
256
+ image_embed: torch.FloatTensor,
257
+ logit_scale: torch.FloatTensor,
258
+ posteriors: Optional[DiagonalGaussianDistribution],
259
+ latents: torch.FloatTensor,
260
+ shape_logits: torch.FloatTensor,
261
+ shape_labels: torch.FloatTensor,
262
+ surface_logits: Optional[torch.FloatTensor],
263
+ surface_normals: Optional[torch.FloatTensor],
264
+ gt_surface_normals: Optional[torch.FloatTensor],
265
+ split: Optional[str] = "train", **kwargs):
266
+ if self.supervision_type == 'occupancy':
267
+ shape_logits = shape_logits.squeeze(-1)
268
+ shape_labels[shape_labels>=0] = 1
269
+ shape_labels[shape_labels<0] = 0
270
+
271
+ elif self.supervision_type == 'occupancy-shapenet':
272
+ shape_logits = shape_logits.squeeze(-1)
273
+
274
+ elif self.supervision_type == 'occupancy-w-surface':
275
+ shape_logits = shape_logits.squeeze(-1)
276
+ shape_labels[shape_labels==10] = 0
277
+ shape_labels[shape_labels>0] = 1
278
+ shape_labels[shape_labels<0] = 0
279
+
280
+ elif 'sdf' in self.supervision_type:
281
+ shape_logits = shape_logits.squeeze(-1)
282
+ if self.gt_sdf_soft:
283
+ shape_labels_sdf = torch.tanh(shape_labels / self.sdf_trunc_val)# * self.sdf_trunc_val
284
+ else:
285
+ shape_labels_sdf = torch.clamp(shape_labels, min=-self.sdf_trunc_val, max=self.sdf_trunc_val) / self.sdf_trunc_val
286
+ else:
287
+ raise ValueError(f"Invalid supervision_type {self.supervision_type}")
288
+
289
+ local_batch_size = shape_embed.size(0)
290
+
291
+ if local_batch_size != self.last_local_batch_size:
292
+ self.labels = local_batch_size * misc.get_rank() + torch.arange(
293
+ local_batch_size, device=shape_embed.device
294
+ ).long()
295
+ self.last_local_batch_size = local_batch_size
296
+
297
+
298
+ if text_embed is not None and image_embed is not None:
299
+ # normalized features
300
+ shape_embed = F.normalize(shape_embed, dim=-1, p=2)
301
+ text_embed = F.normalize(text_embed, dim=-1, p=2)
302
+ image_embed = F.normalize(image_embed, dim=-1, p=2)
303
+
304
+ # gather features from all GPUs
305
+ shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
306
+ [shape_embed, text_embed, image_embed]
307
+ )
308
+
309
+ # cosine similarity as logits
310
+ logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
311
+ logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
312
+ logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
313
+ logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
314
+ contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
315
+ F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
316
+ (F.cross_entropy(logits_per_shape_image, self.labels) +
317
+ F.cross_entropy(logits_per_image_shape, self.labels)) / 2
318
+ else:
319
+ contrast_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
320
+
321
+ # shape reconstruction
322
+ if self.num_near_samples is None:
323
+ num_vol = shape_logits.shape[1] // 2
324
+ else:
325
+ num_vol = shape_logits.shape[1] - self.num_near_samples
326
+
327
+ # occupancy/sdf loss
328
+ if self.supervision_type == 'occupancy' or self.supervision_type == 'occupancy-shapenet':
329
+ vol_logits = shape_logits[:, 0:num_vol]
330
+ vol_labels = shape_labels[:, 0:num_vol]
331
+
332
+ near_logits = shape_logits[:, num_vol:]
333
+ near_labels = shape_labels[:, num_vol:]
334
+
335
+ vol_loss = self.geo_criterion(vol_logits.float(), vol_labels.float())
336
+ near_loss = self.geo_criterion(near_logits.float(), near_labels.float())
337
+
338
+ elif 'sdf' in self.supervision_type:
339
+ if self.supervision_type == "sigmoid-sdf":
340
+ shape_sdfs = logits_to_sdf(shape_logits)
341
+ else:
342
+ shape_sdfs = shape_logits
343
+
344
+ vol_logits = shape_logits[:, 0:num_vol]
345
+ vol_sdfs = shape_sdfs[:, 0:num_vol]
346
+ vol_labels_sdf = shape_labels_sdf[:, 0:num_vol]
347
+
348
+ near_logits= shape_logits[:, num_vol:]
349
+ near_sdfs = shape_sdfs[:, num_vol:]
350
+ near_labels_sdf = shape_labels_sdf[:, num_vol:]
351
+
352
+ # use both sdf loss and occupancy loss
353
+ vol_loss = torch.mean(torch.abs(vol_sdfs - vol_labels_sdf)) + torch.mean((vol_sdfs - vol_labels_sdf) ** 2) #+ self.geo_criterion(vol_logits_sdf, vol_labels)
354
+ near_loss = torch.mean(torch.abs(near_sdfs - near_labels_sdf)) + torch.mean((near_sdfs - near_labels_sdf) ** 2) #+ self.geo_criterion(near_logits_sdf, near_labels)
355
+
356
+ if self.supervision_type == "sigmoid-sdf":
357
+ vol_labels = (vol_labels_sdf + 1) / 2
358
+ near_labels = (near_labels_sdf + 1) / 2
359
+ vol_loss = self.sdf_l1l2_weight * vol_loss + self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels)
360
+ near_loss = self.sdf_l1l2_weight * near_loss + self.sdf_bce_weight * self.geo_criterion(near_logits, near_labels)
361
+ # print(vol_loss, self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels))
362
+
363
+ # surface loss
364
+ if "sdf" in self.supervision_type and surface_logits is not None:
365
+ if self.supervision_type == "sigmoid-sdf":
366
+ surface_sdfs = logits_to_sdf(surface_logits)
367
+ else:
368
+ surface_sdfs = surface_logits
369
+ surface_loss = torch.mean(surface_sdfs ** 2)
370
+ else:
371
+ surface_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
372
+
373
+ if surface_normals is not None and gt_surface_normals is not None and "sdf" in self.supervision_type:
374
+
375
+ valid_mask = surface_sdfs.squeeze(-1) < (self.sdf_trunc_val * 0.8)
376
+
377
+ if valid_mask is not None:
378
+ surface_normals = surface_normals[valid_mask]
379
+ gt_surface_normals = gt_surface_normals[valid_mask]
380
+
381
+ # eikonal loss
382
+ surface_normals_norm = torch.norm(surface_normals, dim=-1)
383
+ eikonal_loss = F.mse_loss(surface_normals_norm * self.sdf_trunc_val, surface_normals_norm.new_ones(surface_normals_norm.shape), reduction="mean")
384
+
385
+ # surface normal loss
386
+ # surface_normals = F.normalize(surface_normals, dim=-1)
387
+ surface_normals = surface_normals * self.sdf_trunc_val
388
+ gt_surface_normals = F.normalize(gt_surface_normals, dim=-1)
389
+
390
+ if self.normal_supervision_type == "cosine":
391
+ # use cosine similarity loss
392
+ normal_loss = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
393
+ elif self.normal_supervision_type == "l1":
394
+ # use l1 loss
395
+ normal_loss = F.l1_loss(surface_normals, gt_surface_normals)
396
+ elif self.normal_supervision_type == "l2":
397
+ normal_loss = F.mse_loss(surface_normals, gt_surface_normals)
398
+ elif self.normal_supervision_type == "von_mises":
399
+ normal_loss = self.von_mises(surface_normals, gt_surface_normals).mean()
400
+ elif self.normal_supervision_type == "l1_cosine":
401
+ normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
402
+ normal_loss_l1 = F.l1_loss(surface_normals, gt_surface_normals)
403
+ normal_loss = normal_loss_cos + normal_loss_l1
404
+ elif self.normal_supervision_type == "l2_cosine":
405
+ normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean()
406
+ normal_loss_l2 = F.mse_loss(surface_normals, gt_surface_normals)
407
+ normal_loss = normal_loss_cos + normal_loss_l2
408
+ else:
409
+ raise NotImplementedError
410
+ else:
411
+ normal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
412
+ eikonal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
413
+ surface_normals_norm = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
414
+
415
+ if posteriors is None:
416
+ kl_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device)
417
+ else:
418
+ kl_loss = posteriors.kl(dims=(1, 2))
419
+ kl_loss = torch.mean(kl_loss)
420
+
421
+ loss = vol_loss + near_loss * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + normal_loss * self.normal_weight + self.eikonal_weight * eikonal_loss + self.surface_weight * surface_loss
422
+
423
+ # compute accuracy
424
+ with torch.no_grad():
425
+ if "sdf" in self.supervision_type:
426
+ preds = shape_sdfs >= 0
427
+ sdf_labels = shape_labels_sdf >= 0
428
+ accuracy = (preds == sdf_labels).float()
429
+ else:
430
+ preds = shape_logits >= 0
431
+ accuracy = (preds == shape_labels).float()
432
+ accuracy = accuracy.mean()
433
+
434
+ log = {
435
+ # "{}/contrast".format(split): contrast_loss.clone().detach(),
436
+ "{}/near".format(split): near_loss.detach(),
437
+ "{}/far".format(split): vol_loss.detach(),
438
+ "{}/normal".format(split): normal_loss.detach(),
439
+ "{}/surface".format(split): surface_loss.detach(),
440
+ "{}/eikonal".format(split): eikonal_loss.detach(),
441
+ "{}/kl".format(split): kl_loss.detach(),
442
+ "{}/surface_grad_norm".format(split): surface_normals_norm.mean().detach(),
443
+ # "{}/shape_text_acc".format(split): shape_text_acc,
444
+ # "{}/shape_image_acc".format(split): shape_image_acc,
445
+ "{}/total_loss".format(split): loss.clone().detach(),
446
+ "{}/accuracy".format(split): accuracy,
447
+ }
448
+
449
+ if posteriors is not None:
450
+ log[f"{split}/posteriors_mean"] = posteriors.mean.mean().detach()
451
+ log[f"{split}/posteriors_std_mean"] = posteriors.std.mean().detach()
452
+ log[f"{split}/posteriors_std_max"] = posteriors.std.max().detach()
453
+
454
+ return loss, log, near_loss
src/model/michelangelo/models/tsal/sal_perceiver.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional, Union
6
+ from einops import repeat
7
+ import math
8
+ import random
9
+ import time
10
+ import numpy as np
11
+
12
+ from ..modules import checkpoint
13
+ from ..modules.embedder import FourierEmbedder
14
+ from ..modules.distributions import DiagonalGaussianDistribution
15
+ from ..modules.transformer_blocks import (
16
+ ResidualCrossAttentionBlock,
17
+ Transformer
18
+ )
19
+ from ...utils.misc import use_flash3
20
+
21
+ from .tsal_base import ShapeAsLatentModule
22
+ from .loss import logits_to_sdf
23
+
24
+ from ....utils import fps
25
+
26
+ class CrossAttentionEncoder(nn.Module):
27
+
28
+ def __init__(self, *,
29
+ device: Optional[torch.device],
30
+ dtype: Optional[torch.dtype],
31
+ num_latents: int,
32
+ fourier_embedder: FourierEmbedder,
33
+ point_feats: int,
34
+ width: int,
35
+ heads: int,
36
+ layers: int,
37
+ init_scale: float = 0.25,
38
+ qkv_bias: bool = True,
39
+ flash: bool = False,
40
+ use_ln_post: bool = False,
41
+ use_checkpoint: bool = False,
42
+ query_method: bool = False,
43
+ use_full_input: bool = True,
44
+ token_num: int = 256,
45
+ no_query: bool=False):
46
+
47
+ super().__init__()
48
+
49
+ self.query_method = query_method
50
+ self.token_num = token_num
51
+ self.use_full_input = use_full_input
52
+
53
+ self.use_checkpoint = use_checkpoint
54
+ self.num_latents = num_latents
55
+
56
+ if no_query:
57
+ self.query = None
58
+ else:
59
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
60
+
61
+ self.fourier_embedder = fourier_embedder
62
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
63
+ self.cross_attn = ResidualCrossAttentionBlock(
64
+ device=device,
65
+ dtype=dtype,
66
+ width=width,
67
+ heads=heads,
68
+ init_scale=init_scale,
69
+ qkv_bias=qkv_bias,
70
+ flash=flash,
71
+ )
72
+
73
+ self.self_attn = Transformer(
74
+ device=device,
75
+ dtype=dtype,
76
+ n_ctx=num_latents,
77
+ width=width,
78
+ layers=layers,
79
+ heads=heads,
80
+ init_scale=init_scale,
81
+ qkv_bias=qkv_bias,
82
+ flash=flash,
83
+ use_checkpoint=False
84
+ )
85
+
86
+ if use_ln_post:
87
+ self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
88
+ else:
89
+ self.ln_post = None
90
+
91
+ def _forward(self, pc, feats):
92
+ """
93
+
94
+ Args:
95
+ pc (torch.FloatTensor): [B, N, 3]
96
+ feats (torch.FloatTensor or None): [B, N, C]
97
+
98
+ Returns:
99
+
100
+ """
101
+ if self.query_method:
102
+ token_num = self.num_latents
103
+ bs = pc.shape[0] #pc [10, 204800, 3]
104
+ data = self.fourier_embedder(pc) #[10, 204800, 51]
105
+ if feats is not None: #[10, 204800, 3]
106
+ data = torch.cat([data, feats], dim=-1)
107
+ data = self.input_proj(data) #[10, 204800, 768]
108
+
109
+ query = repeat(self.query, "m c -> b m c", b=bs) #[10, 257, 768]
110
+
111
+ latents = self.cross_attn(query, data)
112
+ latents = self.self_attn(latents)
113
+
114
+ if self.ln_post is not None:
115
+ latents = self.ln_post(latents)
116
+
117
+ pre_pc = None
118
+ else:
119
+
120
+ if isinstance(self.token_num, int):
121
+ token_num = self.token_num
122
+ else:
123
+ token_num = random.choice(self.token_num)
124
+ # print(token_num,'-----------------------', flush=True)
125
+
126
+ if self.training:
127
+ rng = np.random.default_rng()
128
+ else:
129
+ rng = np.random.default_rng(seed=0)
130
+ ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1])
131
+
132
+ pre_pc = pc[:,ind,:]
133
+ pre_feats = feats[:,ind,:]
134
+
135
+
136
+ B, N, D = pre_pc.shape #[10, 204800, 3]
137
+ C = pre_feats.shape[-1]
138
+ ###### fps
139
+ pos = pre_pc.view(B*N, D)
140
+ pos_feats = pre_feats.view(B*N, C)
141
+ batch = torch.arange(B).to(pc.device)
142
+ batch = torch.repeat_interleave(batch, N)
143
+
144
+ # ratio = 1.0 * token_num / N
145
+ idx = fps(pos, batch, ratio=1. / 4, random_start=self.training)
146
+
147
+ sampled_pc = pos[idx]
148
+ sampled_pc = sampled_pc.view(B, -1, 3)
149
+
150
+ sampled_feats = pos_feats[idx]
151
+ sampled_feats = sampled_feats.view(B, -1, C)
152
+
153
+ ######
154
+ if self.use_full_input:
155
+ data = self.fourier_embedder(pc) #[B, 20480, 51]
156
+ else:
157
+ data = self.fourier_embedder(pre_pc) # [B, 4 * token_num, 51]
158
+
159
+ if feats is not None: #[10, 204800, 3]
160
+ if not self.use_full_input:
161
+ feats = pre_feats
162
+ data = torch.cat([data, feats], dim=-1) #[10, 204800, 54]
163
+ data = self.input_proj(data) #[10, 204800, 768]
164
+
165
+ # print(data.shape)
166
+
167
+ sampled_data = self.fourier_embedder(sampled_pc) #[10, 256, 51]
168
+ if feats is not None: #[10, 256, 3]
169
+ sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) #[10, 256, 54]
170
+ sampled_data = self.input_proj(sampled_data) #[10, 256, 768]
171
+
172
+ latents = self.cross_attn(sampled_data, data) #[10, 256, 768]
173
+ latents = self.self_attn(latents)
174
+
175
+ if self.ln_post is not None:
176
+ latents = self.ln_post(latents)
177
+
178
+ pre_pc = torch.cat([pre_pc, pre_feats], dim=-1)
179
+
180
+ return latents, pc, token_num, pre_pc
181
+
182
+ def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
183
+ """
184
+
185
+ Args:
186
+ pc (torch.FloatTensor): [B, N, 3]
187
+ feats (torch.FloatTensor or None): [B, N, C]
188
+
189
+ Returns:
190
+ dict
191
+ """
192
+
193
+ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
194
+
195
+
196
+ class CrossAttentionDecoder(nn.Module):
197
+
198
+ def __init__(self, *,
199
+ device: Optional[torch.device],
200
+ dtype: Optional[torch.dtype],
201
+ num_latents: int,
202
+ out_channels: int,
203
+ fourier_embedder: FourierEmbedder,
204
+ width: int,
205
+ heads: int,
206
+ init_scale: float = 0.25,
207
+ qkv_bias: bool = True,
208
+ flash: bool = False,
209
+ use_checkpoint: bool = False,
210
+ mlp_width_scale: int = 4,
211
+ supervision_type: str = 'occupancy'):
212
+
213
+ super().__init__()
214
+
215
+ self.use_checkpoint = use_checkpoint
216
+ self.fourier_embedder = fourier_embedder
217
+ self.supervision_type = supervision_type
218
+
219
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
220
+
221
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
222
+ device=device,
223
+ dtype=dtype,
224
+ n_data=num_latents,
225
+ width=width,
226
+ heads=heads,
227
+ init_scale=init_scale,
228
+ qkv_bias=qkv_bias,
229
+ flash=flash,
230
+ mlp_width_scale=mlp_width_scale,
231
+ )
232
+
233
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
234
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
235
+ if self.supervision_type == 'occupancy-sdf':
236
+ self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype)
237
+
238
+
239
+
240
+ def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
241
+ if next(self.query_proj.parameters()).dtype == torch.float16:
242
+ queries = queries.half()
243
+ latents = latents.half()
244
+ # print(f"queries: {queries.dtype}, {queries.device}")
245
+ # print(f"latents: {latents.dtype}, {latents.device}"z)
246
+ queries = self.query_proj(self.fourier_embedder(queries))
247
+ x = self.cross_attn_decoder(queries, latents)
248
+ x = self.ln_post(x)
249
+ x_1 = self.output_proj(x)
250
+ if self.supervision_type == 'occupancy-sdf':
251
+ x_2 = self.output_proj_sdf(x)
252
+ return x_1, x_2
253
+ else:
254
+ return x_1
255
+
256
+ def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
257
+ return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
258
+
259
+
260
+ class ShapeAsLatentPerceiver(ShapeAsLatentModule):
261
+ def __init__(self, *,
262
+ device: Optional[torch.device],
263
+ dtype: Optional[torch.dtype],
264
+ num_latents: int,
265
+ point_feats: int = 0,
266
+ embed_dim: int = 0,
267
+ num_freqs: int = 8,
268
+ include_pi: bool = True,
269
+ width: int,
270
+ heads: int,
271
+ num_encoder_layers: int,
272
+ num_decoder_layers: int,
273
+ decoder_width: Optional[int] = None,
274
+ init_scale: float = 0.25,
275
+ qkv_bias: bool = True,
276
+ flash: bool = False,
277
+ use_ln_post: bool = False,
278
+ use_checkpoint: bool = False,
279
+ supervision_type: str = 'occupancy',
280
+ query_method: bool = False,
281
+ token_num: int = 256,
282
+ grad_type: str = "numerical",
283
+ grad_interval: float = 0.005,
284
+ use_full_input: bool = True,
285
+ freeze_encoder: bool = False,
286
+ decoder_mlp_width_scale: int = 4,
287
+ residual_kl: bool = False,
288
+ ):
289
+
290
+ super().__init__()
291
+
292
+ self.use_checkpoint = use_checkpoint
293
+
294
+ self.num_latents = num_latents
295
+ assert grad_type in ["numerical", "analytical"]
296
+ self.grad_type = grad_type
297
+ self.grad_interval = grad_interval
298
+ self.supervision_type = supervision_type
299
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
300
+
301
+ init_scale = init_scale * math.sqrt(1.0 / width)
302
+ self.encoder = CrossAttentionEncoder(
303
+ device=device,
304
+ dtype=dtype,
305
+ fourier_embedder=self.fourier_embedder,
306
+ num_latents=num_latents,
307
+ point_feats=point_feats,
308
+ width=width,
309
+ heads=heads,
310
+ layers=num_encoder_layers,
311
+ init_scale=init_scale,
312
+ qkv_bias=qkv_bias,
313
+ flash=flash,
314
+ use_ln_post=use_ln_post,
315
+ use_checkpoint=use_checkpoint,
316
+ query_method=query_method,
317
+ use_full_input=use_full_input,
318
+ token_num=token_num
319
+ )
320
+
321
+ self.embed_dim = embed_dim
322
+ self.residual_kl = residual_kl
323
+ if decoder_width is None:
324
+ decoder_width = width
325
+ if embed_dim > 0:
326
+ # VAE embed
327
+ self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
328
+ self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype)
329
+ self.latent_shape = (num_latents, embed_dim)
330
+ if self.residual_kl:
331
+ assert self.post_kl.out_features % self.post_kl.in_features == 0
332
+ assert self.pre_kl.in_features % self.pre_kl.out_features == 0
333
+ else:
334
+ self.latent_shape = (num_latents, width)
335
+
336
+ print("decoder width = ", decoder_width)
337
+
338
+ self.transformer = Transformer(
339
+ device=device,
340
+ dtype=dtype,
341
+ n_ctx=num_latents,
342
+ width=decoder_width,
343
+ layers=num_decoder_layers,
344
+ heads=heads,
345
+ init_scale=init_scale,
346
+ qkv_bias=qkv_bias,
347
+ flash=flash,
348
+ use_checkpoint=use_checkpoint
349
+ )
350
+
351
+ # geometry decoder
352
+ self.geo_decoder = CrossAttentionDecoder(
353
+ device=device,
354
+ dtype=dtype,
355
+ fourier_embedder=self.fourier_embedder,
356
+ out_channels=1,
357
+ num_latents=num_latents,
358
+ width=decoder_width,
359
+ heads=heads,
360
+ init_scale=init_scale,
361
+ qkv_bias=qkv_bias,
362
+ flash=flash,
363
+ use_checkpoint=use_checkpoint,
364
+ supervision_type=supervision_type,
365
+ mlp_width_scale=decoder_mlp_width_scale
366
+ )
367
+
368
+ if freeze_encoder:
369
+ for p in self.encoder.parameters():
370
+ p.requires_grad = False
371
+ for p in self.pre_kl.parameters():
372
+ p.requires_grad = False
373
+ print("freeze encoder and pre kl")
374
+
375
+ def encode(self,
376
+ pc: torch.FloatTensor,
377
+ feats: Optional[torch.FloatTensor] = None,
378
+ sample_posterior: bool = True):
379
+ """
380
+
381
+ Args:
382
+ pc (torch.FloatTensor): [B, N, 3]
383
+ feats (torch.FloatTensor or None): [B, N, C]
384
+ sample_posterior (bool):
385
+
386
+ Returns:
387
+ latents (torch.FloatTensor)
388
+ center_pos (torch.FloatTensor or None):
389
+ posterior (DiagonalGaussianDistribution or None):
390
+ """
391
+
392
+ latents, center_pos = self.encoder(pc, feats)
393
+
394
+ posterior = None
395
+ if self.embed_dim > 0:
396
+ moments = self.pre_kl(latents)
397
+ if self.residual_kl:
398
+ B, N = latents.shape[:2]
399
+ moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2)
400
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
401
+
402
+ if sample_posterior:
403
+ latents = posterior.sample()
404
+ else:
405
+ latents = posterior.mode()
406
+
407
+ return latents, center_pos, posterior
408
+
409
+ def decode(self, latents: torch.FloatTensor):
410
+ if self.residual_kl:
411
+ latents = latents.repeat_interleave(self.post_kl.out_features // self.post_kl.in_features, dim=-1) + self.post_kl(latents)
412
+ else:
413
+ latents = self.post_kl(latents)
414
+
415
+ return self.transformer(latents)
416
+
417
+ def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor, grad: bool = False):
418
+ # logits = self.geo_decoder(queries, latents).squeeze(-1)
419
+ if grad:
420
+ # with torch.autocast(device_type="cuda", dtype=torch.float32):
421
+ if self.grad_type == "numerical":
422
+ raise NotImplementedError
423
+ interval = self.grad_interval
424
+ # print('grad interval = ', interval)
425
+ grad_value = []
426
+ for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]:
427
+ offset_tensor = torch.tensor(offset, device=queries.device)[None, :]
428
+ res_p = self.geo_decoder(queries + offset_tensor, latents)[..., 0]
429
+ res_n = self.geo_decoder(queries - offset_tensor, latents)[..., 0]
430
+ grad_value.append((res_p - res_n) / (2 * interval))
431
+ grad_value = torch.stack(grad_value, dim=-1)
432
+ else:
433
+ # print("auto grad")
434
+ queries_d = torch.clone(queries)
435
+ queries_d.requires_grad = True
436
+ with torch.enable_grad():
437
+ with use_flash3.disable_flash3():
438
+ logits = self.geo_decoder(queries_d, latents)
439
+ if self.supervision_type == "sigmoid-sdf":
440
+ sdfs = logits_to_sdf(logits)
441
+ grad_value = torch.autograd.grad(sdfs, [queries_d],
442
+ grad_outputs=torch.ones_like(sdfs),
443
+ create_graph=self.geo_decoder.training)[0]
444
+ else:
445
+ logits = self.geo_decoder(queries, latents)
446
+ grad_value = None
447
+
448
+ return logits, grad_value
449
+
450
+ def forward(self,
451
+ pc: torch.FloatTensor,
452
+ feats: torch.FloatTensor,
453
+ volume_queries: torch.FloatTensor,
454
+ sample_posterior: bool = True):
455
+ """
456
+
457
+ Args:
458
+ pc (torch.FloatTensor): [B, N, 3]
459
+ feats (torch.FloatTensor or None): [B, N, C]
460
+ volume_queries (torch.FloatTensor): [B, P, 3]
461
+ sample_posterior (bool):
462
+
463
+ Returns:
464
+ logits (torch.FloatTensor): [B, P]
465
+ center_pos (torch.FloatTensor): [B, M, 3]
466
+ posterior (DiagonalGaussianDistribution or None).
467
+
468
+ """
469
+
470
+ latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
471
+
472
+ latents = self.decode(latents)
473
+ logits = self.query_geometry(volume_queries, latents)
474
+
475
+ return logits, center_pos, posterior
476
+
477
+
478
+ class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
479
+
480
+ def __init__(self, *,
481
+ device: Optional[torch.device],
482
+ dtype: Optional[str],
483
+ num_latents: int,
484
+ point_feats: int = 0,
485
+ embed_dim: int = 0,
486
+ num_freqs: int = 8,
487
+ include_pi: bool = True,
488
+ width: int,
489
+ heads: int,
490
+ num_encoder_layers: int,
491
+ num_decoder_layers: int,
492
+ decoder_width: Optional[int] = None,
493
+ init_scale: float = 0.25,
494
+ qkv_bias: bool = True,
495
+ flash: bool = False,
496
+ use_ln_post: bool = False,
497
+ use_checkpoint: bool = False,
498
+ supervision_type: str = 'occupancy',
499
+ grad_type: str = "numerical",
500
+ grad_interval: float = 0.005,
501
+ query_method: bool = False,
502
+ use_full_input: bool = True,
503
+ token_num: int = 256,
504
+ freeze_encoder: bool = False,
505
+ decoder_mlp_width_scale: int = 4,
506
+ residual_kl: bool = False,
507
+ ):
508
+
509
+ MAP_DTYPE = {
510
+ 'float32': torch.float32,
511
+ 'float16': torch.float16,
512
+ 'bfloat16': torch.bfloat16,
513
+ }
514
+ if dtype is not None:
515
+ dtype = MAP_DTYPE[dtype]
516
+ super().__init__(
517
+ device=device,
518
+ dtype=dtype,
519
+ num_latents=1 + num_latents,
520
+ point_feats=point_feats,
521
+ embed_dim=embed_dim,
522
+ num_freqs=num_freqs,
523
+ include_pi=include_pi,
524
+ width=width,
525
+ decoder_width=decoder_width,
526
+ heads=heads,
527
+ num_encoder_layers=num_encoder_layers,
528
+ num_decoder_layers=num_decoder_layers,
529
+ init_scale=init_scale,
530
+ qkv_bias=qkv_bias,
531
+ flash=flash,
532
+ use_ln_post=use_ln_post,
533
+ use_checkpoint=use_checkpoint,
534
+ supervision_type=supervision_type,
535
+ grad_type=grad_type,
536
+ grad_interval=grad_interval,
537
+ query_method=query_method,
538
+ token_num=token_num,
539
+ use_full_input=use_full_input,
540
+ freeze_encoder=freeze_encoder,
541
+ decoder_mlp_width_scale=decoder_mlp_width_scale,
542
+ residual_kl=residual_kl,
543
+ )
544
+
545
+ self.width = width
546
+
547
+ def encode(self,
548
+ pc: torch.FloatTensor,
549
+ feats: Optional[torch.FloatTensor] = None,
550
+ sample_posterior: bool = True,
551
+ only_shape: bool=False):
552
+ """
553
+
554
+ Args:
555
+ pc (torch.FloatTensor): [B, N, 3]
556
+ feats (torch.FloatTensor or None): [B, N, c]
557
+ sample_posterior (bool):
558
+
559
+ Returns:
560
+ shape_embed (torch.FloatTensor)
561
+ kl_embed (torch.FloatTensor):
562
+ posterior (DiagonalGaussianDistribution or None):
563
+ """
564
+
565
+ shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats)
566
+ if only_shape:
567
+ return shape_embed
568
+ kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
569
+
570
+ return shape_embed, kl_embed, posterior, token_num, pre_pc
571
+
572
+ def encode_latents(self,
573
+ pc: torch.FloatTensor,
574
+ feats: Optional[torch.FloatTensor] = None):
575
+
576
+ x, _, token_num, pre_pc = self.encoder(pc, feats)
577
+
578
+ shape_embed = x[:, 0]
579
+ # latents = x[:, 1:]
580
+ # use all tokens
581
+ latents = x
582
+
583
+ return shape_embed, latents, token_num, pre_pc
584
+
585
+ def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
586
+ posterior = None
587
+ if self.embed_dim > 0:
588
+ moments = self.pre_kl(latents)
589
+ if self.residual_kl:
590
+ B, N = latents.shape[:2]
591
+ moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2)
592
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
593
+
594
+ if sample_posterior:
595
+ kl_embed = posterior.sample()
596
+ else:
597
+ kl_embed = posterior.mode()
598
+ else:
599
+ kl_embed = latents
600
+
601
+ return kl_embed, posterior
602
+
603
+ def forward(self,
604
+ pc: torch.FloatTensor,
605
+ feats: torch.FloatTensor,
606
+ volume_queries: torch.FloatTensor,
607
+ sample_posterior: bool = True):
608
+ """
609
+
610
+ Args:
611
+ pc (torch.FloatTensor): [B, N, 3]
612
+ feats (torch.FloatTensor or None): [B, N, C]
613
+ volume_queries (torch.FloatTensor): [B, P, 3]
614
+ sample_posterior (bool):
615
+
616
+ Returns:
617
+ shape_embed (torch.FloatTensor): [B, projection_dim]
618
+ logits (torch.FloatTensor): [B, M]
619
+ posterior (DiagonalGaussianDistribution or None).
620
+
621
+ """
622
+
623
+ shape_embed, kl_embed, posterior, token_num, pre_pc = self.encode(pc, feats, sample_posterior=sample_posterior)
624
+
625
+ latents = self.decode(kl_embed)
626
+ logits, grad = self.query_geometry(volume_queries, latents)
627
+
628
+ return shape_embed, logits, posterior, token_num, pre_pc, grad
629
+
630
+ #####################################################
631
+ # a simplified verstion of perceiver encoder
632
+ #####################################################
633
+
634
+ class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule):
635
+ def __init__(self, *,
636
+ device: Optional[torch.device],
637
+ dtype: Optional[Union[torch.dtype, str]],
638
+ num_latents: int,
639
+ point_feats: int = 0,
640
+ embed_dim: int = 0,
641
+ num_freqs: int = 8,
642
+ include_pi: bool = True,
643
+ width: int,
644
+ heads: int,
645
+ num_encoder_layers: int,
646
+ init_scale: float = 0.25,
647
+ qkv_bias: bool = True,
648
+ flash: bool = False,
649
+ use_ln_post: bool = False,
650
+ use_checkpoint: bool = False,
651
+ supervision_type: str = 'occupancy',
652
+ query_method: bool = False,
653
+ token_num: int = 256,
654
+ grad_type: str = "numerical",
655
+ grad_interval: float = 0.005,
656
+ use_full_input: bool = True,
657
+ freeze_encoder: bool = False,
658
+ residual_kl: bool = False,
659
+ ):
660
+
661
+ super().__init__()
662
+
663
+
664
+ MAP_DTYPE = {
665
+ 'float32': torch.float32,
666
+ 'float16': torch.float16,
667
+ 'bfloat16': torch.bfloat16,
668
+ }
669
+
670
+ if dtype is not None and isinstance(dtype, str):
671
+ dtype = MAP_DTYPE[dtype]
672
+
673
+ self.use_checkpoint = use_checkpoint
674
+
675
+ self.num_latents = num_latents
676
+ assert grad_type in ["numerical", "analytical"]
677
+ self.grad_type = grad_type
678
+ self.grad_interval = grad_interval
679
+ self.supervision_type = supervision_type
680
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
681
+
682
+ init_scale = init_scale * math.sqrt(1.0 / width)
683
+ self.encoder = CrossAttentionEncoder(
684
+ device=device,
685
+ dtype=dtype,
686
+ fourier_embedder=self.fourier_embedder,
687
+ num_latents=num_latents,
688
+ point_feats=point_feats,
689
+ width=width,
690
+ heads=heads,
691
+ layers=num_encoder_layers,
692
+ init_scale=init_scale,
693
+ qkv_bias=qkv_bias,
694
+ flash=flash,
695
+ use_ln_post=use_ln_post,
696
+ use_checkpoint=use_checkpoint,
697
+ query_method=query_method,
698
+ use_full_input=use_full_input,
699
+ token_num=token_num,
700
+ no_query=True,
701
+ )
702
+
703
+ self.embed_dim = embed_dim
704
+ self.residual_kl = residual_kl
705
+ if freeze_encoder:
706
+ for p in self.encoder.parameters():
707
+ p.requires_grad = False
708
+ print("freeze encoder")
709
+ self.width = width
710
+
711
+ def encode_latents(self,
712
+ pc: torch.FloatTensor,
713
+ feats: Optional[torch.FloatTensor] = None):
714
+
715
+ x, _, token_num, pre_pc = self.encoder(pc, feats)
716
+
717
+ shape_embed = x[:, 0]
718
+ latents = x
719
+
720
+ return shape_embed, latents, token_num, pre_pc
721
+
722
+ def forward(self):
723
+ raise NotImplementedError()
src/model/michelangelo/models/tsal/tsal_base.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch.nn as nn
4
+ from typing import Tuple, List, Optional
5
+ import lightning.pytorch as pl
6
+
7
+
8
+ class Point2MeshOutput(object):
9
+ def __init__(self):
10
+ self.mesh_v = None
11
+ self.mesh_f = None
12
+ self.center = None
13
+ self.pc = None
14
+
15
+
16
+ class Latent2MeshOutput(object):
17
+
18
+ def __init__(self):
19
+ self.mesh_v = None
20
+ self.mesh_f = None
21
+
22
+
23
+ class AlignedMeshOutput(object):
24
+
25
+ def __init__(self):
26
+ self.mesh_v = None
27
+ self.mesh_f = None
28
+ self.surface = None
29
+ self.image = None
30
+ self.text: Optional[str] = None
31
+ self.shape_text_similarity: Optional[float] = None
32
+ self.shape_image_similarity: Optional[float] = None
33
+
34
+
35
+ class ShapeAsLatentPLModule(pl.LightningModule):
36
+ latent_shape: Tuple[int]
37
+
38
+ def encode(self, surface, *args, **kwargs):
39
+ raise NotImplementedError
40
+
41
+ def decode(self, z_q, *args, **kwargs):
42
+ raise NotImplementedError
43
+
44
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
45
+ raise NotImplementedError
46
+
47
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
48
+ raise NotImplementedError
49
+
50
+
51
+ class ShapeAsLatentModule(nn.Module):
52
+ latent_shape: Tuple[int, int]
53
+
54
+ def __init__(self, *args, **kwargs):
55
+ super().__init__()
56
+
57
+ def encode(self, *args, **kwargs):
58
+ raise NotImplementedError
59
+
60
+ def decode(self, *args, **kwargs):
61
+ raise NotImplementedError
62
+
63
+ def query_geometry(self, *args, **kwargs):
64
+ raise NotImplementedError
65
+
66
+
67
+ class AlignedShapeAsLatentPLModule(pl.LightningModule):
68
+ latent_shape: Tuple[int]
69
+
70
+ def set_shape_model_only(self):
71
+ raise NotImplementedError
72
+
73
+ def encode(self, surface, *args, **kwargs):
74
+ raise NotImplementedError
75
+
76
+ def decode(self, z_q, *args, **kwargs):
77
+ raise NotImplementedError
78
+
79
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
80
+ raise NotImplementedError
81
+
82
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
83
+ raise NotImplementedError
84
+
85
+
86
+ class AlignedShapeAsLatentModule(nn.Module):
87
+ shape_model: ShapeAsLatentModule
88
+ latent_shape: Tuple[int, int]
89
+
90
+ def __init__(self, *args, **kwargs):
91
+ super().__init__()
92
+
93
+ def set_shape_model_only(self):
94
+ raise NotImplementedError
95
+
96
+ def encode_image_embed(self, *args, **kwargs):
97
+ raise NotImplementedError
98
+
99
+ def encode_text_embed(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+ def encode_shape_embed(self, *args, **kwargs):
103
+ raise NotImplementedError
104
+
105
+
106
+ class TexturedShapeAsLatentModule(nn.Module):
107
+
108
+ def __init__(self, *args, **kwargs):
109
+ super().__init__()
110
+
111
+ def encode(self, *args, **kwargs):
112
+ raise NotImplementedError
113
+
114
+ def decode(self, *args, **kwargs):
115
+ raise NotImplementedError
116
+
117
+ def query_geometry(self, *args, **kwargs):
118
+ raise NotImplementedError
119
+
120
+ def query_color(self, *args, **kwargs):
121
+ raise NotImplementedError
src/model/michelangelo/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .misc import get_config_from_file
4
+ from .misc import instantiate_from_config
src/model/michelangelo/utils/eval.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
7
+
8
+ mse = torch.mean((x - y) ** 2)
9
+ psnr = 10 * torch.log10(data_range / (mse + eps))
10
+
11
+ return psnr
12
+
src/model/michelangelo/utils/misc.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import importlib
4
+ from omegaconf import OmegaConf, DictConfig, ListConfig
5
+ import time
6
+ import torch
7
+ import torch.distributed as dist
8
+ from typing import Union, Any, Optional
9
+ from collections import defaultdict
10
+ from torch.optim import lr_scheduler
11
+ import os
12
+ from dataclasses import dataclass, field
13
+ from contextlib import contextmanager
14
+
15
+ import logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+
20
+ def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
21
+ return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
22
+
23
+
24
+ OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
25
+ OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
26
+
27
+ @dataclass
28
+ class ExperimentConfig:
29
+ task: str = "vae"
30
+ output_dir: str = "outputs"
31
+ resume: Optional[str] = None
32
+
33
+ data: dict = field(default_factory=dict)
34
+ model: dict = field(default_factory=dict)
35
+
36
+ trainer: dict = field(default_factory=dict)
37
+ checkpoint: dict = field(default_factory=dict)
38
+
39
+ wandb: dict = field(default_factory=dict)
40
+
41
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
42
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
43
+ return scfg
44
+
45
+ def get_config_from_file(config_file: str, cli_args: list = [], **kwargs) -> Union[DictConfig, ListConfig]:
46
+ config_file = OmegaConf.load(config_file)
47
+ cli_conf = OmegaConf.from_cli(cli_args)
48
+
49
+ if 'base_config' in config_file.keys():
50
+ if config_file['base_config'] == "default_base":
51
+ base_config = OmegaConf.create()
52
+ # base_config = get_default_config()
53
+ elif config_file['base_config'].endswith(".yaml"):
54
+ base_config = get_config_from_file(config_file['base_config'])
55
+ else:
56
+ raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
57
+
58
+ config_file = {key: value for key, value in config_file.items() if key != "base_config"}
59
+
60
+ cfg = OmegaConf.merge(base_config, config_file, cli_conf, kwargs)
61
+ else:
62
+ cfg = OmegaConf.merge(config_file, cli_conf, kwargs)
63
+
64
+ scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
65
+
66
+ return scfg
67
+
68
+ def get_obj_from_str(string, reload=False):
69
+ module, cls = string.rsplit(".", 1)
70
+ if reload:
71
+ module_imp = importlib.import_module(module)
72
+ importlib.reload(module_imp)
73
+ return getattr(importlib.import_module(module, package=None), cls)
74
+
75
+
76
+ def get_obj_from_config(config):
77
+ if "target" not in config:
78
+ raise KeyError("Expected key `target` to instantiate.")
79
+
80
+ return get_obj_from_str(config["target"])
81
+
82
+
83
+ def instantiate_from_config(config, **kwargs):
84
+ if "target" not in config:
85
+ raise KeyError("Expected key `target` to instantiate.")
86
+
87
+ cls = get_obj_from_str(config["target"])
88
+
89
+ params = config.get("params", dict())
90
+ # params.update(kwargs)
91
+ # instance = cls(**params)
92
+ kwargs.update(params)
93
+ instance = cls(**kwargs)
94
+
95
+ return instance
96
+
97
+
98
+ def is_dist_avail_and_initialized():
99
+ if not dist.is_available():
100
+ return False
101
+ if not dist.is_initialized():
102
+ return False
103
+ return True
104
+
105
+
106
+ def get_rank():
107
+ if not is_dist_avail_and_initialized():
108
+ return 0
109
+ return dist.get_rank()
110
+
111
+
112
+ def get_world_size():
113
+ if not is_dist_avail_and_initialized():
114
+ return 1
115
+ return dist.get_world_size()
116
+
117
+ def get_free_space(path):
118
+ fs_stats = os.statvfs(path)
119
+ free_space = fs_stats.f_bsize * fs_stats.f_bfree
120
+ return free_space
121
+
122
+ def get_device_type():
123
+ # Returns an empty string when no CUDA device is available so that
124
+ # callers like `FLASH3.__init__` (which only check `"H100" in ...`) can
125
+ # be imported safely on CPU-only / ZeroGPU-main processes without
126
+ # raising "No CUDA GPUs are available".
127
+ try:
128
+ if not torch.cuda.is_available():
129
+ return ""
130
+ return torch.cuda.get_device_name(0)
131
+ except (RuntimeError, AssertionError):
132
+ return ""
133
+
134
+ def get_hostname():
135
+ import socket
136
+ return socket.gethostname()
137
+
138
+ def all_gather_batch(tensors):
139
+ """
140
+ Performs all_gather operation on the provided tensors.
141
+ """
142
+ # Queue the gathered tensors
143
+ world_size = get_world_size()
144
+ # There is no need for reduction in the single-proc case
145
+ if world_size == 1:
146
+ return tensors
147
+ tensor_list = []
148
+ output_tensor = []
149
+ for tensor in tensors:
150
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
151
+ dist.all_gather(
152
+ tensor_all,
153
+ tensor,
154
+ async_op=False # performance opt
155
+ )
156
+
157
+ tensor_list.append(tensor_all)
158
+
159
+ for tensor_all in tensor_list:
160
+ output_tensor.append(torch.cat(tensor_all, dim=0))
161
+ return output_tensor
162
+
163
+ def get_scheduler(name):
164
+ if hasattr(lr_scheduler, name):
165
+ return getattr(lr_scheduler, name)
166
+ else:
167
+ raise NotImplementedError
168
+
169
+ def parse_scheduler(config, optimizer):
170
+ interval = config.get("interval", "epoch")
171
+ assert interval in ["epoch", "step"]
172
+ if config.name == "SequentialLR":
173
+ scheduler = {
174
+ "scheduler": lr_scheduler.SequentialLR(
175
+ optimizer,
176
+ [
177
+ parse_scheduler(conf, optimizer)["scheduler"]
178
+ for conf in config.schedulers
179
+ ],
180
+ milestones=config.milestones,
181
+ ),
182
+ "interval": interval,
183
+ }
184
+ elif config.name == "ChainedScheduler":
185
+ scheduler = {
186
+ "scheduler": lr_scheduler.ChainedScheduler(
187
+ [
188
+ parse_scheduler(conf, optimizer)["scheduler"]
189
+ for conf in config.schedulers
190
+ ]
191
+ ),
192
+ "interval": interval,
193
+ }
194
+ else:
195
+ scheduler = {
196
+ "scheduler": get_scheduler(config.name)(optimizer, **config.args),
197
+ "interval": interval,
198
+ }
199
+ return scheduler
200
+
201
+ class TimeRecorder:
202
+ _instance = None
203
+
204
+ def __init__(self):
205
+ self.items = {}
206
+ self.accumulations = defaultdict(list)
207
+ self.time_scale = 1000.0 # ms
208
+ self.time_unit = "ms"
209
+ self.enabled = False
210
+
211
+ def __new__(cls):
212
+ # singleton
213
+ if cls._instance is None:
214
+ cls._instance = super(TimeRecorder, cls).__new__(cls)
215
+ return cls._instance
216
+
217
+ def enable(self, enabled: bool) -> None:
218
+ self.enabled = enabled
219
+
220
+ def start(self, name: str) -> None:
221
+ if not self.enabled:
222
+ return
223
+ torch.cuda.synchronize()
224
+ self.items[name] = time.time()
225
+
226
+ def end(self, name: str, accumulate: bool = False) -> float:
227
+ if not self.enabled or name not in self.items:
228
+ return
229
+ torch.cuda.synchronize()
230
+ start_time = self.items.pop(name)
231
+ delta = time.time() - start_time
232
+ if accumulate:
233
+ self.accumulations[name].append(delta)
234
+ t = delta * self.time_scale
235
+ logger.info(f"{name}: {t:.2f}{self.time_unit}")
236
+
237
+ def get_accumulation(self, name: str, average: bool = False) -> float:
238
+ if not self.enabled or name not in self.accumulations:
239
+ return
240
+ acc = self.accumulations.pop(name)
241
+ total = sum(acc)
242
+ if average:
243
+ t = total / len(acc) * self.time_scale
244
+ else:
245
+ t = total * self.time_scale
246
+ logger.info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
247
+
248
+
249
+ ### global time recorder
250
+ time_recorder = TimeRecorder()
251
+
252
+ class FLASH3:
253
+ def __init__(self) -> None:
254
+ self.available = "H100" in get_device_type()
255
+ self.use = os.environ.get("USE_FLASH3", False)
256
+
257
+ @property
258
+ def is_use(self):
259
+ return self.available and self.use
260
+
261
+ @contextmanager
262
+ def disable_flash3(self):
263
+ use = self.use
264
+ self.set_use(False)
265
+ yield
266
+ self.set_use(use)
267
+
268
+ def set_use(self, use=True):
269
+ self.use = use
270
+
271
+ use_flash3 = FLASH3()
src/model/parse_encoder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+
4
+ from .michelangelo.get_model import get_encoder as get_encoder_michelangelo
5
+ from .michelangelo.get_model import AlignedShapeLatentPerceiver
6
+ from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder
7
+ from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder
8
+ from .skin_vae.autoencoders.autoencoder_kl_tripo2 import Tripo2Encoder
9
+
10
+ @dataclass(frozen=True)
11
+ class _MAP_MESH_ENCODER:
12
+ michelangelo = AlignedShapeLatentPerceiver
13
+ michelangelo_encoder = ShapeAsLatentPerceiverEncoder
14
+ tripo = Tripo2Encoder
15
+
16
+ MAP_MESH_ENCODER = _MAP_MESH_ENCODER()
17
+
18
+
19
+ def get_mesh_encoder(**kwargs):
20
+ MAP = {
21
+ 'michelangelo': get_encoder_michelangelo,
22
+ 'michelangelo_encoder': get_encoder_michelangelo_encoder,
23
+ 'tripo': Tripo2Encoder,
24
+ }
25
+ __target__ = kwargs['__target__']
26
+ del kwargs['__target__']
27
+ assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
28
+ return MAP[__target__](**deepcopy(kwargs))
src/model/skin_vae/attention_processor.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention_processor import Attention
8
+ from diffusers.utils import logging
9
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
10
+ from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
11
+ from torch import nn
12
+
13
+ try:
14
+ from flash_attn_interface import flash_attn_func
15
+ except Exception as e:
16
+ def flash_attn_func(q, k, v):
17
+ q = q.permute(0, 2, 1, 3) # (B, H, L, D)
18
+ k = k.permute(0, 2, 1, 3)
19
+ v = v.permute(0, 2, 1, 3)
20
+
21
+ if q.shape[1] != k.shape[1]:
22
+ repeat_factor = q.shape[1] // k.shape[1]
23
+ k = k.repeat_interleave(repeat_factor, dim=1)
24
+ v = v.repeat_interleave(repeat_factor, dim=1)
25
+
26
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
27
+ return out.permute(0, 2, 1, 3), None # (B, L, H, D)
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class Tripo2AttnProcessor2_0:
33
+ r"""
34
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
35
+ used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
36
+ """
37
+
38
+ def __init__(self):
39
+ if not hasattr(F, "scaled_dot_product_attention"):
40
+ raise ImportError(
41
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
42
+ )
43
+
44
+ def __call__(
45
+ self,
46
+ attn: Attention,
47
+ hidden_states: torch.Tensor,
48
+ encoder_hidden_states: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ temb: Optional[torch.Tensor] = None,
51
+ image_rotary_emb: Optional[torch.Tensor] = None,
52
+ ) -> torch.Tensor:
53
+ from diffusers.models.embeddings import apply_rotary_emb
54
+
55
+ residual = hidden_states
56
+ if attn.spatial_norm is not None:
57
+ hidden_states = attn.spatial_norm(hidden_states, temb)
58
+
59
+ input_ndim = hidden_states.ndim
60
+
61
+ if input_ndim == 4:
62
+ batch_size, channel, height, width = hidden_states.shape
63
+ hidden_states = hidden_states.view(
64
+ batch_size, channel, height * width
65
+ ).transpose(1, 2)
66
+
67
+ batch_size, sequence_length, _ = (
68
+ hidden_states.shape
69
+ if encoder_hidden_states is None
70
+ else encoder_hidden_states.shape
71
+ )
72
+
73
+ if attention_mask is not None:
74
+ attention_mask = attn.prepare_attention_mask(
75
+ attention_mask, sequence_length, batch_size
76
+ )
77
+ # scaled_dot_product_attention expects attention_mask shape to be
78
+ # (batch, heads, source_length, target_length)
79
+ attention_mask = attention_mask.view(
80
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
81
+ )
82
+
83
+ if attn.group_norm is not None:
84
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
85
+ 1, 2
86
+ )
87
+
88
+ query = attn.to_q(hidden_states)
89
+
90
+ if encoder_hidden_states is None:
91
+ encoder_hidden_states = hidden_states
92
+ elif attn.norm_cross:
93
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
94
+ encoder_hidden_states
95
+ )
96
+
97
+ key = attn.to_k(encoder_hidden_states)
98
+ value = attn.to_v(encoder_hidden_states)
99
+
100
+ # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
101
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
102
+ if not attn.is_cross_attention:
103
+ qkv = torch.cat((query, key, value), dim=-1)
104
+ split_size = qkv.shape[-1] // attn.heads // 3
105
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
106
+ query, key, value = torch.split(qkv, split_size, dim=-1)
107
+ else:
108
+ kv = torch.cat((key, value), dim=-1)
109
+ split_size = kv.shape[-1] // attn.heads // 2
110
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
111
+ key, value = torch.split(kv, split_size, dim=-1)
112
+
113
+ head_dim = key.shape[-1]
114
+
115
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+
117
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
118
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
119
+
120
+ if attn.norm_q is not None:
121
+ query = attn.norm_q(query)
122
+ if attn.norm_k is not None:
123
+ key = attn.norm_k(key)
124
+
125
+ # Apply RoPE if needed
126
+ if image_rotary_emb is not None:
127
+ query = apply_rotary_emb(query, image_rotary_emb)
128
+ if not attn.is_cross_attention:
129
+ key = apply_rotary_emb(key, image_rotary_emb)
130
+
131
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
132
+ # TODO: add support for attn.scale when we move to Torch 2.1
133
+
134
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
135
+ hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2))
136
+ if type(hidden_states) == tuple:
137
+ hidden_states = hidden_states[0]
138
+ # hidden_states = F.scaled_dot_product_attention(
139
+ # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
140
+ # )
141
+ #hidden_states =
142
+ hidden_states = hidden_states.reshape(
143
+ batch_size, -1, attn.heads * head_dim
144
+ )
145
+ hidden_states = hidden_states.to(query.dtype)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ if input_ndim == 4:
153
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
154
+ batch_size, channel, height, width
155
+ )
156
+
157
+ if attn.residual_connection:
158
+ hidden_states = hidden_states + residual
159
+
160
+ hidden_states = hidden_states / attn.rescale_output_factor
161
+
162
+ return hidden_states
163
+
164
+
165
+ class FusedTripo2AttnProcessor2_0:
166
+ r"""
167
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
168
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
169
+ query and key vector.
170
+ """
171
+
172
+ def __init__(self):
173
+ if not hasattr(F, "scaled_dot_product_attention"):
174
+ raise ImportError(
175
+ "FusedTripo2AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
176
+ )
177
+
178
+ def __call__(
179
+ self,
180
+ attn: Attention,
181
+ hidden_states: torch.Tensor,
182
+ encoder_hidden_states: Optional[torch.Tensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None,
184
+ temb: Optional[torch.Tensor] = None,
185
+ image_rotary_emb: Optional[torch.Tensor] = None,
186
+ ) -> torch.Tensor:
187
+ from diffusers.models.embeddings import apply_rotary_emb
188
+
189
+ residual = hidden_states
190
+ if attn.spatial_norm is not None:
191
+ hidden_states = attn.spatial_norm(hidden_states, temb)
192
+
193
+ input_ndim = hidden_states.ndim
194
+
195
+ if input_ndim == 4:
196
+ batch_size, channel, height, width = hidden_states.shape
197
+ hidden_states = hidden_states.view(
198
+ batch_size, channel, height * width
199
+ ).transpose(1, 2)
200
+
201
+ batch_size, sequence_length, _ = (
202
+ hidden_states.shape
203
+ if encoder_hidden_states is None
204
+ else encoder_hidden_states.shape
205
+ )
206
+
207
+ if attention_mask is not None:
208
+ attention_mask = attn.prepare_attention_mask(
209
+ attention_mask, sequence_length, batch_size
210
+ )
211
+ # scaled_dot_product_attention expects attention_mask shape to be
212
+ # (batch, heads, source_length, target_length)
213
+ attention_mask = attention_mask.view(
214
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
215
+ )
216
+
217
+ if attn.group_norm is not None:
218
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
219
+ 1, 2
220
+ )
221
+
222
+ # NOTE that tripo2 split heads first, then split qkv
223
+ if encoder_hidden_states is None:
224
+ qkv = attn.to_qkv(hidden_states)
225
+ split_size = qkv.shape[-1] // attn.heads // 3
226
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
227
+ query, key, value = torch.split(qkv, split_size, dim=-1)
228
+ else:
229
+ if attn.norm_cross:
230
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
231
+ encoder_hidden_states
232
+ )
233
+ query = attn.to_q(hidden_states)
234
+
235
+ kv = attn.to_kv(encoder_hidden_states)
236
+ split_size = kv.shape[-1] // attn.heads // 2
237
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
238
+ key, value = torch.split(kv, split_size, dim=-1)
239
+
240
+ head_dim = key.shape[-1]
241
+
242
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
243
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
244
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
245
+
246
+ if attn.norm_q is not None:
247
+ query = attn.norm_q(query)
248
+ if attn.norm_k is not None:
249
+ key = attn.norm_k(key)
250
+
251
+ # Apply RoPE if needed
252
+ if image_rotary_emb is not None:
253
+ query = apply_rotary_emb(query, image_rotary_emb)
254
+ if not attn.is_cross_attention:
255
+ key = apply_rotary_emb(key, image_rotary_emb)
256
+
257
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
258
+ # TODO: add support for attn.scale when we move to Torch 2.1
259
+ hidden_states = F.scaled_dot_product_attention(
260
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
261
+ )
262
+
263
+ hidden_states = hidden_states.transpose(1, 2).reshape(
264
+ batch_size, -1, attn.heads * head_dim
265
+ )
266
+ hidden_states = hidden_states.to(query.dtype)
267
+
268
+ # linear proj
269
+ hidden_states = attn.to_out[0](hidden_states)
270
+ # dropout
271
+ hidden_states = attn.to_out[1](hidden_states)
272
+
273
+ if input_ndim == 4:
274
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
275
+ batch_size, channel, height, width
276
+ )
277
+
278
+ if attn.residual_connection:
279
+ hidden_states = hidden_states + residual
280
+
281
+ hidden_states = hidden_states / attn.rescale_output_factor
282
+
283
+ return hidden_states
src/model/skin_vae/autoencoders/FSQ.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from functools import wraps, partial
3
+ from contextlib import nullcontext
4
+ from typing import List, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import Module
9
+ from torch import Tensor, int32
10
+ from torch.amp import autocast
11
+
12
+ from einops import rearrange, pack, unpack
13
+
14
+ # helper functions
15
+
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(*args):
20
+ for arg in args:
21
+ if exists(arg):
22
+ return arg
23
+ return None
24
+
25
+ def maybe(fn):
26
+ @wraps(fn)
27
+ def inner(x, *args, **kwargs):
28
+ if not exists(x):
29
+ return x
30
+ return fn(x, *args, **kwargs)
31
+ return inner
32
+
33
+ def pack_one(t, pattern):
34
+ return pack([t], pattern)
35
+
36
+ def unpack_one(t, ps, pattern):
37
+ return unpack(t, ps, pattern)[0]
38
+
39
+ # tensor helpers
40
+
41
+ def round_ste(z: Tensor) -> Tensor:
42
+ """Round with straight through gradients."""
43
+ zhat = z.round()
44
+ return z + (zhat - z).detach()
45
+
46
+ # main class
47
+
48
+ class FSQ(Module):
49
+ def __init__(
50
+ self,
51
+ levels: List[int],
52
+ dim: int | None = None,
53
+ num_codebooks = 1,
54
+ keep_num_codebooks_dim: bool | None = None,
55
+ scale: float | None = None,
56
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
57
+ channel_first: bool = False,
58
+ projection_has_bias: bool = True,
59
+ return_indices = True,
60
+ force_quantization_f32 = True
61
+ ):
62
+ super().__init__()
63
+ _levels = torch.tensor(levels, dtype=int32)
64
+ self.register_buffer("_levels", _levels, persistent = False)
65
+
66
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
67
+ self.register_buffer("_basis", _basis, persistent = False)
68
+
69
+ self.scale = scale
70
+
71
+ codebook_dim = len(levels)
72
+ self.codebook_dim = codebook_dim
73
+
74
+ effective_codebook_dim = codebook_dim * num_codebooks
75
+ self.num_codebooks = num_codebooks
76
+ self.effective_codebook_dim = effective_codebook_dim
77
+
78
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
79
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
80
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
81
+
82
+ self.dim = default(dim, len(_levels) * num_codebooks)
83
+
84
+ self.channel_first = channel_first
85
+
86
+ has_projections = self.dim != effective_codebook_dim
87
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
88
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
89
+
90
+ self.has_projections = has_projections
91
+
92
+ self.return_indices = return_indices
93
+ if return_indices:
94
+ self.codebook_size: int = self._levels.prod().item()
95
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
96
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
97
+
98
+ self.allowed_dtypes = allowed_dtypes
99
+ self.force_quantization_f32 = force_quantization_f32
100
+
101
+ def bound(self, z, eps: float = 1e-3):
102
+ """ Bound `z`, an array of shape (..., d). """
103
+ half_l = (self._levels - 1) * (1 + eps) / 2
104
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
105
+ shift = (offset / half_l).atanh()
106
+ return (z + shift).tanh() * half_l - offset
107
+
108
+ def quantize(self, z):
109
+ """ Quantizes z, returns quantized zhat, same shape as z. """
110
+ quantized = round_ste(self.bound(z))
111
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
112
+ return quantized / half_width
113
+
114
+ def _scale_and_shift(self, zhat_normalized):
115
+ half_width = self._levels // 2
116
+ return (zhat_normalized * half_width) + half_width
117
+
118
+ def _scale_and_shift_inverse(self, zhat):
119
+ half_width = self._levels // 2
120
+ return (zhat - half_width) / half_width
121
+
122
+ def _indices_to_codes(self, indices):
123
+ level_indices = self.indices_to_level_indices(indices)
124
+ codes = self._scale_and_shift_inverse(level_indices)
125
+ return codes
126
+
127
+ def codes_to_indices(self, zhat):
128
+ """ Converts a `code` to an index in the codebook. """
129
+ assert zhat.shape[-1] == self.codebook_dim
130
+ zhat = self._scale_and_shift(zhat)
131
+ return (zhat * self._basis).sum(dim=-1).to(int32)
132
+
133
+ def indices_to_level_indices(self, indices):
134
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
135
+ indices = rearrange(indices, '... -> ... 1')
136
+ codes_non_centered = (indices // self._basis) % self._levels
137
+ return codes_non_centered
138
+
139
+ def indices_to_codes(self, indices):
140
+ """ Inverse of `codes_to_indices`. """
141
+ assert exists(indices)
142
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
143
+ codes = self._indices_to_codes(indices)
144
+ if self.keep_num_codebooks_dim:
145
+ codes = rearrange(codes, '... c d -> ... (c d)')
146
+ codes = self.project_out(codes)
147
+ if is_img_or_video or self.channel_first:
148
+ codes = rearrange(codes, 'b ... d -> b d ...')
149
+ return codes
150
+
151
+ def dequantize(self, indices):
152
+ codes = self._indices_to_codes(indices)
153
+ out = self.project_out(codes)
154
+ return out
155
+
156
+ def forward(self, z):
157
+ """
158
+ einstein notation
159
+ b - batch
160
+ n - sequence (or flattened spatial dimensions)
161
+ d - feature dimension
162
+ c - number of codebook dim
163
+ """
164
+ assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
165
+
166
+ z = self.project_in(z)
167
+ z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
168
+
169
+ # whether to force quantization step to be full precision or not
170
+ force_f32 = self.force_quantization_f32
171
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
172
+
173
+ with quantization_context():
174
+ orig_dtype = z.dtype
175
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
176
+ z = z.float()
177
+ codes = self.quantize(z)
178
+ # returning indices could be optional
179
+ indices = None
180
+ if self.return_indices:
181
+ indices = self.codes_to_indices(codes)
182
+ codes = rearrange(codes, 'b n c d -> b n (c d)')
183
+ codes = codes.type(orig_dtype)
184
+
185
+ # project out
186
+ out = self.project_out(codes)
187
+
188
+ if not self.keep_num_codebooks_dim and self.return_indices:
189
+ indices = maybe(rearrange)(indices, '... 1 -> ...')
190
+ # return quantized output and indices
191
+ return out, indices, None
src/model/skin_vae/autoencoders/SimVQ.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from einops import rearrange
6
+
7
+ class SimVQ(nn.Module):
8
+ """
9
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
10
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
11
+ """
12
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
13
+ # backwards compatibility we use the buggy version by default, but you can
14
+ # specify legacy=False to fix it.
15
+ def __init__(self, n_e, e_dim, beta=0.25, remap=None, unknown_index="random",
16
+ same_index_shape=False, legacy=True):
17
+ super().__init__()
18
+ self.n_e = n_e
19
+ self.e_dim = e_dim
20
+ self.beta = beta
21
+ self.legacy = legacy
22
+
23
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
24
+ self.codebook_size = self.n_e
25
+ nn.init.normal_(self.embedding.weight, mean=0, std=self.e_dim**-0.5)
26
+ for p in self.embedding.parameters():
27
+ p.requires_grad = False
28
+
29
+ self.embedding_proj = nn.Linear(self.e_dim, self.e_dim)
30
+
31
+ self.remap = remap
32
+ if self.remap is not None:
33
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
34
+ self.re_embed = self.used.shape[0]
35
+ self.unknown_index = unknown_index # "random" or "extra" or integer
36
+ if self.unknown_index == "extra":
37
+ self.unknown_index = self.re_embed
38
+ self.re_embed = self.re_embed+1
39
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
40
+ f"Using {self.unknown_index} for unknown indices.")
41
+ else:
42
+ self.re_embed = n_e
43
+
44
+ self.same_index_shape = same_index_shape
45
+
46
+ def remap_to_used(self, inds):
47
+ ishape = inds.shape
48
+ assert len(ishape)>1
49
+ inds = inds.reshape(ishape[0],-1)
50
+ used = self.used.to(inds)
51
+ match = (inds[:,:,None]==used[None,None,...]).long()
52
+ new = match.argmax(-1)
53
+ unknown = match.sum(2)<1
54
+ if self.unknown_index == "random":
55
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
56
+ else:
57
+ new[unknown] = self.unknown_index
58
+ return new.reshape(ishape)
59
+
60
+ def unmap_to_all(self, inds):
61
+ ishape = inds.shape
62
+ assert len(ishape)>1
63
+ inds = inds.reshape(ishape[0],-1)
64
+ used = self.used.to(inds)
65
+ if self.re_embed > self.used.shape[0]: # extra token
66
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
67
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
68
+ return back.reshape(ishape)
69
+
70
+ def forward(self, z):
71
+ # reshape z -> (batch, height, width, channel) and flatten
72
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
73
+ assert z.shape[-1] == self.e_dim
74
+ z_flattened = z.view(-1, self.e_dim)
75
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
76
+
77
+ quant_codebook = self.embedding_proj(self.embedding.weight)
78
+
79
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
80
+ torch.sum(quant_codebook**2, dim=1) - 2 * \
81
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n'))
82
+
83
+ min_encoding_indices = torch.argmin(d, dim=1)
84
+ z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape)
85
+
86
+ # compute loss for embedding
87
+ if not self.legacy:
88
+ quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
89
+ torch.mean((z_q - z.detach()) ** 2)
90
+ else:
91
+ quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
92
+ torch.mean((z_q - z.detach()) ** 2)
93
+
94
+ # preserve gradients
95
+ z_q = z + (z_q - z).detach()
96
+
97
+ # reshape back to match original input shape
98
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
99
+
100
+ if self.remap is not None:
101
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
102
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
103
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
104
+
105
+ if self.same_index_shape:
106
+ min_encoding_indices = min_encoding_indices.reshape(
107
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
108
+
109
+ return z_q, min_encoding_indices, quantization_loss
110
+
111
+ def get_codebook_entry(self, indices, shape):
112
+ # shape specifying (batch, height, width, channel)
113
+ if self.remap is not None:
114
+ indices = indices.reshape(shape[0],-1) # add batch axis
115
+ indices = self.unmap_to_all(indices)
116
+ indices = indices.reshape(-1) # flatten again
117
+
118
+ # get quantized latent vectors
119
+ z_q = self.embedding(indices)
120
+
121
+ if shape is not None:
122
+ z_q = z_q.view(shape)
123
+ # reshape back to match original input shape
124
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
125
+
126
+ return z_q
127
+
128
+ def indices_to_codes(self, indices):
129
+ return self.get_codebook_entry(indices, None)
130
+
131
+ def entropy(prob):
132
+ return (-prob * log(prob)).sum(dim=-1)
133
+
134
+ class SimVQ1D(SimVQ):
135
+
136
+ def __init__(self, n_e, e_dim, dim, beta=0.25, remap=None, unknown_index="random", same_index_shape=True, legacy=True):
137
+ super().__init__(n_e, e_dim, beta, remap, unknown_index, same_index_shape, legacy)
138
+
139
+ self.project_in = nn.Linear(dim, e_dim)
140
+ self.project_out = nn.Linear(e_dim, dim)
141
+
142
+ def forward(self, z):
143
+ # reshape z -> (batch, height, width, channel) and flatten
144
+ #assert z.shape[-1] == self.e_dim
145
+ z = self.project_in(z)
146
+
147
+ z_flattened = z.view(-1, self.e_dim)
148
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
149
+
150
+ quant_codebook = self.embedding_proj(self.embedding.weight)
151
+
152
+ # # Use IBQ
153
+ # logits = torch.matmul(z_flattened, quant_codebook.t())
154
+ # Ind_soft = torch.softmax(logits, dim=1)
155
+ # indices = torch.argmax(Ind_soft, dim=1)
156
+ # Ind_hard = F.one_hot(indices, num_classes=Ind_soft.shape[1])
157
+ # Ind = Ind_hard - Ind_soft.detach() + Ind_soft
158
+ # z_q = torch.matmul(Ind, quant_codebook).view(z.shape)
159
+
160
+ # if not self.legacy:
161
+ # quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
162
+ # torch.mean((z_q - z.detach()) ** 2)
163
+ # else:
164
+ # quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
165
+ # torch.mean((z_q - z.detach()) ** 2)
166
+
167
+ # return z_q, indices, quantization_loss
168
+
169
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
170
+ torch.sum(quant_codebook**2, dim=1) - 2 * \
171
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n'))
172
+
173
+ min_encoding_indices = torch.argmin(d, dim=1)
174
+ z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape)
175
+
176
+ # compute loss for embedding
177
+ if not self.legacy:
178
+ quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
179
+ torch.mean((z_q - z.detach()) ** 2)
180
+ else:
181
+ quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
182
+ torch.mean((z_q - z.detach()) ** 2)
183
+
184
+ # preserve gradients
185
+ z_q = z + (z_q - z).detach()
186
+
187
+ if self.remap is not None:
188
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
189
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
190
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
191
+
192
+ if self.same_index_shape:
193
+ min_encoding_indices = min_encoding_indices.view(z.shape[0], z.shape[1])
194
+ z_q = self.project_out(z_q.view(z.shape))
195
+
196
+ return z_q, min_encoding_indices, quantization_loss
197
+
src/model/skin_vae/autoencoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .skin_fsq_cvae_model import SkinFSQCVAEModel
src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.models.normalization import LayerNorm
7
+ from diffusers.utils import logging
8
+ from einops import repeat
9
+ import math
10
+
11
+ from ..embeddings import FrequencyPositionalEmbedding
12
+ from ..transformers.tripo2_transformer import DiTBlock
13
+ from ...utils import fps
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+ def init_linear(l, stddev):
18
+ nn.init.normal_(l.weight, std=stddev)
19
+ if l.bias is not None:
20
+ nn.init.constant_(l.bias, 0.0)
21
+
22
+ class Tripo2Encoder(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 3,
26
+ dim: int = 512,
27
+ num_attention_heads: int = 8,
28
+ num_layers: int = 8,
29
+ is_learned_queries: bool = False,
30
+ sample_tokens: int = 32,
31
+ embed_frequency: int = 8,
32
+ embed_include_pi: bool = False,
33
+ fps: bool = False,
34
+ is_miche: bool = False,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.fps = fps
39
+ if fps and not is_learned_queries:
40
+ self.embedder = FrequencyPositionalEmbedding(
41
+ num_freqs=embed_frequency,
42
+ logspace=True,
43
+ input_dim=3,
44
+ include_pi=embed_include_pi,
45
+ )
46
+ self.proj_k = nn.Linear(3+self.embedder.out_dim, dim, bias=True)
47
+ self.proj_in = nn.Linear(in_channels-3+self.embedder.out_dim, dim, bias=True)
48
+ else:
49
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
50
+ self.output_channels = dim
51
+ self.is_miche = is_miche
52
+ init_scale = 0.25 * math.sqrt(1.0 / dim)
53
+ init_linear(self.proj_in, init_scale)
54
+
55
+ self.blocks = nn.ModuleList(
56
+ [
57
+ DiTBlock(
58
+ dim=dim,
59
+ num_attention_heads=num_attention_heads,
60
+ use_self_attention=False,
61
+ use_cross_attention=True,
62
+ cross_attention_dim=dim,
63
+ cross_attention_norm_type="layer_norm",
64
+ activation_fn="gelu",
65
+ norm_type="fp32_layer_norm",
66
+ norm_eps=1e-5,
67
+ qk_norm=False,
68
+ qkv_bias=False,
69
+ ) # cross attention
70
+ ]
71
+ + [
72
+ DiTBlock(
73
+ dim=dim,
74
+ num_attention_heads=num_attention_heads,
75
+ use_self_attention=True,
76
+ self_attention_norm_type="fp32_layer_norm",
77
+ use_cross_attention=False,
78
+ use_cross_attention_2=False,
79
+ activation_fn="gelu",
80
+ norm_type="fp32_layer_norm",
81
+ norm_eps=1e-5,
82
+ qk_norm=False,
83
+ qkv_bias=False,
84
+ )
85
+ for _ in range(num_layers) # self attention
86
+ ]
87
+ )
88
+ self.norm_out = LayerNorm(dim)
89
+ self.is_learned_queries = is_learned_queries
90
+ if is_learned_queries:
91
+ self.learned_queries = nn.Parameter(torch.randn(sample_tokens, dim) * 0.02)
92
+
93
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor, num_tokens: int=1024):
94
+ if self.is_learned_queries or not self.fps:
95
+ hidden_states = self.proj_in(sample_1) if not self.is_learned_queries else repeat(self.learned_queries[:sample_1.shape[1], :], 'n d -> b n d', b=sample_1.shape[0])
96
+ encoder_hidden_states = self.proj_in(sample_2)
97
+ else:
98
+ x_q, x_kv = self.get_qkv(x=sample_1, num_tokens=num_tokens)
99
+ hidden_states = self.proj_k(x_q)
100
+ encoder_hidden_states = self.proj_in(x_kv)
101
+
102
+ if not self.is_miche:
103
+ for layer, block in enumerate(self.blocks):
104
+ if layer == 0:
105
+ hidden_states = block(
106
+ hidden_states, encoder_hidden_states=encoder_hidden_states
107
+ )
108
+ else:
109
+ hidden_states = block(hidden_states)
110
+ else:
111
+ for layer, block in enumerate(self.blocks):
112
+ if layer == 0:
113
+ hidden_states = block(hidden_states, encoder_hidden_states)
114
+ else:
115
+ hidden_states = block(hidden_states)
116
+
117
+ hidden_states = self.norm_out(hidden_states)
118
+
119
+ return hidden_states
120
+
121
+ def _sample_features(
122
+ self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None
123
+ ):
124
+ """
125
+ Sample points from features of the input point cloud.
126
+
127
+ Args:
128
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
129
+ num_tokens (int, optional): The number of points to sample. Defaults to 1024.
130
+ seed (Optional[int], optional): The random seed. Defaults to None.
131
+ """
132
+ rng = np.random.default_rng(seed)
133
+ indices = rng.choice(
134
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
135
+ )
136
+ selected_points = x[:, indices]
137
+
138
+ batch_size, num_points, num_channels = selected_points.shape
139
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
140
+ batch_indices = (
141
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
142
+ )
143
+
144
+ # fps sampling
145
+ sampling_ratio = 1.0 / 4
146
+ sampled_indices = fps(
147
+ flattened_points[:, :3],
148
+ batch_indices,
149
+ ratio=sampling_ratio,
150
+ random_start=self.training,
151
+ )
152
+ sampled_points = flattened_points[sampled_indices].view(
153
+ batch_size, -1, num_channels
154
+ )
155
+
156
+ return sampled_points
157
+
158
+ def get_qkv(self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None):
159
+ positions, features = x[..., :3], x[..., 3:]
160
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
161
+
162
+ sampled_x = self._sample_features(x, num_tokens, seed)
163
+ positions, features = (
164
+ sampled_x[..., :3],
165
+ sampled_x[..., 3:],
166
+ )
167
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
168
+ return x_q, x_kv
169
+
170
+
171
+ class Tripo2Decoder(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int = 3,
175
+ out_channels: int = 1,
176
+ dim: int = 512,
177
+ num_attention_heads: int = 8,
178
+ num_layers: int = 16,
179
+ grad_type: str = "analytical",
180
+ grad_interval: float = 0.001,
181
+ is_miche: bool = False,
182
+ ):
183
+ super().__init__()
184
+
185
+ if grad_type not in ["numerical", "analytical"]:
186
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
187
+ self.grad_type = grad_type
188
+ self.grad_interval = grad_interval
189
+ self.is_miche = is_miche
190
+
191
+ self.blocks = nn.ModuleList(
192
+ [
193
+ DiTBlock(
194
+ dim=dim,
195
+ num_attention_heads=num_attention_heads,
196
+ use_self_attention=True,
197
+ self_attention_norm_type="fp32_layer_norm",
198
+ use_cross_attention=False,
199
+ use_cross_attention_2=False,
200
+ activation_fn="gelu",
201
+ norm_type="fp32_layer_norm",
202
+ norm_eps=1e-5,
203
+ qk_norm=False,
204
+ qkv_bias=False,
205
+ )
206
+ for _ in range(num_layers) # self attention
207
+ ]
208
+ + [
209
+ DiTBlock(
210
+ dim=dim,
211
+ num_attention_heads=num_attention_heads,
212
+ use_self_attention=False,
213
+ use_cross_attention=True,
214
+ cross_attention_dim=dim,
215
+ cross_attention_norm_type="layer_norm",
216
+ activation_fn="gelu",
217
+ norm_type="fp32_layer_norm",
218
+ norm_eps=1e-5,
219
+ qk_norm=False,
220
+ qkv_bias=False,
221
+ ) # cross attention
222
+ ]
223
+ )
224
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
225
+
226
+ self.norm_out = LayerNorm(dim)
227
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
228
+ self.sigmoid = nn.Sigmoid()
229
+ init_scale = 0.25 * math.sqrt(1.0 / dim)
230
+ init_linear(self.proj_query, init_scale)
231
+ init_linear(self.proj_out, init_scale)
232
+
233
+ def forward(
234
+ self,
235
+ sample: torch.Tensor,
236
+ queries: torch.Tensor,
237
+ kv_cache: Optional[torch.Tensor] = None,
238
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
239
+ if kv_cache is None:
240
+ hidden_states = sample
241
+ for _, block in enumerate(self.blocks[:-1]):
242
+ hidden_states = block(hidden_states)
243
+ kv_cache = hidden_states
244
+ # query grid logits by cross attention
245
+ q = self.proj_query(queries)
246
+ if self.is_miche:
247
+ l = self.blocks[-1](q, kv_cache)
248
+ else:
249
+ l = self.blocks[-1](q, encoder_hidden_states=kv_cache)
250
+ logits = self.proj_out(self.norm_out(l))
251
+
252
+ logits = self.sigmoid(logits)
253
+ assert kv_cache is not None
254
+ return logits, kv_cache
src/model/skin_vae/autoencoders/get_model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .skin_cvae_model import SkinCVAEModel
2
+ from .skin_fsq_cvae_model import SkinFSQCVAEModel
3
+
4
+ def get_model_cvae(
5
+ pretrained_path: str=None,
6
+ **kwargs
7
+ ) -> SkinCVAEModel:
8
+ model = SkinCVAEModel(**kwargs)
9
+ if pretrained_path is not None:
10
+ state_dict = torch.load(pretrained_path, weights_only=True)
11
+ model.load_state_dict(state_dict)
12
+ return model
13
+
14
+ def get_model_fsq_cvae(
15
+ pretrained_path: str=None,
16
+ **kwargs
17
+ ) -> SkinFSQCVAEModel:
18
+ model = SkinFSQCVAEModel(**kwargs)
19
+ if pretrained_path is not None:
20
+ state_dict = torch.load(pretrained_path, weights_only=True)
21
+ model.load_state_dict(state_dict)
22
+ return model
src/model/skin_vae/autoencoders/miche_transformer_blocks.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional
8
+ import os
9
+
10
+ # -*- coding: utf-8 -*-
11
+ """
12
+ Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
13
+ """
14
+
15
+ import torch
16
+ from typing import Callable, Iterable, Sequence, Union
17
+
18
+
19
+ def checkpoint(
20
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
21
+ inputs: Sequence[torch.Tensor],
22
+ params: Iterable[torch.Tensor],
23
+ flag: bool,
24
+ use_deepspeed: bool = False
25
+ ):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ :param use_deepspeed: if True, use deepspeed
35
+ """
36
+ if flag:
37
+ if use_deepspeed:
38
+ import deepspeed
39
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
40
+
41
+ args = tuple(inputs) + tuple(params)
42
+ return CheckpointFunction.apply(func, len(inputs), *args)
43
+ else:
44
+ return func(*inputs)
45
+
46
+
47
+ class CheckpointFunction(torch.autograd.Function):
48
+ @staticmethod
49
+ @torch.amp.custom_fwd(device_type='cuda')
50
+ def forward(ctx, run_function, length, *args):
51
+ ctx.run_function = run_function
52
+ ctx.input_tensors = list(args[:length])
53
+ ctx.input_params = list(args[length:])
54
+
55
+ with torch.no_grad():
56
+ output_tensors = ctx.run_function(*ctx.input_tensors)
57
+ return output_tensors
58
+
59
+ @staticmethod
60
+ @torch.amp.custom_bwd(device_type='cuda')
61
+ def backward(ctx, *output_grads):
62
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
63
+ with torch.enable_grad():
64
+ # Fixes a bug where the first op in run_function modifies the
65
+ # Tensor storage in place, which is not allowed for detach()'d
66
+ # Tensors.
67
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
68
+ output_tensors = ctx.run_function(*shallow_copies)
69
+ input_grads = torch.autograd.grad(
70
+ output_tensors,
71
+ ctx.input_tensors + ctx.input_params,
72
+ output_grads,
73
+ allow_unused=True,
74
+ )
75
+ del ctx.input_tensors
76
+ del ctx.input_params
77
+ del output_tensors
78
+ return (None, None) + input_grads
79
+
80
+ try:
81
+ from flash_attn_interface import flash_attn_func
82
+ print("use flash attention 3.")
83
+ _use_flash3 = True
84
+ except:
85
+ print("use flash attention 2.")
86
+ _use_flash3 = False
87
+
88
+ def init_linear(l, stddev):
89
+ nn.init.normal_(l.weight, std=stddev)
90
+ if l.bias is not None:
91
+ nn.init.constant_(l.bias, 0.0)
92
+
93
+ def flash_attention(q, k, v):
94
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
95
+ if _use_flash3:
96
+ out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous())
97
+ # out = flash_attn_func(q, k, v)
98
+
99
+ # q_ = q.transpose(1, 2)
100
+ # k_ = k.transpose(1, 2)
101
+ # v_ = v.transpose(1, 2)
102
+
103
+ # # print(q.shape, k.shape, v.shape)
104
+ # out_ = F.scaled_dot_product_attention(q_, k_, v_)
105
+ # out_ = out_.transpose(1, 2)
106
+
107
+ # # print(torch.abs(out - out_).mean())
108
+ # assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large"
109
+
110
+ # out = out_
111
+
112
+ # print("use flash_atten 3")
113
+ else:
114
+ q = q.transpose(1, 2)
115
+ k = k.transpose(1, 2)
116
+ v = v.transpose(1, 2)
117
+ out = F.scaled_dot_product_attention(q, k, v)
118
+ out = out.transpose(1, 2)
119
+ # print("use flash atten 2")
120
+
121
+ return out
122
+
123
+ class MultiheadAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ device: torch.device,
128
+ dtype: torch.dtype,
129
+ n_ctx: int,
130
+ width: int,
131
+ heads: int,
132
+ init_scale: float,
133
+ qkv_bias: bool,
134
+ flash: bool = False
135
+ ):
136
+ super().__init__()
137
+ self.n_ctx = n_ctx
138
+ self.width = width
139
+ self.heads = heads
140
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
141
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
142
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
143
+ init_linear(self.c_qkv, init_scale)
144
+ init_linear(self.c_proj, init_scale)
145
+
146
+ def forward(self, x):
147
+ x = self.c_qkv(x)
148
+ x = checkpoint(self.attention, (x,), (), False)
149
+ x = self.c_proj(x)
150
+ return x
151
+
152
+
153
+ class QKVMultiheadAttention(nn.Module):
154
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
155
+ super().__init__()
156
+ self.device = device
157
+ self.dtype = dtype
158
+ self.heads = heads
159
+ self.n_ctx = n_ctx
160
+ self.flash = flash
161
+
162
+ def forward(self, qkv):
163
+ bs, n_ctx, width = qkv.shape
164
+ attn_ch = width // self.heads // 3
165
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
166
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
167
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
168
+
169
+ if self.flash:
170
+ out = flash_attention(q, k, v)
171
+ out = out.reshape(out.shape[0], out.shape[1], -1)
172
+ else:
173
+ weight = torch.einsum(
174
+ "bthc,bshc->bhts", q * scale, k * scale
175
+ ) # More stable with f16 than dividing afterwards
176
+ wdtype = weight.dtype
177
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
178
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
179
+
180
+ return out
181
+
182
+
183
+ class ResidualAttentionBlock(nn.Module):
184
+ def __init__(
185
+ self,
186
+ *,
187
+ device: torch.device,
188
+ dtype: torch.dtype,
189
+ n_ctx: int,
190
+ width: int,
191
+ heads: int,
192
+ init_scale: float = 1.0,
193
+ qkv_bias: bool = True,
194
+ flash: bool = False,
195
+ use_checkpoint: bool = False
196
+ ):
197
+ super().__init__()
198
+
199
+ self.use_checkpoint = use_checkpoint
200
+
201
+ self.attn = MultiheadAttention(
202
+ device=device,
203
+ dtype=dtype,
204
+ n_ctx=n_ctx,
205
+ width=width,
206
+ heads=heads,
207
+ init_scale=init_scale,
208
+ qkv_bias=qkv_bias,
209
+ flash=flash
210
+ )
211
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
212
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
213
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
214
+
215
+ def _forward(self, x: torch.Tensor):
216
+ x = x + self.attn(self.ln_1(x))
217
+ x = x + self.mlp(self.ln_2(x))
218
+ return x
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
222
+
223
+
224
+ class MultiheadCrossAttention(nn.Module):
225
+ def __init__(
226
+ self,
227
+ *,
228
+ device: torch.device,
229
+ dtype: torch.dtype,
230
+ width: int,
231
+ heads: int,
232
+ init_scale: float,
233
+ qkv_bias: bool = True,
234
+ flash: bool = False,
235
+ n_data: Optional[int] = None,
236
+ data_width: Optional[int] = None,
237
+ ):
238
+ super().__init__()
239
+ self.n_data = n_data
240
+ self.width = width
241
+ self.heads = heads
242
+ self.data_width = width if data_width is None else data_width
243
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
244
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
245
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
246
+ self.attention = QKVMultiheadCrossAttention(
247
+ device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
248
+ )
249
+ init_linear(self.c_q, init_scale)
250
+ init_linear(self.c_kv, init_scale)
251
+ init_linear(self.c_proj, init_scale)
252
+
253
+ def forward(self, x, data):
254
+ x = self.c_q(x)
255
+ data = self.c_kv(data)
256
+ x = checkpoint(self.attention, (x, data), (), False)
257
+ x = self.c_proj(x)
258
+ return x
259
+
260
+
261
+ class QKVMultiheadCrossAttention(nn.Module):
262
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
263
+ flash: bool = False, n_data: Optional[int] = None):
264
+
265
+ super().__init__()
266
+ self.device = device
267
+ self.dtype = dtype
268
+ self.heads = heads
269
+ self.n_data = n_data
270
+ self.flash = flash
271
+
272
+ def forward(self, q, kv):
273
+ _, n_ctx, _ = q.shape
274
+ bs, n_data, width = kv.shape
275
+ attn_ch = width // self.heads // 2
276
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
277
+ q = q.view(bs, n_ctx, self.heads, -1)
278
+ kv = kv.view(bs, n_data, self.heads, -1)
279
+ k, v = torch.split(kv, attn_ch, dim=-1)
280
+
281
+ if self.flash:
282
+ out = flash_attention(q, k, v)
283
+ out = out.reshape(out.shape[0], out.shape[1], -1)
284
+ else:
285
+ weight = torch.einsum(
286
+ "bthc,bshc->bhts", q * scale, k * scale
287
+ ) # More stable with f16 than dividing afterwards
288
+ wdtype = weight.dtype
289
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
290
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
291
+
292
+ return out
293
+
294
+
295
+ class ResidualCrossAttentionBlock(nn.Module):
296
+ def __init__(
297
+ self,
298
+ *,
299
+ device: Optional[torch.device],
300
+ dtype: Optional[torch.dtype],
301
+ n_data: Optional[int] = None,
302
+ width: int,
303
+ heads: int,
304
+ data_width: Optional[int] = None,
305
+ mlp_width_scale: int = 4,
306
+ init_scale: float = 0.25,
307
+ qkv_bias: bool = True,
308
+ flash: bool = False
309
+ ):
310
+ super().__init__()
311
+
312
+ if data_width is None:
313
+ data_width = width
314
+
315
+ self.attn = MultiheadCrossAttention(
316
+ device=device,
317
+ dtype=dtype,
318
+ n_data=n_data,
319
+ width=width,
320
+ heads=heads,
321
+ data_width=data_width,
322
+ init_scale=init_scale,
323
+ qkv_bias=qkv_bias,
324
+ flash=flash,
325
+ )
326
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
327
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
328
+ self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale)
329
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
330
+
331
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
332
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
333
+ x = x + self.mlp(self.ln_3(x))
334
+ return x
335
+
336
+
337
+ class MLP(nn.Module):
338
+ def __init__(self, *,
339
+ device: Optional[torch.device],
340
+ dtype: Optional[torch.dtype],
341
+ width: int,
342
+ hidden_width_scale: int = 4,
343
+ init_scale: float):
344
+ super().__init__()
345
+ self.width = width
346
+ self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype)
347
+ self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype)
348
+ self.gelu = nn.GELU()
349
+ init_linear(self.c_fc, init_scale)
350
+ init_linear(self.c_proj, init_scale)
351
+
352
+ def forward(self, x):
353
+ return self.c_proj(self.gelu(self.c_fc(x)))
354
+
355
+
356
+ class Transformer(nn.Module):
357
+ def __init__(
358
+ self,
359
+ *,
360
+ device: Optional[torch.device],
361
+ dtype: Optional[torch.dtype],
362
+ n_ctx: int,
363
+ width: int,
364
+ layers: int,
365
+ heads: int,
366
+ init_scale: float = 0.25,
367
+ qkv_bias: bool = True,
368
+ flash: bool = False,
369
+ use_checkpoint: bool = False
370
+ ):
371
+ super().__init__()
372
+ self.n_ctx = n_ctx
373
+ self.width = width
374
+ self.layers = layers
375
+ self.resblocks = nn.ModuleList(
376
+ [
377
+ ResidualAttentionBlock(
378
+ device=device,
379
+ dtype=dtype,
380
+ n_ctx=n_ctx,
381
+ width=width,
382
+ heads=heads,
383
+ init_scale=init_scale,
384
+ qkv_bias=qkv_bias,
385
+ flash=flash,
386
+ use_checkpoint=use_checkpoint
387
+ )
388
+ for _ in range(layers)
389
+ ]
390
+ )
391
+
392
+ def forward(self, x: torch.Tensor):
393
+ for block in self.resblocks:
394
+ x = block(x)
395
+ return x
src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from einops import repeat
10
+ import math
11
+
12
+ from ..attention_processor import Tripo2AttnProcessor2_0
13
+ from ..embeddings import FrequencyPositionalEmbedding
14
+ from .autoencoder_kl_tripo2 import Tripo2Encoder, Tripo2Decoder
15
+ from .FSQ import FSQ
16
+ from .SimVQ import SimVQ1D
17
+
18
+ from ...utils import fps
19
+
20
+ def init_linear(l, stddev):
21
+ nn.init.normal_(l.weight, std=stddev)
22
+ if l.bias is not None:
23
+ nn.init.constant_(l.bias, 0.0)
24
+
25
+
26
+ class SkinFSQCVAEModel(ModelMixin, ConfigMixin):
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ in_channels: int = 4,
31
+ cond_channels: int = 3,
32
+ latent_channels: int = 64,
33
+ num_attention_heads: int = 8,
34
+ width_encoder: int = 512,
35
+ width_decoder: int = 1024,
36
+ num_layers_encoder: int = 8,
37
+ num_layers_decoder: int = 16,
38
+ embedding_type: str = "frequency",
39
+ embed_frequency: int = 8,
40
+ embed_include_pi: bool = False,
41
+ sample_tokens: int = 32,
42
+ **kwargs
43
+ ):
44
+ super().__init__()
45
+
46
+ self.out_channels = 1
47
+
48
+ if embedding_type == "frequency":
49
+ self.embedder = FrequencyPositionalEmbedding(
50
+ num_freqs=embed_frequency,
51
+ logspace=True,
52
+ input_dim=3,
53
+ include_pi=embed_include_pi,
54
+ use_pmpe=kwargs.get('use_pmpe', False),
55
+ )
56
+ else:
57
+ raise NotImplementedError(
58
+ f"Embedding type {embedding_type} is not supported."
59
+ )
60
+
61
+ self.is_learned_queries = kwargs['is_learned_queries']
62
+
63
+ is_miche = kwargs.get('is_miche', False)
64
+ self.encoder = Tripo2Encoder(
65
+ in_channels=in_channels + self.embedder.out_dim,
66
+ dim=width_encoder,
67
+ num_attention_heads=num_attention_heads,
68
+ num_layers=num_layers_encoder,
69
+ is_learned_queries=self.is_learned_queries,
70
+ sample_tokens=sample_tokens,
71
+ is_miche=is_miche,
72
+ )
73
+
74
+ self.cond_encoder = Tripo2Encoder(
75
+ in_channels=cond_channels + self.embedder.out_dim,
76
+ dim=width_encoder,
77
+ num_attention_heads=num_attention_heads,
78
+ num_layers=num_layers_encoder,
79
+ is_miche=is_miche,
80
+ )
81
+
82
+ self.decoder = Tripo2Decoder(
83
+ in_channels=self.embedder.out_dim + self.cond_channels,
84
+ out_channels=self.out_channels,
85
+ dim=width_decoder,
86
+ num_attention_heads=num_attention_heads,
87
+ num_layers=num_layers_decoder,
88
+ is_miche=is_miche,
89
+ )
90
+
91
+ self.cond_quant = nn.Linear(width_encoder, latent_channels, bias=True)
92
+
93
+ self.quant = nn.Linear(width_encoder, latent_channels, bias=True)
94
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
95
+
96
+ init_scale = 0.25 * math.sqrt(1.0 / width_encoder)
97
+ init_linear(self.cond_quant, init_scale)
98
+ init_linear(self.quant, init_scale)
99
+ init_scale = 0.25 * math.sqrt(1.0 / latent_channels)
100
+ init_linear(self.post_quant, init_scale)
101
+ self.use_slicing = False
102
+ self.slicing_length = 1
103
+ if kwargs.get('FSQ_dict', None) is not None:
104
+ self.FSQ = FSQ(**kwargs['FSQ_dict'])
105
+ else:
106
+ self.FSQ = SimVQ1D(**kwargs['SimVQ_dict'])
107
+
108
+ @property
109
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
110
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
111
+ r"""
112
+ Returns:
113
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
114
+ indexed by its weight name.
115
+ """
116
+ # set recursively
117
+ processors = {}
118
+
119
+ def fn_recursive_add_processors(
120
+ name: str,
121
+ module: torch.nn.Module,
122
+ processors: Dict[str, AttentionProcessor],
123
+ ):
124
+ if hasattr(module, "get_processor"):
125
+ processors[f"{name}.processor"] = module.get_processor()
126
+
127
+ for sub_name, child in module.named_children():
128
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
129
+
130
+ return processors
131
+
132
+ for name, module in self.named_children():
133
+ fn_recursive_add_processors(name, module, processors)
134
+
135
+ return processors
136
+
137
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
138
+ def set_attn_processor(
139
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
140
+ ):
141
+ r"""
142
+ Sets the attention processor to use to compute attention.
143
+
144
+ Parameters:
145
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
146
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
147
+ for **all** `Attention` layers.
148
+
149
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
150
+ processor. This is strongly recommended when setting trainable attention processors.
151
+
152
+ """
153
+ count = len(self.attn_processors.keys())
154
+
155
+ if isinstance(processor, dict) and len(processor) != count:
156
+ raise ValueError(
157
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
158
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
159
+ )
160
+
161
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
162
+ if hasattr(module, "set_processor"):
163
+ if not isinstance(processor, dict):
164
+ module.set_processor(processor)
165
+ else:
166
+ module.set_processor(processor.pop(f"{name}.processor"))
167
+
168
+ for sub_name, child in module.named_children():
169
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
170
+
171
+ for name, module in self.named_children():
172
+ fn_recursive_attn_processor(name, module, processor)
173
+
174
+ def set_default_attn_processor(self):
175
+ """
176
+ Disables custom attention processors and sets the default attention implementation.
177
+ """
178
+ self.set_attn_processor(Tripo2AttnProcessor2_0())
179
+
180
+ def enable_slicing(self, slicing_length: int = 1) -> None:
181
+ r"""
182
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
183
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
184
+ """
185
+ self.use_slicing = True
186
+ self.slicing_length = slicing_length
187
+
188
+ def disable_slicing(self) -> None:
189
+ r"""
190
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
191
+ decoding in one step.
192
+ """
193
+ self.use_slicing = False
194
+
195
+ def _sample_features(
196
+ self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None
197
+ ):
198
+ """
199
+ Sample points from features of the input point cloud.
200
+
201
+ Args:
202
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
203
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
204
+ seed (Optional[int], optional): The random seed. Defaults to None.
205
+ """
206
+ rng = np.random.default_rng(seed)
207
+ indices = rng.choice(
208
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
209
+ )
210
+ selected_points = x[:, indices]
211
+
212
+ batch_size, num_points, num_channels = selected_points.shape
213
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
214
+ batch_indices = (
215
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
216
+ )
217
+
218
+ # fps sampling
219
+ sampling_ratio = 1.0 / 4
220
+ sampled_indices = fps(
221
+ flattened_points[:, :3],
222
+ batch_indices,
223
+ ratio=sampling_ratio,
224
+ random_start=self.training,
225
+ )
226
+ sampled_points = flattened_points[sampled_indices].view(
227
+ batch_size, -1, num_channels
228
+ )
229
+
230
+ return sampled_points
231
+
232
+ def get_qkv(self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None, not_get_q: bool=False):
233
+ positions, features = x[..., :3], x[..., 3:]
234
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
235
+
236
+ if not_get_q:
237
+ x_q = torch.zeros((x.shape[0], num_tokens, x.shape[-1]), dtype=x.dtype, device=x.device)
238
+ else:
239
+ sampled_x = self._sample_features(x, num_tokens, seed)
240
+ positions, features = (
241
+ sampled_x[..., :3],
242
+ sampled_x[..., 3:],
243
+ )
244
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
245
+ return x_q, x_kv
246
+
247
+ def _encode(
248
+ self, x: torch.Tensor|None, cond: torch.Tensor|None, num_tokens: int = 128, cond_tokens: int = 128, seed: Optional[int] = None,
249
+ return_z: bool=True, return_cond: bool=True,
250
+ ):
251
+ position_channels = 3
252
+ if return_z:
253
+ assert x is not None
254
+ x_q, x_kv = self.get_qkv(x, num_tokens, seed, not_get_q=self.is_learned_queries)
255
+ x = self.encoder(x_q, x_kv)
256
+ x = self.quant(x)
257
+ else:
258
+ x = None
259
+
260
+ if return_cond:
261
+ assert cond is not None
262
+ cond_q, cond_kv = self.get_qkv(cond, cond_tokens, seed)
263
+ cond_embed = self.cond_encoder(cond_q, cond_kv)
264
+ cond = self.cond_quant(cond_embed)
265
+ else:
266
+ cond = None
267
+
268
+ return x, cond
269
+
270
+ def _decode(
271
+ self, z: torch.Tensor,
272
+ cond: torch.Tensor,
273
+ sampled_points: torch.Tensor,
274
+ num_chunks: Optional[int] = None,
275
+ ) -> torch.Tensor:
276
+ xyz_samples = sampled_points
277
+ z = self.post_quant(torch.cat([z, cond], dim=1))
278
+
279
+ num_points = xyz_samples.shape[1]
280
+ if num_chunks is None:
281
+ num_chunks = num_points
282
+
283
+ queries = sampled_points.to(z.device, dtype=z.dtype)
284
+ positions, features = (
285
+ queries[..., :3],
286
+ queries[..., 3:],
287
+ )
288
+
289
+ kv_cache = None
290
+ dec = []
291
+ for i in range(0, num_points, num_chunks):
292
+ queries = torch.cat([self.embedder(positions[:, i:i + num_chunks, :]), features[:, i:i + num_chunks, :]], dim=-1)
293
+ z, kv_cache = self.decoder(z, queries, kv_cache)
294
+ dec.append(z)
295
+
296
+ return torch.cat(dec, dim=1)
297
+
298
+ def compile_model(self):
299
+ self.encoder = torch.compile(self.encoder)
300
+ self.cond_encoder = torch.compile(self.cond_encoder)
301
+ self.decoder = torch.compile(self.decoder)
302
+
303
+ def forward(self, x: torch.Tensor):
304
+ pass
src/model/skin_vae/autoencoders/vae.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(
10
+ self,
11
+ parameters: torch.Tensor,
12
+ deterministic: bool = False,
13
+ feature_dim: int = 1,
14
+ ):
15
+ self.parameters = parameters
16
+ self.feature_dim = feature_dim
17
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
18
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
19
+ self.deterministic = deterministic
20
+ self.std = torch.exp(0.5 * self.logvar)
21
+ self.var = torch.exp(self.logvar)
22
+ if self.deterministic:
23
+ self.var = self.std = torch.zeros_like(
24
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
25
+ )
26
+
27
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
28
+ # make sure sample is on the same device as the parameters and has same dtype
29
+ sample = randn_tensor(
30
+ self.mean.shape,
31
+ generator=generator,
32
+ device=self.parameters.device,
33
+ dtype=self.parameters.dtype,
34
+ )
35
+ x = self.mean + self.std * sample
36
+ return x
37
+
38
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
39
+ if self.deterministic:
40
+ return torch.Tensor([0.0])
41
+ else:
42
+ if other is None:
43
+ return 0.5 * torch.mean(
44
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
45
+ )
46
+ elif isinstance(other, DiagonalGaussianDistribution):
47
+ return 0.5 * torch.mean(
48
+ torch.pow(self.mean - other.mean, 2) / other.var
49
+ + self.var / other.var
50
+ - 1.0
51
+ - self.logvar
52
+ + other.logvar,
53
+ )
54
+ elif isinstance(other, torch.Tensor):
55
+ return 0.5 * torch.mean(
56
+ torch.pow(self.mean - other, 2) + self.var - 1.0 - self.logvar,
57
+ )
58
+ else:
59
+ raise ValueError("Other must be a DiagonalGaussianDistribution or torch.Tensor")
60
+
61
+ def nll(
62
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
63
+ ) -> torch.Tensor:
64
+ if self.deterministic:
65
+ return torch.Tensor([0.0])
66
+ logtwopi = np.log(2.0 * np.pi)
67
+ return 0.5 * torch.sum(
68
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
69
+ dim=dims,
70
+ )
71
+
72
+ def mode(self) -> torch.Tensor:
73
+ return self.mean
src/model/skin_vae/embeddings.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FrequencyPositionalEmbedding(nn.Module):
6
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
7
+ each feature dimension of `x[..., i]` into:
8
+ [
9
+ sin(x[..., i]),
10
+ sin(f_1*x[..., i]),
11
+ sin(f_2*x[..., i]),
12
+ ...
13
+ sin(f_N * x[..., i]),
14
+ cos(x[..., i]),
15
+ cos(f_1*x[..., i]),
16
+ cos(f_2*x[..., i]),
17
+ ...
18
+ cos(f_N * x[..., i]),
19
+ x[..., i] # only present if include_input is True.
20
+ ], here f_i is the frequency.
21
+
22
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
23
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
24
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
25
+
26
+ Args:
27
+ num_freqs (int): the number of frequencies, default is 6;
28
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
29
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
30
+ input_dim (int): the input dimension, default is 3;
31
+ include_input (bool): include the input tensor or not, default is True.
32
+
33
+ Attributes:
34
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
36
+
37
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
38
+ otherwise, it is input_dim * num_freqs * 2.
39
+
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_freqs: int = 6,
45
+ logspace: bool = True,
46
+ input_dim: int = 3,
47
+ include_input: bool = True,
48
+ include_pi: bool = True,
49
+ use_pmpe: bool = False,
50
+ ) -> None:
51
+ """The initialization"""
52
+
53
+ super().__init__()
54
+
55
+ if logspace:
56
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
57
+ else:
58
+ frequencies = torch.linspace(
59
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
60
+ )
61
+
62
+ if include_pi:
63
+ frequencies *= torch.pi
64
+
65
+ self.register_buffer("frequencies", frequencies, persistent=False)
66
+ self.include_input = include_input
67
+ self.num_freqs = num_freqs
68
+ self.use_pmpe = use_pmpe
69
+ if use_pmpe:
70
+ phase = torch.arange(num_freqs, dtype=torch.float32)
71
+ for i in range(num_freqs):
72
+ phase[i] = torch.pow(torch.tensor(num_freqs), 1.0-(i+1)/num_freqs)+(i+1)/num_freqs
73
+ phase *= torch.pi*2
74
+ self.register_buffer("phase", phase, persistent=False)
75
+
76
+ self.out_dim = self.get_dims(input_dim)
77
+
78
+ def get_dims(self, input_dim):
79
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
80
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
81
+
82
+ return out_dim
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ """Forward process.
86
+
87
+ Args:
88
+ x: tensor of shape [..., dim]
89
+
90
+ Returns:
91
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
92
+ where temp is 1 if include_input is True and 0 otherwise.
93
+ """
94
+
95
+ if self.num_freqs > 0:
96
+ embed = (x[..., None].contiguous() * self.frequencies).view(
97
+ *x.shape[:-1], -1
98
+ )
99
+ if self.use_pmpe:
100
+ phase = (x[..., None].contiguous()*torch.pi*0.5 + self.phase).view(
101
+ *x.shape[:-1], -1
102
+ )
103
+ res = torch.cat((embed.sin()+phase.sin(), embed.cos()+phase.cos()), dim=-1)
104
+ else:
105
+ res = torch.cat((embed.sin(), embed.cos()), dim=-1)
106
+ if self.include_input:
107
+ return torch.cat((x, res), dim=-1)
108
+ else:
109
+ return res
110
+ else:
111
+ return x
src/model/skin_vae/transformers/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .tripo2_transformer import Tripo2DiTModel
4
+
5
+
6
+ def default_set_attn_proc_func(
7
+ name: str,
8
+ hidden_size: int,
9
+ cross_attention_dim: Optional[int],
10
+ ori_attn_proc: object,
11
+ ) -> object:
12
+ return ori_attn_proc
13
+
14
+
15
+ def set_transformer_attn_processor(
16
+ transformer: Tripo2DiTModel,
17
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
19
+ ) -> None:
20
+ attn_procs = {}
21
+ for name, attn_processor in transformer.attn_processors.items():
22
+ hidden_size = transformer.config.width
23
+ if name.endswith("attn1.processor"):
24
+ # self attention
25
+ attn_procs[name] = set_self_attn_proc_func(
26
+ name, hidden_size, None, attn_processor
27
+ )
28
+ elif name.endswith("attn2.processor"):
29
+ # cross attention
30
+ cross_attention_dim = transformer.config.cross_attention_dim
31
+ attn_procs[name] = set_cross_attn_proc_func(
32
+ name, hidden_size, cross_attention_dim, attn_processor
33
+ )
34
+ elif name.endswith("attn2_2.processor"):
35
+ # cross attention 2
36
+ cross_attention_dim = transformer.config.cross_attention_2_dim
37
+ attn_procs[name] = set_cross_attn_proc_func(
38
+ name, hidden_size, cross_attention_dim, attn_processor
39
+ )
40
+
41
+ transformer.set_attn_processor(attn_procs)
src/model/skin_vae/transformers/modeling_outputs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Transformer1DModelOutput:
8
+ sample: torch.FloatTensor