diff --git a/.gitattributes b/.gitattributes new file mode 100755 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..a2417d77aa770b665f61eed4aa262659fe76414d --- /dev/null +++ b/.gitignore @@ -0,0 +1,67 @@ +# igonore all pychace +**/__pycache__/ +*.py[cod] +*$py.class + +# ignore tmp & output files +_data/ +ckpt/ +tmp/ +tmp_fusion/ +tmp_vae/ +tmp_video/ +*.glb +*.ply +*.obj +*.fbx +*.npz +*.blend +*.blend1 +*.blend2 + +# ignore logs +wandb/ +lightning_logs/ +*.log + +# ignore experiments +experiments/ +results/ +dataset_clean/ +logs/ +datalist/ +dataset_inference/ +dataset_inference_clean/ +feature_viz/ + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg +*.whl + +# Virtual environments +venv/ +env/ +.env/ +.venv/ + +# IDE specific files +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +coverage.xml +*.cover diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..484d9bbef8720dc2797444a18c972a4e6245a07e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 VAST-AI-Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100755 index 0000000000000000000000000000000000000000..c6caed841aaea28e69109e77652ef9fc4443998d --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: SkinTokens +emoji: ๐ŸŒ– +colorFrom: green +colorTo: pink +sdk: gradio +sdk_version: 6.12.0 +python_version: 3.12.12 +app_file: demo.py +pinned: false +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl b/bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..3b7a16a029a67968f70ef44cbddd166a948a015f --- /dev/null +++ b/bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39df5f78dc95d1fae6058a3134a40f645303c6e96540f87e9ee4c0fd436def1d +size 346159222 diff --git a/bpy_server.py b/bpy_server.py new file mode 100755 index 0000000000000000000000000000000000000000..c22455ca1d27ca554b067592e75ea87c3eb17b6b --- /dev/null +++ b/bpy_server.py @@ -0,0 +1,7 @@ +from src.server.bpy_server import run + +def main(): + run() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/configs/skeleton/mixamo.yaml b/configs/skeleton/mixamo.yaml new file mode 100755 index 0000000000000000000000000000000000000000..6271e7b9f59eec59ae08ce77138d1a054be382c8 --- /dev/null +++ b/configs/skeleton/mixamo.yaml @@ -0,0 +1,59 @@ +parts_order: [body, hand] + +parts: + body: [ + mixamorig:Hips, + mixamorig:Spine, + mixamorig:Spine1, + mixamorig:Spine2, + mixamorig:Neck, + mixamorig:Head, + mixamorig:LeftShoulder, + mixamorig:LeftArm, + mixamorig:LeftForeArm, + mixamorig:LeftHand, + mixamorig:RightShoulder, + mixamorig:RightArm, + mixamorig:RightForeArm, + mixamorig:RightHand, + mixamorig:LeftUpLeg, + mixamorig:LeftLeg, + mixamorig:LeftFoot, + mixamorig:LeftToeBase, + mixamorig:RightUpLeg, + mixamorig:RightLeg, + mixamorig:RightFoot, + mixamorig:RightToeBase, + ] + hand: [ + mixamorig:LeftHandThumb1, + mixamorig:LeftHandThumb2, + mixamorig:LeftHandThumb3, + mixamorig:LeftHandIndex1, + mixamorig:LeftHandIndex2, + mixamorig:LeftHandIndex3, + mixamorig:LeftHandMiddle1, + mixamorig:LeftHandMiddle2, + mixamorig:LeftHandMiddle3, + mixamorig:LeftHandRing1, + mixamorig:LeftHandRing2, + mixamorig:LeftHandRing3, + mixamorig:LeftHandPinky1, + mixamorig:LeftHandPinky2, + mixamorig:LeftHandPinky3, + mixamorig:RightHandIndex1, + mixamorig:RightHandIndex2, + mixamorig:RightHandIndex3, + mixamorig:RightHandThumb1, + mixamorig:RightHandThumb2, + mixamorig:RightHandThumb3, + mixamorig:RightHandMiddle1, + mixamorig:RightHandMiddle2, + mixamorig:RightHandMiddle3, + mixamorig:RightHandRing1, + mixamorig:RightHandRing2, + mixamorig:RightHandRing3, + mixamorig:RightHandPinky1, + mixamorig:RightHandPinky2, + mixamorig:RightHandPinky3, + ] \ No newline at end of file diff --git a/configs/skeleton/vroid.yaml b/configs/skeleton/vroid.yaml new file mode 100755 index 0000000000000000000000000000000000000000..1d6f066e78686d4f3e7d9f50dae8b9ab73922aa9 --- /dev/null +++ b/configs/skeleton/vroid.yaml @@ -0,0 +1,59 @@ +parts_order: [body, hand] + +parts: + body: [ + J_Bip_C_Hips, + J_Bip_C_Spine, + J_Bip_C_Chest, + J_Bip_C_UpperChest, + J_Bip_C_Neck, + J_Bip_C_Head, + J_Bip_L_Shoulder, + J_Bip_L_UpperArm, + J_Bip_L_LowerArm, + J_Bip_L_Hand, + J_Bip_R_Shoulder, + J_Bip_R_UpperArm, + J_Bip_R_LowerArm, + J_Bip_R_Hand, + J_Bip_L_UpperLeg, + J_Bip_L_LowerLeg, + J_Bip_L_Foot, + J_Bip_L_ToeBase, + J_Bip_R_UpperLeg, + J_Bip_R_LowerLeg, + J_Bip_R_Foot, + J_Bip_R_ToeBase, + ] + hand: [ + J_Bip_L_Thumb1, + J_Bip_L_Thumb2, + J_Bip_L_Thumb3, + J_Bip_L_Index1, + J_Bip_L_Index2, + J_Bip_L_Index3, + J_Bip_L_Middle1, + J_Bip_L_Middle2, + J_Bip_L_Middle3, + J_Bip_L_Ring1, + J_Bip_L_Ring2, + J_Bip_L_Ring3, + J_Bip_L_Little1, + J_Bip_L_Little2, + J_Bip_L_Little3, + J_Bip_R_Index1, + J_Bip_R_Index2, + J_Bip_R_Index3, + J_Bip_R_Thumb1, + J_Bip_R_Thumb2, + J_Bip_R_Thumb3, + J_Bip_R_Middle1, + J_Bip_R_Middle2, + J_Bip_R_Middle3, + J_Bip_R_Ring1, + J_Bip_R_Ring2, + J_Bip_R_Ring3, + J_Bip_R_Little1, + J_Bip_R_Little2, + J_Bip_R_Little3, + ] \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100755 index 0000000000000000000000000000000000000000..8868490cf9f0306011db18d845f483977fbd1391 --- /dev/null +++ b/demo.py @@ -0,0 +1,764 @@ +import argparse +import atexit +import importlib +import os +import signal +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import List, Optional, Tuple + +import gradio as gr +import requests +from torch import Tensor +from tqdm import tqdm + +# --------------------------------------------------------------------------- +# ZeroGPU compatibility shim. The hosted HF Space provides the `spaces` +# package; running locally we substitute a no-op. +# --------------------------------------------------------------------------- +try: + spaces = importlib.import_module("spaces") +except Exception: + class _SpacesCompat: + @staticmethod + def GPU(*args, **kwargs): + if len(args) == 1 and callable(args[0]) and not kwargs: + return args[0] + + def _decorator(fn): + return fn + + return _decorator + + spaces = _SpacesCompat() + +os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1") +gr.TEMP_DIR = "tmp_gradio" + + +# --------------------------------------------------------------------------- +# Install the bundled `bpy` wheel at runtime if it isn't already importable. +# +# Why this is non-trivial: +# - Putting the wheel in requirements.txt fails: HF Spaces' Docker build +# mounts only requirements.txt BEFORE the repo COPY, so the wheel path +# doesn't exist at pip-install time. +# - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 / +# manylinux_2_39). +# - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS +# layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't +# list `*.whl`). The container's COPY-from-repo only carries the LFS +# *pointer* file โ€” a ~150-byte text stub โ€” not the actual wheel binary. +# So `pip install ` and `zipfile.ZipFile()` both fail with +# "is not a zip file" / "Wheel is invalid". +# +# So: we detect the LFS-pointer case and re-fetch the real wheel from the +# HF Hub at runtime (where the API resolves LFS server-side), then extract +# it directly into site-packages. +# --------------------------------------------------------------------------- +def _ensure_bpy_installed(): + try: + import bpy # noqa: F401 + return + except Exception: + pass + + import glob + import sysconfig + import zipfile + + here = os.path.dirname(os.path.abspath(__file__)) + wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl"))) + if not wheels: + print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True) + return + + wheel = wheels[-1] + wheel_name = os.path.basename(wheel) + + # Detect LFS pointer (text stub starting with "version https://git-lfs..."). + is_real_zip = False + try: + with open(wheel, "rb") as f: + is_real_zip = f.read(4).startswith(b"PK") + except Exception: + pass + + if not is_real_zip: + print( + f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); " + f"fetching real wheel from HF Hub...", + flush=True, + ) + from huggingface_hub import hf_hub_download + + space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens") + token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos + wheel = hf_hub_download( + repo_id=space_id, + repo_type="space", + filename=wheel_name, + token=token, + ) + print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True) + + site = sysconfig.get_paths()["purelib"] + print(f"[demo] Extracting {wheel_name} into {site}", flush=True) + with zipfile.ZipFile(wheel) as z: + z.extractall(site) + print("[demo] bpy wheel extracted.", flush=True) + + +_ensure_bpy_installed() + + +# --------------------------------------------------------------------------- +# Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3 +# tokenizer/config on first cold-start. +# +# These live in the *model* repo `VAST-AI/SkinTokens` (private), separate +# from this Space repo, so they aren't COPYed into the container. Re-uses +# `HF_TOKEN` from the Space secrets. +# --------------------------------------------------------------------------- +def _ensure_models_downloaded(): + here = os.path.dirname(os.path.abspath(__file__)) + needed_ckpts = [ + "experiments/skin_vae_2_10_32768/last.ckpt", + "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", + ] + qwen_dir = os.path.join(here, "models", "Qwen3-0.6B") + + all_present = ( + all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts) + and os.path.exists(os.path.join(qwen_dir, "tokenizer.json")) + ) + if all_present: + return + + from huggingface_hub import hf_hub_download, snapshot_download + + token = os.environ.get("HF_TOKEN") + + for rel in needed_ckpts: + target = os.path.join(here, rel) + if os.path.exists(target): + continue + print(f"[demo] Downloading checkpoint: {rel}", flush=True) + hf_hub_download( + repo_id="VAST-AI/SkinTokens", + filename=rel, + local_dir=here, + token=token, + ) + + if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")): + print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True) + snapshot_download( + repo_id="Qwen/Qwen3-0.6B", + local_dir=qwen_dir, + ignore_patterns=["*.bin", "*.safetensors"], + ) + + print("[demo] All checkpoints ready.", flush=True) + + +_ensure_models_downloaded() + + +from src.data.dataset import DatasetConfig, RigDatasetModule +from src.data.transform import Transform +from src.model.tokenrig import TokenRigResult +from src.tokenizer.parse import get_tokenizer +from src.server.spec import ( + BPY_SERVER, + get_model, + object_to_bytes, + bytes_to_object, +) +from src.data.vertex_group import voxel_skin + + +# --------------------------------------------------------------------------- +# Pre-warm `bpy_server` in the main (Gradio) process at module load. +# +# Why this is necessary on ZeroGPU: each user request runs inside a fresh +# `@spaces.GPU` worker process with a hard time budget (โ‰ˆ60 s on free tier). +# Importing the Blender shared object inside that budget burns 30โ€“60 s, so +# the worker is killed *during* bpy import โ€” manifesting as +# "GPU task aborted" before any model code runs. +# +# We start `bpy_server.py` here, in the always-running main process, so the +# slow bpy import happens exactly once at Space boot. Workers then just hit +# `localhost:59876` over HTTP โ€” sub-millisecond, no startup cost. +# --------------------------------------------------------------------------- + +MODEL_CKPTS = [ + "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", +] + +HF_PATHS = [ + "None", +] + + +def get_dataloader_workers() -> int: + if os.getenv("SPACE_ID"): + return 0 + return 1 + + +# --------------------------------------------------------------------------- +# bpy_server lifecycle โ€” lazy start so the heavy import doesn't fight ZeroGPU +# during module load. +# --------------------------------------------------------------------------- +_BPY_SERVER_PROC = None + + +def is_bpy_server_alive(timeout: float = 1.0) -> bool: + try: + resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout) + return resp.status_code == 200 + except Exception: + return False + + +def start_bpy_server(): + proc = subprocess.Popen( + [sys.executable, "bpy_server.py"], + stdout=None, + stderr=None, + preexec_fn=os.setsid, + ) + print(f"[Main] bpy_server.py started (pid={proc.pid})") + + def cleanup(): + print(f"[Main] Terminating bpy_server.py (pid={proc.pid})") + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except ProcessLookupError: + pass + + atexit.register(cleanup) + return proc + + +def wait_for_bpy_server(timeout: float = 120): + """Wait for bpy_server.py to come up. The first start of bpy_server is + slow because importing the Blender `.so` (~200 MB shared object) takes + 30โ€“60 s on a cold container. We allow up to 120 s.""" + t0 = time.time() + last_log = 0.0 + while True: + try: + requests.get(f"{BPY_SERVER}/ping", timeout=1) + print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)") + return + except Exception: + now = time.time() + if now - t0 > timeout: + raise RuntimeError( + f"bpy_server failed to start after {timeout:.0f}s" + ) + if now - last_log > 10: # progress every 10s + print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)") + last_log = now + time.sleep(0.5) + + +def ensure_bpy_server_started(): + global _BPY_SERVER_PROC + if is_bpy_server_alive(): + return + if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None: + return + _BPY_SERVER_PROC = start_bpy_server() + wait_for_bpy_server() + + +# --------------------------------------------------------------------------- +# Lazy model loading. +# --------------------------------------------------------------------------- +model = None +tokenizer = None +transform = None +CURRENT_MODEL_CKPT: Optional[str] = None +CURRENT_HF_PATH: Optional[str] = None + + +def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]: + global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH + if hf_path == "None": + hf_path = None + if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH: + return ("Model already loaded.", model_ckpt) + + if not model_ckpt: + raise RuntimeError("model_ckpt is empty. Please select a checkpoint.") + + print(f"Loading model: {model_ckpt}, hf_path={hf_path}") + model = get_model(model_ckpt, hf_path=hf_path) + assert model.tokenizer_config is not None + tokenizer = get_tokenizer(**model.tokenizer_config) + transform = Transform.parse(**model.transform_config["predict_transform"]) + CURRENT_MODEL_CKPT = model_ckpt + CURRENT_HF_PATH = hf_path + return ("Model loaded.", model_ckpt) + + +# --------------------------------------------------------------------------- +# File utilities (CLI-side). +# --------------------------------------------------------------------------- +SUPPORTED_EXT = {".obj", ".fbx", ".glb"} + + +def collect_files(input_path: Path) -> List[Path]: + if input_path.is_file(): + return [input_path] + + files = [] + for p in input_path.rglob("*"): + if p.suffix.lower() in SUPPORTED_EXT: + files.append(p) + return files + + +def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path: + rel = in_path.relative_to(input_root) + return (output_root / rel).with_suffix(".glb") + + +# --------------------------------------------------------------------------- +# Core inference (shared by CLI and Gradio). +# --------------------------------------------------------------------------- +def run_rig( + filepaths: List[Path], + top_k: int, + top_p: float, + temperature: float, + repetition_penalty: float, + num_beams: int, + use_skeleton: bool, + use_transfer: bool, + use_postprocess: bool, + output_paths: List[Path], + model_ckpt: str, + hf_path: Optional[str], +): + assert len(filepaths) == len(output_paths) + ensure_bpy_server_started() + load_model(model_ckpt, hf_path) + + datapath = { + "data_name": None, + "loader": "bpy_server", + "filepaths": {"articulation": [str(p) for p in filepaths]}, + } + + dataset_config = DatasetConfig.parse( + shuffle=False, + batch_size=1, + num_workers=get_dataloader_workers(), + pin_memory=get_dataloader_workers() > 0, + persistent_workers=False, + datapath=datapath, + ).split_by_cls() + + module = RigDatasetModule( + predict_dataset_config=dataset_config, + predict_transform=transform, + tokenizer=tokenizer, + process_fn=model._process_fn, + ) + + dataloader = module.predict_dataloader()["articulation"] + + results_out = [] + infer_device = model.device if model is not None else "cuda" + + for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + batch = { + k: v.to(infer_device) if isinstance(v, Tensor) else v + for k, v in batch.items() + } + + if not use_skeleton: + batch.pop("skeleton_tokens", None) + batch.pop("skeleton_mask", None) + + batch["generate_kwargs"] = dict( + max_length=2048, + top_k=int(top_k), + top_p=float(top_p), + temperature=float(temperature), + repetition_penalty=float(repetition_penalty), + num_return_sequences=1, + num_beams=int(num_beams), + do_sample=True, + ) + + if "skeleton_tokens" in batch and "skeleton_mask" in batch: + mask = batch["skeleton_mask"][0] == 1 + skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy() + else: + skeleton_tokens = None + + preds: List[TokenRigResult] = model.predict_step( + batch, + skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None, + make_asset=True, + )["results"] + + asset = preds[0].asset + assert asset is not None + + if use_postprocess: + voxel = asset.voxel(resolution=196) + asset.skin *= voxel_skin( + grid=0, + grid_coords=voxel.coords, + joints=asset.joints, + vertices=asset.vertices, + faces=asset.faces, + mode="square", + voxel_size=voxel.voxel_size, + ) + asset.normalize_skin() + + out_path = output_paths[i] + out_path.parent.mkdir(parents=True, exist_ok=True) + + if use_transfer: + payload = dict( + source_asset=asset, + target_path=asset.path, + export_path=str(out_path), + group_per_vertex=4, + ) + res = bytes_to_object( + requests.post( + f"{BPY_SERVER}/transfer", + data=object_to_bytes(payload), + ).content + ) + else: + payload = dict( + asset=asset, + filepath=str(out_path), + group_per_vertex=4, + ) + res = bytes_to_object( + requests.post( + f"{BPY_SERVER}/export", + data=object_to_bytes(payload), + ).content + ) + + if res != "ok": + print(f"[Error] {res}") + else: + print(f"[OK] Exported: {out_path}") + + results_out.append(out_path) + + return results_out + + +# --------------------------------------------------------------------------- +# CLI entry point. +# --------------------------------------------------------------------------- +def run_cli(args): + input_path = Path(args.input).resolve() + output_path = Path(args.output).resolve() + + files = collect_files(input_path) + if not files: + raise RuntimeError("No valid 3D files found.") + + if len(files) == 1 and output_path.suffix: + outputs = [output_path] + else: + outputs = [map_output_path(f, input_path, output_path) for f in files] + + run_rig( + files, + args.top_k, + args.top_p, + args.temperature, + args.repetition_penalty, + args.num_beams, + args.use_skeleton, + args.use_transfer, + args.use_postprocess, + outputs, + args.model_ckpt, + args.hf_path, + ) + + +# --------------------------------------------------------------------------- +# Gradio wrapper (with ZeroGPU duration estimator). +# --------------------------------------------------------------------------- +TOT = 0 + + +def _gpu_duration( + files, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + use_skeleton, + use_transfer, + use_postprocess, + model_ckpt, + hf_path, +): + # Cold workers spend ~30โ€“60 s importing bpy + loading the model before + # any GPU work. Give every request a generous 240 s floor. + file_count = len(files) if files is not None else 1 + return min(900, max(240, 240 + 60 * file_count)) + + +@spaces.GPU(duration=_gpu_duration) +def run_gradio( + files, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + use_skeleton, + use_transfer, + use_postprocess, + model_ckpt, + hf_path, +): + if not files: + return "Please upload at least one 3D model.", None + + tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_")) + filepaths = [Path(f.name) for f in files] + global TOT + outputs = [] + for filepath in filepaths: + TOT += 1 + outputs.append(tmp_out / f"res_{TOT}.glb") + + run_rig( + filepaths, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + use_skeleton, + use_transfer, + use_postprocess, + outputs, + model_ckpt, + hf_path, + ) + + return f"Processed {len(outputs)} models.", [str(p) for p in outputs] + + +# --------------------------------------------------------------------------- +# Gradio UI. +# --------------------------------------------------------------------------- +def build_gradio_app(): + model_ckpts = MODEL_CKPTS + hf_paths = HF_PATHS + default_ckpt = model_ckpts[0] if model_ckpts else "" + default_hf = hf_paths[0] if hf_paths else "None" + + with gr.Blocks(title="SkinTokens ยท TokenRig Demo") as app: + gr.Markdown( + """ + ## ๐Ÿฆด Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) ยท TokenRig + + Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified + autoregressive model over learned *SkinTokens*. Successor to + [UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH '25). + + * Upload one or more meshes โ†’ click **Run** โ†’ download a rigged `.glb`. + * **Paper**: [arXiv 2602.04805](https://arxiv.org/abs/2602.04805)  ยท  + **Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens)  ยท  + **Weights**: [๐Ÿค— VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens) + * Looking for **image โ†’ rigged 3D** instead? Try our sibling Space + [๐Ÿค— VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen). + * Want a full AI-powered 3D workspace? โ†’ [Tripo](https://www.tripo3d.ai) + """ + ) + + gr.HTML( + """ + +
+ 💡 Tips  + Defaults work well for most meshes. +  โ€ข If your mesh already has a skeleton and you only want skinning, enable + Use existing skeleton below. +  โ€ข To keep your original textures and world scale, enable Preserve original texture & scale. +
+""" + ) + + with gr.Row(): + with gr.Column(scale=1): + files = gr.File( + label="3D Models ( .obj / .fbx / .glb, up to a few at a time )", + file_count="multiple", + file_types=[".obj", ".fbx", ".glb"], + ) + + with gr.Accordion("โš™๏ธ Generation Settings", open=False): + model_ckpt = gr.Dropdown( + choices=model_ckpts, + value=default_ckpt, + label="Model checkpoint", + info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.", + interactive=True, + ) + # Keep the hf_path component for callback compatibility, but hide it + # from the UI since it currently only exposes the default ("None") option. + hf_path = gr.Dropdown( + choices=hf_paths, + value=default_hf, + label="HF path (advanced)", + visible=False, + ) + + gr.Markdown("**Sampling parameters** โ€” control autoregressive decoding of the rig.") + top_k = gr.Slider( + 1, 200, value=5, step=1, + label="top_k", + info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.", + ) + top_p = gr.Slider( + 0.1, 1.0, value=0.95, step=0.01, + label="top_p (nucleus)", + info="Sample from the smallest set of tokens whose cumulative probability โ‰ฅ p.", + ) + temperature = gr.Slider( + 0.1, 2.0, value=1.0, step=0.1, + label="temperature", + info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).", + ) + repetition_penalty = gr.Slider( + 0.5, 3.0, value=2.0, step=0.1, + label="repetition_penalty", + info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.", + ) + num_beams = gr.Slider( + 1, 20, value=10, step=1, + label="num_beams", + info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.", + ) + + gr.Markdown("**Pipeline toggles**") + use_skeleton = gr.Checkbox( + False, + label="Use existing skeleton (predict skinning only)", + info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.", + ) + use_transfer = gr.Checkbox( + False, + label="Preserve original texture & scale", + info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.", + ) + use_postprocess = gr.Checkbox( + False, + label="Voxel skin post-processing", + info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.", + ) + + run_btn = gr.Button("๐Ÿš€ Run", variant="primary") + + with gr.Column(scale=1): + log = gr.Textbox(label="Status", lines=2, interactive=False) + output = gr.File(label="Rigged GLB output", interactive=False) + gr.Markdown( + """ + **Notes** + - The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File โ†’ Import โ†’ glTF 2.0) or any DCC tool that reads glTF. + - In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it. + - On busy moments Zero-GPU may queue your request for ~10โ€“30 s before inference starts โ€” the status box will update once the GPU is attached. + - Please do **not** upload confidential or NSFW content. See the + [project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the + [code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference. + """ + ) + + run_btn.click( + run_gradio, + inputs=[ + files, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + use_skeleton, + use_transfer, + use_postprocess, + model_ckpt, + hf_path, + ], + outputs=[log, output], + ) + + return app + + +demo = build_gradio_app() + + +# Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py` +# transitively imports `src.model.michelangelo.utils.misc`, whose +# module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)` +# at import time. That call fails ("RuntimeError: No CUDA GPUs are +# available") in the main Gradio process on ZeroGPU, where the GPU is only +# attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot +# happens on first request, inside the worker. + + +# --------------------------------------------------------------------------- +# Entry point. +# --------------------------------------------------------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser("TokenRig Demo") + parser.add_argument("--input", help="Input file or directory") + parser.add_argument("--output", help="Output file or directory") + + parser.add_argument("--top_k", type=int, default=5) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition_penalty", type=float, default=2.0) + parser.add_argument("--num_beams", type=int, default=10) + + parser.add_argument("--use_skeleton", action="store_true") + parser.add_argument("--use_transfer", action="store_true") + parser.add_argument("--use_postprocess", action="store_true") + + parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "") + parser.add_argument("--hf_path", default=None) + + parser.add_argument("--gradio", action="store_true") + + args = parser.parse_args() + + if args.gradio or not args.input: + demo.queue() + demo.launch(ssr_mode=False) + else: + ensure_bpy_server_started() + run_cli(args) diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..8091bc370ab204f972cce09a6670ec04c72217de --- /dev/null +++ b/download.py @@ -0,0 +1,72 @@ +from huggingface_hub import hf_hub_download, snapshot_download + +import argparse + +REPO_ID = "VAST-AI/SkinTokens" + +MODELS = [ + "experiments/skin_vae_2_10_32768/last.ckpt", + "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", +] + +DATASETS = [ + "rignet.zip", + "articulation.zip", +] + +LLM_REPO = "Qwen/Qwen3-0.6B" +LLM_LOCAL_DIR = "models/Qwen3-0.6B" + + +def download_model(name: str): + local_path = hf_hub_download( + repo_id=REPO_ID, + filename=name, + local_dir=".", + ) + print(f"[MODEL] {name} downloaded to: {local_path}") + + +def download_llm(): + local_path = snapshot_download( + repo_id=LLM_REPO, + local_dir=LLM_LOCAL_DIR, + ignore_patterns=["*.bin", "*.safetensors"], + ) + print(f"[LLM] Config downloaded to: {local_path}") + + +def download_data(name: str): + local_path = hf_hub_download( + repo_id=REPO_ID, + filename=f"dataset_clean/{name}", + local_dir=".", + ) + name = name.removesuffix(".zip") + local_path = snapshot_download( + repo_id=REPO_ID, + allow_patterns=[f"datalist/{name}/*"], + local_dir=".", + ) + print(f"[DATA] {name} downloaded to: {local_path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", action="store_true", help="Download model checkpoints") + parser.add_argument("--data", action="store_true", help="Download datasets") + args = parser.parse_args() + if not args.model and not args.data: + print("Please specify --model or --data") + return + if args.model: + for model in MODELS: + download_model(model) + download_llm() + if args.data: + for data in DATASETS: + download_data(data) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..a2aacf405c75ad120cf2ce9841f992046cc62835 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +--extra-index-url https://download.pytorch.org/whl/cu128 +--extra-index-url https://pypi.org/simple + +# Pinned to match the flash-attn wheel below (cu12torch2.9cxx11abiTRUE). +# Don't bump torch without also bumping the flash-attn URL โ€” they must agree on +# the (cu12 / torch 2.9 / cxx11-abi=TRUE / cp312) tuple. +torch==2.9.1 +torchvision==0.24.1 +torchaudio==2.9.1 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl +transformers==4.57.0 +diffusers==0.36.0 +python-box +einops +omegaconf +pytorch_lightning +lightning +addict +timm +fast-simplification +trimesh +open3d +pyrender +# bpy is NOT listed here. The bpy wheel committed in this repo +# (`bpy-4.5.4rc0-cp312-cp312-manylinux_2_39_x86_64.whl`) is installed at +# runtime by `demo.py` โ€” see `_ensure_bpy_installed()`. Reason: HF Spaces' +# Docker build mounts only `requirements.txt` BEFORE the repo COPY, so the +# wheel path doesn't exist at pip-install time. Public PyPI also has no bpy +# wheel matching this exact build (`rc0` / manylinux_2_39 / cp312). +huggingface_hub +spaces +wandb +numpy==2.2.6 +gradio +bottle +tornado +cython \ No newline at end of file diff --git a/runtime.txt b/runtime.txt new file mode 100755 index 0000000000000000000000000000000000000000..08f311892edad977549d139203cf4c721918062d --- /dev/null +++ b/runtime.txt @@ -0,0 +1 @@ +python-3.11 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/augment.py b/src/data/augment.py new file mode 100755 index 0000000000000000000000000000000000000000..0909d5dfb01b5c40a0213987c92b6052bdbf6e89 --- /dev/null +++ b/src/data/augment.py @@ -0,0 +1,706 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Tuple, Union, List, Optional, Dict +from numpy import ndarray +from abc import ABC, abstractmethod +from scipy.spatial.transform import Rotation as R + +import numpy as np +import random + +from .spec import ConfigSpec + +from ..rig_package.utils import axis_angle_to_matrix +from ..rig_package.info.asset import Asset + +@dataclass(frozen=True) +class Augment(ConfigSpec): + + @classmethod + @abstractmethod + def parse(cls, **kwags) -> 'Augment': + pass + + @abstractmethod + def transform(self, asset: Asset, **kwargs): + pass + +@dataclass(frozen=True) +class AugmentTrim(Augment): + """randomly delete joints and vertices""" + + @classmethod + def parse(cls, **kwargs) -> 'AugmentTrim': + cls.check_keys(kwargs) + return AugmentTrim() + + def transform(self, asset: Asset, **kwargs): + asset.trim_skeleton() + +@dataclass(frozen=True) +class AugmentDelete(Augment): + """randomly delete joints and vertices""" + + # probability + p: float + + # how much to keep + rate: float + + @classmethod + def parse(cls, **kwargs) -> 'AugmentDelete': + cls.check_keys(kwargs) + return AugmentDelete( + p=kwargs.get('p', 0.), + rate=kwargs.get('rate', 0.5), + ) + + def transform(self, asset: Asset, **kwargs): + if asset.skin is None: + raise ValueError("do not have skin") + if asset.parents is None: + raise ValueError("do not have parents") + asset.normalize_skin() + def select_k(arr: List, k: int): + if len(arr) <= k: + return arr + else: + rest_indices = list(range(1, len(arr))) + selected_indices = sorted(random.sample(rest_indices, k)) + return [arr[i] for i in selected_indices] + if np.random.rand() >= self.p: + return + ids = select_k([i for i in range(asset.J)], max(int(asset.J * (1 - np.random.rand() * self.rate)), 1)) + if len(ids) == 0: + return + # keep bones with no skin + keep = {} + for id in ids: + keep[id] = True + for id in range(asset.J): + if np.all(asset.skin[:, id] < 0.1): + keep[id] = True + keep[asset.root] = True + + vertices_to_delete = np.zeros(asset.N, dtype=bool) + for id in range(asset.J): + if id not in keep: + dominant = asset.skin.argmax(axis=1) == id + x = (asset.skin[:, id] > 0.1) & dominant + if np.all(~x) or x.sum() * asset.J < asset.N: # avoid collapsing + keep[id] = 1 + continue + vertices_to_delete[x] = True + if np.all(vertices_to_delete): + return + if asset.faces is not None: + indices = np.where(~vertices_to_delete)[0] + face_mask = np.all(np.isin(asset.faces, indices), axis=1) + if np.all(~face_mask): + return + + joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep] + asset.delete_joints(joints_to_delete) + asset.delete_vertices(np.arange(asset.N)[vertices_to_delete]) + +@dataclass(frozen=True) +class AugmentDropPart(Augment): + """randomly drop subtrees and their vertices""" + + # probability + p: float + + # drop rate + rate: float + + @classmethod + def parse(cls, **kwargs) -> 'AugmentDropPart': + cls.check_keys(kwargs) + return AugmentDropPart( + p=kwargs.get('p', 0.), + rate=kwargs.get('rate', 0.5), + ) + + def transform(self, asset: Asset, **kwargs): + if np.random.rand() >= self.p: + return + if asset.parents is None: + raise ValueError("do not have parents") + if asset.skin is None: + raise ValueError("do not have skin") + keep = [] + for id in range(asset.J): + if np.random.rand() < self.rate: + keep.append(id) + if len(keep) == 0: + return + for id in reversed(asset.dfs_order): + p = asset.parents[id] + if p == -1: + continue + if id in keep and p not in keep: + keep.append(p) + + mask = np.zeros(asset.N, dtype=bool) + for id in keep: + mask[asset.skin[:, id] > 1e-5] = True + vertices_to_delete = ~mask + if np.all(vertices_to_delete): + return + if asset.faces is not None: + indices = np.where(~vertices_to_delete)[0] + face_mask = np.all(np.isin(asset.faces, indices), axis=1) + if np.all(~face_mask): + return + + joints_to_delete: List[int|str] = [i for i in range(asset.J) if i not in keep] + asset.delete_joints(joints_to_delete) + asset.delete_vertices(np.arange(asset.N)[vertices_to_delete]) + + def inverse(self, asset: Asset): + pass + +@dataclass(frozen=True) +class AugmentCollapse(Augment): + """randomly merge joints""" + + # collapse the skeleton with probability p + p: float + + # probability to merge the bone + rate: float + + # max bones + max_bones: int + + @classmethod + def parse(cls, **kwargs) -> 'AugmentCollapse': + cls.check_keys(kwargs) + return AugmentCollapse( + p=kwargs.get('p', 0.), + rate=kwargs.get('rate', 0.), + max_bones=kwargs.get('max_bones', 2147483647), + ) + + def transform(self, asset: Asset, **kwargs): + def select_k(arr: List, k: int): + if len(arr) <= k: + return arr + else: + rest_indices = list(range(1, len(arr))) + selected_indices = sorted(random.sample(rest_indices, k)) + return [arr[i] for i in selected_indices] + + root = asset.root + if np.random.rand() < self.p: + ids = [] + for id in range(asset.J): + if np.random.rand() >= self.rate: + ids.append(id) + if root not in ids: + ids.append(root) + keep: List[int|str] = select_k([i for i in range(asset.J) if i in ids], self.max_bones) + if root not in keep: + keep[0] = root + asset.set_order(new_orders=keep) + elif asset.J > self.max_bones: + ids = select_k([i for i in range(asset.J)], k=self.max_bones) + if root not in ids: + ids[0] = root + keep: List[int|str] = [i for i in range(asset.J) if i in ids] + asset.set_order(new_orders=keep) + +@dataclass(frozen=True) +class AugmentJointDiscrete(Augment): + # perturb the skeleton with probability p + p: float + + # num of discretized coord + discrete: int + + # continuous range + continuous_range: Tuple[float, float] + + @classmethod + def parse(cls, **kwargs) -> 'AugmentJointDiscrete': + cls.check_keys(kwargs) + return AugmentJointDiscrete( + p=kwargs.get('p', 0.), + discrete=kwargs.get('discrete', 256), + continuous_range=kwargs.get('continuous_range', [-1., 1.]), + ) + + def _discretize( + self, + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, + ) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = (t - lo) / (hi - lo) + t *= num_discrete + return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64) + + def _undiscretize( + self, + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, + ) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = t.astype(np.float32) + 0.5 + t /= num_discrete + return t * (hi - lo) + lo + + def transform(self, asset: Asset, **kwargs): + if np.random.rand() < self.p: + joints = asset.joints + if joints is not None and asset.matrix_local is not None: + joints = self._undiscretize(self._discretize( + joints, + self.continuous_range, + self.discrete, + ), + self.continuous_range, + self.discrete, + ) + asset.matrix_local[:, :3, 3] = joints + +@dataclass(frozen=True) +class AugmentJointPerturb(Augment): + # perturb the skeleton with probability p + p: float + + # jitter sigma on joints + sigma: float + + # jitter clip on joints + clip: float + + @classmethod + def parse(cls, **kwargs) -> 'AugmentJointPerturb': + cls.check_keys(kwargs) + return AugmentJointPerturb( + p=kwargs.get('p', 0.), + sigma=kwargs.get('sigma', 0.), + clip=kwargs.get('clip', 0.), + ) + + def transform(self, asset: Asset, **kwargs): + if np.random.rand() < self.p and asset.matrix_local is not None: + asset.matrix_local[:, :3] += np.clip( + np.random.normal(0, self.sigma, (asset.J, 3)), + -self.clip, + self.clip, + ) + +@dataclass(frozen=True) +class AugmentLBS(Augment): + # apply a random pose with probability p + random_pose_p: float + + # random pose angle range + random_pose_angle: float + + # random scale + random_scale_range: Tuple[float, float] + + @classmethod + def parse(cls, **kwargs) -> 'AugmentLBS': + cls.check_keys(kwargs) + return AugmentLBS( + random_pose_p=kwargs.get('random_pose_p', 0.), + random_pose_angle=kwargs.get('random_pose_angle', 0.), + random_scale_range=kwargs.get('random_scale_range', (1., 1.)), + ) + + def _apply(self, v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + + def transform(self, asset: Asset, **kwargs): + def get_matrix_basis(angle: float): + matrix = axis_angle_to_matrix((np.random.rand(asset.J, 3) - 0.5) * angle / 180 * np.pi * 2).astype(np.float32) + return matrix + + if np.random.rand() < self.random_pose_p and asset.joints is not None: + matrix_basis = get_matrix_basis(self.random_pose_angle) + max_offset = (asset.joints.max(axis=0) - asset.joints.min(axis=0)).max() + matrix_basis[:, :3, :3] *= np.tile(np.random.uniform(low=self.random_scale_range[0], high=self.random_scale_range[1], size=(asset.J, 1, 1)), (1, 3, 3)) + asset.vertices_with_pose(matrix_basis=matrix_basis, inplace=True) + +@dataclass(frozen=True) +class AugmentLinear(Augment): + # apply random rotation with probability p + random_rotate_p: float + + # random rotation angle(degree) + random_rotate_angle: float + + # swap x with probability p + random_flip_x_p: float + + # swap y with probability p + random_flip_y_p: float + + # swap z with probability p + random_flip_z_p: float + + # probability to pick an angle in static_rotate_x + static_rotate_x_p: float + + # rotate around x axis among given angles(degrees) + static_rotate_x: List[float] + + # probability to pick an angle in static_rotate_y + static_rotate_y_p: float + + # rotate around y axis among given angles(degrees) + static_rotate_y: List[float] + + # probability to pick an angle in static_rotate_z + static_rotate_z_p: float + + # rotate around z axis among given angles(degrees) + static_rotate_z: List[float] + + # apply random scaling with probability p + random_scale_p: float + + # random scaling xyz axis + random_scale: Tuple[float, float] + + # randomly change xyz orientation + random_transpose: float + + @classmethod + def parse(cls, **kwargs) -> 'AugmentLinear': + if kwargs.get('random_flip_x_p', 0) > 0 or kwargs.get('random_flip_y_p', 0) > 0 or kwargs.get('random_flip_z_p', 0) > 0: + print("\033[31mWARNING: random flip is enabled and is very likely to confuse ar model !\033[0m") + cls.check_keys(kwargs) + return AugmentLinear( + random_rotate_p=kwargs.get('random_rotate_p', 0.), + random_rotate_angle=kwargs.get('random_rotate_angle', 0.), + random_flip_x_p=kwargs.get('random_flip_x_p', 0.), + random_flip_y_p=kwargs.get('random_flip_y_p', 0.), + random_flip_z_p=kwargs.get('random_flip_z_p', 0.), + static_rotate_x_p=kwargs.get('static_rotate_x_p', 0.), + static_rotate_x=kwargs.get('static_rotate_x', []), + static_rotate_y_p=kwargs.get('static_rotate_y_p', 0.), + static_rotate_y=kwargs.get('static_rotate_y', []), + static_rotate_z_p=kwargs.get('static_rotate_z_p', 0.), + static_rotate_z=kwargs.get('static_rotate_z', []), + random_scale_p=kwargs.get('random_scale_p', 0.), + random_scale=kwargs.get('random_scale', [1.0, 1.0]), + random_transpose=kwargs.get('random_transpose', 0.), + ) + + def _apply(self, v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + + def transform(self, asset: Asset, **kwargs): + trans_vertex = np.eye(4, dtype=np.float32) + r = np.eye(4, dtype=np.float32) + if np.random.rand() < self.random_rotate_p: + angle = self.random_rotate_angle + axis_angle = (np.random.rand(3) - 0.5) * angle / 180 * np.pi * 2 + r = R.from_rotvec(axis_angle).as_matrix() + r = np.pad(r, ((0, 1), (0, 1)), 'constant', constant_values=0.) + r[3, 3] = 1. + + if np.random.uniform(0, 1) < self.random_flip_x_p: + r @= np.array([ + [-1.0, 0.0, 0.0, 0.0], + [ 0.0, 1.0, 0.0, 0.0], + [ 0.0, 0.0, 1.0, 0.0], + [ 0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.random_flip_y_p: + r @= np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.random_flip_z_p: + r @= np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.static_rotate_x_p: + assert len(self.static_rotate_x) > 0, "static rotation of x is enabled, but static_rotate_x is empty" + angle = np.random.choice(self.static_rotate_x) / 180 * np.pi + c = np.cos(angle) + s = np.sin(angle) + r @= np.array([ + [ 1.0, 0.0, 0.0, 0.0], + [ 0.0, c, s, 0.0], + [ 0.0, -s, c, 0.0], + [ 0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.static_rotate_y_p: + assert len(self.static_rotate_y) > 0, "static rotation of y is enabled, but static_rotate_y is empty" + angle = np.random.choice(self.static_rotate_y) / 180 * np.pi + c = np.cos(angle) + s = np.sin(angle) + r @= np.array([ + [ c, 0.0, -s, 0.0], + [ 0.0, 1.0, 0.0, 0.0], + [ s, 0.0, c, 0.0], + [ 0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.static_rotate_z_p: + assert len(self.static_rotate_z) > 0, "static rotation of z is enabled, but static_rotate_z is empty" + angle = np.random.choice(self.static_rotate_z) / 180 * np.pi + c = np.cos(angle) + s = np.sin(angle) + r @= np.array([ + [ c, s, 0.0, 0.0], + [ -s, c, 0.0, 0.0], + [ 0.0, 0.0, 1.0, 0.0], + [ 0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.random_scale_p: + scale_x = np.random.uniform(self.random_scale[0], self.random_scale[1]) + scale_y = np.random.uniform(self.random_scale[0], self.random_scale[1]) + scale_z = np.random.uniform(self.random_scale[0], self.random_scale[1]) + r @= np.array([ + [scale_x, 0.0, 0.0, 0.0], + [0.0, scale_y, 0.0, 0.0], + [0.0, 0.0, scale_z, 0.0], + [0.0, 0.0, 0.0, 1.0], + ]) + + if np.random.uniform(0, 1) < self.random_transpose: + permutations = [ + (0, 1, 2), # x, y, z + (0, 2, 1), # x, z, y + (1, 0, 2), # y, x, z + (1, 2, 0), # y, z, x + (2, 0, 1), # z, x, y + (2, 1, 0), # z, y, x + ] + direction_signs = [ + (1, 1, 1), + (1, 1, -1), + (1, -1, 1), + (1, -1, -1), + (-1, 1, 1), + (-1, 1, -1), + (-1, -1, 1), + (-1, -1, -1), + ] + perm = permutations[np.random.randint(0, 6)] + sign = direction_signs[np.random.randint(0, 8)] + m = np.zeros((4, 4)) + for i in range(3): + m[i, perm[i]] = sign[i] + m[3, 3] = 1.0 + r = m @ r + + trans_vertex = r @ trans_vertex + + # apply transform here + asset.transform(trans=trans_vertex) + +@dataclass(frozen=True) +class AugmentAffine(Augment): + # final normalization cube + normalize_into: Tuple[float, float] + + # randomly scale coordinates with probability p + random_scale_p: float + + # scale range (lower, upper) + random_scale: Tuple[float, float] + + # randomly shift coordinates with probability p + random_shift_p: float + + # shift range (lower, upper) + random_shift: Tuple[float, float] + + @classmethod + def parse(cls, **kwargs) -> 'AugmentAffine': + cls.check_keys(kwargs) + return AugmentAffine( + normalize_into=kwargs.get('normalize_into', [-1.0, 1.0]), + random_scale_p=kwargs.get('random_scale_p', 0.), + random_scale=kwargs.get('random_scale', [1., 1.]), + random_shift_p=kwargs.get('random_shift_p', 0.), + random_shift=kwargs.get('random_shift', [0., 0.]), + ) + + def transform(self, asset: Asset, **kwargs): + if asset.vertices is None: + raise ValueError("do not have vertices") + bound_min = asset.vertices.min(axis=0) + bound_max = asset.vertices.max(axis=0) + if asset.joints is not None: + joints_bound_min = asset.joints.min(axis=0) + joints_bound_max = asset.joints.max(axis=0) + bound_min = np.minimum(bound_min, joints_bound_min) + bound_max = np.maximum(bound_max, joints_bound_max) + + trans_vertex = np.eye(4, dtype=np.float32) + + trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex + + if self.normalize_into is not None: + # scale into the cube + normalize_into = self.normalize_into + scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0])) + trans_vertex = _scale_to_m(1. / scale) @ trans_vertex + + bias = (normalize_into[0] + normalize_into[1]) / 2 + trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex + + if np.random.rand() < self.random_scale_p: + scale = _scale_to_m(np.random.uniform(self.random_scale[0], self.random_scale[1])) + trans_vertex = scale @ trans_vertex + + if np.random.rand() < self.random_shift_p: + l, r = self.random_shift + shift_vals = np.array([ + np.random.uniform(l, r), + np.random.uniform(l, r), + np.random.uniform(l, r), + ], dtype=np.float32) + if self.normalize_into is not None: + def _apply(v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + lo, hi = self.normalize_into + pts_min = _apply(bound_min[None, :], trans_vertex)[0] + pts_max = _apply(bound_max[None, :], trans_vertex)[0] + low_allowed = lo - pts_min + high_allowed = hi - pts_max + shift_vals = np.array([ + np.random.uniform(low_allowed[0], high_allowed[0]), + np.random.uniform(low_allowed[1], high_allowed[1]), + np.random.uniform(low_allowed[2], high_allowed[2]), + ], dtype=np.float32) + shift = _trans_to_m(shift_vals.astype(np.float32)) + trans_vertex = shift @ trans_vertex + asset.transform(trans=trans_vertex) + +@dataclass(frozen=True) +class AugmentJitter(Augment): + # probability + p: float + + # jitter sigma on vertices + vertex_sigma: float + + # jitter clip on vertices + vertex_clip: float + + # jitter sigma on normals + normal_sigma: float + + # jitter clip on normals + normal_clip: float + + @classmethod + def parse(cls, **kwargs) -> 'AugmentJitter': + cls.check_keys(kwargs) + return AugmentJitter( + p=kwargs.get('p', 0.5), + vertex_sigma=kwargs.get('vertex_sigma', 0.), + vertex_clip=kwargs.get('vertex_clip', 0.), + normal_sigma=kwargs.get('normal_sigma', 0.), + normal_clip=kwargs.get('normal_clip', 0.), + ) + + def transform(self, asset: Asset, **kwargs): + vertex_sigma = self.vertex_sigma + vertex_clip = self.vertex_clip + normal_sigma = self.normal_sigma + normal_clip = self.normal_clip + + if np.random.rand() < self.p: + scale = np.random.rand() + 1e-6 + vertex_sigma *= scale + vertex_clip *= scale + scale = np.random.rand() + 1e-6 + normal_sigma *= scale + normal_clip *= scale + if vertex_sigma > 0 and asset.vertices is not None: + noise = np.clip(np.random.randn(*asset.vertices.shape) * vertex_sigma, -vertex_clip, vertex_clip).astype(np.float32) + asset.vertices += noise + + if normal_sigma > 0: + if asset.vertex_normals is not None: + noise = np.clip(np.random.randn(*asset.vertex_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32) + asset.vertex_normals += noise + + if asset.face_normals is not None: + noise = np.clip(np.random.randn(*asset.face_normals.shape) * normal_sigma, -normal_clip, normal_clip).astype(np.float32) + asset.face_normals += noise + +@dataclass(frozen=True) +class AugmentNormalize(Augment): + + @classmethod + def parse(cls, **kwargs) -> 'AugmentNormalize': + cls.check_keys(kwargs) + return AugmentNormalize() + + def transform(self, asset: Asset, **kwargs): + epsilon = 1e-10 + if asset.vertex_normals is not None: + vertex_norms = np.linalg.norm(asset.vertex_normals, axis=1, keepdims=True) + vertex_norms = np.maximum(vertex_norms, epsilon) + asset.vertex_normals = asset.vertex_normals / vertex_norms + asset.vertex_normals = np.nan_to_num(asset.vertex_normals, nan=0., posinf=0., neginf=0.) # type: ignore + + if asset.face_normals is not None: + face_norms = np.linalg.norm(asset.face_normals, axis=1, keepdims=True) + face_norms = np.maximum(face_norms, epsilon) + asset.face_normals = asset.face_normals / face_norms + asset.face_normals = np.nan_to_num(asset.face_normals, nan=0., posinf=0., neginf=0.) # type: ignore + +def _trans_to_m(v: ndarray): + m = np.eye(4, dtype=np.float32) + m[0:3, 3] = v + return m + +def _scale_to_m(r: ndarray|float): + m = np.zeros((4, 4), dtype=np.float32) + m[0, 0] = r + m[1, 1] = r + m[2, 2] = r + m[3, 3] = 1. + return m + +def get_augments(*args) -> List[Augment]: + MAP = { + 'trim': AugmentTrim, + 'delete': AugmentDelete, + 'drop_part': AugmentDropPart, + 'collapse': AugmentCollapse, + 'lbs': AugmentLBS, + 'linear': AugmentLinear, + 'affine': AugmentAffine, + 'jitter': AugmentJitter, + 'joint_perturb': AugmentJointPerturb, + 'joint_discrete': AugmentJointDiscrete, + 'normalize': AugmentNormalize, + } + MAP: Dict[str, type[Augment]] + augments = [] + for (i, config) in enumerate(args): + __target__ = config.get('__target__') + assert __target__ is not None, f"do not find `__target__` in augment of position {i}" + c = deepcopy(config) + del c['__target__'] + augments.append(MAP[__target__].parse(**c)) + return augments \ No newline at end of file diff --git a/src/data/datapath.py b/src/data/datapath.py new file mode 100755 index 0000000000000000000000000000000000000000..268123aa11032cdf0ca25ebbc7bf3d933e5c4199 --- /dev/null +++ b/src/data/datapath.py @@ -0,0 +1,344 @@ +from abc import abstractmethod, ABC +from collections import defaultdict +from dataclasses import dataclass, field +from numpy import ndarray +from random import shuffle +from typing import Dict, List, Optional + +import numpy as np +import requests +import os + +from ..rig_package.info.asset import Asset +from ..server.spec import BPY_SERVER, bytes_to_object, object_to_bytes +from .spec import ConfigSpec + +@dataclass +class LazyAsset(ABC): + """store datapath and load upon requiring""" + path: str + + cls: Optional[str]=None + + @abstractmethod + def load(self) -> 'Asset': + raise NotImplementedError() + +@dataclass +class BpyLazyAsset(LazyAsset): + + def load(self) -> 'Asset': + from ..rig_package.parser.bpy import BpyParser + asset = BpyParser.load(filepath=self.path) + asset.cls = self.cls + asset.path = self.path + return asset + +@dataclass +class BpyServerLazyAsset(LazyAsset): + """workaround while bpy is working in multiple threads""" + def load(self) -> 'Asset': + try: + asset = bytes_to_object(requests.get(f"{BPY_SERVER}/load", data=object_to_bytes(self.path)).content) + if isinstance(asset, str): + raise RuntimeError(f"bpy server failed: {asset}") + assert isinstance(asset, Asset) + asset.cls = self.cls + asset.path = self.path + return asset + except Exception as e: + raise RuntimeError(f"bpy server failed: {str(e)}") + +@dataclass +class NpzLazyAsset(LazyAsset): + + def load(self) -> 'Asset': + d = np.load(self.path, allow_pickle=True) + asset = Asset( + vertices=d['vertices'], + faces=d['faces'], + mesh_names=d.get('mesh_names', None), + joint_names=d.get('joint_names', None), + parents=d.get('parents', None), + lengths=d.get('lengths', None), + matrix_world=d.get('matrix_world', None), + matrix_local=d.get('matrix_local', None), + armature_name=d.get('armature_name', None), + skin=d.get('skin', None), + cls=self.cls, + path=self.path + ) + asset.cls = self.cls + asset.path = self.path + return asset + +@dataclass +class UniRigLazyAsset(LazyAsset): + """map unirig's data correctly""" + + def load(self) -> 'Asset': + def bn(x): + if isinstance(x, ndarray) and x.ndim==0: + return x.item() + return x + + d = np.load(self.path, allow_pickle=True) + parents = bn(d.get('parents', None)) + if parents is not None: + parents = [-1 if x is None else x for x in parents] + parents = np.array(parents) + matrix_local = bn(d.get('matrix_local', None)) + joints = bn(d.get('joints', None)) + if matrix_local is not None and matrix_local.ndim != 3 and joints is not None: + matrix_local = np.zeros((joints.shape[0], 4, 4)) + matrix_local[...] = np.eye(4) + matrix_local[:, :3, 3] = joints + asset = Asset( + vertices=d['vertices'], + faces=d['faces'], + joint_names=bn(d.get('names', None)), + parents=parents, # type: ignore + lengths=bn(d.get('lengths', None)), + matrix_world=bn(d.get('matrix_world', None)), + matrix_local=matrix_local, + armature_name=bn(d.get('armature_name', None)), + skin=bn(d.get('skin', None)), + cls=self.cls, + path=self.path + ).change_dtype(float_dtype=np.float32, int_dtype=np.int32) + asset.cls = self.cls + asset.path = self.path + return asset + +@dataclass +class Datapath(ConfigSpec): + """handle input data paths""" + + # all filepaths + filepaths: List[str] + + # root to add to prefix + input_dataset_dir: str='' + + # name of class + cls_name: Optional[List[str]]=None + + # bias in a single class + cls_bias: Optional[List[int]]=None + + # num of files in a single class + cls_length: Optional[List[int]]=None + + # how many files to return when using data sampling + num_files: Optional[int]=None + + # use proportion data sampling + use_prob: bool=False + + # weight + cls_weight: Optional[List[float]]=None + + # use bpy loader + loader: type[LazyAsset]=BpyLazyAsset + + # data name + data_name: Optional[str]=None + + # check if path exists + ignore_check: bool=False + + ################################################################# + # other vertex groups + vertex_groups: Dict[str, ndarray]=field(default_factory=dict) + + # sampled vertices + sampled_vertices: Optional[ndarray]=None + + # sampled normals + sampled_normals: Optional[ndarray]=None + + # sampled vertex groups + sampled_vertex_groups: Optional[Dict[str, ndarray]]=None + + @classmethod + def parse(cls, **kwargs) -> 'Datapath': + MAP = { + None: BpyLazyAsset, + 'bpy': BpyLazyAsset, + 'bpy_server': BpyServerLazyAsset, + 'npz': NpzLazyAsset, + 'unirig': UniRigLazyAsset, + } + input_dataset_dir = kwargs.get('input_dataset_dir', '') + num_files = kwargs.get('num_files', None) + use_prob = kwargs.get('use_prob', False) + data_name = kwargs.get('data_name', 'raw_data.npz') + data_path = kwargs.get('data_path', None) + loader_cls = MAP[kwargs.get('loader', None)] + ignore_check = kwargs.get('ignore_check', False) + + if data_path is not None: + filepaths = [] + if isinstance(data_path, dict): + cls_name = [] + cls_bias = [] + cls_length = [] + cls_weight = [] + for name, v in data_path.items(): + assert isinstance(v, list), "items in the dict must be a list of data list paths" + for item in v: + if isinstance(item, str): + datalist_path = item + weight = 1.0 + else: + datalist_path = item[0] + weight = item[1] + cls_name.append(name) + lines = [x.strip() for x in open(datalist_path, "r").readlines()] + ok_lines = [] + missing = 0 + for line in lines: + if ignore_check: + ok_lines.append(line) + elif os.path.exists(os.path.join(input_dataset_dir, line, data_name)): + ok_lines.append(line) + else: + missing += 1 + if missing != 0: + print(f"\033[31m{datalist_path}: {missing} missing files\033[0m") + cls_bias.append(len(filepaths)) + cls_length.append(len(ok_lines)) + cls_weight.append(weight) + filepaths.extend(ok_lines) + else: + raise NotImplementedError() + else: + _filepaths = kwargs['filepaths'] + if isinstance(_filepaths, list): + filepaths = _filepaths + cls_name = None + cls_bias = None + cls_length = None + cls_weight = None + elif isinstance(_filepaths, dict): + filepaths = [] + cls_name = [] + cls_bias = [] + cls_length = [] + cls_weight = [] + for k, v in _filepaths.items(): + assert isinstance(v, list), "items in the dict must be a list of paths" + cls_name.append(k) + cls_bias.append(len(filepaths)) + cls_length.append(len(v)) + cls_weight.append(1.0) + filepaths.extend(v) + else: + raise NotImplementedError() + if cls_weight is not None: + total = sum(cls_weight) + cls_weight = [x/total for x in cls_weight] + return Datapath( + filepaths=filepaths, + input_dataset_dir=input_dataset_dir, + cls_name=cls_name, + cls_bias=cls_bias, + cls_length=cls_length, + num_files=num_files, + use_prob=use_prob, + cls_weight=cls_weight, + loader=loader_cls, + data_name=data_name, + ignore_check=ignore_check, + ) + + def make(self, path: str, cls: str|None) -> LazyAsset: + return self.loader(path=path, cls=cls) + + def __getitem__(self, index: int) -> LazyAsset: + if self.use_prob and self.cls_weight is not None: + if self.cls_bias is None: + raise ValueError("do not have cls_bias") + if self.cls_length is None: + raise ValueError("do not have cls_length") + if not hasattr(self, "perms"): + self.perms = [] + self.current_bias = [] + for i in range(len(self.cls_weight)): + self.perms.append([x for x in range(self.cls_length[i])]) + self.current_bias.append(0) + idx = np.random.choice(len(self.cls_weight), p=self.cls_weight) + i = self.perms[idx][self.current_bias[idx]] + self.current_bias[idx] += 1 + if self.current_bias[idx] >= self.cls_length[idx]: + shuffle(self.perms[idx]) + self.current_bias[idx] = 0 + if self.cls_name is None: + name = None + else: + name = self.cls_name[idx] + path = os.path.join(self.input_dataset_dir, self.filepaths[i+self.cls_bias[idx]]) + if self.data_name is not None: + path = os.path.join(path, self.data_name) + return self.make(path=path, cls=name) + else: + if self.cls_name is None or self.cls_bias is None or self.cls_length is None: + name = None + else: + name = None + for i in range(len(self.cls_bias)): + start = self.cls_bias[i] + end = start + self.cls_length[i] + if start <= index < end: + name = self.cls_name[i] + break + path = os.path.join(self.input_dataset_dir, self.filepaths[index]) + if self.data_name is not None: + path = os.path.join(path, self.data_name) + return self.make(path=path, cls=name) + + def get_data(self) -> List[LazyAsset]: + return [self[i] for i in range(len(self))] + + def split_by_cls(self) -> Dict[str|None, 'Datapath']: + res: Dict[str|None, Datapath] = {} + if self.cls_name is None: + res[None] = self + return res + if self.cls_bias is None: + raise ValueError("do not have cls_bias") + if self.cls_length is None: + raise ValueError("do not have cls_length") + d_filepaths = defaultdict(list) + d_length = defaultdict(int) + d_weight = defaultdict(list) + for (i, cls) in enumerate(self.cls_name): + s = slice(self.cls_bias[i], self.cls_bias[i]+self.cls_length[i]) + d_filepaths[cls].extend(self.filepaths[s].copy()) + d_length[cls] += self.cls_length[i] + if self.cls_weight is not None: + d_weight[cls].append(self.cls_weight[i]) + for cls in d_filepaths: + cls_weight = None if self.cls_weight is None else d_weight[cls] + if cls_weight is not None: + total = sum(cls_weight) + cls_weight = [x/total for x in cls_weight] + res[cls] = Datapath( + filepaths=d_filepaths[cls], + input_dataset_dir=self.input_dataset_dir, + cls_name=[cls], + cls_bias=[0], + cls_length=[len(d_filepaths[cls])], + num_files=self.num_files, + use_prob=self.use_prob, + cls_weight=cls_weight, + loader=self.loader, + data_name=self.data_name, + ) + return res + + def __len__(self): + if self.use_prob: + assert self.num_files is not None, 'num_files is not specified' + return self.num_files + return len(self.filepaths) \ No newline at end of file diff --git a/src/data/dataset.py b/src/data/dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..6d8de18b42738811cf678a601136d20f8599ab1c --- /dev/null +++ b/src/data/dataset.py @@ -0,0 +1,319 @@ +from copy import deepcopy +from dataclasses import dataclass +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from numpy import ndarray +from torch import Tensor +from torch.utils import data +from torch.utils.data import DataLoader, Dataset +from typing import Dict, List, Tuple, Callable, Optional + +import os +import lightning.pytorch as pl +import numpy as np +import torch + +from .datapath import Datapath, LazyAsset +from .spec import ConfigSpec +from .transform import Transform + +from ..model.spec import ModelInput +from ..rig_package.info.asset import Asset +from ..tokenizer.spec import Tokenizer, TokenizeInput + +@dataclass +class DatasetConfig(ConfigSpec): + shuffle: bool + batch_size: int + num_workers: int + datapath: Datapath + pin_memory: bool=True + persistent_workers: bool=True + + @classmethod + def parse(cls, **kwargs) -> 'DatasetConfig': + cls.check_keys(kwargs) + return DatasetConfig( + shuffle=kwargs.get('shuffle', False), + batch_size=kwargs.get('batch_size', 1), + num_workers=kwargs.get('num_workers', 1), + pin_memory=kwargs.get('pin_memory', True), + persistent_workers=kwargs.get('persistent_workers', True), + datapath=Datapath.parse(**kwargs.get('datapath')), + ) + + def split_by_cls(self) -> Dict[str|None, 'DatasetConfig']: + res: Dict[str|None, DatasetConfig] = {} + datapath_dict = self.datapath.split_by_cls() + for cls, v in datapath_dict.items(): + res[cls] = DatasetConfig( + shuffle=self.shuffle, + batch_size=self.batch_size, + num_workers=self.num_workers, + datapath=v, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + return res + +class RigDatasetModule(pl.LightningDataModule): + def __init__( + self, + process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None, + train_dataset_config: Optional[DatasetConfig]=None, + validate_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None, + predict_dataset_config: Optional[Dict[str|None, DatasetConfig]]=None, + train_transform: Optional[Transform]=None, + validate_transform: Optional[Transform]=None, + predict_transform: Optional[Transform]=None, + tokenizer: Optional[Tokenizer]=None, + debug: bool=False, + ): + super().__init__() + self.process_fn = process_fn + self.train_dataset_config = train_dataset_config + self.validate_dataset_config = validate_dataset_config + self.predict_dataset_config = predict_dataset_config + self.train_transform = train_transform + self.validate_transform = validate_transform + self.predict_transform = predict_transform + self.tokenizer = tokenizer + self.debug = debug + + if debug: + print("\033[31mWARNING: debug mode, dataloader will be extremely slow !!!\033[0m") + + # build train datapath + if self.train_dataset_config is not None: + self.train_datapath = self.train_dataset_config.datapath + else: + self.train_datapath = None + + # build validate datapath + if self.validate_dataset_config is not None: + self.validate_datapath = { + cls: self.validate_dataset_config[cls].datapath + for cls in self.validate_dataset_config + } + else: + self.validate_datapath = None + + # build predict datapath + if self.predict_dataset_config is not None: + self.predict_datapath = { + cls: self.predict_dataset_config[cls].datapath + for cls in self.predict_dataset_config + } + else: + self.predict_datapath = None + + self.tokenizer = tokenizer + + def prepare_data(self): + pass + + def train_dataloader(self) -> TRAIN_DATALOADERS: + if self.train_dataset_config is None: + raise ValueError("do not have train_dataset_config") + if self.train_transform is None: + raise ValueError("do not have train_transform") + if self.train_datapath is not None: + self._train_ds = RigDataset( + process_fn=self.process_fn, + data=self.train_datapath.get_data(), + name="train", + tokenizer=self.tokenizer, + transform=self.train_transform, + debug=self.debug, + ) + else: + return None + return self._create_dataloader( + dataset=self._train_ds, + config=self.train_dataset_config, + is_train=True, + drop_last=False, + ) + + def val_dataloader(self) -> EVAL_DATALOADERS: + if self.validate_dataset_config is None: + raise ValueError("do not have validate_dataset_config") + if self.validate_transform is None: + raise ValueError("do not have validate_transform") + if self.validate_datapath is not None: + self._validation_ds = {} + for cls in self.validate_datapath: + self._validation_ds[cls] = RigDataset( + process_fn=self.process_fn, + data=self.validate_datapath[cls].get_data(), + name=f"validate-{cls}", + tokenizer=self.tokenizer, + transform=self.validate_transform, + debug=self.debug, + ) + else: + return None + return self._create_dataloader( + dataset=self._validation_ds, + config=self.validate_dataset_config, + is_train=False, + drop_last=False, + ) + + def predict_dataloader(self): + if self.predict_dataset_config is None: + raise ValueError("do not have predict_dataset_config") + if self.predict_transform is None: + raise ValueError("do not have predict_transform") + if self.predict_datapath is not None: + self._predict_ds = {} + for cls in self.predict_datapath: + self._predict_ds[cls] = RigDataset( + process_fn=self.process_fn, + data=self.predict_datapath[cls].get_data(), + name=f"predict-{cls}", + tokenizer=self.tokenizer, + transform=self.predict_transform, + debug=self.debug, + ) + else: + return None + return self._create_dataloader( + dataset=self._predict_ds, + config=self.predict_dataset_config, + is_train=False, + drop_last=False, + ) + + def _create_dataloader( + self, + dataset: Dataset|Dict[str, Dataset], + config: DatasetConfig|Dict[str|None, DatasetConfig], + is_train: bool, + **kwargs, + ) -> DataLoader|Dict[str, DataLoader]: + def create_single_dataloader(dataset, config: DatasetConfig, **kwargs): + return DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=config.shuffle, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + persistent_workers=config.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + if isinstance(dataset, Dict): + assert isinstance(config, dict) + return {k: create_single_dataloader(v, config[k], **kwargs) for k, v in dataset.items()} + else: + assert isinstance(config, DatasetConfig) + return create_single_dataloader(dataset, config, **kwargs) + +class RigDataset(Dataset): + def __init__( + self, + data: List[LazyAsset], + transform: Transform, + name: Optional[str]=None, + process_fn: Optional[Callable[[List[ModelInput]], List[Dict]]]=None, + tokenizer: Optional[Tokenizer]=None, + debug: bool=False, + ) -> None: + super().__init__() + + self.data = data + self.name = name + self.process_fn = process_fn + self.tokenizer = tokenizer + self.transform = transform + self.debug = debug + + if not debug: + assert self.process_fn is not None, 'missing data processing function' + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx) -> ModelInput: + lazy_asset = self.data[idx] + asset = lazy_asset.load() + self.transform.apply(asset=asset) + if self.tokenizer is not None and asset.parents is not None: + x = TokenizeInput( + joints=asset.joints, + parents=asset.parents, + cls=asset.cls, + joint_names=asset.joint_names, + ) + tokens = self.tokenizer.tokenize(input=x) + else: + tokens = None + return ModelInput(asset=asset, tokens=tokens) + + def _collate_fn_debug(self, batch): + return batch + + def _collate_fn(self, batch): + processed_batch = self.process_fn(batch) # type: ignore + processed_batch: List[Dict] + + tensors_stack = {} + tensors_cat = {} + non_tensors = {} + vis = {} + def check(x): + assert x not in vis, f"multiple keys found: {x}" + vis[x] = True + + for k, v in processed_batch[0].items(): + if k == "cat": + assert isinstance(v, dict) + for k1 in v.keys(): + check(k1) + tensors_cat[k1] = [] + for i in range(len(processed_batch)): + v1 = processed_batch[i]['cat'][k1] + if isinstance(v1, ndarray): + v1 = torch.from_numpy(v1) + elif isinstance(v1, Tensor): + v1 = v1 + else: + raise ValueError(f"cannot concatenate non-tensor type of key {k1}, type: {type(v1)}") + tensors_cat[k1].append(v1) + elif k == "non": + assert isinstance(v, dict) + for k1 in v.keys(): + check(k1) + non_tensors[k1] = [] + for i in range(len(processed_batch)): + v1 = processed_batch[i]['non'][k1] + if isinstance(v1, ndarray): + v1 = torch.from_numpy(v1) + non_tensors[k1].append(v1) + else: + check(k) + tensors_stack[k] = [] + for i in range(len(processed_batch)): + v1 = processed_batch[i][k] + if isinstance(v1, ndarray): + v1 = torch.from_numpy(v1) + elif isinstance(v1, Tensor): + v1 = v1 + else: + raise ValueError(f"cannot stack type of key {k}, type: {type(v1)}") + tensors_stack[k].append(v1) + + collated_stack = {k: torch.stack(v) for k, v in tensors_stack.items()} + collated_cat = {k: torch.concat(v, dim=1) for k, v in tensors_cat.items()} + + collated_batch = { + **collated_stack, + **collated_cat, + **non_tensors, + } + return collated_batch + + def collate_fn(self, batch): + if self.debug: + return self._collate_fn_debug(batch) + return self._collate_fn(batch) \ No newline at end of file diff --git a/src/data/order.py b/src/data/order.py new file mode 100755 index 0000000000000000000000000000000000000000..2bde3911fe4fd758f3c3fc57c544c1ab9e1e49b3 --- /dev/null +++ b/src/data/order.py @@ -0,0 +1,132 @@ + +from collections import defaultdict +from dataclasses import dataclass +from numpy import ndarray +from omegaconf import OmegaConf +from typing import Dict, List, Tuple, Optional + +from .spec import ConfigSpec + +@dataclass +class Order(ConfigSpec): + + # {part_name: [bone_name_1, bone_name_2, ...]} + parts: Dict[str, Dict[str, List[str]]] + + # parts of bones to be arranged in [part_name_1, part_name_2, ...] + parts_order: Dict[str, List[str]] + + # {skeleton_name: path} + skeleton_path: Optional[Dict[str, str]]=None + + sort_by_xyz: bool=False + + @classmethod + def parse(cls, **kwargs) -> 'Order': + cls.check_keys(kwargs) + skeleton_path = kwargs.get('skeleton_path', None) + if skeleton_path is not None: + parts = {} + parts_order = {} + for (cls, path) in skeleton_path.items(): + assert cls not in parts, 'cls conflicts' + d = OmegaConf.load(path) + parts[cls] = d.parts + parts_order[cls] = d.parts_order + else: + parts = kwargs.get('parts') + parts_order = kwargs.get('parts_order') + assert parts is not None + assert parts_order is not None + return Order( + skeleton_path=skeleton_path, + parts=parts, + parts_order=parts_order, + sort_by_xyz=kwargs.get('sort_by_xyz', False), + ) + + def part_exists(self, cls: str, part: str, names: List[str]) -> bool: + ''' + Check if part exists. + ''' + if part not in self.parts[cls]: + return False + for name in self.parts[cls][part]: + if name not in names: + return False + return True + + def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]: + ''' + Get names for specified cls. + ''' + names = [] + for part in parts: + if part is None: # spring + continue + if cls in self.parts and part in self.parts[cls]: + names.extend(self.parts[cls][part]) + assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones" + for i in range(len(names), num_bones): + names.append(f"bone_{i}") + return names + + def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]: + ''' + Arrange names according to required parts order. + ''' + def sort_by_xyz(joints): + return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1])) + + if self.sort_by_xyz: + assert joints is not None + new_names = [] + root = -1 + son = defaultdict(list) + not_root = {} + for (i, p) in enumerate(parents): + if p != -1: + son[p].append(i) + not_root[i] = True + for i in range(len(parents)): + if not_root.get(i, False) == False: + root = i + break + Q = [root] + while Q: + u = Q.pop(0) + new_names.append(names[u]) + wait = [] + for v in son[u]: + wait.append((v, joints[v])) + wait_sorted = sort_by_xyz(wait) + new_wait = [v for v, _ in wait_sorted] + Q = new_wait + Q + return new_names, {} + if cls not in self.parts_order: + return names, {0: None} # add a spring token + vis = defaultdict(bool) + name_to_id = {name: i for (i, name) in enumerate(names)} + new_names = [] + parts_bias = {} + for part in self.parts_order[cls]: + if self.part_exists(cls=cls, part=part, names=names): + for name in self.parts[cls][part]: + vis[name] = True + flag = False + for name in self.parts[cls][part]: + pid = parents[name_to_id[name]] + if pid==-1: + continue + if not vis[names[pid]]: + flag = True + break + if flag: # incorrect parts order and should immediately add a spring token + break + parts_bias[len(new_names)] = part + new_names.extend(self.parts[cls][part]) + parts_bias[len(new_names)] = None # add a spring token + for name in names: + if name not in new_names: + new_names.append(name) + return new_names, parts_bias \ No newline at end of file diff --git a/src/data/sampler.py b/src/data/sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..ef7bc0d8fccd2bd4e0b2bd13887c1c96b2caa6dc --- /dev/null +++ b/src/data/sampler.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass +from abc import ABC, abstractmethod +from numpy import ndarray +from scipy.spatial import cKDTree # type: ignore +from typing import Dict, Optional + +import numpy as np +import random + +from ..rig_package.info.asset import Asset +from ..rig_package.utils import sample_vertex_groups +from .spec import ConfigSpec + +@dataclass +class SamplerResult(): + sampled_vertices: Optional[ndarray]=None + sampled_normals: Optional[ndarray]=None + sampled_vertex_groups: Optional[Dict[str, ndarray]]=None + + # number of sampled skin + skin_samples: Optional[int]=None + +class Sampler(ABC): + @abstractmethod + def sample( + self, + asset: Asset, + ) -> SamplerResult: + ''' + Return sampled vertices, sampled normals and vertex groups. + ''' + pass + + @classmethod + @abstractmethod + def parse(cls, **kwargs) -> 'Sampler': + pass + +@dataclass +class SamplerMix(Sampler, ConfigSpec): + num_samples: int + num_vertex_samples: int + num_skin_samples: Optional[int]=None + replace: bool=True + all_skeleton: Optional[bool]=None + max_distance: float=0.1 + rate_distance: float=0.1 + + @classmethod + def parse(cls, **kwargs) -> 'SamplerMix': + cls.check_keys(kwargs) + return SamplerMix( + num_samples=kwargs.get('num_samples', 0), + num_vertex_samples=kwargs.get('num_vertex_samples', 0), + num_skin_samples=kwargs.get('num_skin_samples', None), + replace=kwargs.get('replace', True), + all_skeleton=kwargs.get('all_skeleton', None), + max_distance=kwargs.get('max_distance', 0.1), + rate_distance=kwargs.get('rate_distance', 0.1), + ) + + def sample_on_skin( + self, + skin: ndarray, + vertices: ndarray, + faces: ndarray, + ): + face_has_skin = np.any(skin[faces] > 0, axis=-1) + if face_has_skin.sum() == 0: + face_has_skin = np.ones_like(face_has_skin) + elif self.max_distance < 1e-5: + return face_has_skin + else: + # sample near points + p = np.unique(faces[face_has_skin].reshape(-1)) + tree = cKDTree(vertices[p]) + dis, _ = tree.query(vertices, k=1) + dis_skin = np.sqrt(((np.max(vertices[p], axis=0) - np.min(vertices[p], axis=0))**2).sum()) + mask_face_near = np.any(dis[faces] < min(self.max_distance, dis_skin * self.rate_distance), axis=-1) + face_has_skin |= mask_face_near + return face_has_skin + + def sample( + self, + asset: Asset, + ) -> SamplerResult: + if asset.vertices is None: + raise ValueError("do not have vertices") + if asset.faces is None: + raise ValueError("do not have faces") + vertex_groups = [] + mapping = {} + tot = 0 + for k, v in asset.vertex_groups.items(): + if v.ndim == 1: + v = v[:, None] + elif v.ndim != 2: + raise ValueError(f"ndim of key {k} is {v.ndim}") + s = tot + e = tot + v.shape[1] + mapping[k] = slice(s,e) + vertex_groups.append(v) + if len(vertex_groups) > 0: + vertex_groups = np.concatenate(vertex_groups, axis=1) + else: + vertex_groups = None + final_sampled_vertices, final_sampled_normals, sampled_vertex_groups = sample_vertex_groups( + vertices=asset.vertices, + faces=asset.faces, + num_samples=self.num_samples, + vertex_normals=asset.vertex_normals, + face_normals=asset.face_normals, + vertex_groups=vertex_groups, + face_mask=None, + shuffle=True, + same=True, + ) + if vertex_groups is not None: + final_sampled_vertices = final_sampled_vertices[:, 0] + if final_sampled_normals is not None: + final_sampled_normals = final_sampled_normals[:, 0] + final_sampled_vertex_groups = {} + if sampled_vertex_groups is not None: + for k, s in mapping.items(): + final_sampled_vertex_groups[k] = sampled_vertex_groups[:, s] # (N, k) + if vertex_groups is not None and self.num_skin_samples is not None: + dense_vertices = [] + dense_normals = [] + dense_skin = [] + if 'skin' not in mapping: + raise ValueError("do not have skin") + if self.all_skeleton: + dense_indices = [i for i in range(asset.J)] + else: + dense_indices = [random.randint(0, asset.J-1)] + for indice in dense_indices: + _s = asset.vertex_groups['skin'][:, indice] + face_has_skin = self.sample_on_skin( + skin=_s, + vertices=asset.vertices, + faces=asset.faces, + ) + sampled_vertices, sampled_normals, sampled_skin = sample_vertex_groups( + vertices=asset.vertices, + faces=asset.faces, + vertex_normals=asset.vertex_normals, + face_normals=asset.face_normals, + vertex_groups=_s, + num_samples=self.num_skin_samples, + num_vertex_samples=self.num_vertex_samples, + face_mask=face_has_skin, + shuffle=True, + same=True, + ) + assert sampled_skin is not None + assert sampled_skin.ndim == 2 + dense_vertices.append(sampled_vertices[:, 0]) + if sampled_normals is not None: + dense_normals.append(sampled_normals[:, 0]) + dense_skin.append(sampled_skin[:, 0]) + dense_vertices = np.stack(dense_vertices, axis=0) # (J, m, 3) + if len(dense_normals) > 0: + dense_normals = np.stack(dense_normals, axis=0) # (J, m, 3) + else: + dense_normals = None + dense_skin = np.stack(dense_skin, axis=0) # (J, m, 1) + final_sampled_vertex_groups['skin'] = final_sampled_vertex_groups['skin'][:, dense_indices] + if asset.meta is None: + asset.meta = {} + asset.meta['dense_vertices'] = dense_vertices + asset.meta['dense_normals'] = dense_normals + asset.meta['dense_skin'] = dense_skin + asset.meta['dense_indices'] = dense_indices + return SamplerResult( + sampled_vertices=final_sampled_vertices, + sampled_normals=final_sampled_normals if final_sampled_normals is not None else None, + sampled_vertex_groups=final_sampled_vertex_groups, + skin_samples=self.num_skin_samples, + ) + +def get_sampler(**kwargs) -> Sampler: + __target__ = kwargs.get('__target__') + assert __target__ is not None + del kwargs['__target__'] + if __target__ == 'mix': + sampler = SamplerMix.parse(**kwargs) + else: + raise ValueError(f"sampler method {__target__} not supported") + return sampler \ No newline at end of file diff --git a/src/data/spec.py b/src/data/spec.py new file mode 100755 index 0000000000000000000000000000000000000000..3719d80052549e8d4aef3e77dab5dca8f741a73f --- /dev/null +++ b/src/data/spec.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from dataclasses import fields + +class ConfigSpec(ABC): + @classmethod + def check_keys(cls, config, expect=None): + if expect is None: + expect = [field.name for field in fields(cls)] # type: ignore + for key in config.keys(): + if key not in expect: + raise ValueError(f"expect names {expect} in {cls.__name__}, found {key}") + + @classmethod + @abstractmethod + def parse(cls, **kwargs) -> 'ConfigSpec': + raise NotImplementedError() \ No newline at end of file diff --git a/src/data/transform.py b/src/data/transform.py new file mode 100755 index 0000000000000000000000000000000000000000..3601d755766932490a584aa495848402950a0e19 --- /dev/null +++ b/src/data/transform.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from typing import List, Optional + +from ..rig_package.info.asset import Asset +from .augment import Augment, get_augments +from .order import Order +from .sampler import Sampler, get_sampler +from .spec import ConfigSpec +from .vertex_group import VertexGroup, get_vertex_groups + +@dataclass +class Transform(ConfigSpec): + + order: Optional[Order]=None + + vertex_groups: Optional[List[VertexGroup]]=None + + augments: Optional[List[Augment]]=None + + sampler: Optional[Sampler]=None + + @classmethod + def parse(cls, **kwargs) -> 'Transform': + cls.check_keys(kwargs) + order_config = kwargs.get('order') + vertex_groups_config = kwargs.get('vertex_groups') + augments_config = kwargs.get('augments') + sampler_config = kwargs.get('sampler') + + d = {} + if order_config is not None: + d['order'] = Order.parse(**order_config) + if vertex_groups_config is not None: + d['vertex_groups'] = get_vertex_groups(*vertex_groups_config) + if augments_config is not None: + d['augments'] = get_augments(*augments_config) + if sampler_config is not None: + d['sampler'] = get_sampler(**sampler_config) + return Transform(**d) + + def apply(self, asset: Asset, **kwargs): + + # 1. arrange bones + if self.order is not None: + if asset.joint_names is not None and asset.parents is not None: + new_names, _ = self.order.arrange_names(cls=asset.cls, names=asset.joint_names, parents=asset.parents.tolist()) + asset.set_order(new_orders=new_names) # type: ignore + + # 2. collapse must perform first + if self.augments is not None: + kwargs = {} + for augment in self.augments: + augment.transform(asset=asset, **kwargs) + + # 3. get vertex groups + if self.vertex_groups is not None: + d = {} + for v in self.vertex_groups: + d.update(v.get_vertex_group(asset=asset)) + asset.vertex_groups = d + else: + asset.vertex_groups = {} + + # 4. sample + if self.sampler is not None: + res = self.sampler.sample(asset=asset) + asset.sampled_vertices = res.sampled_vertices + asset.sampled_normals = res.sampled_normals + asset.sampled_vertex_groups = res.sampled_vertex_groups + asset.skin_samples = res.skin_samples \ No newline at end of file diff --git a/src/data/vertex_group.py b/src/data/vertex_group.py new file mode 100755 index 0000000000000000000000000000000000000000..04c0c174fcff5f1eadf451e7b11016304f6f1a88 --- /dev/null +++ b/src/data/vertex_group.py @@ -0,0 +1,257 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass +from numpy import ndarray +from scipy.spatial import cKDTree # type: ignore +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import shortest_path, connected_components +from typing import Dict, List, Optional, Literal + +import numpy as np + +from ..rig_package.info.asset import Asset + +@dataclass(frozen=True) +class VertexGroup(ABC): + + @classmethod + @abstractmethod + def parse(cls, **kwargs) -> 'VertexGroup': + pass + + @abstractmethod + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + pass + +@dataclass(frozen=True) +class VertexGroupSkin(VertexGroup): + """capture skin""" + + normalize: bool=True + + @classmethod + def parse(cls, **kwargs) -> 'VertexGroupSkin': + return VertexGroupSkin(normalize=kwargs.get('normalize', True)) + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + if asset.skin is None: + raise ValueError("do not have skin") + if self.normalize: + asset.normalize_skin() + return {'skin': asset.skin.copy()} + +@dataclass(frozen=True) +class VertexGroupVoxelSkin(VertexGroup): + """capture voxel skin""" + + grid: int + alpha: float + link_dis: float + grid_query: int + vertex_query: int + grid_weight: float + mode: Literal['square', 'exp'] + + @classmethod + def parse(cls, **kwargs) -> 'VertexGroupVoxelSkin': + return VertexGroupVoxelSkin( + grid=kwargs.get('grid', 64), + alpha=kwargs.get('alpha', 0.5), + link_dis=kwargs.get('link_dis', 0.00001), + grid_query=kwargs.get('grid_query', 27), + vertex_query=kwargs.get('vertex_query', 27), + grid_weight=kwargs.get('grid_weight', 3.0), + mode=kwargs.get('mode', 'square'), + ) + + def get_vertex_group(self, asset: Asset) -> Dict[str, ndarray]: + if asset.vertices is None: + raise ValueError("do not have vertices") + if asset.faces is None: + raise ValueError("do not have faces") + if asset.joints is None: + raise ValueError("do not have joints") + # normalize into [-1, 1] first + min_vals = np.min(asset.vertices, axis=0) + max_vals = np.max(asset.vertices, axis=0) + + center = (min_vals + max_vals) / 2 + + scale = np.max(max_vals - min_vals) / 2 + + normalized_vertices = (asset.vertices - center) / scale + normalized_joints = (asset.joints - center) / scale + + grid_coords = asset.voxel().coords + skin = voxel_skin( + grid=self.grid, + grid_coords=grid_coords, + joints=normalized_joints, + vertices=normalized_vertices, + faces=asset.faces, + alpha=self.alpha, + link_dis=self.link_dis, + grid_query=self.grid_query, + vertex_query=self.vertex_query, + grid_weight=self.grid_weight, + mode=self.mode, + ) + skin = np.nan_to_num(skin, nan=0., posinf=0., neginf=0.) + return {'voxel_skin': skin,} + +def voxel_skin( + grid: int, + grid_coords: ndarray, # (M, 3) + joints: ndarray, # (J, 3) + vertices: ndarray, # (N, 3) + faces: ndarray, # (F, 3) + alpha: float=0.5, + link_dis: float=0.00001, + grid_query: int=27, + vertex_query: int=27, + grid_weight: float=3.0, + voxel_size: Optional[float]=None, + mode: str='square', + parents: Optional[ndarray]=None, +): + # modified from https://dl.acm.org/doi/pdf/10.1145/2485895.2485919 + assert mode in ['square', 'exp'] + J = joints.shape[0] + M = grid_coords.shape[0] + N = vertices.shape[0] + + if voxel_size is None: + _range = 2/grid*1.74 + else: + _range = voxel_size*1.74 + + grid_tree = cKDTree(grid_coords) + vertex_tree = cKDTree(vertices) + if parents is not None: + son = defaultdict(list) + for i, p in enumerate(parents): + if i == -1: + continue + son[p].append(i) + divide_joints = [] + joints_map = [] + for u in range(len(parents)): + if len(son[u]) != 1: + divide_joints.append(joints[u]) + joints_map.append(u) + else: + pu = joints[u] + pv = joints[son[u][0]] + seg = 10 + for i in range(seg+1): + p = (pu*i + pv*(seg-i)) / seg + divide_joints.append(p) + joints_map.append(u) + divide_joints = np.stack(divide_joints) + joints_map = np.array(joints_map) + else: + divide_joints = joints + joints_map = np.arange(joints.shape[0]) + joint_tree = cKDTree(divide_joints) + + # make combined vertices + # 0 ~ N-1: mesh vertices + # N ~ N+M-1: grid vertices + combined_vertices = np.concatenate([vertices, grid_coords], axis=0) + + # link adjacent grids + dist, idx = grid_tree.query(grid_coords, grid_query) # 3*3*3 + dist = dist[:, 1:] + idx = idx[:, 1:] + mask = (0 < dist) & (dist < _range) + source_grid2grid = np.repeat(np.arange(M), grid_query-1)[mask.ravel()] + N + to_grid2grid = idx[mask] + N + weight_grid2grid = dist[mask] * grid_weight + + # link very close vertices + dist, idx = vertex_tree.query(vertices, 4) + dist = dist[:, 1:] + idx = idx[:, 1:] + mask = (0 < dist) & (dist < link_dis) + source_close = np.repeat(np.arange(N), 3)[mask.ravel()] + to_close = idx[mask] + weight_close = dist[mask] + + # link grids to mesh vertices + dist, idx = vertex_tree.query(grid_coords, vertex_query) + mask = (0 < dist) & (dist < _range) # sqrt(3) + source_grid2vertex = np.repeat(np.arange(M), vertex_query)[mask.ravel()] + N + to_grid2vertex = idx[mask] + weight_grid2vertex = dist[mask] + + # build combined vertices tree + combined_tree = cKDTree(combined_vertices) + # link bones to the neartest vertices + _, joint_indices = combined_tree.query(divide_joints) + + # build graph + source_vertex2vertex = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0) + to_vertex2vertex = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0) + weight_vertex2vertex = np.sqrt(((vertices[source_vertex2vertex] - vertices[to_vertex2vertex])**2).sum(axis=-1)) + graph = csr_matrix( + (np.concatenate([weight_close, weight_vertex2vertex, weight_grid2grid, weight_grid2vertex]), + ( + np.concatenate([source_close, source_vertex2vertex, source_grid2grid, source_grid2vertex], axis=0), + np.concatenate([to_close, to_vertex2vertex, to_grid2grid, to_grid2vertex], axis=0)), + ), + shape=(N+M, N+M), + ) + + # get shortest path (J, N+M) + dist_matrix = shortest_path(graph, method='D', directed=False, indices=joint_indices) + + # (sum_J, N) + dis_vertex2bone = dist_matrix[:, :N] + unreachable = np.isinf(dis_vertex2bone).all(axis=0) + k = min(J, 3) + dist, idx = joint_tree.query(vertices[unreachable], k) + + # make sure at least one value in dis is not inf + unreachable_indices = np.where(unreachable)[0] + row_indices = idx + col_indices = np.repeat(unreachable_indices, k).reshape(-1, k) + dis_vertex2bone[row_indices, col_indices] = dist + + finite_vals = dis_vertex2bone[np.isfinite(dis_vertex2bone)] + max_dis = np.max(finite_vals) + dis_vertex2bone = np.nan_to_num(dis_vertex2bone, nan=max_dis, posinf=max_dis, neginf=max_dis) + dis_vertex2bone = np.maximum(dis_vertex2bone, 1e-6) + + # turn dis2bone to dis2vertex + dis_vertex2joint = np.full((joints.shape[0], vertices.shape[0]), max_dis) + for i in range(len(dis_vertex2bone)): + dis_vertex2joint[joints_map[i]] = np.minimum(dis_vertex2bone[i], dis_vertex2joint[joints_map[i]]) + + # (J, N) + if mode == 'exp': + skin = np.exp(-dis_vertex2joint / max_dis * 20.0) + elif mode == 'square': + skin = (1./((1-alpha)*dis_vertex2joint + alpha*dis_vertex2joint**2))**2 + else: + assert False, f'invalid mode: {mode}' + skin = skin / skin.sum(axis=0) + # (N, J) + skin = skin.transpose() + return skin + +def get_vertex_groups(*args) -> List[VertexGroup]: + vertex_groups = [] + MAP = { + 'skin': VertexGroupSkin, + 'voxel_skin': VertexGroupVoxelSkin, + } + MAP: Dict[str, type[VertexGroup]] + for (i, c) in enumerate(args): + __target__ = c.get('__target__') + assert __target__ is not None, f"do not find `__target__` in config of vertex_groups of position {i}" + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + c = deepcopy(c) + del c['__target__'] + vertex_groups.append(MAP[__target__].parse(**c)) + return vertex_groups \ No newline at end of file diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/michelangelo/__init__.py b/src/model/michelangelo/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/src/model/michelangelo/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/src/model/michelangelo/get_model.py b/src/model/michelangelo/get_model.py new file mode 100755 index 0000000000000000000000000000000000000000..4247b93ccb89c31dd8a08a731bc2b4c4a722dd9b --- /dev/null +++ b/src/model/michelangelo/get_model.py @@ -0,0 +1,30 @@ +import torch + +from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder + +def get_encoder( + pretrained_path: str=None, + freeze_decoder: bool=False, + **kwargs +) -> AlignedShapeLatentPerceiver: + model = AlignedShapeLatentPerceiver(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + if freeze_decoder: + model.geo_decoder.requires_grad_(False) + model.encoder.query.requires_grad_(False) + model.pre_kl.requires_grad_(False) + model.post_kl.requires_grad_(False) + model.transformer.requires_grad_(False) + return model + +def get_encoder_simplified( + pretrained_path: str=None, + **kwargs +) -> ShapeAsLatentPerceiverEncoder: + model = ShapeAsLatentPerceiverEncoder(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/src/model/michelangelo/models/__init__.py b/src/model/michelangelo/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/src/model/michelangelo/models/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/src/model/michelangelo/models/modules/__init__.py b/src/model/michelangelo/models/modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0729b49eadf964584d3524d9c0f6adec3f04a6a9 --- /dev/null +++ b/src/model/michelangelo/models/modules/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import checkpoint diff --git a/src/model/michelangelo/models/modules/checkpoint.py b/src/model/michelangelo/models/modules/checkpoint.py new file mode 100755 index 0000000000000000000000000000000000000000..c54807df8ceded8c30a9a2f2a6586228f3d59817 --- /dev/null +++ b/src/model/michelangelo/models/modules/checkpoint.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from typing import Callable, Iterable, Sequence, Union + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + import deepspeed + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type='cuda') + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.amp.custom_bwd(device_type='cuda') + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/src/model/michelangelo/models/modules/distributions.py b/src/model/michelangelo/models/modules/distributions.py new file mode 100755 index 0000000000000000000000000000000000000000..cf1cdcd53f1eb534b55d92ae1bd0b9854f6b890c --- /dev/null +++ b/src/model/michelangelo/models/modules/distributions.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +from typing import Union, List + + +class AbstractDistribution(object): + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/src/model/michelangelo/models/modules/embedder.py b/src/model/michelangelo/models/modules/embedder.py new file mode 100755 index 0000000000000000000000000000000000000000..223de828f44903a3ce96b59d1cc5621e0989b535 --- /dev/null +++ b/src/model/michelangelo/models/modules/embedder.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import torch +import torch.nn as nn +import math + +VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + """ following @crowsonkb "s lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, in_channels, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // in_channels + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + def forward(self, x): + """ + + Args: + x (torch.FloatTensor): [..., c] + + Returns: + x (torch.FloatTensor): [..., d] + """ + + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class TriplaneLearnedFourierEmbedder(nn.Module): + def __init__(self, in_channels, dim): + super().__init__() + + self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + + self.out_dim = in_channels + dim + + def forward(self, x): + + yz_embed = self.yz_plane_embedder(x) + xz_embed = self.xz_plane_embedder(x) + xy_embed = self.xy_plane_embedder(x) + + embed = yz_embed + xz_embed + xy_embed + + return embed + + +def sequential_pos_embed(num_len, embed_dim): + assert embed_dim % 2 == 0 + + pos = torch.arange(num_len, dtype=torch.float32) + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return embeddings + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, + num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, + logspace=True, include_input=True) + return embedder_obj, embedder_obj.out_dim + + elif embed_type == "hashgrid": + raise NotImplementedError + + elif embed_type == "sphere_harmonic": + raise NotImplementedError + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") diff --git a/src/model/michelangelo/models/modules/transformer_blocks.py b/src/model/michelangelo/models/modules/transformer_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..ba3ecffd0fd3e05708a9c9cc29b7ff45591a9daf --- /dev/null +++ b/src/model/michelangelo/models/modules/transformer_blocks.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional +import os + +from .checkpoint import checkpoint +from ...utils.misc import use_flash3 + + +if use_flash3.is_use: + from flash_attn_interface import flash_attn_func + print("use flash attention 3.") +else: + print("use flash attention 2.") + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +def flash_attention(q, k, v): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + if use_flash3.is_use: + out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous()) + # out = flash_attn_func(q, k, v) + + # q_ = q.transpose(1, 2) + # k_ = k.transpose(1, 2) + # v_ = v.transpose(1, 2) + + # # print(q.shape, k.shape, v.shape) + # out_ = F.scaled_dot_product_attention(q_, k_, v_) + # out_ = out_.transpose(1, 2) + + # # print(torch.abs(out - out_).mean()) + # assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large" + + # out = out_ + + # print("use flash_atten 3") + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = F.scaled_dot_product_attention(q, k, v) + out = out.transpose(1, 2) + # print("use flash atten 2") + + return out + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + flash: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + mlp_width_scale: int = 4, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + hidden_width_scale: int = 4, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/src/model/michelangelo/models/tsal/__init__.py b/src/model/michelangelo/models/tsal/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/src/model/michelangelo/models/tsal/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/src/model/michelangelo/models/tsal/loss.py b/src/model/michelangelo/models/tsal/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..a49fbacf64799505078fd43664fc944bddf34a42 --- /dev/null +++ b/src/model/michelangelo/models/tsal/loss.py @@ -0,0 +1,454 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Dict + +from ..modules.distributions import DiagonalGaussianDistribution +from ...utils.eval import compute_psnr +from ...utils import misc +import numpy as np +from copy import deepcopy + + +def logits_to_sdf(logits): + return torch.sigmoid(logits) * 2 - 1 + +class KLNearFar(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + pos_ratio = torch.mean(labels) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/accuracy".format(split): accuracy, + "{}/pos_ratio".format(split): pos_ratio + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log + + +class KLNearFarColor(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + color_weight: float = 1.0, + color_criterion: str = "mse", + num_near_samples: Optional[int] = None): + + super().__init__() + + self.color_weight = color_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + + if color_criterion == "mse": + self.color_criterion = nn.MSELoss() + + elif color_criterion == "l1": + self.color_criterion = nn.L1Loss() + + else: + raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") + + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + pred_colors: torch.FloatTensor, + gt_colors: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + pred_colors (torch.FloatTensor): [B, M, 3] + gt_colors (torch.FloatTensor): [B, M, 3] + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + # surface color loss + color = self.color_criterion(pred_colors, gt_colors) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + psnr = compute_psnr(pred_colors, gt_colors) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/color".format(split): color.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/psnr".format(split): psnr.detach(), + "{}/accuracy".format(split): accuracy + } + + return loss, log + + +class ContrastKLNearFar(nn.Module): + def __init__(self, + contrast_weight: float = 1.0, + near_weight: float = 0.1, + kl_weight: float = 1.0, + normal_weight: float = 0.0, + surface_weight: float = 0.0, + eikonal_weight: float = 0.0, + sdf_bce_weight: float = 0.0, + sdf_l1l2_weight: float = 1.0, + num_near_samples: Optional[int] = None, + sdf_trunc_val: float = 0.05, + gt_sdf_soft: bool = False, + normal_supervision_type: str = "cosine", + supervision_type: str = 'occupancy'): + + super().__init__() + + self.labels = None + self.last_local_batch_size = None + self.supervision_type = supervision_type + + assert normal_supervision_type in ["l1", "l2", "cosine", "l1_cosine", "l2_cosine", "von_mises"] + self.normal_supervision_type = normal_supervision_type + + self.contrast_weight = contrast_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.normal_weight = normal_weight + self.surface_weight = surface_weight + self.eikonal_weight = eikonal_weight + self.sdf_bce_weight = sdf_bce_weight # only used in sigmoid-sdf + self.sdf_l1l2_weight = sdf_l1l2_weight # only used in sigmoid-sdf + self.sdf_trunc_val = sdf_trunc_val + self.gt_sdf_soft = gt_sdf_soft + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + self.geo_criterion_sdf = nn.MSELoss() + + def sdf_loss(self, pred_sdf, gt_sdf): + scaled_sdf = gt_sdf / self.sdf_trunc_val + greater_mask = scaled_sdf > 1. + smaller_mask = scaled_sdf < -1. + inside_mask = 1. - greater_mask - smaller_mask + greater_loss = F.smooth_l1_loss(F.relu(1. - pred_sdf), torch.zeros_like(pred_sdf), reduction="none") * greater_mask + smaller_loss = F.smooth_l1_loss(F.relu(pred_sdf + 1.), torch.zeros_like(pred_sdf), reduction="none") * smaller_mask + inside_loss = F.smooth_l1_loss(pred_sdf, gt_sdf, beta=1e-2, reduction="none") * inside_mask + loss = (greater_loss + smaller_loss + inside_loss).mean() + return loss + + def von_mises(self, x, y, k=1): + cos = F.cosine_similarity(x, y, dim=-1) + exp = torch.exp(k * (cos - 1)) + return 1 - exp + + def forward(self, + shape_embed: torch.FloatTensor, + text_embed: torch.FloatTensor, + image_embed: torch.FloatTensor, + logit_scale: torch.FloatTensor, + posteriors: Optional[DiagonalGaussianDistribution], + latents: torch.FloatTensor, + shape_logits: torch.FloatTensor, + shape_labels: torch.FloatTensor, + surface_logits: Optional[torch.FloatTensor], + surface_normals: Optional[torch.FloatTensor], + gt_surface_normals: Optional[torch.FloatTensor], + split: Optional[str] = "train", **kwargs): + if self.supervision_type == 'occupancy': + shape_logits = shape_logits.squeeze(-1) + shape_labels[shape_labels>=0] = 1 + shape_labels[shape_labels<0] = 0 + + elif self.supervision_type == 'occupancy-shapenet': + shape_logits = shape_logits.squeeze(-1) + + elif self.supervision_type == 'occupancy-w-surface': + shape_logits = shape_logits.squeeze(-1) + shape_labels[shape_labels==10] = 0 + shape_labels[shape_labels>0] = 1 + shape_labels[shape_labels<0] = 0 + + elif 'sdf' in self.supervision_type: + shape_logits = shape_logits.squeeze(-1) + if self.gt_sdf_soft: + shape_labels_sdf = torch.tanh(shape_labels / self.sdf_trunc_val)# * self.sdf_trunc_val + else: + shape_labels_sdf = torch.clamp(shape_labels, min=-self.sdf_trunc_val, max=self.sdf_trunc_val) / self.sdf_trunc_val + else: + raise ValueError(f"Invalid supervision_type {self.supervision_type}") + + local_batch_size = shape_embed.size(0) + + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * misc.get_rank() + torch.arange( + local_batch_size, device=shape_embed.device + ).long() + self.last_local_batch_size = local_batch_size + + + if text_embed is not None and image_embed is not None: + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # gather features from all GPUs + shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( + [shape_embed, text_embed, image_embed] + ) + + # cosine similarity as logits + logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() + logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() + logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() + logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() + contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + + F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ + (F.cross_entropy(logits_per_shape_image, self.labels) + + F.cross_entropy(logits_per_image_shape, self.labels)) / 2 + else: + contrast_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + + # shape reconstruction + if self.num_near_samples is None: + num_vol = shape_logits.shape[1] // 2 + else: + num_vol = shape_logits.shape[1] - self.num_near_samples + + # occupancy/sdf loss + if self.supervision_type == 'occupancy' or self.supervision_type == 'occupancy-shapenet': + vol_logits = shape_logits[:, 0:num_vol] + vol_labels = shape_labels[:, 0:num_vol] + + near_logits = shape_logits[:, num_vol:] + near_labels = shape_labels[:, num_vol:] + + vol_loss = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_loss = self.geo_criterion(near_logits.float(), near_labels.float()) + + elif 'sdf' in self.supervision_type: + if self.supervision_type == "sigmoid-sdf": + shape_sdfs = logits_to_sdf(shape_logits) + else: + shape_sdfs = shape_logits + + vol_logits = shape_logits[:, 0:num_vol] + vol_sdfs = shape_sdfs[:, 0:num_vol] + vol_labels_sdf = shape_labels_sdf[:, 0:num_vol] + + near_logits= shape_logits[:, num_vol:] + near_sdfs = shape_sdfs[:, num_vol:] + near_labels_sdf = shape_labels_sdf[:, num_vol:] + + # use both sdf loss and occupancy loss + vol_loss = torch.mean(torch.abs(vol_sdfs - vol_labels_sdf)) + torch.mean((vol_sdfs - vol_labels_sdf) ** 2) #+ self.geo_criterion(vol_logits_sdf, vol_labels) + near_loss = torch.mean(torch.abs(near_sdfs - near_labels_sdf)) + torch.mean((near_sdfs - near_labels_sdf) ** 2) #+ self.geo_criterion(near_logits_sdf, near_labels) + + if self.supervision_type == "sigmoid-sdf": + vol_labels = (vol_labels_sdf + 1) / 2 + near_labels = (near_labels_sdf + 1) / 2 + vol_loss = self.sdf_l1l2_weight * vol_loss + self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels) + near_loss = self.sdf_l1l2_weight * near_loss + self.sdf_bce_weight * self.geo_criterion(near_logits, near_labels) + # print(vol_loss, self.sdf_bce_weight * self.geo_criterion(vol_logits, vol_labels)) + + # surface loss + if "sdf" in self.supervision_type and surface_logits is not None: + if self.supervision_type == "sigmoid-sdf": + surface_sdfs = logits_to_sdf(surface_logits) + else: + surface_sdfs = surface_logits + surface_loss = torch.mean(surface_sdfs ** 2) + else: + surface_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + + if surface_normals is not None and gt_surface_normals is not None and "sdf" in self.supervision_type: + + valid_mask = surface_sdfs.squeeze(-1) < (self.sdf_trunc_val * 0.8) + + if valid_mask is not None: + surface_normals = surface_normals[valid_mask] + gt_surface_normals = gt_surface_normals[valid_mask] + + # eikonal loss + surface_normals_norm = torch.norm(surface_normals, dim=-1) + eikonal_loss = F.mse_loss(surface_normals_norm * self.sdf_trunc_val, surface_normals_norm.new_ones(surface_normals_norm.shape), reduction="mean") + + # surface normal loss + # surface_normals = F.normalize(surface_normals, dim=-1) + surface_normals = surface_normals * self.sdf_trunc_val + gt_surface_normals = F.normalize(gt_surface_normals, dim=-1) + + if self.normal_supervision_type == "cosine": + # use cosine similarity loss + normal_loss = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean() + elif self.normal_supervision_type == "l1": + # use l1 loss + normal_loss = F.l1_loss(surface_normals, gt_surface_normals) + elif self.normal_supervision_type == "l2": + normal_loss = F.mse_loss(surface_normals, gt_surface_normals) + elif self.normal_supervision_type == "von_mises": + normal_loss = self.von_mises(surface_normals, gt_surface_normals).mean() + elif self.normal_supervision_type == "l1_cosine": + normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean() + normal_loss_l1 = F.l1_loss(surface_normals, gt_surface_normals) + normal_loss = normal_loss_cos + normal_loss_l1 + elif self.normal_supervision_type == "l2_cosine": + normal_loss_cos = 1 - F.cosine_similarity(F.normalize(surface_normals, dim=-1), gt_surface_normals, dim=-1).mean() + normal_loss_l2 = F.mse_loss(surface_normals, gt_surface_normals) + normal_loss = normal_loss_cos + normal_loss_l2 + else: + raise NotImplementedError + else: + normal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + eikonal_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + surface_normals_norm = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=shape_logits.dtype, device=shape_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_loss + near_loss * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + normal_loss * self.normal_weight + self.eikonal_weight * eikonal_loss + self.surface_weight * surface_loss + + # compute accuracy + with torch.no_grad(): + if "sdf" in self.supervision_type: + preds = shape_sdfs >= 0 + sdf_labels = shape_labels_sdf >= 0 + accuracy = (preds == sdf_labels).float() + else: + preds = shape_logits >= 0 + accuracy = (preds == shape_labels).float() + accuracy = accuracy.mean() + + log = { + # "{}/contrast".format(split): contrast_loss.clone().detach(), + "{}/near".format(split): near_loss.detach(), + "{}/far".format(split): vol_loss.detach(), + "{}/normal".format(split): normal_loss.detach(), + "{}/surface".format(split): surface_loss.detach(), + "{}/eikonal".format(split): eikonal_loss.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/surface_grad_norm".format(split): surface_normals_norm.mean().detach(), + # "{}/shape_text_acc".format(split): shape_text_acc, + # "{}/shape_image_acc".format(split): shape_image_acc, + "{}/total_loss".format(split): loss.clone().detach(), + "{}/accuracy".format(split): accuracy, + } + + if posteriors is not None: + log[f"{split}/posteriors_mean"] = posteriors.mean.mean().detach() + log[f"{split}/posteriors_std_mean"] = posteriors.std.mean().detach() + log[f"{split}/posteriors_std_max"] = posteriors.std.max().detach() + + return loss, log, near_loss diff --git a/src/model/michelangelo/models/tsal/sal_perceiver.py b/src/model/michelangelo/models/tsal/sal_perceiver.py new file mode 100755 index 0000000000000000000000000000000000000000..ded43f16368ea0b4df43de45db60c300efa05286 --- /dev/null +++ b/src/model/michelangelo/models/tsal/sal_perceiver.py @@ -0,0 +1,723 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional, Union +from einops import repeat +import math +import random +import time +import numpy as np + +from ..modules import checkpoint +from ..modules.embedder import FourierEmbedder +from ..modules.distributions import DiagonalGaussianDistribution +from ..modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) +from ...utils.misc import use_flash3 + +from .tsal_base import ShapeAsLatentModule +from .loss import logits_to_sdf + +from ....utils import fps + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + query_method: bool = False, + use_full_input: bool = True, + token_num: int = 256, + no_query: bool=False): + + super().__init__() + + self.query_method = query_method + self.token_num = token_num + self.use_full_input = use_full_input + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + + if no_query: + self.query = None + else: + self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) + + self.fourier_embedder = fourier_embedder + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + + self.self_attn = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=False + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + if self.query_method: + token_num = self.num_latents + bs = pc.shape[0] #pc [10, 204800, 3] + data = self.fourier_embedder(pc) #[10, 204800, 51] + if feats is not None: #[10, 204800, 3] + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) #[10, 204800, 768] + + query = repeat(self.query, "m c -> b m c", b=bs) #[10, 257, 768] + + latents = self.cross_attn(query, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + pre_pc = None + else: + + if isinstance(self.token_num, int): + token_num = self.token_num + else: + token_num = random.choice(self.token_num) + # print(token_num,'-----------------------', flush=True) + + if self.training: + rng = np.random.default_rng() + else: + rng = np.random.default_rng(seed=0) + ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1]) + + pre_pc = pc[:,ind,:] + pre_feats = feats[:,ind,:] + + + B, N, D = pre_pc.shape #[10, 204800, 3] + C = pre_feats.shape[-1] + ###### fps + pos = pre_pc.view(B*N, D) + pos_feats = pre_feats.view(B*N, C) + batch = torch.arange(B).to(pc.device) + batch = torch.repeat_interleave(batch, N) + + # ratio = 1.0 * token_num / N + idx = fps(pos, batch, ratio=1. / 4, random_start=self.training) + + sampled_pc = pos[idx] + sampled_pc = sampled_pc.view(B, -1, 3) + + sampled_feats = pos_feats[idx] + sampled_feats = sampled_feats.view(B, -1, C) + + ###### + if self.use_full_input: + data = self.fourier_embedder(pc) #[B, 20480, 51] + else: + data = self.fourier_embedder(pre_pc) # [B, 4 * token_num, 51] + + if feats is not None: #[10, 204800, 3] + if not self.use_full_input: + feats = pre_feats + data = torch.cat([data, feats], dim=-1) #[10, 204800, 54] + data = self.input_proj(data) #[10, 204800, 768] + + # print(data.shape) + + sampled_data = self.fourier_embedder(sampled_pc) #[10, 256, 51] + if feats is not None: #[10, 256, 3] + sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) #[10, 256, 54] + sampled_data = self.input_proj(sampled_data) #[10, 256, 768] + + latents = self.cross_attn(sampled_data, data) #[10, 256, 768] + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + pre_pc = torch.cat([pre_pc, pre_feats], dim=-1) + + return latents, pc, token_num, pre_pc + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False, + mlp_width_scale: int = 4, + supervision_type: str = 'occupancy'): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + self.supervision_type = supervision_type + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + mlp_width_scale=mlp_width_scale, + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + if self.supervision_type == 'occupancy-sdf': + self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype) + + + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + if next(self.query_proj.parameters()).dtype == torch.float16: + queries = queries.half() + latents = latents.half() + # print(f"queries: {queries.dtype}, {queries.device}") + # print(f"latents: {latents.dtype}, {latents.device}"z) + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x_1 = self.output_proj(x) + if self.supervision_type == 'occupancy-sdf': + x_2 = self.output_proj_sdf(x) + return x_1, x_2 + else: + return x_1 + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentPerceiver(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + decoder_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + query_method: bool = False, + token_num: int = 256, + grad_type: str = "numerical", + grad_interval: float = 0.005, + use_full_input: bool = True, + freeze_encoder: bool = False, + decoder_mlp_width_scale: int = 4, + residual_kl: bool = False, + ): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + assert grad_type in ["numerical", "analytical"] + self.grad_type = grad_type + self.grad_interval = grad_interval + self.supervision_type = supervision_type + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + query_method=query_method, + use_full_input=use_full_input, + token_num=token_num + ) + + self.embed_dim = embed_dim + self.residual_kl = residual_kl + if decoder_width is None: + decoder_width = width + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + if self.residual_kl: + assert self.post_kl.out_features % self.post_kl.in_features == 0 + assert self.pre_kl.in_features % self.pre_kl.out_features == 0 + else: + self.latent_shape = (num_latents, width) + + print("decoder width = ", decoder_width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=decoder_width, + layers=num_decoder_layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=decoder_width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint, + supervision_type=supervision_type, + mlp_width_scale=decoder_mlp_width_scale + ) + + if freeze_encoder: + for p in self.encoder.parameters(): + p.requires_grad = False + for p in self.pre_kl.parameters(): + p.requires_grad = False + print("freeze encoder and pre kl") + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + sample_posterior (bool): + + Returns: + latents (torch.FloatTensor) + center_pos (torch.FloatTensor or None): + posterior (DiagonalGaussianDistribution or None): + """ + + latents, center_pos = self.encoder(pc, feats) + + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + if self.residual_kl: + B, N = latents.shape[:2] + moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + + return latents, center_pos, posterior + + def decode(self, latents: torch.FloatTensor): + if self.residual_kl: + latents = latents.repeat_interleave(self.post_kl.out_features // self.post_kl.in_features, dim=-1) + self.post_kl(latents) + else: + latents = self.post_kl(latents) + + return self.transformer(latents) + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor, grad: bool = False): + # logits = self.geo_decoder(queries, latents).squeeze(-1) + if grad: + # with torch.autocast(device_type="cuda", dtype=torch.float32): + if self.grad_type == "numerical": + raise NotImplementedError + interval = self.grad_interval + # print('grad interval = ', interval) + grad_value = [] + for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]: + offset_tensor = torch.tensor(offset, device=queries.device)[None, :] + res_p = self.geo_decoder(queries + offset_tensor, latents)[..., 0] + res_n = self.geo_decoder(queries - offset_tensor, latents)[..., 0] + grad_value.append((res_p - res_n) / (2 * interval)) + grad_value = torch.stack(grad_value, dim=-1) + else: + # print("auto grad") + queries_d = torch.clone(queries) + queries_d.requires_grad = True + with torch.enable_grad(): + with use_flash3.disable_flash3(): + logits = self.geo_decoder(queries_d, latents) + if self.supervision_type == "sigmoid-sdf": + sdfs = logits_to_sdf(logits) + grad_value = torch.autograd.grad(sdfs, [queries_d], + grad_outputs=torch.ones_like(sdfs), + create_graph=self.geo_decoder.training)[0] + else: + logits = self.geo_decoder(queries, latents) + grad_value = None + + return logits, grad_value + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + logits (torch.FloatTensor): [B, P] + center_pos (torch.FloatTensor): [B, M, 3] + posterior (DiagonalGaussianDistribution or None). + + """ + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior + + +class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[str], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + decoder_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + grad_type: str = "numerical", + grad_interval: float = 0.005, + query_method: bool = False, + use_full_input: bool = True, + token_num: int = 256, + freeze_encoder: bool = False, + decoder_mlp_width_scale: int = 4, + residual_kl: bool = False, + ): + + MAP_DTYPE = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + } + if dtype is not None: + dtype = MAP_DTYPE[dtype] + super().__init__( + device=device, + dtype=dtype, + num_latents=1 + num_latents, + point_feats=point_feats, + embed_dim=embed_dim, + num_freqs=num_freqs, + include_pi=include_pi, + width=width, + decoder_width=decoder_width, + heads=heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + supervision_type=supervision_type, + grad_type=grad_type, + grad_interval=grad_interval, + query_method=query_method, + token_num=token_num, + use_full_input=use_full_input, + freeze_encoder=freeze_encoder, + decoder_mlp_width_scale=decoder_mlp_width_scale, + residual_kl=residual_kl, + ) + + self.width = width + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True, + only_shape: bool=False): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, c] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor) + kl_embed (torch.FloatTensor): + posterior (DiagonalGaussianDistribution or None): + """ + + shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats) + if only_shape: + return shape_embed + kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) + + return shape_embed, kl_embed, posterior, token_num, pre_pc + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _, token_num, pre_pc = self.encoder(pc, feats) + + shape_embed = x[:, 0] + # latents = x[:, 1:] + # use all tokens + latents = x + + return shape_embed, latents, token_num, pre_pc + + def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + if self.residual_kl: + B, N = latents.shape[:2] + moments = moments + latents.view(B, N, -1, self.pre_kl.out_features).mean(dim=-2) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + + return kl_embed, posterior + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor): [B, projection_dim] + logits (torch.FloatTensor): [B, M] + posterior (DiagonalGaussianDistribution or None). + + """ + + shape_embed, kl_embed, posterior, token_num, pre_pc = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(kl_embed) + logits, grad = self.query_geometry(volume_queries, latents) + + return shape_embed, logits, posterior, token_num, pre_pc, grad + +##################################################### +# a simplified verstion of perceiver encoder +##################################################### + +class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[Union[torch.dtype, str]], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False, + supervision_type: str = 'occupancy', + query_method: bool = False, + token_num: int = 256, + grad_type: str = "numerical", + grad_interval: float = 0.005, + use_full_input: bool = True, + freeze_encoder: bool = False, + residual_kl: bool = False, + ): + + super().__init__() + + + MAP_DTYPE = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + } + + if dtype is not None and isinstance(dtype, str): + dtype = MAP_DTYPE[dtype] + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + assert grad_type in ["numerical", "analytical"] + self.grad_type = grad_type + self.grad_interval = grad_interval + self.supervision_type = supervision_type + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint, + query_method=query_method, + use_full_input=use_full_input, + token_num=token_num, + no_query=True, + ) + + self.embed_dim = embed_dim + self.residual_kl = residual_kl + if freeze_encoder: + for p in self.encoder.parameters(): + p.requires_grad = False + print("freeze encoder") + self.width = width + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _, token_num, pre_pc = self.encoder(pc, feats) + + shape_embed = x[:, 0] + latents = x + + return shape_embed, latents, token_num, pre_pc + + def forward(self): + raise NotImplementedError() \ No newline at end of file diff --git a/src/model/michelangelo/models/tsal/tsal_base.py b/src/model/michelangelo/models/tsal/tsal_base.py new file mode 100755 index 0000000000000000000000000000000000000000..0b39830f5d45c7e87ece1c24b940a2dea6074431 --- /dev/null +++ b/src/model/michelangelo/models/tsal/tsal_base.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from typing import Tuple, List, Optional +import lightning.pytorch as pl + + +class Point2MeshOutput(object): + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.center = None + self.pc = None + + +class Latent2MeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + + +class AlignedMeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.surface = None + self.image = None + self.text: Optional[str] = None + self.shape_text_similarity: Optional[float] = None + self.shape_image_similarity: Optional[float] = None + + +class ShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class ShapeAsLatentModule(nn.Module): + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def set_shape_model_only(self): + raise NotImplementedError + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class AlignedShapeAsLatentModule(nn.Module): + shape_model: ShapeAsLatentModule + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def set_shape_model_only(self): + raise NotImplementedError + + def encode_image_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_text_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_shape_embed(self, *args, **kwargs): + raise NotImplementedError + + +class TexturedShapeAsLatentModule(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + def query_color(self, *args, **kwargs): + raise NotImplementedError diff --git a/src/model/michelangelo/utils/__init__.py b/src/model/michelangelo/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..76d2dd39781034eaa33293a2243ebee3b3c982c6 --- /dev/null +++ b/src/model/michelangelo/utils/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .misc import get_config_from_file +from .misc import instantiate_from_config diff --git a/src/model/michelangelo/utils/eval.py b/src/model/michelangelo/utils/eval.py new file mode 100755 index 0000000000000000000000000000000000000000..954b9ae2643c8adb6c9af6141ede2b38a329db22 --- /dev/null +++ b/src/model/michelangelo/utils/eval.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +import torch + + +def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): + + mse = torch.mean((x - y) ** 2) + psnr = 10 * torch.log10(data_range / (mse + eps)) + + return psnr + diff --git a/src/model/michelangelo/utils/misc.py b/src/model/michelangelo/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..45dc1141660bbb5f544e6a8ac8f8daa9f1f42d27 --- /dev/null +++ b/src/model/michelangelo/utils/misc.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +import importlib +from omegaconf import OmegaConf, DictConfig, ListConfig +import time +import torch +import torch.distributed as dist +from typing import Union, Any, Optional +from collections import defaultdict +from torch.optim import lr_scheduler +import os +from dataclasses import dataclass, field +from contextlib import contextmanager + +import logging +logger = logging.getLogger(__name__) + + + +def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8): + return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs + + +OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + +@dataclass +class ExperimentConfig: + task: str = "vae" + output_dir: str = "outputs" + resume: Optional[str] = None + + data: dict = field(default_factory=dict) + model: dict = field(default_factory=dict) + + trainer: dict = field(default_factory=dict) + checkpoint: dict = field(default_factory=dict) + + wandb: dict = field(default_factory=dict) + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) + return scfg + +def get_config_from_file(config_file: str, cli_args: list = [], **kwargs) -> Union[DictConfig, ListConfig]: + config_file = OmegaConf.load(config_file) + cli_conf = OmegaConf.from_cli(cli_args) + + if 'base_config' in config_file.keys(): + if config_file['base_config'] == "default_base": + base_config = OmegaConf.create() + # base_config = get_default_config() + elif config_file['base_config'].endswith(".yaml"): + base_config = get_config_from_file(config_file['base_config']) + else: + raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.") + + config_file = {key: value for key, value in config_file.items() if key != "base_config"} + + cfg = OmegaConf.merge(base_config, config_file, cli_conf, kwargs) + else: + cfg = OmegaConf.merge(config_file, cli_conf, kwargs) + + scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) + + return scfg + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def get_obj_from_config(config): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + return get_obj_from_str(config["target"]) + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + cls = get_obj_from_str(config["target"]) + + params = config.get("params", dict()) + # params.update(kwargs) + # instance = cls(**params) + kwargs.update(params) + instance = cls(**kwargs) + + return instance + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + +def get_free_space(path): + fs_stats = os.statvfs(path) + free_space = fs_stats.f_bsize * fs_stats.f_bfree + return free_space + +def get_device_type(): + # Returns an empty string when no CUDA device is available so that + # callers like `FLASH3.__init__` (which only check `"H100" in ...`) can + # be imported safely on CPU-only / ZeroGPU-main processes without + # raising "No CUDA GPUs are available". + try: + if not torch.cuda.is_available(): + return "" + return torch.cuda.get_device_name(0) + except (RuntimeError, AssertionError): + return "" + +def get_hostname(): + import socket + return socket.gethostname() + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather( + tensor_all, + tensor, + async_op=False # performance opt + ) + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler + +class TimeRecorder: + _instance = None + + def __init__(self): + self.items = {} + self.accumulations = defaultdict(list) + self.time_scale = 1000.0 # ms + self.time_unit = "ms" + self.enabled = False + + def __new__(cls): + # singleton + if cls._instance is None: + cls._instance = super(TimeRecorder, cls).__new__(cls) + return cls._instance + + def enable(self, enabled: bool) -> None: + self.enabled = enabled + + def start(self, name: str) -> None: + if not self.enabled: + return + torch.cuda.synchronize() + self.items[name] = time.time() + + def end(self, name: str, accumulate: bool = False) -> float: + if not self.enabled or name not in self.items: + return + torch.cuda.synchronize() + start_time = self.items.pop(name) + delta = time.time() - start_time + if accumulate: + self.accumulations[name].append(delta) + t = delta * self.time_scale + logger.info(f"{name}: {t:.2f}{self.time_unit}") + + def get_accumulation(self, name: str, average: bool = False) -> float: + if not self.enabled or name not in self.accumulations: + return + acc = self.accumulations.pop(name) + total = sum(acc) + if average: + t = total / len(acc) * self.time_scale + else: + t = total * self.time_scale + logger.info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") + + +### global time recorder +time_recorder = TimeRecorder() + +class FLASH3: + def __init__(self) -> None: + self.available = "H100" in get_device_type() + self.use = os.environ.get("USE_FLASH3", False) + + @property + def is_use(self): + return self.available and self.use + + @contextmanager + def disable_flash3(self): + use = self.use + self.set_use(False) + yield + self.set_use(use) + + def set_use(self, use=True): + self.use = use + +use_flash3 = FLASH3() diff --git a/src/model/parse_encoder.py b/src/model/parse_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..00c37098665a08b1fe1d9a9f19a1ebe444c303e8 --- /dev/null +++ b/src/model/parse_encoder.py @@ -0,0 +1,28 @@ +from copy import deepcopy +from dataclasses import dataclass + +from .michelangelo.get_model import get_encoder as get_encoder_michelangelo +from .michelangelo.get_model import AlignedShapeLatentPerceiver +from .michelangelo.get_model import get_encoder_simplified as get_encoder_michelangelo_encoder +from .michelangelo.get_model import ShapeAsLatentPerceiverEncoder +from .skin_vae.autoencoders.autoencoder_kl_tripo2 import Tripo2Encoder + +@dataclass(frozen=True) +class _MAP_MESH_ENCODER: + michelangelo = AlignedShapeLatentPerceiver + michelangelo_encoder = ShapeAsLatentPerceiverEncoder + tripo = Tripo2Encoder + +MAP_MESH_ENCODER = _MAP_MESH_ENCODER() + + +def get_mesh_encoder(**kwargs): + MAP = { + 'michelangelo': get_encoder_michelangelo, + 'michelangelo_encoder': get_encoder_michelangelo_encoder, + 'tripo': Tripo2Encoder, + } + __target__ = kwargs['__target__'] + del kwargs['__target__'] + assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" + return MAP[__target__](**deepcopy(kwargs)) \ No newline at end of file diff --git a/src/model/skin_vae/attention_processor.py b/src/model/skin_vae/attention_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..b28b8997c9811972fa9915a6edcdfab0535a85bd --- /dev/null +++ b/src/model/skin_vae/attention_processor.py @@ -0,0 +1,283 @@ +import inspect +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.utils import logging +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph +from torch import nn + +try: + from flash_attn_interface import flash_attn_func +except Exception as e: + def flash_attn_func(q, k, v): + q = q.permute(0, 2, 1, 3) # (B, H, L, D) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if q.shape[1] != k.shape[1]: + repeat_factor = q.shape[1] // k.shape[1] + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v) + return out.permute(0, 2, 1, 3), None # (B, L, H, D) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Tripo2AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) + # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. + if not attn.is_cross_attention: + qkv = torch.cat((query, key, value), dim=-1) + split_size = qkv.shape[-1] // attn.heads // 3 + qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + kv = torch.cat((key, value), dim=-1) + split_size = kv.shape[-1] // attn.heads // 2 + kv = kv.view(batch_size, -1, attn.heads, split_size * 2) + key, value = torch.split(kv, split_size, dim=-1) + + head_dim = key.shape[-1] + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)) + if type(hidden_states) == tuple: + hidden_states = hidden_states[0] + # hidden_states = F.scaled_dot_product_attention( + # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + # ) + #hidden_states = + hidden_states = hidden_states.reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FusedTripo2AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedTripo2AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + # NOTE that tripo2 split heads first, then split qkv + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // attn.heads // 3 + qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // attn.heads // 2 + kv = kv.view(batch_size, -1, attn.heads, split_size * 2) + key, value = torch.split(kv, split_size, dim=-1) + + head_dim = key.shape[-1] + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/src/model/skin_vae/autoencoders/FSQ.py b/src/model/skin_vae/autoencoders/FSQ.py new file mode 100755 index 0000000000000000000000000000000000000000..9f7bcf0e670071a04a19495ea8efa7d3ead0a1fe --- /dev/null +++ b/src/model/skin_vae/autoencoders/FSQ.py @@ -0,0 +1,191 @@ +from __future__ import annotations +from functools import wraps, partial +from contextlib import nullcontext +from typing import List, Tuple + +import torch +import torch.nn as nn +from torch.nn import Module +from torch import Tensor, int32 +from torch.amp import autocast + +from einops import rearrange, pack, unpack + +# helper functions + +def exists(v): + return v is not None + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + +def maybe(fn): + @wraps(fn) + def inner(x, *args, **kwargs): + if not exists(x): + return x + return fn(x, *args, **kwargs) + return inner + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +# tensor helpers + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + +# main class + +class FSQ(Module): + def __init__( + self, + levels: List[int], + dim: int | None = None, + num_codebooks = 1, + keep_num_codebooks_dim: bool | None = None, + scale: float | None = None, + allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), + channel_first: bool = False, + projection_has_bias: bool = True, + return_indices = True, + force_quantization_f32 = True + ): + super().__init__() + _levels = torch.tensor(levels, dtype=int32) + self.register_buffer("_levels", _levels, persistent = False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + self.register_buffer("_basis", _basis, persistent = False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + self.channel_first = channel_first + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity() + + self.has_projections = has_projections + + self.return_indices = return_indices + if return_indices: + self.codebook_size: int = self._levels.prod().item() + implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) + self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) + + self.allowed_dtypes = allowed_dtypes + self.force_quantization_f32 = force_quantization_f32 + + def bound(self, z, eps: float = 1e-3): + """ Bound `z`, an array of shape (..., d). """ + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z): + """ Quantizes z, returns quantized zhat, same shape as z. """ + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def _indices_to_codes(self, indices): + level_indices = self.indices_to_level_indices(indices) + codes = self._scale_and_shift_inverse(level_indices) + return codes + + def codes_to_indices(self, zhat): + """ Converts a `code` to an index in the codebook. """ + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(int32) + + def indices_to_level_indices(self, indices): + """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """ + indices = rearrange(indices, '... -> ... 1') + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered + + def indices_to_codes(self, indices): + """ Inverse of `codes_to_indices`. """ + assert exists(indices) + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + codes = self._indices_to_codes(indices) + if self.keep_num_codebooks_dim: + codes = rearrange(codes, '... c d -> ... (c d)') + codes = self.project_out(codes) + if is_img_or_video or self.channel_first: + codes = rearrange(codes, 'b ... d -> b d ...') + return codes + + def dequantize(self, indices): + codes = self._indices_to_codes(indices) + out = self.project_out(codes) + return out + + def forward(self, z): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension + c - number of codebook dim + """ + assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' + + z = self.project_in(z) + z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) + + # whether to force quantization step to be full precision or not + force_f32 = self.force_quantization_f32 + quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext + + with quantization_context(): + orig_dtype = z.dtype + if force_f32 and orig_dtype not in self.allowed_dtypes: + z = z.float() + codes = self.quantize(z) + # returning indices could be optional + indices = None + if self.return_indices: + indices = self.codes_to_indices(codes) + codes = rearrange(codes, 'b n c d -> b n (c d)') + codes = codes.type(orig_dtype) + + # project out + out = self.project_out(codes) + + if not self.keep_num_codebooks_dim and self.return_indices: + indices = maybe(rearrange)(indices, '... 1 -> ...') + # return quantized output and indices + return out, indices, None \ No newline at end of file diff --git a/src/model/skin_vae/autoencoders/SimVQ.py b/src/model/skin_vae/autoencoders/SimVQ.py new file mode 100755 index 0000000000000000000000000000000000000000..4e23d3f721d77b7945876a21678ecfdf318eb129 --- /dev/null +++ b/src/model/skin_vae/autoencoders/SimVQ.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange + +class SimVQ(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta=0.25, remap=None, unknown_index="random", + same_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.codebook_size = self.n_e + nn.init.normal_(self.embedding.weight, mean=0, std=self.e_dim**-0.5) + for p in self.embedding.parameters(): + p.requires_grad = False + + self.embedding_proj = nn.Linear(self.e_dim, self.e_dim) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.same_index_shape = same_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + assert z.shape[-1] == self.e_dim + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + quant_codebook = self.embedding_proj(self.embedding.weight) + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(quant_codebook**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape) + + # compute loss for embedding + if not self.legacy: + quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.same_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, min_encoding_indices, quantization_loss + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + def indices_to_codes(self, indices): + return self.get_codebook_entry(indices, None) + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) + +class SimVQ1D(SimVQ): + + def __init__(self, n_e, e_dim, dim, beta=0.25, remap=None, unknown_index="random", same_index_shape=True, legacy=True): + super().__init__(n_e, e_dim, beta, remap, unknown_index, same_index_shape, legacy) + + self.project_in = nn.Linear(dim, e_dim) + self.project_out = nn.Linear(e_dim, dim) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #assert z.shape[-1] == self.e_dim + z = self.project_in(z) + + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + quant_codebook = self.embedding_proj(self.embedding.weight) + + # # Use IBQ + # logits = torch.matmul(z_flattened, quant_codebook.t()) + # Ind_soft = torch.softmax(logits, dim=1) + # indices = torch.argmax(Ind_soft, dim=1) + # Ind_hard = F.one_hot(indices, num_classes=Ind_soft.shape[1]) + # Ind = Ind_hard - Ind_soft.detach() + Ind_soft + # z_q = torch.matmul(Ind, quant_codebook).view(z.shape) + + # if not self.legacy: + # quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + # torch.mean((z_q - z.detach()) ** 2) + # else: + # quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + # torch.mean((z_q - z.detach()) ** 2) + + # return z_q, indices, quantization_loss + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(quant_codebook**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(quant_codebook, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = F.embedding(min_encoding_indices, quant_codebook).view(z.shape) + + # compute loss for embedding + if not self.legacy: + quantization_loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + quantization_loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.same_index_shape: + min_encoding_indices = min_encoding_indices.view(z.shape[0], z.shape[1]) + z_q = self.project_out(z_q.view(z.shape)) + + return z_q, min_encoding_indices, quantization_loss + diff --git a/src/model/skin_vae/autoencoders/__init__.py b/src/model/skin_vae/autoencoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..93f59dd8fea38fe9a245e1f5eced949c9221fe16 --- /dev/null +++ b/src/model/skin_vae/autoencoders/__init__.py @@ -0,0 +1 @@ +from .skin_fsq_cvae_model import SkinFSQCVAEModel \ No newline at end of file diff --git a/src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py b/src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py new file mode 100755 index 0000000000000000000000000000000000000000..fa21d6f80e38f91cb02a555ef668117918c568b6 --- /dev/null +++ b/src/model/skin_vae/autoencoders/autoencoder_kl_tripo2.py @@ -0,0 +1,254 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from diffusers.models.normalization import LayerNorm +from diffusers.utils import logging +from einops import repeat +import math + +from ..embeddings import FrequencyPositionalEmbedding +from ..transformers.tripo2_transformer import DiTBlock +from ...utils import fps + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +class Tripo2Encoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + dim: int = 512, + num_attention_heads: int = 8, + num_layers: int = 8, + is_learned_queries: bool = False, + sample_tokens: int = 32, + embed_frequency: int = 8, + embed_include_pi: bool = False, + fps: bool = False, + is_miche: bool = False, + ): + super().__init__() + + self.fps = fps + if fps and not is_learned_queries: + self.embedder = FrequencyPositionalEmbedding( + num_freqs=embed_frequency, + logspace=True, + input_dim=3, + include_pi=embed_include_pi, + ) + self.proj_k = nn.Linear(3+self.embedder.out_dim, dim, bias=True) + self.proj_in = nn.Linear(in_channels-3+self.embedder.out_dim, dim, bias=True) + else: + self.proj_in = nn.Linear(in_channels, dim, bias=True) + self.output_channels = dim + self.is_miche = is_miche + init_scale = 0.25 * math.sqrt(1.0 / dim) + init_linear(self.proj_in, init_scale) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=False, + use_cross_attention=True, + cross_attention_dim=dim, + cross_attention_norm_type="layer_norm", + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) # cross attention + ] + + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=True, + self_attention_norm_type="fp32_layer_norm", + use_cross_attention=False, + use_cross_attention_2=False, + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) + for _ in range(num_layers) # self attention + ] + ) + self.norm_out = LayerNorm(dim) + self.is_learned_queries = is_learned_queries + if is_learned_queries: + self.learned_queries = nn.Parameter(torch.randn(sample_tokens, dim) * 0.02) + + def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor, num_tokens: int=1024): + if self.is_learned_queries or not self.fps: + hidden_states = self.proj_in(sample_1) if not self.is_learned_queries else repeat(self.learned_queries[:sample_1.shape[1], :], 'n d -> b n d', b=sample_1.shape[0]) + encoder_hidden_states = self.proj_in(sample_2) + else: + x_q, x_kv = self.get_qkv(x=sample_1, num_tokens=num_tokens) + hidden_states = self.proj_k(x_q) + encoder_hidden_states = self.proj_in(x_kv) + + if not self.is_miche: + for layer, block in enumerate(self.blocks): + if layer == 0: + hidden_states = block( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + else: + hidden_states = block(hidden_states) + else: + for layer, block in enumerate(self.blocks): + if layer == 0: + hidden_states = block(hidden_states, encoder_hidden_states) + else: + hidden_states = block(hidden_states) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states + + def _sample_features( + self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None + ): + """ + Sample points from features of the input point cloud. + + Args: + x (torch.Tensor): The input point cloud. shape: (B, N, C) + num_tokens (int, optional): The number of points to sample. Defaults to 1024. + seed (Optional[int], optional): The random seed. Defaults to None. + """ + rng = np.random.default_rng(seed) + indices = rng.choice( + x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1] + ) + selected_points = x[:, indices] + + batch_size, num_points, num_channels = selected_points.shape + flattened_points = selected_points.view(batch_size * num_points, num_channels) + batch_indices = ( + torch.arange(batch_size).to(x.device).repeat_interleave(num_points) + ) + + # fps sampling + sampling_ratio = 1.0 / 4 + sampled_indices = fps( + flattened_points[:, :3], + batch_indices, + ratio=sampling_ratio, + random_start=self.training, + ) + sampled_points = flattened_points[sampled_indices].view( + batch_size, -1, num_channels + ) + + return sampled_points + + def get_qkv(self, x: torch.Tensor, num_tokens: int = 1024, seed: Optional[int] = None): + positions, features = x[..., :3], x[..., 3:] + x_kv = torch.cat([self.embedder(positions), features], dim=-1) + + sampled_x = self._sample_features(x, num_tokens, seed) + positions, features = ( + sampled_x[..., :3], + sampled_x[..., 3:], + ) + x_q = torch.cat([self.embedder(positions), features], dim=-1) + return x_q, x_kv + + +class Tripo2Decoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 1, + dim: int = 512, + num_attention_heads: int = 8, + num_layers: int = 16, + grad_type: str = "analytical", + grad_interval: float = 0.001, + is_miche: bool = False, + ): + super().__init__() + + if grad_type not in ["numerical", "analytical"]: + raise ValueError(f"grad_type must be one of ['numerical', 'analytical']") + self.grad_type = grad_type + self.grad_interval = grad_interval + self.is_miche = is_miche + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=True, + self_attention_norm_type="fp32_layer_norm", + use_cross_attention=False, + use_cross_attention_2=False, + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) + for _ in range(num_layers) # self attention + ] + + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=False, + use_cross_attention=True, + cross_attention_dim=dim, + cross_attention_norm_type="layer_norm", + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) # cross attention + ] + ) + self.proj_query = nn.Linear(in_channels, dim, bias=True) + + self.norm_out = LayerNorm(dim) + self.proj_out = nn.Linear(dim, out_channels, bias=True) + self.sigmoid = nn.Sigmoid() + init_scale = 0.25 * math.sqrt(1.0 / dim) + init_linear(self.proj_query, init_scale) + init_linear(self.proj_out, init_scale) + + def forward( + self, + sample: torch.Tensor, + queries: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if kv_cache is None: + hidden_states = sample + for _, block in enumerate(self.blocks[:-1]): + hidden_states = block(hidden_states) + kv_cache = hidden_states + # query grid logits by cross attention + q = self.proj_query(queries) + if self.is_miche: + l = self.blocks[-1](q, kv_cache) + else: + l = self.blocks[-1](q, encoder_hidden_states=kv_cache) + logits = self.proj_out(self.norm_out(l)) + + logits = self.sigmoid(logits) + assert kv_cache is not None + return logits, kv_cache \ No newline at end of file diff --git a/src/model/skin_vae/autoencoders/get_model.py b/src/model/skin_vae/autoencoders/get_model.py new file mode 100755 index 0000000000000000000000000000000000000000..6661546b671ab243ffda9d8a7642ebdebfc619c1 --- /dev/null +++ b/src/model/skin_vae/autoencoders/get_model.py @@ -0,0 +1,22 @@ +from .skin_cvae_model import SkinCVAEModel +from .skin_fsq_cvae_model import SkinFSQCVAEModel + +def get_model_cvae( + pretrained_path: str=None, + **kwargs +) -> SkinCVAEModel: + model = SkinCVAEModel(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + return model + +def get_model_fsq_cvae( + pretrained_path: str=None, + **kwargs +) -> SkinFSQCVAEModel: + model = SkinFSQCVAEModel(**kwargs) + if pretrained_path is not None: + state_dict = torch.load(pretrained_path, weights_only=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/src/model/skin_vae/autoencoders/miche_transformer_blocks.py b/src/model/skin_vae/autoencoders/miche_transformer_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..97874e6bcff125d8af3b4b1c413ecc523dfc19b0 --- /dev/null +++ b/src/model/skin_vae/autoencoders/miche_transformer_blocks.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional +import os + +# -*- coding: utf-8 -*- +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from typing import Callable, Iterable, Sequence, Union + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + import deepspeed + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type='cuda') + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.amp.custom_bwd(device_type='cuda') + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +try: + from flash_attn_interface import flash_attn_func + print("use flash attention 3.") + _use_flash3 = True +except: + print("use flash attention 2.") + _use_flash3 = False + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +def flash_attention(q, k, v): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + if _use_flash3: + out, _ = flash_attn_func(q.contiguous(), k.contiguous(), v.contiguous()) + # out = flash_attn_func(q, k, v) + + # q_ = q.transpose(1, 2) + # k_ = k.transpose(1, 2) + # v_ = v.transpose(1, 2) + + # # print(q.shape, k.shape, v.shape) + # out_ = F.scaled_dot_product_attention(q_, k_, v_) + # out_ = out_.transpose(1, 2) + + # # print(torch.abs(out - out_).mean()) + # assert torch.abs(out - out_).mean() < 1e-2, f"the error {torch.abs(out - out_).mean()} is too large" + + # out = out_ + + # print("use flash_atten 3") + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = F.scaled_dot_product_attention(q, k, v) + out = out.transpose(1, 2) + # print("use flash atten 2") + + return out + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + flash: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), False) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = flash_attention(q, k, v) + out = out.reshape(out.shape[0], out.shape[1], -1) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + mlp_width_scale: int = 4, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, hidden_width_scale=mlp_width_scale, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + hidden_width_scale: int = 4, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * hidden_width_scale, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * hidden_width_scale, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py b/src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py new file mode 100755 index 0000000000000000000000000000000000000000..6b66d9e9486d261f0a5d67f795bbe60499fb26e3 --- /dev/null +++ b/src/model/skin_vae/autoencoders/skin_fsq_cvae_model.py @@ -0,0 +1,304 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.modeling_utils import ModelMixin +from einops import repeat +import math + +from ..attention_processor import Tripo2AttnProcessor2_0 +from ..embeddings import FrequencyPositionalEmbedding +from .autoencoder_kl_tripo2 import Tripo2Encoder, Tripo2Decoder +from .FSQ import FSQ +from .SimVQ import SimVQ1D + +from ...utils import fps + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + + +class SkinFSQCVAEModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 4, + cond_channels: int = 3, + latent_channels: int = 64, + num_attention_heads: int = 8, + width_encoder: int = 512, + width_decoder: int = 1024, + num_layers_encoder: int = 8, + num_layers_decoder: int = 16, + embedding_type: str = "frequency", + embed_frequency: int = 8, + embed_include_pi: bool = False, + sample_tokens: int = 32, + **kwargs + ): + super().__init__() + + self.out_channels = 1 + + if embedding_type == "frequency": + self.embedder = FrequencyPositionalEmbedding( + num_freqs=embed_frequency, + logspace=True, + input_dim=3, + include_pi=embed_include_pi, + use_pmpe=kwargs.get('use_pmpe', False), + ) + else: + raise NotImplementedError( + f"Embedding type {embedding_type} is not supported." + ) + + self.is_learned_queries = kwargs['is_learned_queries'] + + is_miche = kwargs.get('is_miche', False) + self.encoder = Tripo2Encoder( + in_channels=in_channels + self.embedder.out_dim, + dim=width_encoder, + num_attention_heads=num_attention_heads, + num_layers=num_layers_encoder, + is_learned_queries=self.is_learned_queries, + sample_tokens=sample_tokens, + is_miche=is_miche, + ) + + self.cond_encoder = Tripo2Encoder( + in_channels=cond_channels + self.embedder.out_dim, + dim=width_encoder, + num_attention_heads=num_attention_heads, + num_layers=num_layers_encoder, + is_miche=is_miche, + ) + + self.decoder = Tripo2Decoder( + in_channels=self.embedder.out_dim + self.cond_channels, + out_channels=self.out_channels, + dim=width_decoder, + num_attention_heads=num_attention_heads, + num_layers=num_layers_decoder, + is_miche=is_miche, + ) + + self.cond_quant = nn.Linear(width_encoder, latent_channels, bias=True) + + self.quant = nn.Linear(width_encoder, latent_channels, bias=True) + self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True) + + init_scale = 0.25 * math.sqrt(1.0 / width_encoder) + init_linear(self.cond_quant, init_scale) + init_linear(self.quant, init_scale) + init_scale = 0.25 * math.sqrt(1.0 / latent_channels) + init_linear(self.post_quant, init_scale) + self.use_slicing = False + self.slicing_length = 1 + if kwargs.get('FSQ_dict', None) is not None: + self.FSQ = FSQ(**kwargs['FSQ_dict']) + else: + self.FSQ = SimVQ1D(**kwargs['SimVQ_dict']) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(Tripo2AttnProcessor2_0()) + + def enable_slicing(self, slicing_length: int = 1) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + self.slicing_length = slicing_length + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _sample_features( + self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None + ): + """ + Sample points from features of the input point cloud. + + Args: + x (torch.Tensor): The input point cloud. shape: (B, N, C) + num_tokens (int, optional): The number of points to sample. Defaults to 2048. + seed (Optional[int], optional): The random seed. Defaults to None. + """ + rng = np.random.default_rng(seed) + indices = rng.choice( + x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1] + ) + selected_points = x[:, indices] + + batch_size, num_points, num_channels = selected_points.shape + flattened_points = selected_points.view(batch_size * num_points, num_channels) + batch_indices = ( + torch.arange(batch_size).to(x.device).repeat_interleave(num_points) + ) + + # fps sampling + sampling_ratio = 1.0 / 4 + sampled_indices = fps( + flattened_points[:, :3], + batch_indices, + ratio=sampling_ratio, + random_start=self.training, + ) + sampled_points = flattened_points[sampled_indices].view( + batch_size, -1, num_channels + ) + + return sampled_points + + def get_qkv(self, x: torch.Tensor, num_tokens: int = 128, seed: Optional[int] = None, not_get_q: bool=False): + positions, features = x[..., :3], x[..., 3:] + x_kv = torch.cat([self.embedder(positions), features], dim=-1) + + if not_get_q: + x_q = torch.zeros((x.shape[0], num_tokens, x.shape[-1]), dtype=x.dtype, device=x.device) + else: + sampled_x = self._sample_features(x, num_tokens, seed) + positions, features = ( + sampled_x[..., :3], + sampled_x[..., 3:], + ) + x_q = torch.cat([self.embedder(positions), features], dim=-1) + return x_q, x_kv + + def _encode( + self, x: torch.Tensor|None, cond: torch.Tensor|None, num_tokens: int = 128, cond_tokens: int = 128, seed: Optional[int] = None, + return_z: bool=True, return_cond: bool=True, + ): + position_channels = 3 + if return_z: + assert x is not None + x_q, x_kv = self.get_qkv(x, num_tokens, seed, not_get_q=self.is_learned_queries) + x = self.encoder(x_q, x_kv) + x = self.quant(x) + else: + x = None + + if return_cond: + assert cond is not None + cond_q, cond_kv = self.get_qkv(cond, cond_tokens, seed) + cond_embed = self.cond_encoder(cond_q, cond_kv) + cond = self.cond_quant(cond_embed) + else: + cond = None + + return x, cond + + def _decode( + self, z: torch.Tensor, + cond: torch.Tensor, + sampled_points: torch.Tensor, + num_chunks: Optional[int] = None, + ) -> torch.Tensor: + xyz_samples = sampled_points + z = self.post_quant(torch.cat([z, cond], dim=1)) + + num_points = xyz_samples.shape[1] + if num_chunks is None: + num_chunks = num_points + + queries = sampled_points.to(z.device, dtype=z.dtype) + positions, features = ( + queries[..., :3], + queries[..., 3:], + ) + + kv_cache = None + dec = [] + for i in range(0, num_points, num_chunks): + queries = torch.cat([self.embedder(positions[:, i:i + num_chunks, :]), features[:, i:i + num_chunks, :]], dim=-1) + z, kv_cache = self.decoder(z, queries, kv_cache) + dec.append(z) + + return torch.cat(dec, dim=1) + + def compile_model(self): + self.encoder = torch.compile(self.encoder) + self.cond_encoder = torch.compile(self.cond_encoder) + self.decoder = torch.compile(self.decoder) + + def forward(self, x: torch.Tensor): + pass diff --git a/src/model/skin_vae/autoencoders/vae.py b/src/model/skin_vae/autoencoders/vae.py new file mode 100755 index 0000000000000000000000000000000000000000..0e53eada49dd0cbca365829411ef78e0cbec711e --- /dev/null +++ b/src/model/skin_vae/autoencoders/vae.py @@ -0,0 +1,73 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +from diffusers.utils.torch_utils import randn_tensor + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters: torch.Tensor, + deterministic: bool = False, + feature_dim: int = 1, + ): + self.parameters = parameters + self.feature_dim = feature_dim + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + ) + elif isinstance(other, DiagonalGaussianDistribution): + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + ) + elif isinstance(other, torch.Tensor): + return 0.5 * torch.mean( + torch.pow(self.mean - other, 2) + self.var - 1.0 - self.logvar, + ) + else: + raise ValueError("Other must be a DiagonalGaussianDistribution or torch.Tensor") + + def nll( + self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3] + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/src/model/skin_vae/embeddings.py b/src/model/skin_vae/embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..cd253932239850792f53b2b8c7bfd126bbcc26ce --- /dev/null +++ b/src/model/skin_vae/embeddings.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + + +class FrequencyPositionalEmbedding(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + use_pmpe: bool = False, + ) -> None: + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + self.use_pmpe = use_pmpe + if use_pmpe: + phase = torch.arange(num_freqs, dtype=torch.float32) + for i in range(num_freqs): + phase[i] = torch.pow(torch.tensor(num_freqs), 1.0-(i+1)/num_freqs)+(i+1)/num_freqs + phase *= torch.pi*2 + self.register_buffer("phase", phase, persistent=False) + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view( + *x.shape[:-1], -1 + ) + if self.use_pmpe: + phase = (x[..., None].contiguous()*torch.pi*0.5 + self.phase).view( + *x.shape[:-1], -1 + ) + res = torch.cat((embed.sin()+phase.sin(), embed.cos()+phase.cos()), dim=-1) + else: + res = torch.cat((embed.sin(), embed.cos()), dim=-1) + if self.include_input: + return torch.cat((x, res), dim=-1) + else: + return res + else: + return x diff --git a/src/model/skin_vae/transformers/__init__.py b/src/model/skin_vae/transformers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..3a61dd3419b5cc0844e061e583b172c5ad5454f0 --- /dev/null +++ b/src/model/skin_vae/transformers/__init__.py @@ -0,0 +1,41 @@ +from typing import Callable, Optional + +from .tripo2_transformer import Tripo2DiTModel + + +def default_set_attn_proc_func( + name: str, + hidden_size: int, + cross_attention_dim: Optional[int], + ori_attn_proc: object, +) -> object: + return ori_attn_proc + + +def set_transformer_attn_processor( + transformer: Tripo2DiTModel, + set_self_attn_proc_func: Callable = default_set_attn_proc_func, + set_cross_attn_proc_func: Callable = default_set_attn_proc_func, +) -> None: + attn_procs = {} + for name, attn_processor in transformer.attn_processors.items(): + hidden_size = transformer.config.width + if name.endswith("attn1.processor"): + # self attention + attn_procs[name] = set_self_attn_proc_func( + name, hidden_size, None, attn_processor + ) + elif name.endswith("attn2.processor"): + # cross attention + cross_attention_dim = transformer.config.cross_attention_dim + attn_procs[name] = set_cross_attn_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + elif name.endswith("attn2_2.processor"): + # cross attention 2 + cross_attention_dim = transformer.config.cross_attention_2_dim + attn_procs[name] = set_cross_attn_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + + transformer.set_attn_processor(attn_procs) diff --git a/src/model/skin_vae/transformers/modeling_outputs.py b/src/model/skin_vae/transformers/modeling_outputs.py new file mode 100755 index 0000000000000000000000000000000000000000..0928fa0ca39275a85f8b7fa49c68af745d4c74c5 --- /dev/null +++ b/src/model/skin_vae/transformers/modeling_outputs.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +import torch + + +@dataclass +class Transformer1DModelOutput: + sample: torch.FloatTensor diff --git a/src/model/skin_vae/transformers/tripo2_transformer.py b/src/model/skin_vae/transformers/tripo2_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..55fd8cf093a9d5c9fccb3494964ee91062d32f32 --- /dev/null +++ b/src/model/skin_vae/transformers/tripo2_transformer.py @@ -0,0 +1,664 @@ +# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + FP32LayerNorm, + LayerNorm, +) +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +from ..attention_processor import FusedTripo2AttnProcessor2_0, Tripo2AttnProcessor2_0 +from .modeling_outputs import Transformer1DModelOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class DiTBlock(nn.Module): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + use_self_attention: bool = True, + use_cross_attention: bool = False, + self_attention_norm_type: Optional[str] = None, # ada layer norm + cross_attention_dim: Optional[int] = None, + cross_attention_norm_type: Optional[str] = "fp32_layer_norm", + # parallel second cross attention + use_cross_attention_2: bool = False, + cross_attention_2_dim: Optional[int] = None, + cross_attention_2_norm_type: Optional[str] = None, + dropout=0.0, + activation_fn: str = "gelu", + norm_type: str = "fp32_layer_norm", # TODO + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, # int(dim * 4) if None + ff_bias: bool = True, + skip: bool = False, + skip_concat_front: bool = False, # [x, skip] or [skip, x] + skip_norm_last: bool = False, # this is an error + qk_norm: bool = True, + qkv_bias: bool = True, + ): + super().__init__() + + self.use_self_attention = use_self_attention + self.use_cross_attention = use_cross_attention + self.use_cross_attention_2 = use_cross_attention_2 + self.skip_concat_front = skip_concat_front + self.skip_norm_last = skip_norm_last + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + if use_self_attention: + if ( + self_attention_norm_type == "fp32_layer_norm" + or self_attention_norm_type is None + ): + self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + raise NotImplementedError + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-6, + bias=qkv_bias, + processor=Tripo2AttnProcessor2_0(), + ) + + # 2. Cross-Attn + if use_cross_attention: + assert cross_attention_dim is not None + + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + cross_attention_norm=cross_attention_norm_type, + eps=1e-6, + bias=qkv_bias, + processor=Tripo2AttnProcessor2_0(), + ) + + # 2'. Parallel Second Cross-Attn + if use_cross_attention_2: + assert cross_attention_2_dim is not None + + self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2_2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_2_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + cross_attention_norm=cross_attention_2_norm_type, + eps=1e-6, + bias=qkv_bias, + processor=Tripo2AttnProcessor2_0(), + ) + + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + skip: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + # Prepare attention kwargs + attention_kwargs = attention_kwargs or {} + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat( + ( + [skip, hidden_states] + if self.skip_concat_front + else [hidden_states, skip] + ), + dim=-1, + ) + if self.skip_norm_last: + # don't do this + hidden_states = self.skip_linear(cat) + hidden_states = self.skip_norm(hidden_states) + else: + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + if self.use_self_attention: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + if self.use_cross_attention: + if self.use_cross_attention_2: + hidden_states = ( + hidden_states + + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + self.attn2_2( + self.norm2_2(hidden_states), + encoder_hidden_states=encoder_hidden_states_2, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + ) + else: + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class Tripo2DiTModel(ModelMixin, ConfigMixin): + """ + Tripo2DiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + use_style_cond_and_image_meta_size (`bool`, *optional*): + Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + width: int = 2048, + in_channels: int = 64, + num_layers: int = 21, + cross_attention_dim: int = 768, + cross_attention_2_dim: int = 1024, + ): + super().__init__() + self.out_channels = in_channels + self.num_heads = num_attention_heads + self.inner_dim = width + self.mlp_ratio = 4.0 + + time_embed_dim, timestep_input_dim = self._set_time_proj( + "positional", + inner_dim=self.inner_dim, + flip_sin_to_cos=False, + freq_shift=0, + time_embedding_dim=None, + ) + self.time_proj = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim + ) + self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + use_self_attention=True, + use_cross_attention=True, + self_attention_norm_type="fp32_layer_norm", + cross_attention_dim=self.config.cross_attention_dim, + cross_attention_norm_type=None, + use_cross_attention_2=True, + cross_attention_2_dim=self.config.cross_attention_2_dim, + cross_attention_2_norm_type=None, + activation_fn="gelu", + norm_type="fp32_layer_norm", # TODO + norm_eps=1e-5, + ff_inner_dim=int(self.inner_dim * self.mlp_ratio), + skip=layer > num_layers // 2, + skip_concat_front=True, + skip_norm_last=True, # this is an error + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + qkv_bias=False, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = LayerNorm(self.inner_dim) + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True) + + def _set_time_proj( + self, + time_embedding_type: str, + inner_dim: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or inner_dim * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_embed = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or inner_dim * 4 + + self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + timestep_input_dim = inner_dim + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripo2AttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is ๐Ÿงช experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedTripo2AttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is ๐Ÿงช experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(Tripo2AttnProcessor2_0()) + + def forward( + self, + hidden_states: Optional[torch.Tensor], + timestep: Union[int, float, torch.LongTensor], + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + return_dict: bool + Whether to return a dictionary. + """ + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + _, N, _ = hidden_states.shape + + temb = self.time_embed(timestep).to(hidden_states.dtype) + temb = self.time_proj(temb) + temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states + + hidden_states = self.proj_in(hidden_states) + + # N + 1 token + hidden_states = torch.cat([temb, hidden_states], dim=1) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer <= self.config.num_layers // 2: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_2=encoder_hidden_states_2, + attention_kwargs=attention_kwargs, + ) # (N, L, D) + else: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_2=encoder_hidden_states_2, + skip=skip, + attention_kwargs=attention_kwargs, + ) # (N, L, D) + + if layer < self.config.num_layers // 2: + skips.append(hidden_states) + + # final layer + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states[:, -N:] + hidden_states = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer1DModelOutput(sample=hidden_states) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking( + self, chunk_size: Optional[int] = None, dim: int = 0 + ) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) diff --git a/src/model/skin_vae_model.py b/src/model/skin_vae_model.py new file mode 100755 index 0000000000000000000000000000000000000000..a2b4e88b6997883ce5861333bb498dc9593ff80d --- /dev/null +++ b/src/model/skin_vae_model.py @@ -0,0 +1,208 @@ +from dataclasses import asdict, dataclass +from omegaconf import OmegaConf +from scipy.spatial import cKDTree # type: ignore +from torch import nn, Tensor +from typing import Dict, List + +import math +import numpy as np +import random +import torch +import torch.nn.functional as F + +from src.rig_package.info.asset import Asset + +from .spec import ModelSpec, ModelInput, VaeInput +from .skin_vae.autoencoders import SkinFSQCVAEModel + +try: + from flash_attn_interface import flash_attn_func # type: ignore +except Exception as e: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func + def flash_attn_func(*args, **kwargs): + res = _flash_attn_func(*args, **kwargs) + return res, None + +class Perceiver(nn.Module): + def __init__(self, channels, out_tokens, num_heads=8): + super().__init__() + self.q_vec = nn.Parameter(torch.randn(out_tokens // num_heads, num_heads, channels) * 0.02) + self.num_heads = num_heads + self.head_dim = channels // num_heads + + self.k_proj = nn.Linear(channels, channels) + self.v_proj = nn.Linear(channels, channels) + self.out_proj = nn.Linear(channels, channels) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + k = self.k_proj(x) # [B, N, C] + v = self.v_proj(x) # [B, N, C] + q_repeated = self.q_vec.repeat(B, 1, 1, 1) + + q = q_repeated.view(B, -1, self.num_heads, self.head_dim).type(torch.bfloat16) + k = k.view(B, -1, self.num_heads, self.head_dim) + v = v.view(B, -1, self.num_heads, self.head_dim) + + hidden_states, _ = flash_attn_func(q, k, v) + hidden_states = hidden_states.view(B, -1, self.num_heads * self.head_dim) # type: ignore + hidden_states = self.out_proj(hidden_states) + return hidden_states + +class SkinVAEModel(ModelSpec): + + def __init__(self, model_config, transform_config, tokenizer_config=None): + super().__init__(model_config, transform_config, tokenizer_config) + + cfg = self.model_config + self.cond_tokens = cfg['sample']['cond_tokens'] + self.compress_tokens = cfg['sample']['compress_tokens'] + self.sample_tokens = cfg['sample']['sample_tokens'] + self.only_dense = cfg['sample'].get('only_dense', False) + self.model_type = cfg.get('type', 'fsqc') + + if self.model_type == 'fsqc': + self.model = SkinFSQCVAEModel(**cfg['model'], sample_tokens=self.sample_tokens) + else: + raise NotImplementedError() + if self.sample_tokens != self.compress_tokens: + self.down_perceiver = Perceiver(self.model.latent_channels, self.compress_tokens) + if self.sample_tokens != self.compress_tokens: + self.up_perceiver = Perceiver(self.model.latent_channels, self.sample_tokens) + + def compile_model(self): + self.model.compile_model() + + @property + def vocab_size(self) -> int: + return self.model.FSQ.codebook_size + + @property + def latent_channels(self) -> int: + return self.model.latent_channels + + def encode(self, vae_input: VaeInput, num_tokens: int=4, j: int=0, full: bool=False, encode_repeat: int=4, return_cond: bool=True): + raise NotImplementedError() + + def decode(self, z: Tensor, sampled_cond: Tensor, cond_tokens: Tensor, full: bool=False, encode_repeat: int=4) -> Tensor: + assert z.shape[0] == sampled_cond.shape[0] == cond_tokens.shape[0] + if full: + l = z.shape[0] + s = [] + for i in range(0, l, encode_repeat): + t = min(l,i+encode_repeat) + if self.sample_tokens != self.compress_tokens: + _z = self.up_perceiver(z[i:t]) + else: + _z = z[i:t] + logits = self.model._decode(z=_z, cond=cond_tokens[i:t], sampled_points=sampled_cond[i:t]) + s.append(logits) + return torch.cat(s, dim=0) + else: + if self.sample_tokens != self.compress_tokens: + z = self.up_perceiver(z) + logits = self.model._decode(z=z, cond=cond_tokens, sampled_points=sampled_cond) + return logits + + def get_loss_dict( + self, + skin_pred: Tensor, + skin_gt: Tensor, + ) -> Dict[str, Tensor]: + raise NotImplementedError() + + def get_input(self, batch: Dict) -> VaeInput: + vertices: Tensor = batch['vertices'].float() # (B, N, 3) + normals: Tensor = batch['normals'].float() # (B, N, 3) + uniform_skin: List[Tensor] = batch['uniform_skin'] # [(N, J)] + dense_skin: List[Tensor] = batch['dense_skin'] # [(J, skin_samples)] + dense_vertices: List[Tensor] = batch['dense_vertices'] # [(J, skin_samples, 3)] + dense_normals: List[Tensor] = batch['dense_normals'] # [(J, skin_samples, 3)] + dense_indices: List[List[int]] = batch['dense_indices'] # [List[J]] + + B = vertices.shape[0] + uniform_cond = torch.cat([vertices, normals], dim=-1).float() + dense_cond = [] + for i in range(B): + dense_cond.append(torch.cat([dense_vertices[i], dense_normals[i]], dim=-1).float()) + + uniform_skin = [s.float() for s in uniform_skin] + dense_skin = [s.float() for s in dense_skin] + return VaeInput( + dense_cond=dense_cond, + dense_skin=dense_skin, + dense_indices=dense_indices, + uniform_cond=uniform_cond, + uniform_skin=uniform_skin, + ) + + @torch.autocast(device_type='cuda', dtype=torch.bfloat16) + def training_step(self, batch: Dict) -> Dict: + raise NotImplementedError() + + def process_fn(self, batch: List[ModelInput], is_train: bool = True) -> List[Dict]: + res = [] + for b in batch: + asset = b.asset + assert asset is not None + assert asset.sampled_vertex_groups is not None + assert 'skin' in asset.sampled_vertex_groups + assert asset.meta is not None + assert 'dense_indices' in asset.meta + assert 'dense_skin' in asset.meta + assert 'dense_vertices' in asset.meta + assert 'dense_normals' in asset.meta + _d = { + 'vertices': asset.sampled_vertices, + 'normals': b.asset.sampled_normals, + 'non': { + 'uniform_skin': asset.sampled_vertex_groups['skin'], + 'num_bones': asset.J, + 'skin_samples': asset.skin_samples, + 'dense_indices': asset.meta['dense_indices'], + 'dense_skin': asset.meta['dense_skin'], + 'dense_vertices': asset.meta['dense_vertices'], + 'dense_normals': asset.meta['dense_normals'], + } + } + res.append(_d) + return res + + def forward(self, batch: Dict) -> Dict: + return self.training_step(batch=batch) + + @torch.autocast('cuda', dtype=torch.bfloat16) + def predict_step(self, batch: Dict) -> Dict: + vertices: Tensor = batch['vertices'].float() # (B, N, 3) + num_bones: List[int] = batch['num_bones'] + + B = vertices.shape[0] + N = vertices.shape[1] + + vae_input = self.get_input(batch=batch) + num_tokens = 4 + z, cond_tokens, indices, _ = self.encode(vae_input=vae_input, num_tokens=num_tokens, full=True, encode_repeat=8) + assert cond_tokens is not None + + z = self.model.FSQ.indices_to_codes(indices).reshape(z.shape) + _skin_pred = self.decode(z=z, sampled_cond=vae_input.get_flatten_uniform_cond(), cond_tokens=cond_tokens[vae_input.get_flatten_indices()], full=True, encode_repeat=8) + _skin_pred = _skin_pred.squeeze(-1) + + tot = 0 + results = [] + for i in range(B): + asset: Asset = batch['model_input'][i].asset.copy() + skin_pred = torch.zeros((N, num_bones[i]), dtype=vertices.dtype, device=vertices.device) + for j in range(vae_input.get_len(i=i)): + skin_pred[:, vae_input.true_j(i=i, j=j)] = _skin_pred[tot] + tot += 1 + sampled_vertices = vertices[i].detach().float().cpu().numpy() + tree = cKDTree(sampled_vertices) + distances, indices = tree.query(asset.vertices) + sampled_skin = skin_pred.detach().float().cpu().numpy()[indices] + asset.skin = sampled_skin + results.append(asset) + + return { + 'results': results, + } \ No newline at end of file diff --git a/src/model/spec.py b/src/model/spec.py new file mode 100755 index 0000000000000000000000000000000000000000..fbb0e932f5dcfdb45351b63fad325018d62f84b0 --- /dev/null +++ b/src/model/spec.py @@ -0,0 +1,300 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from numpy import ndarray +from omegaconf import OmegaConf +from typing import Dict, List, Optional, final +from torch import Tensor + +import numpy as np +import lightning.pytorch as pl +import torch + +from ..data.transform import Transform +from ..rig_package.info.asset import Asset +from ..tokenizer.spec import DetokenizeOutput + +@dataclass +class ModelInput(): + asset: Asset + tokens: Optional[ndarray]=None + +class ModelSpec(pl.LightningModule, ABC): + + model_config: Dict + transform_config: Dict + tokenizer_config: Dict|None + + @abstractmethod + def __init__(self, model_config, transform_config, tokenizer_config=None): + super().__init__() + if not isinstance(model_config, dict): + model_cfg = OmegaConf.to_container(model_config, resolve=True) + else: + model_cfg = model_config + if not isinstance(transform_config, dict): + transform_cfg = OmegaConf.to_container(transform_config, resolve=True) + else: + transform_cfg = transform_config + if tokenizer_config is not None and not isinstance(tokenizer_config, dict): + tokenizer_cfg = OmegaConf.to_container(tokenizer_config, resolve=True) + else: + tokenizer_cfg = tokenizer_config + self.model_config = model_cfg # type: ignore + self.transform_config = transform_cfg # type: ignore + self.tokenizer_config = tokenizer_cfg # type: ignore + self.save_hyperparameters(model_cfg) + self.save_hyperparameters(transform_cfg) + self.save_hyperparameters(tokenizer_cfg) + + @final + def _process_fn(self, batch: List[ModelInput]) -> List[Dict]: + n_batch = self.process_fn(batch) + if self._trainer is None or not self.trainer.training: + for k in n_batch[0].keys(): + if not isinstance(n_batch[0][k], ndarray) and not isinstance(n_batch[0][k], Tensor): + continue + s = n_batch[0][k].shape + for i in range(1, len(n_batch)): + assert n_batch[i][k].shape == s, f"{k} has different shape in batch" + for (i, b) in enumerate(batch): + non = n_batch[i].get('non', {}) + non['model_input'] = deepcopy(b) + n_batch[i]['non'] = non + else: + for b in batch: + del b.asset + return n_batch + + @abstractmethod + def process_fn(self, batch: List[ModelInput]) -> List[Dict]: + """ + Fetch data from dataloader and turn it into Tensor objects. + """ + raise NotImplementedError() + + def compile_model(self): + """ + Compile the model. Do this before training and after loading state dicts. + """ + pass + + @classmethod + def load_from_system_checkpoint(cls, checkpoint_path: str, strict: bool=True, **kwargs): + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + state_dict = ckpt['state_dict'] + model_config = kwargs.get('model_config', None) + transform_config = kwargs.get('transform_config', None) + tokenizer_config = kwargs.get('tokenizer_config', None) + if model_config is None: + model_config = ckpt['hyper_parameters']['model_config'] + if transform_config is None: + transform_config = ckpt['hyper_parameters']['transform_config'] + if tokenizer_config is None: + tokenizer_config = ckpt['hyper_parameters']['tokenizer_config'] + new_state_dict = {} + for k, v in state_dict.items(): + k = k.replace("_orig_mod.", "") + if k.startswith("model."): + k = k[len("model.") :] + new_state_dict[k] = v + model = cls( + model_config=model_config, + transform_config=transform_config, + tokenizer_config=tokenizer_config, + ) + missing, unexpected = model.load_state_dict(new_state_dict, strict=strict) + if missing or unexpected: + print(f"[Warning] Missing keys: {missing}") + print(f"[Warning] Unexpected keys: {unexpected}") + model.on_load_checkpoint(ckpt) + return model + + def get_train_transform(self) -> Transform|None: + cfg = self.transform_config.get('train_transform', None) + if cfg is None: + return None + return Transform.parse(**cfg) + + def get_validate_transform(self) -> Transform|None: + cfg = self.transform_config.get('validate_transform', None) + if cfg is None: + return None + return Transform.parse(**cfg) + + def get_predict_transform(self) -> Transform|None: + cfg = self.transform_config.get('predict_transform', None) + if cfg is None: + return None + return Transform.parse(**cfg) + + def predict_step(self, batch: Dict, no_cls: bool=False, skeleton_tokens=None) -> Dict: + raise NotImplementedError() + + +@dataclass +class VaeInput(): + dense_cond: List[Tensor] # [(J, skin_samples, 6)] + dense_skin: List[Tensor] # [(J, skin_samples)] + dense_indices: List[List[int]] # [List[J]], corresponding indices of gt + uniform_cond: Tensor # (B, N, 6) + uniform_skin: List[Tensor] # [(N, J)] + + @property + def B(self): + return self.uniform_cond.shape[0] + + @property + def max_J(self): + return max([len(s) for s in self.dense_indices]) + + def get_len(self, i) -> int: + return len(self.dense_indices[i]) + + def _clamp_j(self, i: int, j: int) -> int: + return min(j, len(self.dense_indices[i])-1) + + def get_dense_cond(self, j: int) -> Tensor: + """return (B, skin_samples, 6)""" + return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_dense_skin(self, j: int) -> Tensor: + """return (B, skin_samples)""" + return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_full_cond(self, j: int) -> Tensor: + """return (B, N+skin_samples, 6)""" + return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1) + + def get_uniform_skin(self, j: int) -> Tensor: + """return (B, N)""" + return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_full_skin(self, j: int) -> Tensor: + """return (B, N+skin_samples)""" + return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1) + + def get_flatten_uniform_cond(self) -> Tensor: + """return (sum_J, N, 6)""" + return self.uniform_cond[self.get_flatten_indices()] + + def get_flatten_dense_cond(self) -> Tensor: + """return (sum_J, skin_samples, 6)""" + return torch.cat(self.dense_cond, dim=0) + + def get_flatten_dense_skin(self) -> Tensor: + """return (sum_J, skin_samples)""" + return torch.cat(self.dense_skin, dim=0) + + def get_flatten_full_skin(self) -> Tensor: + """return (sum_J, N+skin_samples)""" + # (sum_J, N) + s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0) + return torch.cat([s, self.get_flatten_dense_skin()], dim=1) + + def get_flatten_full_cond(self) -> Tensor: + """return (sum_J, N+skin_samples, 6)""" + return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1) + + def get_flatten_indices(self) -> List[int]: + """return (sum_J)""" + return [i for i in range(self.B) for _ in range(self.get_len(i=i))] + + def true_j(self, i: int, j: int) -> int: + """return (clamped) corresponding indice in the skeleton""" + return self.dense_indices[i][self._clamp_j(i=i, j=j)] + +@dataclass +class TokenRigResult(): + cond: Optional[Tensor]=None # [vertices, normals] + cond_latents: Optional[Tensor]=None # (len, dim) + input_ids: Optional[Tensor]=None # (l,) + output_ids: Optional[Tensor]=None # (l,) + skin_pred: Optional[Tensor]=None # (N, J) + detokenize_output: Optional[DetokenizeOutput]=None + asset: Optional[Asset]=None + +@dataclass +class BoneVaeInput(): + dense_cond: List[Tensor] # [(J, skin_samples, 6)] + dense_skin: List[Tensor] # [(J, skin_samples)] + dense_indices: List[List[int]] # [List[J]], corresponding indices of gt + bones: List[Tensor] # [(J, 6)] + uniform_cond: Tensor # (B, N, 6) + uniform_skin: List[Tensor] # [(N, J)] + + @property + def total_samples(self) -> int: + return self.dense_cond[0].shape[1] + self.uniform_cond.shape[1] + + @property + def B(self) -> int: + return self.uniform_cond.shape[0] + + @property + def max_J(self) -> int: + return max([len(s) for s in self.dense_indices]) + + def get_len(self, i) -> int: + return len(self.dense_indices[i]) + + def _clamp_j(self, i: int, j: int) -> int: + return min(j, len(self.dense_indices[i])-1) + + def get_dense_cond(self, j: int) -> Tensor: + """return (B, skin_samples, 6)""" + return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_dense_skin(self, j: int) -> Tensor: + """return (B, skin_samples)""" + return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_full_cond(self, j: int) -> Tensor: + """return (B, N+skin_samples, 6)""" + return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1) + + def get_uniform_skin(self, j: int) -> Tensor: + """return (B, N)""" + return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_full_skin(self, j: int) -> Tensor: + """return (B, N+skin_samples)""" + return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1) + + def get_bones(self, j: int) -> Tensor: + """return (B, 3)""" + return torch.stack([self.bones[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) + + def get_flatten_bones(self) -> Tensor: + """return (sum_J, 3)""" + return torch.cat([self.bones[i] for i in range(self.B)]) + + def get_flatten_uniform_cond(self) -> Tensor: + """return (sum_J, N, 6)""" + return self.uniform_cond[self.get_flatten_indices()] + + def get_flatten_dense_cond(self) -> Tensor: + """return (sum_J, skin_samples, 6)""" + return torch.cat(self.dense_cond, dim=0) + + def get_flatten_dense_skin(self) -> Tensor: + """return (sum_J, skin_samples)""" + return torch.cat(self.dense_skin, dim=0) + + def get_flatten_full_skin(self) -> Tensor: + """return (sum_J, N+skin_samples)""" + # (sum_J, N) + s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0) + return torch.cat([s, self.get_flatten_dense_skin()], dim=1) + + def get_flatten_full_cond(self) -> Tensor: + """return (sum_J, N+skin_samples, 6)""" + return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1) + + def get_flatten_indices(self) -> List[int]: + """return (sum_J)""" + return [i for i in range(self.B) for _ in range(self.get_len(i=i))] + + def true_j(self, i: int, j: int) -> int: + """return (clamped) corresponding indice in the skeleton""" + return self.dense_indices[i][self._clamp_j(i=i, j=j)] \ No newline at end of file diff --git a/src/model/tokenrig.py b/src/model/tokenrig.py new file mode 100644 index 0000000000000000000000000000000000000000..ab35e3fb8dd850737cba7d1782f03c75846971be --- /dev/null +++ b/src/model/tokenrig.py @@ -0,0 +1,666 @@ +from copy import deepcopy +from pathlib import Path +from torch import nn, Tensor, FloatTensor +from torch.nn.functional import pad +from transformers import AutoModelForCausalLM, AutoConfig, LogitsProcessor, LogitsProcessorList # type: ignore +from typing import Dict, List, Tuple + +import math +import numpy as np +import torch +import torch.nn.functional as F + +LLM_LOCAL_DIR = Path("models/Qwen3-0.6B") + +from .skin_vae_model import SkinVAEModel +from .skin_vae.autoencoders import SkinFSQCVAEModel +from .spec import ModelSpec, ModelInput, VaeInput, TokenRigResult +from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder + +from ..rig_package.info.asset import Asset +from ..tokenizer.spec import Tokenizer +from ..tokenizer.spec import DetokenizeOutput +from ..tokenizer.parse import get_tokenizer + +try: + from flash_attn_interface import flash_attn_func # type: ignore +except Exception as e: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func + def flash_attn_func(*args, **kwargs): + res = _flash_attn_func(*args, **kwargs) + return res, None + +class VocabSwitchingLogitsProcessor(LogitsProcessor): + def __init__(self, tokenizer: Tokenizer, switch_token_id, eos_token_id, tokens_per_skin, init): + # make sure all skin tokens > switch_token_id + self.tokenizer = tokenizer + self.switch_token_id = switch_token_id + self.eos_token_id = eos_token_id + self.tokens_per_skin = tokens_per_skin + self.init = init + + def __call__(self, input_ids: Tensor, scores: FloatTensor) -> FloatTensor: + # input_ids shape: (batch_size, seq_len) + for batch_idx, sequence in enumerate(input_ids): + mask = torch.full_like(scores[batch_idx], float('-inf')) + sequence = torch.cat([self.init, sequence]) + length = len(sequence) + if self.switch_token_id in sequence: + mask[self.switch_token_id:] = 0 + where = torch.where(sequence == self.switch_token_id)[0][:1] + J = self.tokenizer.bones_in_sequence(ids=sequence.detach().cpu().numpy()) + if (length-where) == J*self.tokens_per_skin: + mask[:] = float('-inf') + mask[self.eos_token_id] = 0 + else: + mask[self.eos_token_id] = float('-inf') + else: + tokens = self.tokenizer.next_posible_token(ids=sequence.detach().cpu().numpy()) + mask[tokens] = 0 + scores[batch_idx] = scores[batch_idx] + mask + return scores + +class TokenRig(ModelSpec): + + def __init__(self, model_config, transform_config, tokenizer_config=None): + assert tokenizer_config is not None + super().__init__(model_config=model_config, transform_config=transform_config, tokenizer_config=tokenizer_config) + + cfg = self.model_config + + self.tokens_per_skin: int = cfg['tokens_per_skin'] + self.tokens_skin_cond: int = cfg['tokens_skin_cond'] + + self.use_rope: bool = cfg.get('use_rope', True) + self.encode_repeat: int = cfg.get('encode_repeat', 4) + + self.skin_warmup_start_epoch: int = cfg.get('skin_warmup_start_epoch', 0) + self.skin_warmup_end_epoch: int = cfg.get('skin_warmup_end_epoch', -1) + + self.vae = SkinVAEModel.load_from_system_checkpoint(cfg['pretrained_vae']).to(torch.bfloat16) + for param in self.vae.parameters(): + param.requires_grad_(False) + self.vae.eval() + + self.mesh_encoder = get_mesh_encoder(**cfg['mesh_encoder']) + + assert ( + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or + isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) + ) + self.mesh_encoder = self.mesh_encoder.to(torch.bfloat16) + + self.tokenizer: Tokenizer = get_tokenizer(**tokenizer_config) + # (tokenizer codebook, fsq vae codebook) + self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 + self.eos = self.vocab_size - 1 + + _d = cfg['llm'].copy() + self.hidden_size = _d['hidden_size'] + + _d['vocab_size'] = self.vocab_size + if LLM_LOCAL_DIR.exists(): + _d['pretrained_model_name_or_path'] = str(LLM_LOCAL_DIR) + llm_config = AutoConfig.from_pretrained(**_d) + self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 + llm_config.torch_dtype = torch.bfloat16 + llm_config.pre_norm = True + self.llm_config = llm_config + self.transformer = AutoModelForCausalLM.from_config(config=llm_config, attn_implementation="flash_attention_2").to(torch.bfloat16) + + self.output_proj = nn.Sequential( + nn.Linear(self.mesh_encoder.width, self.hidden_size), + nn.RMSNorm(self.hidden_size), + ).to(torch.bfloat16) + + init_scale = cfg.get('init_scale', None) + if init_scale is not None: + self.initialize_weights(init_scale) + + def compile_model(self): + self.vae.compile_model() + self.transformer = torch.compile(self.transformer, dynamic=False) + self.mesh_encoder = torch.compile(self.mesh_encoder, dynamic=False) + + def initialize_weights(self, s: float): + def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + init_scale = s * math.sqrt(1.0 / self.mesh_encoder.width) + + for m in self.mesh_encoder.modules(): + if isinstance(m, nn.Linear): + init_linear(m, stddev=init_scale) + init_scale = s * math.sqrt(1.0 / self.hidden_size) + for m in self.output_proj.modules(): + if isinstance(m, nn.Linear): + init_linear(m, stddev=init_scale) + + def get_skin_warmup_rate(self, steps_per_epoch: int) -> float: + if self.current_epoch < self.skin_warmup_start_epoch: + return 0. + if self.current_epoch > self.skin_warmup_end_epoch: + return 1. + start_steps = self.skin_warmup_start_epoch * steps_per_epoch + end_steps = (self.skin_warmup_end_epoch+1) * steps_per_epoch + rate = (self.global_step-start_steps) / (end_steps-start_steps) + return min(max((1.0-math.cos(math.pi * rate))/2, 0), 1) + + @torch.autocast(device_type='cuda', dtype=torch.bfloat16) + def training_step(self, batch: Dict) -> Dict: + raise NotImplementedError() + + def make_start_tokens(self, **kwargs) -> List[List[int]]: + skeleton_tokens = kwargs.get('skeleton_tokens', None) + skeleton_mask = kwargs.get('skeleton_mask', None) + num_joints = kwargs.get('num_joints', None) + parents = kwargs.get('parents', None) + cls = kwargs.get('cls', None) + start_tokens_list = [] + + batch_size = 1 + if skeleton_tokens is not None: + batch_size = len(skeleton_tokens) + elif cls is not None: + batch_size = len(cls) + elif num_joints is not None: + batch_size = len(num_joints) + elif parents is not None: + batch_size = len(parents) + else: + assert 0, "must provide one of skeleton_tokens, cls, num_joints, parents" + for i in range(batch_size): + if skeleton_tokens is not None: + _skeleton_tokens = skeleton_tokens[i] + _skeleton_mask = skeleton_mask[i] if skeleton_mask is not None else None + assert _skeleton_tokens[0] == self.tokenizer.bos + if skeleton_mask is not None: + start_tokens = _skeleton_tokens[_skeleton_mask==1] + else: + start_tokens = _skeleton_tokens + else: + start_tokens = [self.tokenizer.bos] + start_tokens += self.tokenizer.make_cls_head( + cls=cls[i] if cls is not None else None, + num_joints=num_joints[i] if num_joints is not None else None, + parents=parents[i] if parents is not None else None, + ) + if isinstance(start_tokens, Tensor): + start_tokens = start_tokens.detach().cpu().numpy().tolist() + start_tokens_list.append(start_tokens) + return start_tokens_list + + @torch.autocast(device_type='cuda', dtype=torch.bfloat16) + def generate( + self, + vertices: Tensor, + normals: Tensor, + cls: str|None=None, + skeleton_tokens: np.ndarray|Tensor|None=None, + only_ids: bool=False, + return_decode_dict: bool=False, + num_joints: int|None=None, + parents: Tensor|None=None, + **kwargs, + ) -> TokenRigResult: + """ + Do not support batch! + """ + assert isinstance(self.vae.model, SkinFSQCVAEModel) + assert vertices.dim() == 2, 'do not support batch' + assert normals.dim() == 2, 'do not support batch' + + if isinstance(skeleton_tokens, np.ndarray): + skeleton_tokens = torch.from_numpy(skeleton_tokens).to(self.device) + + cond = torch.cat([vertices, normals], dim=-1).unsqueeze(0) + _, cond_latents = self.vae.model._encode( + x=None, + cond=cond, + num_tokens=self.tokens_per_skin, + cond_tokens=self.tokens_skin_cond, + return_z=False, + ) + assert cond_latents is not None + # (1, len, dim) + learned_mesh_cond = encode_mesh_cond(self.mesh_encoder, self.output_proj, self.tokens_skin_cond, {'vertices': vertices, 'normals': normals}) + + device = cond.device + start_tokens = torch.tensor(self.make_start_tokens( + device=device, + cls=None if cls is None else [cls], + skeleton_tokens=None if skeleton_tokens is None else [skeleton_tokens], + num_joints=None if num_joints is None else [num_joints], + parents=None if parents is None else [parents], + )[0], device=device).unsqueeze(0) + assert start_tokens.shape[0] == 1 + start_embed = self.transformer.get_input_embeddings()(start_tokens) + inputs_embeds = torch.cat([learned_mesh_cond, start_embed], dim=1) + + results = self.transformer.generate( + inputs_embeds=inputs_embeds, + bos_token_id=self.tokenizer.bos, + eos_token_id=self.eos, + pad_token_id=self.tokenizer.pad, + logits_processor=get_logits_processor( + tokenizer=self.tokenizer, + eos=self.eos, + tokens_per_skin=self.tokens_per_skin, + start_tokens=start_tokens[0], + ), + **kwargs, + ) + + res = TokenRigResult() + output_ids = results[0, :] + for token in reversed(start_tokens[0]): + v = token.item() + output_ids = pad(output_ids, (1, 0), value=v) + res.input_ids = start_tokens[0] + res.output_ids = output_ids + if only_ids: + return res + res.cond = cond[0] + res.cond_latents = cond_latents[0] + if return_decode_dict: + return res + d = decode( + cond=cond[0], + cond_latents=cond_latents[0], + inputs_ids=output_ids, + tokenizer=self.tokenizer, + tokens_per_skin=self.tokens_per_skin, + vae=self.vae, + ) + res.skin_pred = d['skin_pred'] + res.detokenize_output = d['detokenize_output'] + return res + + def _debug_export( + self, + batch: Dict, + cond: Tensor, + cond_latents: Tensor, + inputs_ids: Tensor, + id: int=0, + path: str='res.fbx', + ): + if inputs_ids.dim() == 2: + assert cond_latents.dim() == cond.dim() == 3, f"Expected 3 dimensions, got {cond_latents.dim()}, {cond.dim()}" + cond = cond[id] + cond_latents = cond_latents[id] + inputs_ids = inputs_ids[id] + res = decode( + cond=cond, + cond_latents=cond_latents, + inputs_ids=inputs_ids, + tokenizer=self.tokenizer, + tokens_per_skin=self.tokens_per_skin, + vae=self.vae, + ) + detokenize_output: DetokenizeOutput = res['detokenize_output'] + origin_asset: Asset = batch['model_input'][id].asset + asset = Asset.from_data( + vertices=origin_asset.vertices, + faces=origin_asset.faces, + sampled_vertices=batch['vertices'][id].detach().cpu().numpy(), + sampled_skin=res['skin_pred'].detach().cpu().numpy(), + parents=np.array(detokenize_output.parents), + joint_names=detokenize_output.joint_names, + joints=detokenize_output.joints, + ) + from ..rig_package.parser.bpy import BpyParser + BpyParser.export_asset(asset, filepath=path) + + def process_fn(self, batch: List[ModelInput]) -> List[Dict]: + res = [] + max_length = 0 + for b in batch: + if b.tokens is not None: + max_length = max(max_length, b.tokens.shape[0]) + res = [] + for b in batch: + if b.tokens is not None: + skeleton_tokens = np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=self.tokenizer.pad) + skeleton_mask = np.pad(np.ones_like(b.tokens), ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=0) + else: + skeleton_tokens = None + skeleton_mask = None + _d = { + 'vertices': torch.from_numpy(b.asset.sampled_vertices).float(), + 'normals': torch.from_numpy(b.asset.sampled_normals).float(), + 'non': { + 'cls': b.asset.cls, + } + } + if skeleton_mask is not None: + _d.update({ + 'skeleton_tokens': skeleton_tokens, + 'skeleton_mask': skeleton_mask, + }) + _d['non'].update({ + 'parents': b.asset.parents, + 'num_bones': b.asset.J, + }) + if b.asset.sampled_vertex_groups is not None and 'skin' in b.asset.sampled_vertex_groups: + assert b.asset.meta is not None + _d['non'].update({ + 'cls': b.asset.cls, + 'uniform_skin': torch.from_numpy(b.asset.sampled_vertex_groups['skin']).float(), + 'skin_samples': b.asset.skin_samples, + 'dense_indices': b.asset.meta['dense_indices'], + 'dense_skin': torch.from_numpy(b.asset.meta['dense_skin']).float(), + 'dense_vertices': torch.from_numpy(b.asset.meta['dense_vertices']).float(), + 'dense_normals': torch.from_numpy(b.asset.meta['dense_normals']).float(), + }) + res.append(_d) + return res + + def predict_step( + self, + batch: Dict, + no_cls: bool=False, + skeleton_tokens=None, + parents=None, + num_joints=None, + make_asset: bool=False, + **kwargs + ) -> Dict: + vertices: Tensor = batch['vertices'] + normals : Tensor = batch['normals'] + cls = batch['cls'] + generate_kwargs = deepcopy(batch['generate_kwargs']) + + if vertices.dim() == 2: + vertices = vertices.unsqueeze(0) + normals = normals.unsqueeze(0) + results = [] + if skeleton_tokens is None: + skeleton_tokens = [None] * vertices.shape[0] + d = {} + for i in range(vertices.shape[0]): + res = self.generate( + vertices=vertices[i], + normals=normals[i], + skeleton_tokens=skeleton_tokens[i], + cls=None if no_cls else cls[i], + parents=None if parents is None else parents[i], + num_joints=None if num_joints is None else num_joints[i], + **generate_kwargs + ) + if make_asset: + assert 'model_input' in batch, "need model_input to make asset (in validate/predict mode)" + assert res.detokenize_output is not None + assert res.skin_pred is not None + asset: Asset = batch['model_input'][i].asset.copy() + res.asset = Asset.from_data( + vertices=asset.vertices, + faces=asset.faces, + sampled_vertices=vertices[i].detach().float().cpu().numpy(), + sampled_skin=res.skin_pred.detach().float().cpu().numpy(), + joints=res.detokenize_output.joints, + parents=np.array(res.detokenize_output.parents), + cls=asset.cls, + path=asset.path, + ) + results.append(res) + d['results'] = results + return d + + def forward(self, batch: Dict) -> Dict[str, Tensor]: + return self.training_step(batch=batch) + +def _check(x: Tensor, s, m=None): + assert isinstance(s, (list, tuple)), "Expected shape must be a list or tuple" + assert x.dim() == len(s), f"Expected {len(s)} dims, got {x.dim()}" + for i, (dim_actual, dim_expected) in enumerate(zip(x.shape, s)): + if dim_expected is not None and dim_expected != -1: + if m is None: + assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}" + else: + assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}. Message: {m}" + +def encode_mesh_cond(mesh_encoder, output_proj, tokens_skin_cond, batch: Dict) -> Tensor: + vertices = batch['vertices'] # (B, N, 3) + normals = batch['normals'] # (B, N, 3) + assert isinstance(vertices, Tensor) + assert isinstance(normals, Tensor) + if (len(vertices.shape) == 3): + shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices, feats=normals) # type: ignore + else: + shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) # type: ignore + latents = output_proj(latents) + return latents + +@torch.no_grad() +def encode( + tokenizer: Tokenizer, + vae: SkinVAEModel, + vae_input: VaeInput, + encode_repeat: int, + tokens_skin_cond: int, + tokens_per_skin: int, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Returns: + skin_tokens: (B, tokens_per_skin*J) + + cond_latents: (B, tokens_skin_cond, vae.latent_channels) + + skin_mask: (B, tokens_per_skin*J), 1 -> skin, 0 -> pad + """ + device = vae_input.uniform_cond.device + B = vae_input.B + J = vae_input.max_J + _, cond_latents, codes, _ = vae.encode(vae_input=vae_input, num_tokens=tokens_per_skin, full=True, encode_repeat=encode_repeat) + codes = codes[:, :tokens_per_skin] + indices = vae_input.get_flatten_indices() + + skin_tokens = torch.full((B, J * tokens_per_skin), tokenizer.pad, dtype=torch.long, device=device) + skin_mask = torch.zeros_like(skin_tokens, dtype=torch.long) + j_counters = [0 for _ in range(B)] + for idx, batch_id in enumerate(indices): + j = j_counters[batch_id] + s = j * tokens_per_skin + t = s + tokens_per_skin + skin_tokens[batch_id, s:t] = codes[idx] + tokenizer.vocab_size + skin_mask[batch_id, s:t] = 1 + j_counters[batch_id] += 1 + + assert cond_latents is not None + _check(cond_latents, (B, tokens_skin_cond, vae.latent_channels)) + _check(skin_tokens, (B, J * tokens_per_skin)) + _check(skin_mask, (B, J * tokens_per_skin)) + return skin_tokens, cond_latents, skin_mask + +def prepare_llm_tokens( + tokenizer: Tokenizer, + eos: int, + skeleton_tokens: Tensor, + skeleton_mask: Tensor, + skin_tokens: Tensor, + skin_mask: Tensor, + cond_latents: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Args: + skeleton_tokens: (B, n) + + skeleton_mask: (B, n) + + skin_tokens: (B, tokens_per_skin*J) + + skin_mask: (B, tokens_per_skin*J) + + cond_latents: (B, tokens_skin_cond, vae.latent_channels) + + Returns: + llm_tokens: (B, seq_len) + + attention_mask: (B, seq_len), 1 -> attend, 0 -> pad + """ + B = skeleton_tokens.shape[0] + inputs_ids = torch.ones((B, skeleton_tokens.shape[1] + skin_tokens.shape[1] + 1), dtype=torch.long, device=skeleton_tokens.device) * tokenizer.pad + num_skeleton = skeleton_mask.sum(dim=1) + num_skin = skin_mask.sum(dim=1) + attention_mask = torch.ones((B, inputs_ids.shape[1]), dtype=torch.float32, device=skeleton_tokens.device) + llm_skeleton_mask = torch.ones_like(attention_mask, dtype=torch.bool) + llm_skin_mask = torch.ones_like(attention_mask, dtype=torch.bool) + for i in range(B): + length = num_skeleton[i] + num_skin[i] + inputs_ids[i, :num_skeleton[i]] = skeleton_tokens[i, :num_skeleton[i]] + inputs_ids[i, num_skeleton[i]:num_skeleton[i]+num_skin[i]] = skin_tokens[i, :num_skin[i]] + inputs_ids[i, num_skeleton[i]+num_skin[i]] = eos # add an eos + attention_mask[i, length+1:] = 0. + llm_skeleton_mask[i, num_skeleton[i]:] = 0 + llm_skin_mask[i, :num_skeleton[i]] = 0 + llm_skin_mask[i, length+1:] = 0 + + seq_len = inputs_ids.shape[1] + _check(inputs_ids, (B, seq_len)) + _check(attention_mask, (B, seq_len)) + return inputs_ids, attention_mask, llm_skeleton_mask, llm_skin_mask + +def get_logits_processor(tokenizer: Tokenizer, eos: int, tokens_per_skin: int, start_tokens): + processor = VocabSwitchingLogitsProcessor( + tokenizer=tokenizer, + switch_token_id=tokenizer.eos, + eos_token_id=eos, + tokens_per_skin=tokens_per_skin, + init=start_tokens, + ) + return LogitsProcessorList([processor]) + +@torch.no_grad() +def decode( + cond: Tensor, + cond_latents: Tensor, + inputs_ids: Tensor, + tokenizer: Tokenizer, + tokens_per_skin: int, + vae: SkinVAEModel, + encode_repeat: int=1, +) -> Dict: + """ + inputs_ids: (seq_len) + + cond: (N, c) + + cond_latents: (tokens_skin_cond, dim) + """ + assert cond.dim() == 2, 'do not support batch' + assert cond_latents.dim() == 2, 'do not support batch' + + where_eos = torch.where(inputs_ids == tokenizer.eos) + if where_eos[0].shape[0] == 0: + raise ValueError("No EOS token found in inputs_ids") + where_eos = where_eos[0][:1] + skeleton_tokens = inputs_ids[:where_eos+1] + skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) + detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) + J = detokenize_output.joints.shape[0] + + skin_tokens = inputs_ids[where_eos+1:where_eos+1+J*tokens_per_skin] + if skin_tokens.shape != (J*(tokens_per_skin),): + return { + 'skin_pred': None, + 'detokenize_output': detokenize_output, + } + cond = cond.unsqueeze(0) + cond_latents = cond_latents.unsqueeze(0) + skin = [] + g = tokens_per_skin * encode_repeat + for s in range(0, J*tokens_per_skin, g): + t = min(s+g, J*tokens_per_skin) + indices = skin_tokens[s:t].unsqueeze(0) - tokenizer.vocab_size + # expect: (b, tokens_per_skin, dim) + b = (t-s)//tokens_per_skin + z = vae.model.FSQ.indices_to_codes(indices).reshape(b, tokens_per_skin, -1) + # (b, n, 1) + logits = vae.decode(z=z, sampled_cond=cond.repeat(b, 1, 1), cond_tokens=cond_latents.repeat(b, 1, 1)) + skin_pred = logits.reshape(b, logits.shape[1]).permute(1, 0) + skin.append(skin_pred) + skin = torch.concat(skin, dim=1).float() + return { + 'skin_pred': skin, + 'detokenize_output': detokenize_output, + } + +@torch.no_grad() +def decode_multi( + cond: Tensor, + cond_latents: Tensor, + inputs_ids: List[Tensor], + tokenizer: Tokenizer, + tokens_per_skin: int, + vae: SkinVAEModel, + is_numpy: bool=True, + encode_repeat: int=1, +) -> List[Dict]: + """ + inputs_ids: List[(seq_len)] + + cond: (N, c) + + cond_latents: (tokens_skin_cond, dim) + """ + assert cond.dim() == 2, 'do not support batch' + assert cond_latents.dim() == 2, 'do not support batch' + + B = len(inputs_ids) + res = [{'skin_pred': None, 'detokenize_output': None} for _ in range(B)] + device = cond.device + batch_mapping = [] + skin_tokens_list = [] + oks = [] + oks_J = [] + for i in range(B): + where_eos = torch.where(inputs_ids[i] == tokenizer.eos) + if where_eos[0].shape[0] == 0: + print(f"decode_multi: {i} has bad skeleton") + continue + where_eos = where_eos[0][:1] + skeleton_tokens = inputs_ids[i][:where_eos+1] + skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) + try: + detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) + except Exception as e: + print(f"decode_multi: error while decoding skeleton: {str(e)}") + continue + J = detokenize_output.joints.shape[0] + res[i]['detokenize_output'] = detokenize_output # type: ignore + skin_tokens = inputs_ids[i][where_eos+1:where_eos+1+J*tokens_per_skin] + if skin_tokens.shape != (J*(tokens_per_skin),): + print(f"decode_multi: {i} has bad skin") + continue + batch_mapping.append(torch.full((J,), i, device=device, dtype=torch.long)) + skin_tokens_list.append(skin_tokens) + oks.append(i) + oks_J.append(J) + if len(batch_mapping) == 0: + return res + batch_mapping = torch.cat(batch_mapping, dim=0) + # (1, sum_J*tokens_per_skin) + skin_tokens = torch.cat(skin_tokens_list, dim=0).unsqueeze(0) + cond = cond.unsqueeze(0) + cond_latents = cond_latents.unsqueeze(0) + skin_list = [] + g = tokens_per_skin * encode_repeat + sum_J = batch_mapping.shape[0] + for s in range(0, sum_J*tokens_per_skin, g): + t = min(s+g, sum_J*tokens_per_skin) + # (1, m*tokens_per_skin) + indices = skin_tokens[:, s:t] - tokenizer.vocab_size + # expect: (m, tokens_per_skin, dim) + m = (t-s)//tokens_per_skin + z = vae.model.FSQ.indices_to_codes(indices).reshape(m, tokens_per_skin, -1) + # (m, n, 1) + logits = vae.decode(z=z, sampled_cond=cond.repeat(m, 1, 1), cond_tokens=cond_latents.repeat(m, 1, 1)) + skin_pred = logits.reshape(m, logits.shape[1]).permute(1, 0) + skin_list.append(skin_pred) + skin = torch.concat(skin_list, dim=1).float() + for (i, id) in enumerate(oks): + skin_pred = skin[:, batch_mapping==id].reshape(-1, oks_J[i]) + res[id]['skin_pred'] = skin_pred.detach().cpu().numpy() if is_numpy else skin_pred + return res \ No newline at end of file diff --git a/src/model/utils.py b/src/model/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..8dec6ec71822ef1f46bb5fbb2e4170d8be00f736 --- /dev/null +++ b/src/model/utils.py @@ -0,0 +1,118 @@ +import torch + +def fps( + x: torch.Tensor, + batch: torch.Tensor, + ratio: float, + random_start: bool = False, +) -> torch.Tensor: + """ + Args: + x: (N, C) points. + batch: (N,) batch indices for each point. + ratio: sampling ratio in (0, 1]. + random_start: whether to start from a random point per batch. + + Returns: + 1D tensor of sampled indices in the flattened input space. + """ + if x.ndim != 2: + raise ValueError(f"Expected x to have shape (N, C), got {tuple(x.shape)}") + if batch.ndim != 1 or batch.shape[0] != x.shape[0]: + raise ValueError("batch must be 1D and aligned with x") + if not (0 < ratio <= 1.0): + raise ValueError(f"ratio must be in (0, 1], got {ratio}") + + sampled_indices = [] + unique_batches = torch.unique(batch) + + for batch_id in unique_batches: + mask = batch == batch_id + points = x[mask] + num_points = points.shape[0] + + if num_points == 0: + continue + + num_samples = max(1, int(round(num_points * ratio))) + num_samples = min(num_samples, num_points) + + if random_start: + farthest = torch.randint(num_points, (1,), device=x.device).item() + else: + farthest = 0 + + distances = torch.full((num_points,), float("inf"), device=x.device) + selected_local = torch.empty(num_samples, dtype=torch.long, device=x.device) + + for i in range(num_samples): + selected_local[i] = farthest + centroid = points[farthest] + dist = torch.sum((points - centroid) ** 2, dim=-1) + distances = torch.minimum(distances, dist) + farthest = torch.argmax(distances).item() + + global_indices = torch.nonzero(mask, as_tuple=False).squeeze(-1)[selected_local] + sampled_indices.append(global_indices) + + if not sampled_indices: + return torch.empty((0,), dtype=torch.long, device=x.device) + return torch.cat(sampled_indices, dim=0) + + +def segment_csr( + src: torch.Tensor, + indptr: torch.Tensor, + reduce: str = "sum", +) -> torch.Tensor: + """ + Args: + src: source tensor with shape (N, ...). + indptr: CSR index pointer with shape (S + 1,). + reduce: one of {"sum", "mean", "min", "max"}. + + Returns: + Reduced tensor with shape (S, ...). + """ + if src.ndim < 1: + raise ValueError(f"Expected src to have at least 1 dim, got {src.ndim}") + if indptr.ndim != 1: + raise ValueError(f"Expected indptr to be 1D, got shape {tuple(indptr.shape)}") + if indptr.numel() < 1: + raise ValueError("indptr must contain at least one element") + if reduce not in {"sum", "mean", "min", "max"}: + raise ValueError(f"Unsupported reduce mode: {reduce}") + + indptr = indptr.to(device=src.device, dtype=torch.long) + segments = indptr.numel() - 1 + out_shape = (segments, *src.shape[1:]) + + if reduce in {"sum", "mean"}: + out = torch.zeros(out_shape, dtype=src.dtype, device=src.device) + elif reduce == "min": + out = torch.full(out_shape, float("inf"), dtype=src.dtype, device=src.device) + else: + out = torch.full(out_shape, float("-inf"), dtype=src.dtype, device=src.device) + + for i in range(segments): + start = indptr[i].item() + end = indptr[i + 1].item() + if end <= start: + continue + + chunk = src[start:end] + if reduce == "sum": + out[i] = chunk.sum(dim=0) + elif reduce == "mean": + out[i] = chunk.mean(dim=0) + elif reduce == "min": + out[i] = chunk.min(dim=0).values + else: + out[i] = chunk.max(dim=0).values + + if reduce == "min": + out = torch.where(torch.isinf(out), torch.zeros_like(out), out) + elif reduce == "max": + out = torch.where(torch.isinf(out), torch.zeros_like(out), out) + + return out \ No newline at end of file diff --git a/src/rig_package/__init__.py b/src/rig_package/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rig_package/info/__init__.py b/src/rig_package/info/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rig_package/info/asset.py b/src/rig_package/info/asset.py new file mode 100755 index 0000000000000000000000000000000000000000..6fc46154036b3041dba8c7be3640ac45e04c87c0 --- /dev/null +++ b/src/rig_package/info/asset.py @@ -0,0 +1,773 @@ +from dataclasses import dataclass, field +from numpy import ndarray +from scipy.spatial import cKDTree # type: ignore +from typing import Dict, List, Optional, Tuple + +import numpy as np +import os +import trimesh + +from ..utils import assert_list, assert_ndarray, linear_blend_skinning, sample_vertex_groups +from .voxel import Voxel + +@dataclass +class Asset(): + + # vertices of merged mesh in edit space, shape (N, 3) + vertices: Optional[ndarray]=None + + # faces of merged mesh, shape (F, 3) + faces: Optional[ndarray]=None + + # vertex normals of merged mesh in edit space, shape (N, 3), calculated by trimesh + vertex_normals: Optional[ndarray]=None + + # face normals of merged mesh in edit space, shape (F, 3), calculated by trimesh + face_normals: Optional[ndarray]=None + + # offset of vertices in each part, shape (P,), + # vertices[vertex_bias[i-1]:vertex_bias[i]] are in the same part (vertex_bias[-1]=0) + vertex_bias: Optional[ndarray]=None + + # offset of faces in each part, shape (P,), + # faces[face_bias[i-1]:face_bias[i]] are in the same part (face_bias[-1]=0) + face_bias: Optional[ndarray]=None + + # name of each mesh part, shape (P,) + mesh_names: Optional[List[str]]=None + + # name of each joint, shape (J,) + joint_names: Optional[List[str]]=None + + # parent index of each joint, shape (J,), -1 for root + parents: Optional[ndarray]=None + + # length of each bone indicating euclidean distance between head and tail(which is proposed in blender), shape (J,) + lengths: Optional[ndarray]=None + + # matrix to convert from edit space(or motion space) to world space, shape (4, 4) + matrix_world: Optional[ndarray]=None + + # local matrix of each joint, shape (J, 4, 4) + matrix_local: Optional[ndarray]=None + + # matrix to convert from edit space to motion space, shape (frames, J, 4, 4) + matrix_basis: Optional[ndarray]=None + + # name of the armature + armature_name: Optional[str]=None + + # skinning weights, shape (N, J) + skin: Optional[ndarray]=None + + ########################################################################### + cls: Optional[str]=None + path: Optional[str]=None + vertex_groups: Dict[str, ndarray]=field(default_factory=dict) + sampled_vertices: Optional[ndarray]=None + sampled_normals: Optional[ndarray]=None + sampled_vertex_groups: Optional[Dict[str, ndarray]]=None + skin_samples: Optional[int]=None + + meta: Optional[Dict]=None + + @property + def dirname(self) -> str: + """return directory name of the asset""" + if self.path is None: + return "" + return os.path.dirname(self.path) + + @property + def N(self) -> int: + """return number of vertices""" + if self.vertices is None: + return 0 + return self.vertices.shape[0] + + @property + def F(self) -> int: + """return number of faces""" + if self.faces is None: + return 0 + return self.faces.shape[0] + + @property + def J(self) -> int: + """return number of joints""" + if self.parents is None: + return 0 + return self.parents.shape[0] + + @property + def P(self) -> int: + """return number of mesh parts""" + self._build_bias() + if self.vertex_bias is None: + return 0 + return self.vertex_bias.shape[0] + + @property + def root(self) -> int: + """return the index of root joint""" + if self.parents is None: + return -1 + for i, p in enumerate(self.parents): + if p == -1: + return i + raise ValueError("no root found") + + @property + def joints(self) -> ndarray|None: + """return joints in edit space, shape (J, 3)""" + if self.matrix_local is None: + return None + return self.matrix_local[:, :3, 3] + + @property + def skeleton(self) -> ndarray|None: + """return skeleton where joint is followed by parent, shape (J-1, 6), ignore root""" + if self.joints is None or self.parents is None: + return None + indices = np.linspace(0, self.J-1, num=self.J, dtype=int)[self.parents!=-1] + return np.concatenate([self.joints[indices], self.joints[self.parents[indices]]], axis=1) + + @property + def dfs_order(self) -> List[int]: + """return the dfs order of joints""" + if self.parents is None: + return [] + sons = [[] for _ in range(self.J)] + stack = [] + for i, p in enumerate(self.parents): + if p == -1: + stack.append(i) + continue + sons[p].append(i) + order = [] + while len(stack) > 0: + u = stack.pop() + order.append(u) + for s in reversed(sons[u]): + stack.append(s) + return order + + @property + def tails(self) -> ndarray|None: + """ + Return tails in edit space, shape (J, 3). The bone is extrueded along local Y axis, in accordance with Blender. + """ + joints = self.joints + matrix_local = self.matrix_local + if joints is None or self.lengths is None or matrix_local is None: + return None + + x = np.array([0.0, 1.0, 0.0]) + x = self.lengths * x[:, np.newaxis] + y = np.zeros((self.J, 3)) + for i in range(self.J): + y[i] = matrix_local[i, :3, :3] @ x[:, i] + return joints + y + + def _build_bias(self): + if self.vertex_bias is None and self.vertices is not None: + self.vertex_bias = np.array([self.vertices.shape[0]]) + if self.face_bias is None and self.faces is not None: + self.face_bias = np.array([self.faces.shape[0]]) + + def get_vertex_slice(self, index: int) -> slice: + """return slice of vertices of a specific part""" + self._build_bias() + if self.vertex_bias is None: + return slice(0, 0) + if index == 0: + return slice(0, self.vertex_bias[0]) + return slice(self.vertex_bias[index-1], self.vertex_bias[index]) + + def get_face_slice(self, index: int) -> slice: + """return slice of faces of a specific part""" + self._build_bias() + if self.face_bias is None: + return slice(0, 0) + if index == 0: + return slice(0, self.face_bias[index]) + return slice(self.face_bias[index-1], self.face_bias[index]) + + def names_to_ids(self, arr: List[int|str]) -> List[int]: + for s in arr: + if isinstance(s, str) and (self.joint_names is None or s not in self.joint_names): + raise ValueError(f"do not find {s} in joint_names") + elif not isinstance(s, int) and not isinstance(s, str): + raise ValueError(f"element must be int or str") + if self.joint_names is not None: + _name_to_id = {s: i for (i, s) in enumerate(self.joint_names)} + else: + _name_to_id = {} + return [_name_to_id[s] if isinstance(s, str) else s for s in arr] + + def set_order( + self, + new_orders: List[int|str], + merge_skin: bool=True, + do_not_normalize: bool=False + ): + """ + Rearrange the order of the joints. + Args: + new_orders: A list of int or bone names to indicate orders. + For example, if the first element is 2, then the rearranged + joint will be the second first joint in the current skeleton. + + merge_skin: If True, if some joints are merged, skin will be + added to its nearest ancestor. Otherwise completely removes + skin and finally normalized. + + do_not_normalize: Do not normalize skin. + """ + if len(np.unique(new_orders)) != len(new_orders): + raise ValueError("multiple values found in new_orders") + _new_orders = self.names_to_ids(arr=new_orders) + ancestors = [] + grandsons = [] + beyond_root = [] + root_id = 0 + if self.parents is not None: + new_positions = [0 for i in range(self.J)] + new_parents = [-1 for i in range(self.J)] + for i, x in enumerate(_new_orders): + new_positions[x] = i + set_new_orders = set(_new_orders) + roots = 0 + for i in self.dfs_order: + if i not in set_new_orders: + if self.parents[i] == -1: + new_positions[i] = -1 + beyond_root.append(i) + else: + new_positions[i] = new_positions[self.parents[i]] + if new_positions[i] == -1: + beyond_root.append(i) + else: + ancestors.append(new_positions[i]) + grandsons.append(i) + else: + if self.parents[i] == -1: + new_parents[i] = -1 + else: + new_parents[i] = new_positions[self.parents[i]] + if new_parents[i] == -1: + roots += 1 + root_id = new_positions[i] + if roots >= 2: + raise ValueError(f"multiple roots found: {self.path} {self.parents} {new_orders}") + self.parents = np.array(new_parents)[_new_orders] + if self.joint_names is not None: + _joint_names = [self.joint_names[u] for u in _new_orders] + self.joint_names = _joint_names + if self.lengths is not None: + self.lengths = self.lengths[_new_orders] + if self.matrix_local is not None: + self.matrix_local = self.matrix_local[_new_orders] + if self.matrix_basis is not None: + self.matrix_basis = self.matrix_basis[:, _new_orders] + if self.skin is not None: + if merge_skin: + skin = self.skin.copy() + self.skin = skin[:, _new_orders] + for x, y in zip(ancestors, grandsons): + self.skin[:, x] += skin[:, y] + self.skin[:, root_id] += skin[:, beyond_root].sum(axis=1) + else: + self.skin = self.skin[:, _new_orders] + if not do_not_normalize: + self.normalize_skin() + + def delete_joints(self, joints_to_remove: List[int|str]): + """ + Delete joints and their corresponding values. + """ + _joints_to_remove = set(self.names_to_ids(arr=joints_to_remove)) + new_orders: List[int|str] = [i for i in range(self.J) if i not in _joints_to_remove] + self.set_order(new_orders=new_orders) + + def delete_vertices(self, vertices_to_remove: List[int]|ndarray): + """ + Delete vertices and their corresponding values. + """ + if self.vertices is None: + return + if isinstance(vertices_to_remove, list): + vertices_to_remove = np.array(vertices_to_remove) + mask = np.ones(self.N, dtype=bool) + mask[vertices_to_remove] = False + indices = np.where(mask)[0] + + # handle vertex bias + if self.vertex_bias is not None: + cumsum_mask = np.cumsum(mask) + self.vertex_bias = cumsum_mask[self.vertex_bias-1] + + N = self.N + self.vertices = self.vertices[indices] + if self.vertex_normals is not None: + self.vertex_normals = self.vertex_normals[indices] + if self.skin is not None: + self.skin = self.skin[indices, :] + if self.faces is not None: # keep faces + face_mask = np.all(np.isin(self.faces, indices), axis=1) + self.faces = self.faces[face_mask] + old_to_new = np.zeros(N, dtype=np.int32) + old_to_new[indices] = np.arange(len(indices)) + self.faces = old_to_new[self.faces] + if self.face_normals is not None: + self.face_normals = self.face_normals[indices] + # handle face bias + if self.face_bias is not None: + cumsum_face_mask = np.cumsum(face_mask) + self.face_bias = cumsum_face_mask[self.face_bias-1] + + self._build_bias() + + def normalize_skin(self) -> 'Asset': + """ + Normalize skin so that add up to 1. + """ + if self.skin is None: + return self + self.skin = self.skin / np.maximum(np.sum(self.skin, axis=1, keepdims=True), 1e-8) + return self + + def build_normals(self): + """ + Build vertex_normals and face_normals using trimesh. + """ + if self.vertices is None: + raise ValueError("do not have vertices") + if self.faces is None: + raise ValueError("do not have faces") + mesh = trimesh.Trimesh(vertices=self.vertices, faces=self.faces, process=False, maintain_order=True) + self.vertex_normals = mesh.vertex_normals.copy() + self.face_normals = mesh.face_normals.copy() + + def normalize_vertices( + self, + range: Optional[Tuple[float, float]]=None, + range_x: Optional[Tuple[float, float]]=None, + range_y: Optional[Tuple[float, float]]=None, + range_z: Optional[Tuple[float, float]]=None, + ): + """ + Normalize vertices into cube in edit space. If range_x/y/z is provided, + use range_x/y/z, otherwise use range by default. + """ + if self.vertices is None: + return + if range is None: + if range_x is None: + raise ValueError("range_x is None, but range is missing") + if range_y is None: + raise ValueError("range_y is None, but range is missing") + if range_z is None: + raise ValueError("range_z is None, but range is missing") + _range_x = range_x + _range_y = range_y + _range_z = range_z + else: + _range_x = range if range_x is None else range_x + _range_y = range if range_y is None else range_y + _range_z = range if range_z is None else range_z + v_min = self.vertices.min(axis=0) + v_max = self.vertices.max(axis=0) + scale_range = (v_max - v_min).max() + # normalize into [0, 1]^3 + v = (self.vertices - v_min) / scale_range + mid_point = (v.min(axis=0) + v.max(axis=0)) / 2 + bias = np.array([0.5, 0.5, 0.5]) - mid_point + v += bias + dx = (_range_x[1] - _range_x[0]) + dy = (_range_y[1] - _range_y[0]) + dz = (_range_z[1] - _range_z[0]) + if self.faces is not None and np.abs(dx-dy) > 1e-3 or np.abs(dy-dz) > 1e-3 or np.abs(dy-dz) > 1e-3: + raise ValueError("do not support non-uniform normalization") + v[:, 0] = v[:, 0] * dx + _range_x[0] + v[:, 1] = v[:, 1] * dy + _range_y[0] + v[:, 2] = v[:, 2] * dz + _range_z[0] + self.vertices = v + if self.matrix_local is not None: + jv = (self.matrix_local[:, :3, 3] - v_min) / scale_range + bias + self.matrix_local[:, 0, 3] = jv[:, 0] * dx + _range_x[0] + self.matrix_local[:, 1, 3] = jv[:, 1] * dy + _range_y[0] + self.matrix_local[:, 2, 3] = jv[:, 2] * dz + _range_z[0] + + def get_matrix( + self, + matrix_basis: ndarray, + ) -> ndarray: + """ + Get pose matrix in motion space using forward kinetics. + """ + J = self.J + parents = self.parents + if parents is None: + raise ValueError("do not have parents") + if self.matrix_local is None: + raise ValueError("do not have matrix_local") + assert_ndarray(matrix_basis, "matrix_basis", (J, 4, 4)) + matrix = np.zeros((J, 4, 4)) + for i in self.dfs_order: + pid = parents[i] + if pid==-1: + matrix[i] = self.matrix_local[i] @ matrix_basis[i] + else: + matrix_parent = matrix[pid] + matrix_local_parent = self.matrix_local[pid] + + matrix[i] = ( + matrix_parent @ + (np.linalg.inv(matrix_local_parent) @ self.matrix_local[i]) @ + matrix_basis[i] + ) + return matrix + + def vertices_with_pose( + self, + matrix_basis: ndarray, + inplace: bool=True, + ) -> ndarray: + """ + Apply pose to vertices and return the deformed vertices. + + Args: + inplace: if True, change vertices and all motion related fileds of the asset. + """ + if self.vertices is None: + raise ValueError("do not have vertices") + if self.matrix_local is None: + raise ValueError("do not have matrix_local") + if self.joints is None: + raise ValueError("do not have joints") + if self.skin is None: + raise ValueError("do not have skin") + matrix = self.get_matrix(matrix_basis=matrix_basis) + vertices = linear_blend_skinning( + vertices=self.vertices, + matrix_local=self.matrix_local, + matrix=matrix, + skin=self.skin, + pad=1, + value=1.0, + ) + if inplace: + self.vertices = vertices + if self.faces is not None: + self.build_normals() + self.matrix_local = matrix + return vertices + + def transform(self, trans: ndarray): + """trans: 4x4 affine matrix""" + def _apply(v: ndarray, trans: ndarray) -> ndarray: + return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] + + if self.vertices is not None: + self.vertices = _apply(self.vertices, trans) + if self.matrix_local is not None: + self.matrix_local = trans @ self.matrix_local + self.build_normals() + + def trim_skeleton(self): + """remove all leaf bones and coordinate bones""" + if self.skin is None: + return + if self.parents is None: + return + has_skin = self.skin.sum(axis=0) > 1e-6 + if not np.any(has_skin): + return + sons = [[] for _ in range(self.J)] + good_sons = [[] for _ in range(self.J)] + sub_tree_has_skin = [False for _ in range(self.J)] + dfs_order = self.dfs_order + for u in dfs_order: + p = self.parents[u] + if p != -1: + sons[p].append(u) + for u in reversed(dfs_order): + p = self.parents[u] + if has_skin[u]: + sub_tree_has_skin[u] = True + else: + for v in sons[u]: + if sub_tree_has_skin[v]: + sub_tree_has_skin[u] = True + break + keep = [False for _ in range(self.J)] + for u in dfs_order: + for v in sons[u]: + if sub_tree_has_skin[v]: + good_sons[u].append(v) + if has_skin[u]: + keep[u] = True + else: + p = self.parents[u] + if len(good_sons[u]) >= 2: + keep[u] = True + elif len(good_sons[u]) == 1 and p != -1: + if len(good_sons[p]) >= 2: + keep[u] = True + elif len(good_sons[p]) == 1 and good_sons[p][0] != u: + keep[u] = True + joints_to_remove: List[int|str] = [i for i in range(self.J) if not keep[i]] + self.delete_joints(joints_to_remove=joints_to_remove) + + def check_field(self): + + def _check_array(arr, name, shape, dtype=None): + if arr is not None: + assert_ndarray(arr, name=name, shape=shape, dtype=dtype) + + def _check_list(arr, name, dtype=None): + if arr is not None: + assert_list(arr, name=name, dtype=dtype) + + _check_array(self.vertices, name="vertices", shape=(self.N, 3)) + _check_array(self.faces, name="faces", shape=(self.F, 3)) + _check_array(self.vertex_normals, name="vertex_normals", shape=(self.N, 3)) + _check_array(self.face_normals, name="face_normals", shape=(self.F, 3)) + _check_array(self.vertex_bias, name="vertex_bias", shape=(self.P,), dtype=np.integer) + _check_array(self.face_bias, name="face_bias", shape=(self.P,), dtype=np.integer) + _check_list(self.mesh_names, name="mesh_names", dtype=str) + _check_list(self.joint_names, name="joint_names", dtype=str) + _check_array(self.parents, name="parents", shape=(-1,), dtype=np.integer) + _check_array(self.lengths, name="lengths", shape=(-1,)) + _check_array(self.matrix_world, name="matrix_world", shape=(4, 4)) + _check_array(self.matrix_local, name="matrix_local", shape=(self.J, 4, 4)) + _check_array(self.matrix_basis, name="matrix_basis", shape=(self.F, self.J, 4, 4)) + if self.armature_name is not None: + if not isinstance(self.armature_name, str): + raise ValueError(f"armature_name should be str") + _check_array(self.skin, name="skin", shape=(self.N, self.J)) + + if self.vertices is not None and self.vertex_normals is not None: + if self.vertices.shape[0] != self.vertex_normals.shape[0]: + raise ValueError(f"shapes of vertices and vertex_normals do not match: {self.vertices.shape} and {self.vertex_normals.shape}") + + if self.faces is not None and self.face_normals is not None: + if self.faces.shape[0] != self.face_normals.shape[0]: + raise ValueError(f"shapes of faces and face_normals do not match: {self.faces.shape} and {self.face_normals.shape}") + + if self.vertex_bias is not None: + if self.vertices is None: + raise ValueError("have vertex_bias, but do not have vertices") + if self.vertex_bias[-1] != self.N: + raise ValueError(f"vertex_bias must end with number of vertices {self.N}") + + if self.face_bias is not None: + if self.faces is None: + raise ValueError("have face_bias, but do not have faces") + if self.face_bias[-1] != self.F: + raise ValueError(f"vertex_bias must end with number of vertices {self.N}") + + if self.matrix_local is not None and self.matrix_basis is not None: + if self.matrix_local.shape[0] != self.matrix_basis.shape[1]: + raise ValueError(f"number of joints do not match in matix_local and matrix_basis: {self.matrix_local.shape[0]} and {self.matrix_basis.shape[1]}") + + if self.joint_names is not None and self.matrix_local is not None: + if len(self.joint_names) != self.matrix_local.shape[0]: + raise ValueError(f"number of joints do not match in joint_names and matrix_local: {len(self.joint_names)} and {self.matrix_local.shape[0]}") + + if self.skin is not None and self.matrix_local is not None: + if self.skin.shape[1] != self.matrix_local.shape[0]: + raise ValueError(f"number of joints do not match in skin and matrix_local: {self.skin.shape[0]} and {self.matrix_local.shape[0]}") + + if self.parents is not None: + if (self.parents==-1).sum() != 1: + raise ValueError(f"no root or multiple roots found, count: {(self.parents==-1).sum()}") + + def voxel(self, resolution: int=128, voxel_size: Optional[float]=None) -> Voxel: + """ + Return a voxel created from mesh. + Args: + resolution: Maximum number of cubes along one axis. + + voxel_size: Forcibly asign length of the cube with this value. + """ + import open3d as o3d + if self.vertices is None: + raise ValueError("do not have vertices") + if self.faces is None: + raise ValueError("do not have faces") + if voxel_size is None: + max_d = (self.vertices.max(axis=1) - self.vertices.min(axis=1)).max() + v = max_d / resolution + else: + v = voxel_size + mesh_o3d = o3d.geometry.TriangleMesh() + mesh_o3d.vertices = o3d.utility.Vector3dVector(self.vertices.copy()) + mesh_o3d.triangles = o3d.utility.Vector3iVector(self.faces) + voxel = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh_o3d, voxel_size=v) + coords = np.array([pt.grid_index for pt in voxel.get_voxels()]) + return Voxel( + origin=voxel.origin, + voxel_size=v, + coords=coords, + ) + + def sample_pc( + self, + num_samples: int, + num_vertex_samples: Optional[int]=None, + face_mask: Optional[ndarray]=None, + shuffle: bool=True, + ) -> 'Asset': + """ + Return a asset where vertices, normals and skin are sampled. + """ + if self.vertices is None: + raise ValueError("do not have vertices") + if self.faces is None: + raise ValueError("do not have faces") + if self.vertex_normals is None or self.face_normals is None: + self.build_normals() + if face_mask is not None: + assert_ndarray(arr=face_mask, name="face_mask", shape=(self.F,)) + sampled_vertices, sampled_normals, sampled_vertex_groups = sample_vertex_groups( + vertices=self.vertices, + faces=self.faces, + num_samples=num_samples, + num_vertex_samples=num_vertex_samples, + vertex_normals=self.vertex_normals, + face_normals=self.face_normals, + vertex_groups=self.skin, + face_mask=face_mask, + shuffle=shuffle, + same=True, + ) + asset = self.copy() + asset.vertices = sampled_vertices[:, 0] + asset.vertex_normals = sampled_normals[:, 0] # type: ignore + asset.skin = sampled_vertex_groups + asset.vertex_bias = None + asset.faces = None + asset.face_bias = None + asset.face_normals = None + asset._build_bias() + return asset + + def copy(self) -> 'Asset': + def _copy(x): + if isinstance(x, ndarray): + return x.copy() + elif isinstance(x, list): + return x.copy() + elif isinstance(x, str): + return x + else: + return None + return Asset( + vertices=_copy(self.vertices), + faces=_copy(self.faces), + vertex_normals=_copy(self.vertex_normals), # type: ignore + face_normals=_copy(self.face_normals), + vertex_bias=_copy(self.vertex_bias), + face_bias=_copy(self.face_bias), + mesh_names=_copy(self.mesh_names), + joint_names=_copy(self.joint_names), + parents=_copy(self.parents), + lengths=_copy(self.lengths), + matrix_world=_copy(self.matrix_world), + matrix_local=_copy(self.matrix_local), + matrix_basis=_copy(self.matrix_basis), + armature_name=_copy(self.armature_name), # type: ignore + skin=_copy(self.skin), + cls=_copy(self.cls), # type: ignore + path=_copy(self.path), # type: ignore + ) + + def change_dtype(self, float_dtype=np.float32, int_dtype=np.int32) -> 'Asset': + """change dtype""" + def convert(arr): + if arr is None: + return None + if np.issubdtype(arr.dtype, np.floating): + return arr.astype(float_dtype) + elif np.issubdtype(arr.dtype, np.integer): + return arr.astype(int_dtype) + else: + return arr + + self.vertices = convert(self.vertices) + self.faces = convert(self.faces) + self.vertex_normals = convert(self.vertex_normals) + self.face_normals = convert(self.face_normals) + self.vertex_bias = convert(self.vertex_bias) + self.face_bias = convert(self.face_bias) + self.parents = convert(self.parents) + self.lengths = convert(self.lengths) + self.matrix_world = convert(self.matrix_world) + self.matrix_local = convert(self.matrix_local) + self.matrix_basis = convert(self.matrix_basis) + self.skin = convert(self.skin) + return self + + @classmethod + def from_data( + c, + vertices: Optional[ndarray]=None, + faces: Optional[ndarray]=None, + vertex_normals: Optional[ndarray]=None, + face_normals: Optional[ndarray]=None, + vertex_bias: Optional[ndarray]=None, + face_bias: Optional[ndarray]=None, + mesh_names: Optional[List[str]]=None, + joint_names: Optional[List[str]]=None, + parents: Optional[ndarray]=None, + lengths: Optional[ndarray]=None, + matrix_world: Optional[ndarray]=None, + matrix_local: Optional[ndarray]=None, + matrix_basis: Optional[ndarray]=None, + armature_name: Optional[str]=None, + skin: Optional[ndarray]=None, + joints: Optional[ndarray]=None, + sampled_vertices: Optional[ndarray]=None, + sampled_skin: Optional[ndarray]=None, + cls: Optional[str]=None, + path: Optional[str]=None, + ) -> 'Asset': + """ + Return an asset with as many fields as possible. + """ + if matrix_local is None and joints is not None: + J = joints.shape[0] + matrix_local = np.zeros((J, 4, 4), dtype=np.float32) + matrix_local[...] = np.eye(4) + matrix_local[:, :3, 3] = joints + if joint_names is None and matrix_local is not None: + joints_names = [f"bone_{i}" for i in range(matrix_local.shape[0])] + + if sampled_vertices is not None and vertices is not None and sampled_skin is not None: + tree = cKDTree(sampled_vertices) + distances, indices = tree.query(vertices) + _s = sampled_skin[indices] + skin = _s + asset = Asset( + vertices=vertices, + faces=faces, + vertex_normals=vertex_normals, + face_normals=face_normals, + vertex_bias=vertex_bias, + face_bias=face_bias, + mesh_names=mesh_names, + joint_names=joint_names, + parents=parents, + lengths=lengths, + matrix_world=matrix_world, + matrix_local=matrix_local, + matrix_basis=matrix_basis, + armature_name=armature_name, + skin=skin, + cls=cls, + path=path, + ) + asset.check_field() + return asset diff --git a/src/rig_package/info/voxel.py b/src/rig_package/info/voxel.py new file mode 100755 index 0000000000000000000000000000000000000000..6b509998d9caab95f5b2e0dbe30819fce0dd6a43 --- /dev/null +++ b/src/rig_package/info/voxel.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from numpy import ndarray +from typing import Optional + +import numpy as np + +@dataclass +class Voxel(): + # coordinates of mesh + coords: ndarray + + # origin of the voxel + origin: ndarray + + # grid size + voxel_size: float + + # a boolen array + _voxel: Optional[ndarray]=None + + @property + def voxel(self) -> ndarray: + if self._voxel is None: + max_coords = np.max(self.coords, axis=0) + shape = tuple(max_coords + 1) + voxel = np.zeros(shape, dtype=bool) + voxel[tuple(self.coords.T)] = True + self._voxel = voxel + return self._voxel + + @property + def pc(self) -> ndarray: + return self.origin + (self.coords + 0.5) * self.voxel_size + + def projection_fill(self, rigid: bool=True): + """ + Fill up holes in the voxel. + """ + grids = np.indices(self.voxel.shape) + x_coord = grids[0, ...] + y_coord = grids[1, ...] + z_coord = grids[2, ...] + + INF = 2147483647 + x_tmp = x_coord.copy() + x_tmp[~self.voxel] = INF + x_min = x_tmp.min(axis=0) + + x_tmp[~self.voxel] = -1 + x_max = x_tmp.max(axis=0) + + y_tmp = y_coord.copy() + y_tmp[~self.voxel] = INF + y_min = y_tmp.min(axis=1) + + y_tmp[~self.voxel] = -1 + y_max = y_tmp.max(axis=1) + + z_tmp = z_coord.copy() + z_tmp[~self.voxel] = INF + z_min = z_tmp.min(axis=2) + z_tmp[~self.voxel] = -1 + z_max = z_tmp.max(axis=2) + + in_x = (x_coord >= x_min[None, :, :]) & (x_coord <= x_max[None, :, :]) + in_y = (y_coord >= y_min[:, None, :]) & (y_coord <= y_max[:, None, :]) + in_z = (z_coord >= z_min[:, :, None]) & (z_coord <= z_max[:, :, None]) + + count = in_x.astype(int) + in_y.astype(int) + in_z.astype(int) + fill_mask = count >= (3 if rigid else 2) + self._voxel = self.voxel | fill_mask + x, y, z = np.where(self.voxel) + self.coords = np.stack([x, y, z], axis=1) + + def inside(self, points: ndarray) -> ndarray: + if points.ndim == 1: + points = points[None, :] + points = np.asarray(points) + idx = np.floor((points - self.origin) / self.voxel_size).astype(int) + invalid = ( + (idx < 0).any(axis=1) | + (idx >= np.array(self.voxel.shape)).any(axis=1) + ) + result = np.zeros(len(points), dtype=bool) + valid_idx = np.where(~invalid)[0] + valid_voxel_idx = idx[valid_idx] + result[valid_idx] = self.voxel[valid_voxel_idx[:, 0], valid_voxel_idx[:, 1], valid_voxel_idx[:, 2]] + return result \ No newline at end of file diff --git a/src/rig_package/parser/__init__.py b/src/rig_package/parser/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rig_package/parser/abstract.py b/src/rig_package/parser/abstract.py new file mode 100755 index 0000000000000000000000000000000000000000..fce33a4f1333708421f125fd763d9e09a9ef4970 --- /dev/null +++ b/src/rig_package/parser/abstract.py @@ -0,0 +1,17 @@ +"""Abstract class for parsers.""" + +from abc import ABC, abstractmethod + +from ..info.asset import Asset + +class AbstractParser(ABC): + """Abstract class for parsers.""" + + @classmethod + @abstractmethod + def load(cls, filepath: str, **kwargs) -> Asset: + pass + + @classmethod + def export(cls, asset: Asset, filepath: str, **kwargs): + raise NotImplementedError("do not implement") \ No newline at end of file diff --git a/src/rig_package/parser/bpy.py b/src/rig_package/parser/bpy.py new file mode 100755 index 0000000000000000000000000000000000000000..fc33b8c66fff4eccbd60bc04f9c0a223e87ceec5 --- /dev/null +++ b/src/rig_package/parser/bpy.py @@ -0,0 +1,825 @@ +from collections import defaultdict +from numpy import ndarray +from typing import Optional, Tuple + +import bpy # type: ignore +import logging +import numpy as np +import os +import trimesh + +from .abstract import AbstractParser +from ..info.asset import Asset +from mathutils import Vector, Matrix # type: ignore + +class BpyParser(AbstractParser): + + @classmethod + def load(cls, filepath: str, **kwargs) -> Asset: + clean_bpy() + load(filepath=filepath, **kwargs) + collection = bpy.data.collections.get("glTF_not_exported") + if collection is not None: + for obj in list(collection.objects): + bpy.data.objects.remove(obj, do_unlink=True) + armature = get_armature() + if armature is None: + bones = None + joint_names = None + parents = None + lengths = None + matrix_world = np.eye(4) + matrix_local = None + matrix_basis = None + armature_name = None + else: + bones = armature.pose.bones # list of PoseBone + joint_names = [b.name for b in bones] + parents = [] + lengths = [] + matrix_world = np.array(armature.matrix_world) + obj = armature.parent + while obj is not None: + matrix_world = np.array(obj.matrix_world) @ matrix_world + obj = obj.parent + + matrix_local = [] + for pbone in bones: + matrix_local.append(np.array(pbone.bone.matrix_local)) + parents.append(joint_names.index(pbone.parent.name) if pbone.parent is not None else -1) + lengths.append(pbone.bone.length) + matrix_local = np.stack(matrix_local, axis=0) + parents = np.array(parents, dtype=np.int32) + lengths = np.array(lengths, dtype=np.float32) + + matrix_basis = get_matrix_basis(bones=bones) + armature_name = armature.name + mesh_dict = extract_mesh(bones=bones) + + return Asset( + vertices=mesh_dict['vertices'], + faces=mesh_dict['faces'], + vertex_normals=mesh_dict['vertex_normals'], + face_normals=mesh_dict['face_normals'], + vertex_bias=mesh_dict['vertex_bias'], + face_bias=mesh_dict['face_bias'], + mesh_names=mesh_dict['mesh_names'], + joint_names=joint_names, + parents=parents, + lengths=lengths, + matrix_world=matrix_world, + matrix_local=matrix_local, + matrix_basis=matrix_basis, + armature_name=armature_name, + skin=mesh_dict['skin'], + ) + + @classmethod + def export(cls, asset: Asset, filepath: str, **kwargs): + """ + If export obj, kwargs: + precision: int=6, number of decimal places for vertex coordinates + + Otherwise, export fbx/glb/gltf using bpy, kwargs: + extrude_scale: float=0.5, if there is no tails in asset, first calculate the average length between parents and sons, then the length of leaf bone is l*extrude_scale. Otherwise do not affect final results. + + connect_tail_to_unique_child: bool=False, if True, the tail of a bone with only one child will be exactly at the head of its child. + + extrude_from_parent: bool=False, if True, the orientation of the leaf bone will be the same as its parent. + + group_per_vertex: int=-1, number of the largest weights to keep for each vertex. -1 means keep all. + + add_root: bool=False, if True, add a root bone at (0, 0, 0). + + do_not_normalize: bool=False, if True, do not normalize the skinning weights. + + collection_name: str='new_collection', name of the new collection to store objects. + + add_leaf_bones: bool=False, if True, add a leaf bone at the end of each bone. + """ + ext = os.path.splitext(filepath)[1].lower() + if ext == '.obj': + cls.export_obj(asset, filepath, **kwargs) + elif ext == 'ply': + cls.export_ply(asset, filepath, **kwargs) + else: + cls.export_asset(asset, filepath, **kwargs) + + @classmethod + def export_obj( + cls, + asset: Asset, + filepath: str, + precision: int=6, + use_pc: bool=False, + use_normal: bool=False, + use_skeleton: bool=False, + normal_size: float=0.01, + ): + """ + Export the asset as an .obj file. This will ignore skeleton and skinning. + + Args: + use_normal: export normals + + use_skeleton: export skeleton + """ + asset._build_bias() + if asset.vertices is None or asset.vertex_bias is None: + raise ValueError("do not have vertices or vertex_bias") + if use_normal and asset.vertex_normals is None: + raise ValueError("use_normal is True but do not have vertex_normals") + if not filepath.lower().endswith('.obj'): + filepath += ".obj" + faces = asset.faces + mesh_names = asset.mesh_names + if mesh_names is None: + mesh_names = [f"mesh_{i}" for i in range(asset.P)] + cls._safe_make_dir(filepath) + file = open(filepath, 'w') + lines = [] + tot = 0 + if use_skeleton: + raise NotImplementedError() + for i, mesh_name in enumerate(mesh_names): + lines.append(f'o {mesh_name}\n') + if use_normal: + s = asset.get_vertex_slice(i) + for v, n in zip(asset.vertices[s], asset.vertex_normals[s]): # type: ignore + vv = v + n * normal_size + lines.append(f'v {v[0]:.{precision}f} {v[2]:.{precision}f} {-v[1]:.{precision}f}\n') + lines.append(f'v {vv[0]:.{precision}f} {vv[2]:.{precision}f} {-vv[1]:.{precision}f}\n') + lines.append(f'v {vv[0]:.{precision}f} {vv[2]:.{precision}f} {-vv[1]+0.000001:.{precision}f}\n') + lines.append(f"f {tot+1} {tot+2} {tot+3}\n") + tot += 3 + else: + for v in asset.vertices[asset.get_vertex_slice(i)]: + lines.append(f'v {v[0]:.{precision}f} {v[2]:.{precision}f} {-v[1]:.{precision}f}\n') + if faces is not None and use_pc == False: + for f in faces[asset.get_face_slice(i)]: + lines.append(f"f {f[0]+1} {f[1]+1} {f[2]+1}\n") + file.writelines(lines) + file.close() + + @classmethod + def export_ply( + cls, + asset: Asset, + filepath: str, + use_pc: bool=False, + render_skin_id: Optional[int]=None, + ): + """ + Export the asset as an .ply file. This will ignore skeleton and skinning. + """ + import open3d as o3d + asset._build_bias() + if asset.vertices is None or asset.vertex_bias is None: + raise ValueError("do not have vertices or vertex_bias") + if not filepath.lower().endswith('.ply'): + filepath += ".ply" + faces = asset.faces + if use_pc: + faces = None + mesh_names = asset.mesh_names + if mesh_names is None: + mesh_names = [f"mesh_{i}" for i in range(asset.P)] + cls._safe_make_dir(filepath) + + if render_skin_id is not None: + if asset.skin is None: + raise ValueError("render_skin_id is not None, but skin of asset is None") + colors = np.stack([ + asset.skin[:, render_skin_id], + np.zeros(asset.N), + 1-asset.skin[:, render_skin_id], + ], axis=1) + else: + colors = None + if faces is None: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(asset.vertices) + if colors is not None: + pcd.colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_point_cloud(filepath, pcd) + else: + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(asset.vertices) + mesh.triangles = o3d.utility.Vector3iVector(faces) + if colors is not None: + mesh.vertex_colors = o3d.utility.Vector3dVector(colors) + o3d.io.write_triangle_mesh(filepath, mesh) + + @classmethod + def export_asset(cls, asset: Asset, filepath: str, **kwargs): + use_origin = kwargs.pop('use_origin', False) if 'use_origin' in kwargs else False + if not use_origin: + clean_bpy() + make_asset(asset=asset, **kwargs) + cls._safe_make_dir(filepath) + + _, ext = os.path.splitext(filepath) + ext = ext.lower()[1:] + if ext == 'fbx': + if asset.joints is not None and asset.matrix_basis is not None: + logging.warning("Exporting animation, but fbx format is deprecated because the rest pose will not be exported in bpy4.2. Use glb/gltf format instead. See: https://blender.stackexchange.com/questions/273398/blender-export-fbx-lose-the-origin-rest-pose.") + bpy.ops.export_scene.fbx(filepath=filepath, check_existing=False, add_leaf_bones=kwargs.get('add_leaf_bones', False), path_mode='COPY', embed_textures=True, mesh_smooth_type="FACE") + elif ext == 'glb' or ext == 'gltf': + bpy.ops.export_scene.gltf(filepath=filepath) + else: + raise ValueError(f"Unsupported format: {ext}") + + @classmethod + def _safe_make_dir(cls, path: str): + if os.path.dirname(path) == '': + return + os.makedirs(os.path.dirname(path), exist_ok=True) + +def clean_bpy(): + """Clean all the data in bpy.""" + bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True) + data_types = [ + bpy.data.actions, + bpy.data.armatures, + bpy.data.cameras, + bpy.data.collections, + bpy.data.curves, + bpy.data.lights, + bpy.data.materials, + bpy.data.meshes, + bpy.data.objects, + bpy.data.worlds, + bpy.data.node_groups, + bpy.data.images, + bpy.data.textures, + ] + for data_collection in data_types: + for item in data_collection: + data_collection.remove(item) + +def load(filepath: str, **kwargs): + """Load a 3D file into bpy.""" + _, ext = os.path.splitext(filepath) + ext = ext.lower()[1:] + + if not os.path.exists(filepath): + raise RuntimeError(f"file does not exist: {filepath}") + + if ext == "obj": + bpy.ops.wm.obj_import(filepath=filepath) + elif ext == "fbx": + bpy.ops.import_scene.fbx( + filepath=filepath, + ignore_leaf_bones=kwargs.get('ignore_leaf_bones', False), + use_image_search=kwargs.get('use_image_search', True), + ) + elif ext == "glb" or ext == "gltf": + bpy.ops.import_scene.gltf(filepath=filepath, import_pack_images=kwargs.get('import_pack_images', False)) + elif ext == "dae": + bpy.ops.wm.collada_import(filepath=filepath) + elif ext == "blend": + with bpy.data.libraries.load(filepath) as (data_from, data_to): + data_to.objects = data_from.objects + for obj in data_to.objects: + if obj is not None: + bpy.context.collection.objects.link(obj) + elif ext == "bvh": + bpy.ops.import_anim.bvh(filepath=filepath) + else: + raise ValueError(f"unsupported type: {ext}") + +def get_armature(): + """Get the armature object in the current scene.""" + armatures = [obj for obj in bpy.context.scene.objects if obj.type == 'ARMATURE'] + if len(armatures) == 0: + return None + return armatures[0] + +def extract_mesh(bones=None): + """ + Extract vertices, face_normals, faces and skinning(if possible). + """ + meshes = [] + for v in bpy.data.objects: + if v.type == 'MESH': + meshes.append(v) + + index = {} + if bones is not None: + for (id, pbone) in enumerate(bones): + index[pbone.name] = id + total_bones = len(bones) + else: + total_bones = None + + mesh_names_list = [] + vertices_list = [] + faces_list = [] + skin_list = [] + vertex_bias = [] + face_bias = [] + cur_vertex_bias = 0 + cur_face_bias = 0 + for obj in meshes: + # directly apply mesh's transformation because armature operates on the transformed mesh + if obj.parent is not None: + m = np.linalg.inv(np.array(obj.parent.matrix_world)) @ np.array(obj.matrix_world) + else: + m = np.array(obj.matrix_world) + matrix_world_rot = m[:3, :3] + matrix_world_bias = m[:3, 3] + rot = matrix_world_rot + total_vertices = len(obj.data.vertices) + vertices = np.zeros((3, total_vertices)) + if total_bones is not None: + skin_weight = np.zeros((total_vertices, total_bones)) + else: + skin_weight = np.zeros((1, 1)) + obj_verts = obj.data.vertices + obj_group_names = [g.name for g in obj.vertex_groups] + faces = [] + normals = [] + + for polygon in obj.data.polygons: + edges = polygon.edge_keys + nodes = [] + adj = {} + for edge in edges: + if adj.get(edge[0]) is None: + adj[edge[0]] = [] + adj[edge[0]].append(edge[1]) + if adj.get(edge[1]) is None: + adj[edge[1]] = [] + adj[edge[1]].append(edge[0]) + nodes.append(edge[0]) + nodes.append(edge[1]) + normal = polygon.normal + nodes = list(set(sorted(nodes))) + first = nodes[0] + loop = [] + now = first + vis = {} + while True: + loop.append(now) + vis[now] = True + if vis.get(adj[now][0]) is None: + now = adj[now][0] + elif vis.get(adj[now][1]) is None: + now = adj[now][1] + else: + break + for (second, third) in zip(loop[1:], loop[2:]): + faces.append((first, second, third)) + normals.append(rot @ normal) + faces = np.array(faces, dtype=np.int32) + normals = np.array(normals, dtype=np.float32) + + coords = np.array([v.co for v in obj_verts]) + rot_np = np.array(rot) + coords = (rot_np @ coords.T).T + matrix_world_bias + vertices[0:3, :coords.shape[0]] = coords.T + + # extract skin + if bones is not None: + vg_lut = {} + for v in obj_verts: + for g in v.groups: + vg_lut[(v.index, g.group)] = g.weight + + for bone in bones: + if bone.name not in obj_group_names: + continue + gidx = obj.vertex_groups[bone.name].index + col = index[bone.name] + for v in obj_verts: + w = vg_lut.get((v.index, gidx)) + if w is not None: + skin_weight[v.index, col] = w + + vertices = vertices.T + # determine the orientation of the face normal + v0 = vertices[faces[:, 0]] + v1 = vertices[faces[:, 1]] + v2 = vertices[faces[:, 2]] + cross = np.cross(v1-v0, v2-v0) + dot = np.einsum("ij,ij->i", cross, normals) + correct_faces = faces.copy() + mask = dot < 0 + correct_faces[mask, 1], correct_faces[mask, 2] = faces[mask, 2], faces[mask, 1] + + mesh_names_list.append(obj.name) + vertices_list.append(vertices) + faces_list.append(correct_faces+cur_vertex_bias) # add bias to faces + if total_bones is not None: + skin_list.append(skin_weight) + cur_vertex_bias += len(vertices) + cur_face_bias += len(faces) + vertex_bias.append(cur_vertex_bias) + face_bias.append(cur_face_bias) + + vertices = np.vstack(vertices_list) + faces = np.vstack(faces_list) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, maintain_order=True) + vertex_normals = mesh.vertex_normals + face_normals = mesh.face_normals + + return { + 'mesh_names': np.array(mesh_names_list), + 'vertices': vertices, + 'faces': faces, + 'face_normals': face_normals, + 'vertex_normals': vertex_normals, + 'skin': np.vstack(skin_list) if len(skin_list) > 0 else None, + 'vertex_bias': np.array(vertex_bias), + 'face_bias': np.array(face_bias), + } + +def get_matrix_basis(bones=None): + if bones is None: + return None + if bpy.data.actions is not None and len(bpy.data.actions) > 0: + action = bpy.data.actions[0] + frames = int(action.frame_range.y - action.frame_range.x) + else: + return None + + J = len(bones) + matrix_basis = np.zeros((frames, J, 4, 4)) + matrix_basis[...] = np.eye(4) + for frame in range(frames): + bpy.context.scene.frame_set(frame + 1) + for (id, pbone) in enumerate(bones): + matrix_basis[frame, id] = np.array(pbone.matrix_basis) + return matrix_basis + +def make_asset( + asset: Asset, + extrude_scale: float=0.5, + connect_tail_to_unique_child: bool=False, + extrude_from_parent: bool=False, + group_per_vertex: int=-1, + add_root: bool=False, + do_not_normalize: bool=False, + collection_name: str='new_collection', + use_face: bool=True, +): + """ + Args: + + extrude_scale: float=0.5, if there is no tails in asset, first calculate the average length between parents and sons, then the length of leaf bone is l*extrude_scale. Otherwise do not affect final results. + + connect_tail_to_unique_child: bool=False, if True, the tail of a bone with only one child will be exactly at the head of its child. + + extrude_from_parent: bool=False, if True, the orientation of the leaf bone will be the same as its parent. + + group_per_vertex: int=-1, number of the largest weights to keep for each vertex. -1 means keep all. + + add_root: bool=False, if True, add a root bone at (0, 0, 0). + + do_not_normalize: bool=False, if True, do not normalize the skinning weights. + + collection_name: str='new_collection', name of the new collection to store objects. + + use_face: bool=True, if False, do not export faces. + """ + + collection = bpy.data.collections.new(collection_name) + bpy.context.scene.collection.children.link(collection) + + # 1. if there are meshes, make meshes + + objects = [] + mesh_names = [] + for v in bpy.data.objects: + if v.type == 'MESH': + objects.append(v) + mesh_names.append(v.name) + + if len(objects) == 0: + mesh_names = [f"mesh_{i}" for i in range(asset.P)] + if len(objects)==0 and asset.vertices is not None: + + if asset.mesh_names is not None: + mesh_names = asset.mesh_names + + for i in range(asset.P): + mesh = bpy.data.meshes.new(f"data_{mesh_names[i]}") + v = asset.vertices[asset.get_vertex_slice(i)] + if not use_face or (asset.faces is None or asset.face_bias is None or asset.vertex_bias is None): + mesh.from_pydata(v, [], []) + else: + if i == 0: + mesh.from_pydata(v, [], asset.faces[asset.get_face_slice(i)]) + else: + mesh.from_pydata(v, [], asset.faces[asset.get_face_slice(i)]-asset.vertex_bias[i-1]) + mesh.update() + + # make object from mesh + object = bpy.data.objects.new(mesh_names[i], mesh) + objects.append(object) + + # add object to scene collection + collection.objects.link(object) + + # 2. if there is armature, process tails and make armature + if len(bpy.data.armatures) > 0: + armature = bpy.data.armatures[0] + armature_name = armature.name + joint_names = [b.name for b in armature.bones] + else: + armature = None + armature_name = 'Armature' + joint_names = asset.joint_names if asset.joint_names is not None else [f"bone_{i}" for i in range(asset.J)] + + if armature is None and asset.joints is not None and asset.parents is not None: + joints = asset.joints + if asset.tails is None: + tails = joints.copy() + connect_tail_to_unique_child = True + extrude_from_parent = True + else: + tails = asset.tails + + root_tail = False + root_id = asset.root + + length_sum = 0. + sons = defaultdict(list) + for i in range(len(asset.parents)): + p = asset.parents[i] + if p == -1: + continue + sons[p].append(i) + length_sum += np.linalg.norm(joints[i] - joints[p]) + if asset.J <= 1: + length = 1.0 + else: + length_avg = length_sum / max(len(asset.parents) - 1, 1) + length = length_avg * extrude_scale + + for i in range(len(asset.parents)): + p = asset.parents[i] + if p == -1: + continue + sons[p].append(i) + d = np.linalg.norm(joints[i] - joints[p]) + if d <= length * 1e-2: + max_d = max(length, 1e-5) + joints[i] += np.random.randn(3) * max_d * 1e-2 + if connect_tail_to_unique_child: + for i in range(len(asset.parents)): + if len(sons[i]) == 1: + child = sons[i][0] + tails[i] = joints[child] + if root_id == i: + root_tail = True + + if extrude_from_parent: + for i in range(len(asset.parents)): + if len(sons[i]) != 1 and asset.parents[i] != -1: + p = asset.parents[i] + d = joints[i] - joints[p] + if np.linalg.norm(d) < 1e-6: + d = np.array([0., 0., 1.]) # in case son.head == parent.head + else: + d = d / np.linalg.norm(d) + tails[i] = joints[i] + d * length + if root_tail is False: + tails[root_id] = joints[root_id] + np.array([0., 0., length]) + bpy.ops.object.armature_add(enter_editmode=True) + armature = bpy.data.armatures.get('Armature') + armature_name = asset.armature_name if asset.armature_name is not None else 'Armature' + + edit_bones = armature.edit_bones + + if add_root: + bone_root = edit_bones.get('Bone') + root_name = 'Root' + x = 0 + while root_name in joint_names: + root_name = f'Root_{x}' + x += 1 + bone_root.name = root_name + bone_root.tail = Vector((joints[0, 0], joints[0, 1], joints[0, 2])) + else: + bone_root = edit_bones.get('Bone') + bone_root.name = joint_names[0] + bone_root.head = Vector((joints[0, 0], joints[0, 1], joints[0, 2])) + bone_root.tail = Vector((tails[0, 0], tails[0, 1], tails[0, 2])) + + def extrude_bone( + edit_bones, + name: str, + parent_name: str, + head: Tuple[float, float, float], + tail: Tuple[float, float, float], + ): + bone = edit_bones.new(name) + bone.head = Vector((head[0], head[1], head[2])) + bone.tail = Vector((tail[0], tail[1], tail[2])) + bone.name = name + parent_bone = edit_bones.get(parent_name) + bone.parent = parent_bone + bone.use_connect = False + assert not np.isnan(head).any(), f"nan found in head of bone {name}" + assert not np.isnan(tail).any(), f"nan found in tail of bone {name}" + + for u in asset.dfs_order: + if add_root is False and u==0: + continue + pname = joint_names[u] if asset.parents[u] == -1 else joint_names[asset.parents[u]] + extrude_bone(edit_bones, joint_names[u], pname, joints[u], tails[u]) + bpy.ops.object.mode_set(mode='OBJECT') + + # 3. if there is skin, set vertex groups + if asset.skin is not None and armature is not None and len(objects) > 0: + # must set to object mode to enable parent_set + bpy.ops.object.mode_set(mode='OBJECT') + N = len(objects) + objects = bpy.data.objects + for o in bpy.context.selected_objects: + o.select_set(False) + for i in range(N): + skin = asset.skin[asset.get_vertex_slice(i)] + ob = objects[mesh_names[i]] + armature_b = bpy.data.objects[armature_name] + ob.select_set(True) + armature_b.select_set(True) + bpy.ops.object.parent_set(type='ARMATURE_NAME') + # sparsify + argsorted = np.argsort(-skin, axis=1) + vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted] + group_per_vertex = min(group_per_vertex, skin.shape[1]) + if group_per_vertex == -1: + group_per_vertex = vertex_group_reweight.shape[-1] + if not do_not_normalize: + vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None] + # clean vertex groups first in case skin exists + for name in joint_names: + ob.vertex_groups[name].remove(range(990)) + for v, w in enumerate(skin): + for ii in range(group_per_vertex): + j = argsorted[v, ii] + n = joint_names[j] + ob.vertex_groups[n].add([v], vertex_group_reweight[v, ii], 'REPLACE') + + def to_matrix(x: ndarray): + return Matrix((x[0, :], x[1, :], x[2, :], x[3, :])) + + if asset.matrix_world is None: + matrix_world = to_matrix(np.eye(4)) + else: + matrix_world = to_matrix(asset.matrix_world) + if armature is not None: + bpy.data.objects[armature_name].matrix_world = matrix_world + + # 4. if there is animation, set keyframes + if asset.matrix_basis is not None and asset.matrix_local is not None and armature is not None: + matrix_basis = asset.matrix_basis + matrix_local = asset.matrix_local + objects = bpy.data.objects + for o in bpy.context.selected_objects: + o.select_set(False) + armature = bpy.data.objects[armature_name] + armature.select_set(True) + armature.matrix_world = matrix_world + frames = matrix_basis.shape[0] + + # change matrix_local + bpy.context.view_layer.objects.active = armature + bpy.ops.object.mode_set(mode='EDIT') + for (id, name) in enumerate(joint_names): + # matrix_local of pose bone + bpy.context.active_object.data.edit_bones[id].matrix = to_matrix(matrix_local[id]) + bpy.ops.object.mode_set(mode='OBJECT') + for (id, name) in enumerate(joint_names): + pbone = armature.pose.bones.get(name) + for frame in range(frames): + bpy.context.scene.frame_set(frame + 1) + q = to_matrix(matrix_basis[frame, id]) + if pbone.rotation_mode == "QUATERNION": + pbone.rotation_quaternion = q.to_quaternion() + pbone.keyframe_insert(data_path = 'rotation_quaternion') + else: + pbone.rotation_euler = q.to_euler() + pbone.keyframe_insert(data_path = 'rotation_euler') + pbone.location = q.to_translation() + pbone.keyframe_insert(data_path = 'location') + bpy.ops.object.mode_set(mode='OBJECT') + +def _umeyama_similarity(src: ndarray, tgt: ndarray) -> ndarray: + assert src.shape == tgt.shape + n = src.shape[0] + src_mean = src.mean(axis=0) + tgt_mean = tgt.mean(axis=0) + src_c = src - src_mean + tgt_c = tgt - tgt_mean + + # cross-covariance + C = (src_c.T @ tgt_c) / n + U, S, Vt = np.linalg.svd(C) + R = Vt.T @ U.T + if np.linalg.det(R) < 0: + Vt[-1, :] *= -1 + R = Vt.T @ U.T + var_src = (src_c ** 2).sum() / n + scale = S.sum() / var_src + t = tgt_mean - scale * R @ src_mean + T = np.eye(4) + T[:3, :3] = scale * R + T[:3, 3] = t + return T + +def _pca_similarity( + src: ndarray, + tgt: ndarray, + max_points: int=4096, +) -> ndarray: + if src.shape[0] > max_points: + src = src[np.random.choice(src.shape[0], max_points, replace=False)] + if tgt.shape[0] > max_points: + tgt = tgt[np.random.choice(tgt.shape[0], max_points, replace=False)] + src_mean = src.mean(axis=0) + tgt_mean = tgt.mean(axis=0) + src_c = src - src_mean + tgt_c = tgt - tgt_mean + U_src, _, _ = np.linalg.svd(src_c.T @ src_c) + U_tgt, _, _ = np.linalg.svd(tgt_c.T @ tgt_c) + R = U_tgt @ U_src.T + if np.linalg.det(R) < 0: + U_tgt[:, -1] *= -1 + R = U_tgt @ U_src.T + scale = np.sqrt((tgt_c ** 2).sum() / (src_c ** 2).sum()) + t = tgt_mean - scale * R @ src_mean + T = np.eye(4) + T[:3, :3] = scale * R + T[:3, 3] = t + return T + +def estimate_similarity_transform( + src: ndarray, + tgt: ndarray, + max_points: int=4096, +) -> ndarray: + """ + src: (N, 3) + tgt: (M, 3) + return: (4, 4) similarity transform matrix + """ + if src.shape[0] == tgt.shape[0]: + return _umeyama_similarity(src, tgt) + return _pca_similarity(src, tgt, max_points) + +def transfer_rigging( + source_asset: Asset, + target_path: str, + export_path: str, + **kwargs, +): + assert source_asset.matrix_local is not None + assert source_asset.parents is not None + + target_asset = BpyParser.load(filepath=target_path) + bpy.ops.outliner.orphans_purge(do_local_ids=True, do_linked_ids=True, do_recursive=True) + data_types = [ + bpy.data.actions, + bpy.data.armatures, + ] + for data_collection in data_types: + for item in data_collection: + data_collection.remove(item) + + source_vertices = source_asset.vertices # (n, 3) + target_vertices = target_asset.vertices # (m, 3) + assert source_vertices is not None and target_vertices is not None + target_asset.matrix_local = source_asset.matrix_local.copy() + target_asset.matrix_basis = source_asset.matrix_basis.copy() if source_asset.matrix_basis is not None else None + + source_joints = source_asset.joints + assert source_joints is not None + + max_points = kwargs.pop('max_points', 4096) if kwargs.get('max_points') is not None else 4096 + T = estimate_similarity_transform(src=source_vertices, tgt=target_vertices, max_points=max_points) + source_joints_h = np.concatenate([ + source_joints, np.ones((len(source_joints), 1)) + ], axis=1) + target_joints = (T @ source_joints_h.T).T[:, :3] + target_asset.matrix_local[:, :3, 3] = target_joints + target_asset.parents = source_asset.parents.copy() + target_asset.lengths = source_asset.lengths.copy() if source_asset.lengths is not None else None + target_asset.joint_names = source_asset.joint_names.copy() if source_asset.joint_names is not None else None + + if source_asset.skin is not None: + from scipy.spatial import cKDTree + source_skin = source_asset.skin # (n, J) + + source_vertices_h = np.concatenate([ + source_vertices, np.ones((len(source_vertices), 1)) + ], axis=1) + source_vertices = (T @ source_vertices_h.T).T[:, :3] + tree = cKDTree(source_vertices) + dists, idx = tree.query(target_vertices, k=1) + target_asset.skin = source_skin[idx] + + BpyParser.export(target_asset, export_path, use_origin=True, **kwargs) + clean_bpy() \ No newline at end of file diff --git a/src/rig_package/utils.py b/src/rig_package/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..80a89f9fa5c3a3e2a6ef0646071ff638e640f3d1 --- /dev/null +++ b/src/rig_package/utils.py @@ -0,0 +1,312 @@ + +from numpy import ndarray +from typing import Optional, Tuple + +import numpy as np +import scipy + +def assert_ndarray(arr, name: str="arr", shape: Optional[Tuple[int, ...]]=None, dtype=None): + if not isinstance(arr, np.ndarray): + raise ValueError(f"{name} must be a numpy.ndarray or None, got {type(arr)}") + if shape is not None: + # shape may contain None as wildcard + if len(shape) != arr.ndim: + raise ValueError(f"{name}: expected shape length {len(shape)} but array ndim is {arr.ndim}") + for i, (exp, actual) in enumerate(zip(shape, arr.shape)): + if exp > 0 and exp != actual: + raise ValueError(f"{name} shape mismatch at axis {i}: expected {exp}, got {actual}") + if dtype is not None: + if not np.issubdtype(arr.dtype, dtype): + raise ValueError(f"{name} dtype must be {dtype}, got {arr.dtype}") + +def assert_list(arr, name: str="arr", dtype=None): + if not isinstance(arr, list): + raise ValueError(f"found type {type(arr)}, expect a list") + if dtype is not None: + for x in arr: + if not isinstance(x, dtype): + raise ValueError(f"found type {type(x)} in {name}, expect all to be {dtype}") + +def linear_blend_skinning( + vertices: ndarray, + matrix_local: ndarray, + matrix: ndarray, + skin: ndarray, + pad: int=1, + value: float=1.0, +) -> ndarray: + """ + Args: + vertices: (N, 4-pad) + matrix_local: (J, 4, 4) + matrix: (J, 4, 4) + skin: (N, J) + pad: 0 or 1 + value: value to pad + Returns: + (N, 3) vertices using LBS algorithm: Skinning with dual quaternions, Kavan, 2007 + """ + J = matrix_local.shape[0] + N = vertices.shape[0] + assert_ndarray(vertices, name='vertices', shape=(N, 3)) + assert_ndarray(matrix_local, name="matrix_local", shape=(J, 4, 4)) + assert_ndarray(matrix, name="matrix", shape=(J, 4, 4)) + assert_ndarray(skin, name="skin", shape=(N, J)) + assert vertices.shape[-1] + pad == 4 + # (4, N) + padded = np.pad(vertices, ((0, 0), (0, pad)), 'constant', constant_values=(0, value)).T + # (J, 4, 4) + trans = matrix @ np.linalg.inv(matrix_local) + weighted_per_bone_matrix = [] + # (J, N) + mask = (skin > 0).T + for i in range(J): + offset = np.zeros((4, N), dtype=np.float32) + offset[:, mask[i]] = (trans[i] @ padded[:, mask[i]]) * skin.T[i, mask[i]] + weighted_per_bone_matrix.append(offset) + weighted_per_bone_matrix = np.stack(weighted_per_bone_matrix) + g = np.sum(weighted_per_bone_matrix, axis=0) + final = g[:3, :] / (np.sum(skin, axis=1) + 1e-8) + return final.T + +def axis_angle_to_matrix(axis_angle: ndarray) -> ndarray: + """ + Turn axis angle representation to matrix representation. + """ + res = np.pad(scipy.spatial.transform.Rotation.from_rotvec(axis_angle).as_matrix(), ((0, 0), (0, 1), (0, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0))) + assert res.ndim == 3 + res[:, -1, -1] = 1 + return res + +def sample_surface( + num_samples: int, + vertices: ndarray, + faces: ndarray, + mask: Optional[ndarray]=None, +) -> Tuple[ndarray, ndarray, ndarray]: + ''' + Randomly pick samples proportional to face area. + + See sample_surface: https://github.com/mikedh/trimesh/blob/main/trimesh/sample.py + + Args: + mask: (num_faces,), only sample points on the faces where value is True. + Return: + vertex_samples: sampled vertices + + original_face_index: on which face is sampled + + random_lengths: sampled vectors on face + ''' + original_face_indices = np.arange(len(faces)) + # sample according to mask + if mask is not None: + assert_ndarray(arr=mask, name="mask", shape=(faces.shape[0],)) + original_face_indices = original_face_indices[mask] + faces = faces[mask] + + # get face area + offset_0 = vertices[faces[:, 1]] - vertices[faces[:, 0]] + offset_1 = vertices[faces[:, 2]] - vertices[faces[:, 0]] + # TODO: change to correct uniform sampling... + face_weight = np.linalg.norm(np.cross(offset_0, offset_1, axis=-1), axis=-1) + + weight_cum = np.cumsum(face_weight, axis=0) + face_pick = np.random.rand(num_samples) * weight_cum[-1] + face_index = np.searchsorted(weight_cum, face_pick) + + # face_weight = np.cross(offset_0, offset_1, axis=-1) + # face_weight = (face_weight * face_weight).sum(axis=1) + + # weight_cum = np.cumsum(face_weight, axis=0) + # face_pick = np.random.rand(num_samples) * weight_cum[-1] + # face_index = np.searchsorted(weight_cum, face_pick) + + # map face_index back to original indices + original_face_index = original_face_indices[face_index] + + # pull triangles into the form of an origin + 2 vectors + tri_origins = vertices[faces[:, 0]] + tri_vectors = vertices[faces[:, 1:]] + tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) + + # pull the vectors for the faces we are going to sample from + tri_origins = tri_origins[face_index] + tri_vectors = tri_vectors[face_index] + + # randomly generate two 0-1 scalar components to multiply edge vectors b + random_lengths = np.random.rand(len(tri_vectors), 2, 1) + + random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 + random_lengths[random_test] -= 1.0 + random_lengths = np.abs(random_lengths) + + sample_vector = (tri_vectors * random_lengths).sum(axis=1) + vertex_samples = sample_vector + tri_origins + return vertex_samples, original_face_index, random_lengths + +def sample_barycentric( + vertex_group: ndarray, + faces: ndarray, + face_index: ndarray, + random_lengths: ndarray, +) -> ndarray: + v_origins = vertex_group[faces[face_index, 0]] + v_vectors = vertex_group[faces[face_index, 1:]] + v_vectors -= v_origins[:, np.newaxis, :] + + sample_vector = (v_vectors * random_lengths).sum(axis=1) + v_samples = sample_vector + v_origins + return v_samples + +def sample_vertex_groups( + vertices: ndarray, + faces: ndarray, + num_samples: int, + num_vertex_samples: Optional[int]=None, + vertex_normals: Optional[ndarray]=None, + face_normals: Optional[ndarray]=None, + vertex_groups: Optional[ndarray]=None, + face_mask: Optional[ndarray]=None, + shuffle: bool=True, + same: bool=False, +) -> Tuple[ndarray, ndarray|None, ndarray|None]: + """ + Choose num_samples samples on the mesh and get their positions and normals. + If vertex_group is provided, get its weights using barycentric sampling. + + Return: + sampled_vertices, sampled_normals, sampled_vertex_groups + + Args: + vertices: (N, 3) + + faces: (F, 3) + + num_samples: how many samples + + num_vertex_samples: + At most num_vertex_samples unique vertices to be included, + these points will be concatenated in the last (if shuffle is False). + + vertex_normals: (N, 3), sampled_normals will be None if not provided + + face_normals: (N, 3), sampled_normals will be None if not provided + + vertex_groups: (N, m), sampled_vertex_groups will be None if not provided + + face_mask: + (F,) or (F, m), if shape is (F,), use the same mask across all + vertex groups. Only sample on faces where value is True. + + shuffle: shuffle samples in the end + + same: + Sample on the same locations, only useful when using mutiple + vertex groups and mask is None or shape of (F,). + """ + + if num_vertex_samples is None: + num_vertex_samples = 0 + if num_vertex_samples > num_samples: + raise ValueError(f"num_vertex_samples cannot be larger than num_samples, found: {num_vertex_samples} > {num_samples}") + + def get_mask_perm(mask: Optional[ndarray]): + if mask is None: + vertex_mask = np.arange(vertices.shape[0]) + else: + vertex_mask = np.unique(mask) + perm = np.random.permutation(vertex_mask.shape[0]) + return vertex_mask[perm[:num_vertex_samples]] + + if vertex_groups is not None: + if vertex_groups.ndim == 1: + assert_ndarray(arr=vertex_groups, name="vertex_groups", shape=(vertices.shape[0],)) + vertex_groups = vertex_groups[:, None] + else: + assert_ndarray(arr=vertex_groups, name="vertex_groups", shape=(vertices.shape[0], -1)) + vertex_groups = vertex_groups + + if vertex_groups is not None: + if face_mask is not None: + if face_mask.ndim == 1: + assert_ndarray(arr=face_mask, name="mask", shape=(faces.shape[0],)) + else: + assert_ndarray(arr=face_mask, name="mask", shape=(faces.shape[0], vertex_groups.shape[1])) + list_sampled_vertices = [] + list_sampled_normals = [] + list_sampled_vertex_groups = [] + perm = None + _mask = None + same = same and (face_mask is None or (face_mask is not None and face_mask.ndim != 2)) + for i in range(vertex_groups.shape[1]): + if face_mask is not None: + if face_mask.ndim == 1: + perm = get_mask_perm(faces[face_mask]) + _mask = face_mask + else: + perm = get_mask_perm(faces[face_mask[:, i]]) + _mask = face_mask[:, i] + else: + perm = get_mask_perm(None) + _mask = None + _num_samples = num_samples - len(perm) + + face_vertices, face_index, random_lengths = sample_surface( + num_samples=_num_samples, + vertices=vertices, + faces=faces, + mask=_mask, + ) + + list_sampled_vertices.append(np.concatenate([vertices[perm], face_vertices], axis=0)) + if vertex_normals is not None and face_normals is not None: + list_sampled_normals.append(np.concatenate([vertex_normals[perm], face_normals[face_index]], axis=0)) + + if same: + g = sample_barycentric( + vertex_group=vertex_groups, + faces=faces, + face_index=face_index, + random_lengths=random_lengths, + ) + list_sampled_vertex_groups.append(np.concatenate([vertex_groups[perm], g], axis=0)) + break + g = sample_barycentric( + vertex_group=vertex_groups[:, i:i+1], + faces=faces, + face_index=face_index, + random_lengths=random_lengths, + )[:, 0] + list_sampled_vertex_groups.append(np.concatenate([vertex_groups[:, i][perm], g], axis=0)) + sampled_vertices = np.stack(list_sampled_vertices, axis=1) + if len(list_sampled_normals) > 0: + sampled_normals = np.stack(list_sampled_normals, axis=1) + else: + sampled_normals = None + if same: + sampled_vertex_groups = list_sampled_vertex_groups[0] + else: + sampled_vertex_groups = np.stack(list_sampled_vertex_groups, axis=1) + else: # otherwise only sample vertices and normals + if face_mask is not None: + assert_ndarray(arr=face_mask, name="mask", shape=(faces.shape[0],)) + perm = get_mask_perm(faces[face_mask]) + else: + perm = get_mask_perm(None) + num_samples -= len(perm) + n_vertex = vertices[perm] + face_vertices, face_index, random_lengths = sample_surface( + num_samples=num_samples, + vertices=vertices, + faces=faces, + mask=face_mask, + ) + sampled_vertices = np.concatenate([n_vertex, face_vertices], axis=0) + if vertex_normals is not None and face_normals is not None: + sampled_normals = np.concatenate([vertex_normals[perm], face_normals[face_index]], axis=0) + else: + sampled_normals = None + sampled_vertex_groups = None + + return sampled_vertices, sampled_normals, sampled_vertex_groups \ No newline at end of file diff --git a/src/server/__init__.py b/src/server/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/server/bpy_server.py b/src/server/bpy_server.py new file mode 100755 index 0000000000000000000000000000000000000000..20c7541d9c192f366477685d5647d724381603dd --- /dev/null +++ b/src/server/bpy_server.py @@ -0,0 +1,68 @@ +from bottle import request, response + +import bottle +import queue +import threading + +from .spec import bytes_to_object, object_to_bytes, BPY_PORT + +from ..rig_package.parser.bpy import BpyParser, transfer_rigging + +def run(): + path_queue = queue.Queue() + result_queue = queue.Queue() + + app = bottle.Bottle() + + @app.route('/load', method='GET') # type: ignore + def load(): + data = request.body.read() # type: ignore + path_queue.put(('load', data)) + res = result_queue.get() + payload = object_to_bytes(res) + response.content_type = 'application/octet-stream' # type: ignore + return payload + + @app.route('/ping', method='GET') # type: ignore + def ping(): + return 'pong' + + @app.route('/export', method='post') # type: ignore + def export(): + data = request.body.read() # type: ignore + path_queue.put(('export', data)) + res = result_queue.get() + payload = object_to_bytes(res) + response.content_type = 'application/octet-stream' # type: ignore + return payload + + @app.route('/transfer', method='post') # type: ignore + def transfer(): + data = request.body.read() # type: ignore + path_queue.put(('transfer', data)) + res = result_queue.get() + payload = object_to_bytes(res) + response.content_type = 'application/octet-stream' # type: ignore + return payload + + def run_server(): bottle.run(app, host='0.0.0.0', port=BPY_PORT, server='tornado') + threading.Thread(target=run_server, daemon=False).start() + + while True: + d = path_queue.get() + op = d[0] + data = bytes_to_object(d[1]) + if op == 'load': + print("[SERVER] received load path:", data) + asset = BpyParser.load(data) + result_queue.put(asset) + elif op == 'export': + print("[SERVER] received export path:", data['filepath']) + BpyParser.export(**data) + result_queue.put('ok') + elif op == 'transfer': + print("[SERVER] received transfer path:", data['target_path']) + transfer_rigging(**data) + result_queue.put('ok') + else: + result_queue.put(f"unsupported op: {str(op)}") \ No newline at end of file diff --git a/src/server/spec.py b/src/server/spec.py new file mode 100755 index 0000000000000000000000000000000000000000..f2ca6d72eb17b346e23f24038216eb103f370fb3 --- /dev/null +++ b/src/server/spec.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from torch import Tensor +from typing import Dict, Optional, List, Tuple + +import io +import os +import torch + +from ..rig_package.info.asset import Asset +from ..model.tokenrig import TokenRig + +PORT = 59875 +SERVER = f"http://localhost:{PORT}" +TMP_CKPT_DIR = "./tmp_ckpt" + +BPY_PORT = 59876 +BPY_SERVER = f"http://localhost:{BPY_PORT}" + +@dataclass +class TensorPacket: + """make sure stays on cpu""" + validate: bool=False + know_skeleton: bool=False + learned_mesh_cond: Optional[Tensor]=None + cond_latents: Optional[Tensor]=None + mesh_cond: Optional[Tensor]=None + vertices: Optional[Tensor]=None + assets: Optional[List[Asset]]=None + output_ids: Optional[Tensor]=None + start_embed_list: Optional[List[Tensor]]=None + start_tokens_list: Optional[List[List[int]]]=None + + def to_device(self, device): + if self.learned_mesh_cond is not None: + self.learned_mesh_cond = self.learned_mesh_cond.to(device) + if self.cond_latents is not None: + self.cond_latents = self.cond_latents.to(device) + if self.mesh_cond is not None: + self.mesh_cond = self.mesh_cond.to(device) + if self.vertices is not None: + self.vertices = self.vertices.to(device) + if self.output_ids is not None: + self.output_ids = self.output_ids.to(device) + if self.start_embed_list is not None: + self.start_embed_list = [x.to(device) for x in self.start_embed_list] + + @property + def B(self): + assert self.learned_mesh_cond is not None + return self.learned_mesh_cond.shape[0] + + def to_bytes(self): + return object_to_bytes(self) + + @classmethod + def from_bytes(cls, bytes) -> 'TensorPacket': + return bytes_to_object(bytes) + + +def object_to_bytes(t): + buffer = io.BytesIO() + torch.save(t, buffer) + return buffer.getvalue() + +def bytes_to_object(b, map_location=None): + return torch.load(io.BytesIO(b), weights_only=False, map_location=map_location) + +def get_model( + ckpt_path: str, + hf_path: Optional[str]=None, + device='cuda', +) -> TokenRig: + model = TokenRig.load_from_system_checkpoint(checkpoint_path=ckpt_path) + if hf_path is not None: + from transformers import AutoModel + a = AutoModel.from_pretrained( + hf_path, + local_files_only=True, + _attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + ) + model.transformer.model.load_state_dict(a.state_dict()) + + model = model.to(device) + return model diff --git a/src/tokenizer/__init__.py b/src/tokenizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/tokenizer/parse.py b/src/tokenizer/parse.py new file mode 100755 index 0000000000000000000000000000000000000000..c07d958d14fc904022351b0288eeed9d85c2baca --- /dev/null +++ b/src/tokenizer/parse.py @@ -0,0 +1,13 @@ +from copy import deepcopy + +from .spec import Tokenizer +from .tokenizer_part import TokenizerPart + +def get_tokenizer(**kwargs) -> Tokenizer: + __target__ = kwargs.get('__target__') + assert __target__ is not None, "do not find `__target__` in tokenizer config" + del kwargs['__target__'] + MAP = { + 'tokenizer_part': TokenizerPart, + } + return MAP[__target__].parse(**deepcopy(kwargs)) diff --git a/src/tokenizer/spec.py b/src/tokenizer/spec.py new file mode 100755 index 0000000000000000000000000000000000000000..d4cc653519937bb98d6059d765686e2ea53b89dc --- /dev/null +++ b/src/tokenizer/spec.py @@ -0,0 +1,267 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Dict + +import numpy as np +from numpy import ndarray + +from typing import Union, List, Tuple, Optional +from dataclasses import dataclass + +@dataclass +class TokenizeInput(): + # (J, 3) + joints: ndarray + + # (J) + parents: List[Union[None, int]] + + # string of class in tokenizer + cls: Optional[str]=None + + joint_names: Optional[List[str]]=None + + @property + def J(self) -> int: + return self.joints.shape[0] + + @property + def branch(self) -> ndarray: + if not hasattr(self, '_branch'): + branch = [] + last = None + for i in range(self.J): + if i == 0: + branch.append(False) + else: + pid = self.parents[i] + branch.append(pid!=last) + last = i + self._branch = np.array(branch, dtype=bool) + return self._branch + + @property + def bones(self): + _p = self.parents.copy() + _p[0] = 0 + return np.concatenate([self.joints[_p], self.joints], axis=1) + + @property + def num_bones(self): + return self.bones.shape[0] + +@dataclass +class DetokenizeOutput(): + # original tokens + tokens: ndarray + + # (J, 6), (parent position, position) + bones: ndarray + + # (J), parent of each bone + parents: List[int] + + # string of class in tokenizer + cls: Optional[str]=None + + # names of joints + joint_names: Optional[List[str]]=None + + continuous_range: Optional[Tuple[float, float]]=None + + @property + def joints(self): + return self.bones[:, 3:] + + @property + def p_joints(self): + return self.bones[:, :3] + + @property + def num_bones(self): + return self.bones.shape[0] + + @property + def J(self): + return self.bones.shape[0] + + def _get_parents(self) -> List[int]: + parents = [] + for (i, bone) in enumerate(self.bones): + p_joint = bone[:3] + dis = 999999 + pid = -1 + for j in reversed(range(i)): + n_dis = ((self.bones[j][3:] - p_joint)**2).sum() + if n_dis < dis: + pid = j + dis = n_dis + parents.append(pid) + return parents + +class Tokenizer(ABC): + """ + Abstract class for tokenizer + """ + + @classmethod + @abstractmethod + def parse(cls, **kwags) -> 'Tokenizer': + pass + + @abstractmethod + def tokenize(self, input: TokenizeInput) -> ndarray: + pass + + @abstractmethod + def detokenize(self, ids: ndarray, **kwargs) -> DetokenizeOutput: + pass + + @property + @abstractmethod + def vocab_size(self) -> int: + """The vocabulary size""" + raise NotImplementedError() + + @property + def pad(self): + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + + @property + def bos(self): + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + + @property + def eos(self): + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + + def cls_name_to_token(self, cls: str) -> int: + raise NotImplementedError() + + def next_posible_token(self, ids: ndarray) -> List[int]: + raise NotImplementedError() + + def bones_in_sequence(self, ids: ndarray) -> int: + raise NotImplementedError() + + def make_cls_head(self, **kwargs) -> List[int]: + raise NotImplementedError() + +def make_skeleton( + joints: ndarray, + p_joints: ndarray, + tails_dict: Dict[int, ndarray], + convert_leaf_bones_to_tails: bool, + extrude_tail_for_leaf: bool, + extrude_tail_for_branch: bool, + extrude_scale: float=0.5, + strict: bool=False, +) -> Tuple[ndarray, ndarray, List[int], List[int]]: + ''' + Args: + joints: heads of bones + + p_joints: parent position of joints + + tails_dict: tail position of the i-th joint + + convert_leaf_bones_to_tails: remove leaf bones and make them tails of their parents + + extrude_tail_for_leaf: add a tail for leaf bone + + extrude_tail_for_branch: add a tail for joint with multiple children + + extrude_scale: length scale of tail offset + + strict: if true, raise error when there are joints in the same location + + Returns: + bones, tails, available_bones_id, parents + ''' + assert (convert_leaf_bones_to_tails & extrude_tail_for_leaf)==False, 'cannot extrude tail for leaf when convert_leaf_bones_to_tails is True' + assert joints.shape[0] == p_joints.shape[0] + # build parents + bones = [] # (parent_position, position) + parents = [] + for (i, joint) in enumerate(joints): + if len(bones) == 0: + bones.append(np.concatenate([joint, joint])) # root + parents.append(-1) + continue + p_joint = p_joints[i] + dis = 999999 + pid = None + for j in reversed(range(i)): + n_dis = ((bones[j][3:] - p_joint)**2).sum() + if n_dis < dis: + pid = j + dis = n_dis + bones.append(np.concatenate([joints[pid], joint])) + parents.append(pid) + bones = np.stack(bones) + + children = defaultdict(list) + for (i, pid) in enumerate(parents): + if pid == -1: + continue + children[pid].append(i) + + available_bones_id = [] + if convert_leaf_bones_to_tails: + for (i, pid) in enumerate(parents): + if len(children[i]) != 0: + available_bones_id.append(i) + continue + tails_dict[pid] = bones[i, 3:] + else: + available_bones_id = [i for i in range(bones.shape[0])] + + # tail for leaf + for (i, pid) in enumerate(parents): + if len(children[i]) != 0: + continue + if extrude_tail_for_leaf: + d = bones[i, 3:] - bones[pid, 3:] + length = np.linalg.norm(d) + if strict: + assert length > 1e-9, 'two joints in the same point found' + elif length <= 1e-9: + d = np.array([0., 0., 1.]) + tails_dict[i] = bones[i, 3:] + d * extrude_scale + else: + tails_dict[i] = bones[i, 3:] + + # tail for branch + for (i, pid) in enumerate(parents): + if len(children[i]) <= 1: + continue + if extrude_tail_for_branch: + if pid == -1: # root + av_len = 0 + for child in children[i]: + av_len += np.linalg.norm(bones[i, 3:] - bones[child, 3:]) + av_len /= len(children[i]) + d = bones[i, 3:] + np.array([0., 0., extrude_scale * av_len]) + else: + d = bones[i, 3:] - bones[pid, 3:] + length = np.linalg.norm(d) + if strict: + assert length > 1e-9, 'two joints in the same point found' + elif length <= 1e-9: + d = np.array([0., 0., 1.]) + tails_dict[i] = bones[i, 3:] + d * extrude_scale + else: + tails_dict[i] = bones[i, 3:] + + # assign new tail + for (i, pid) in enumerate(parents): + if len(children[i]) != 1: + continue + child = children[i][0] + tails_dict[i] = bones[child, 3:] + + tails = [] + for i in range(bones.shape[0]): + tails.append(tails_dict[i]) + tails = np.stack(tails) + return bones, tails, available_bones_id, parents \ No newline at end of file diff --git a/src/tokenizer/tokenizer_part.py b/src/tokenizer/tokenizer_part.py new file mode 100755 index 0000000000000000000000000000000000000000..5e91e3c37be61653dec2be3263bfd2b53f90e5fc --- /dev/null +++ b/src/tokenizer/tokenizer_part.py @@ -0,0 +1,410 @@ +from dataclasses import dataclass, field +from numpy import ndarray +from typing import Dict, Tuple, Union, List, Optional + +import numpy as np + +from .spec import Tokenizer, TokenizeInput, DetokenizeOutput +from .spec import make_skeleton +from ..data.order import Order + +@dataclass +class TokenizerPart(Tokenizer): + + # cls token id + cls_token_id: Dict[str, int] + + # parts token id + parts_token_id: Dict[str, int] + part_token_to_name: Dict[int, str] + cls_token_to_name: Dict[int, str] + + parts_token_id_name: List[str] + + # normalization range + continuous_range: Tuple[float, float] + + # coordinate discrete + num_discrete: int + + token_id_branch: int + token_id_bos: int + token_id_eos: int + token_id_pad: int + token_id_spring: int + token_id_cls_none: int + + _vocab_size: int + + order: Optional[Order]=None + + @classmethod + def parse( + cls, + **kwargs, + ): + num_discrete = kwargs.pop('num_discrete') + continuous_range = kwargs.pop('continuous_range') + cls_token_id = kwargs.pop('cls_token_id') + parts_token_id = kwargs.pop('parts_token_id') + order = kwargs.get('order') + if order is not None: + assert isinstance(order, Order) + _offset = num_discrete + + token_id_branch = _offset + 0 + token_id_bos = _offset + 1 + token_id_eos = _offset + 2 + token_id_pad = _offset + 3 + _offset += 4 + + token_id_spring = _offset + 0 + _offset += 1 + + assert None not in parts_token_id + for i in parts_token_id: + parts_token_id[i] += _offset + _offset += len(parts_token_id) + + token_id_cls_none = _offset + 0 + _offset += 1 + + for i in cls_token_id: + cls_token_id[i] += _offset + _offset += len(cls_token_id) + + _vocab_size = _offset + + parts_token_id_name = [x for x in parts_token_id] + + part_token_to_name = {v: k for k, v in parts_token_id.items()} + assert len(part_token_to_name) == len(parts_token_id), 'names with same token found in parts_token_id' + part_token_to_name[token_id_spring] = None + + cls_token_to_name = {v: k for k, v in cls_token_id.items()} + assert len(cls_token_to_name) == len(cls_token_id), 'names with same token found in cls_token_id' + return TokenizerPart( + num_discrete=num_discrete, + continuous_range=continuous_range, + cls_token_id=cls_token_id, + parts_token_id=parts_token_id, + order=order, + token_id_branch=token_id_branch, + token_id_bos=token_id_bos, + token_id_eos=token_id_eos, + token_id_pad=token_id_pad, + token_id_spring=token_id_spring, + token_id_cls_none=token_id_cls_none, + parts_token_id_name=parts_token_id_name, + part_token_to_name=part_token_to_name, + cls_token_to_name=cls_token_to_name, + _vocab_size=_vocab_size, + ) + + def make_cls_head(self, **kwargs) -> List[int]: + cls = kwargs.get('cls', None) + if cls is not None: + return [self.cls_name_to_token(cls=cls)] + return [self.token_id_cls_none] + + def cls_name_to_token(self, cls: str) -> int: + if cls not in self.cls_token_id: + return self.token_id_cls_none + return self.cls_token_id[cls] + + def part_name_to_token(self, part: str) -> int: + assert part in self.parts_token_id, f"do not find part name `{part}` in tokenizer" + return self.parts_token_id[part] + + def next_posible_token(self, ids: ndarray) -> List[int]: + if ids.shape[0] == 0 or ids.ndim == 0: + return [self.token_id_bos] + assert ids.ndim == 1, "expect an array" + state = 'expect_bos' + for id in ids: + if state == 'expect_bos': + assert id == self.token_id_bos, 'ids do not start with bos' + state = 'expect_cls_or_part_or_joint' + elif state == 'expect_cls_or_part_or_joint': + if id < self.num_discrete: + state = 'expect_joint_2' + elif id == self.token_id_cls_none or id in self.cls_token_id.values(): + state = 'expect_part_or_joint' + else: # a part + state = 'expect_joint' + elif state == 'expect_part_or_joint': + if id < self.num_discrete: + state = 'expect_joint_2' + else: + state = 'expect_part_or_joint' + elif state == 'expect_joint_2': + state = 'expect_joint_3' + elif state == 'expect_joint_3': + state = 'expect_branch_or_part_or_joint' + elif state == 'expect_branch_or_part_or_joint': + if id == self.token_id_branch: + state = 'expect_joint' + elif id < self.num_discrete: + state = 'expect_joint_2' + else: # find a part + state = 'expect_joint' + elif state == 'expect_joint': + state = 'expect_joint_2' + else: + assert 0, state + s = [] + def add_cls(): + s.append(self.token_id_cls_none) + for v in self.cls_token_id.values(): + s.append(v) + def add_part(): + s.append(self.token_id_spring) + for v in self.parts_token_id.values(): + s.append(v) + def add_joint(): + for i in range(self.num_discrete): + s.append(i) + def add_branch(): + s.append(self.token_id_branch) + def add_eos(): + s.append(self.token_id_eos) + def add_bos(): + s.append(self.token_id_bos) + if state == 'expect_bos': + add_bos() + elif state == 'expect_cls_or_part_or_joint': + add_cls() + add_part() + add_joint() + elif state == 'expect_cls': + add_cls() + elif state == 'expect_part_or_joint': + add_part() + add_joint() + add_eos() + elif state == 'expect_joint_2': + add_joint() + elif state == 'expect_joint_3': + add_joint() + elif state == 'expect_branch_or_part_or_joint': + add_joint() + add_part() + add_branch() + add_eos() + elif state == 'expect_joint': + add_joint() + else: + assert 0, state + return s + + def bones_in_sequence(self, ids: ndarray): + assert ids.ndim == 1, "expect an array" + s = 0 + is_branch = False + state = 'expect_bos' + for id in ids: + if state == 'expect_bos': + assert id == self.token_id_bos, 'ids do not start with bos' + state = 'expect_cls_or_part_or_joint' + elif state == 'expect_cls_or_part_or_joint': + if id < self.num_discrete: + state = 'expect_joint_2' + elif id == self.token_id_cls_none or id in self.cls_token_id.values(): + state = 'expect_part_or_joint' + else: # a part + state = 'expect_joint' + elif state == 'expect_part_or_joint': + if id < self.num_discrete: + state = 'expect_joint_2' + else: + state = 'expect_part_or_joint' + elif state == 'expect_joint_2': + state = 'expect_joint_3' + elif state == 'expect_joint_3': + if not is_branch: + s += 1 + is_branch = False + state = 'expect_branch_or_part_or_joint' + elif state == 'expect_branch_or_part_or_joint': + if id == self.token_id_branch: + state = 'expect_joint' + is_branch = True + elif id < self.num_discrete: + state = 'expect_joint_2' + else: # find a part + state = 'expect_joint' + elif state == 'expect_joint': + state = 'expect_joint_2' + else: + assert 0, state + if id == self.token_id_eos: + break + return s + + def tokenize(self, input: TokenizeInput) -> ndarray: + num_bones = input.num_bones + bones = discretize(t=input.bones, continuous_range=self.continuous_range, num_discrete=self.num_discrete) + + branch = input.branch + + tokens = [self.token_id_bos] + if input.cls is None or input.cls not in self.cls_token_id: + tokens.append(self.token_id_cls_none) + else: + tokens.append(self.cls_token_id[input.cls]) + if self.order is not None and input.cls is not None and input.joint_names is not None: + _, parts_bias = self.order.arrange_names(cls=input.cls, names=input.joint_names, parents=input.parents) + else: + parts_bias = [] + for i in range(num_bones): + # add parts token id + if i in parts_bias: + part = parts_bias[i] + if part is None: + tokens.append(self.token_id_spring) + else: + assert part in self.parts_token_id, f"do not find part name {part} in tokenizer {self.__class__}" + tokens.append(self.parts_token_id[part]) + if branch[i]: + tokens.append(self.token_id_branch) + tokens.append(bones[i, 0]) + tokens.append(bones[i, 1]) + tokens.append(bones[i, 2]) + tokens.append(bones[i, 3]) + tokens.append(bones[i, 4]) + tokens.append(bones[i, 5]) + else: + tokens.append(bones[i, 3]) + tokens.append(bones[i, 4]) + tokens.append(bones[i, 5]) + tokens.append(self.token_id_eos) + return np.array(tokens, dtype=np.int64) + + def detokenize(self, ids: ndarray, **kwargs) -> DetokenizeOutput: + assert isinstance(ids, ndarray), 'expect ids to be ndarray' + if ids[0] != self.token_id_bos: + raise ValueError(f"first token is not bos") + trailing_pad = 0 + while trailing_pad < ids.shape[0] and ids[-trailing_pad-1] == self.token_id_pad: + trailing_pad += 1 + if ids[-1-trailing_pad] != self.token_id_eos: + raise ValueError(f"last token is not eos") + ids = ids[1:-1-trailing_pad] + joints = [] + p_joints = [] + tails_dict = {} + parts = [] + i = 0 + is_branch = False + last_joint = None + num_bones = 0 + cls = None + while i < len(ids): + if ids[i] < self.num_discrete: + if is_branch: + p_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + current_joint = undiscretize(t=ids[i+3:i+6], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + joints.append(current_joint) + p_joints.append(p_joint) + i += 6 + else: + current_joint = undiscretize(t=ids[i:i+3], continuous_range=self.continuous_range, num_discrete=self.num_discrete) + joints.append(current_joint) + if len(p_joints) == 0: # root + p_joints.append(current_joint) + p_joint = current_joint + else: + assert last_joint is not None + p_joints.append(last_joint) + p_joint = last_joint + i += 3 + if last_joint is not None: + tails_dict[num_bones-1] = current_joint + last_joint = current_joint + num_bones += 1 + is_branch = False + elif ids[i]==self.token_id_branch: + is_branch = True + last_joint = None + i += 1 + elif ids[i]==self.token_id_spring or ids[i] in self.parts_token_id.values(): + parts.append(self.part_token_to_name[ids[i]]) + i += 1 + elif ids[i] in self.cls_token_id.values(): + cls = ids[i] + i += 1 + elif ids[i] == self.token_id_cls_none: + cls = None + i += 1 + else: + raise ValueError(f"unexpected token found: {ids[i]}") + joints = np.stack(joints) + p_joints = np.stack(p_joints) + # leaf is ignored in this tokenizer so need to extrude tails for leaf and branch + bones, tails, available_bones_id, parents = make_skeleton( + joints=joints, + p_joints=p_joints, + tails_dict=tails_dict, + convert_leaf_bones_to_tails=False, + extrude_tail_for_leaf=True, + extrude_tail_for_branch=True, + ) + bones = bones[available_bones_id] + tails = tails[available_bones_id] + if cls in self.cls_token_to_name: + cls = self.cls_token_to_name[cls] + else: + cls = None + if self.order is not None: + joint_names = self.order.make_names(cls=cls, parts=parts, num_bones=num_bones) + else: + joint_names = [f"bone_{i}" for i in range(num_bones)] + return DetokenizeOutput( + tokens=ids, + bones=bones, + parents=parents, + cls=cls, + joint_names=joint_names, + continuous_range=self.continuous_range, + ) + + def get_require_parts(self) -> List[str]: + return self.parts_token_id_name + + @property + def vocab_size(self): + return self._vocab_size + + @property + def pad(self): + return self.token_id_pad + + @property + def bos(self): + return self.token_id_bos + + @property + def eos(self): + return self.token_id_eos + +def discretize( + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, +) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = (t - lo) / (hi - lo) + t *= num_discrete + return np.clip(t.round(), 0, num_discrete - 1).astype(np.int64) + +def undiscretize( + t: ndarray, + continuous_range: Tuple[float, float], + num_discrete: int, +) -> ndarray: + lo, hi = continuous_range + assert hi >= lo + t = t.astype(np.float32) + 0.5 + t /= num_discrete + return t * (hi - lo) + lo