Spaces:
Running
on
Zero
Running
on
Zero
daidedou
commited on
Commit
·
458efe2
1
Parent(s):
cf4ac70
first_try
Browse files- app.py +272 -0
- notebook_helpers.py +88 -0
- requirements.txt +45 -0
- shape_models/__init__.py +3 -0
- shape_models/encoder.py +24 -0
- shape_models/fmap.py +186 -0
- shape_models/geometry.py +897 -0
- shape_models/layers.py +453 -0
- shape_models/utils.py +123 -0
- utils/__init__.py +3 -0
- utils/descriptors.py +250 -0
- utils/eval.py +25 -0
- utils/fmap.py +121 -0
- utils/geometry.py +951 -0
- utils/io.py +78 -0
- utils/layers.py +430 -0
- utils/mesh.py +214 -0
- utils/meshplot.py +67 -0
- utils/misc.py +122 -0
- utils/pickle_stuff.py +38 -0
- utils/surfaces.py +1377 -0
- utils/torch_fmap.py +77 -0
- utils/utils_func.py +123 -0
- utils/utils_legacy.py +130 -0
- zero_shot.py +402 -0
app.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Gradio app for two-mesh initialization and run phases.
|
| 3 |
+
- Upload two meshes (.ply, .obj, .off)
|
| 4 |
+
- (Optional) upload a YAML config to override defaults
|
| 5 |
+
- Adjust a few numeric settings (sane ranges). Defaults pulled from the provided YAML when present.
|
| 6 |
+
- Click **Init** to generate "initialization maps" (here: position/normal-based vertex colors) for both meshes.
|
| 7 |
+
- Click **Run** to simulate an iterative evolution with a progress bar, then output another pair of colored meshes.
|
| 8 |
+
|
| 9 |
+
Replace the bodies of `make_initialization_maps` and `run_evolution` with your real pipeline as needed.
|
| 10 |
+
|
| 11 |
+
Tested with: gradio >= 4.0, trimesh, pyyaml, numpy.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
import os
|
| 15 |
+
import io
|
| 16 |
+
import time
|
| 17 |
+
import json
|
| 18 |
+
import tempfile
|
| 19 |
+
from typing import Dict, Tuple, Optional
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import numpy as np
|
| 23 |
+
import trimesh
|
| 24 |
+
import zero_shot
|
| 25 |
+
import yaml
|
| 26 |
+
from utils.surfaces import Surface
|
| 27 |
+
import notebook_helpers as helper
|
| 28 |
+
from utils.meshplot import visu_pts
|
| 29 |
+
from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
|
| 30 |
+
import torch
|
| 31 |
+
import argparse
|
| 32 |
+
# -----------------------------
|
| 33 |
+
# Utils
|
| 34 |
+
# -----------------------------
|
| 35 |
+
SUPPORTED_EXTS = {".ply", ".obj", ".off", ".stl", ".glb", ".gltf"}
|
| 36 |
+
|
| 37 |
+
def _safe_ext(path: str) -> str:
|
| 38 |
+
for ext in SUPPORTED_EXTS:
|
| 39 |
+
if path.lower().endswith(ext):
|
| 40 |
+
return ext
|
| 41 |
+
return os.path.splitext(path)[1].lower()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def normalize_vertices(vertices: np.ndarray) -> np.ndarray:
|
| 45 |
+
v = vertices.astype(np.float64)
|
| 46 |
+
v = v - v.mean(axis=0, keepdims=True)
|
| 47 |
+
scale = np.linalg.norm(v, axis=1).max()
|
| 48 |
+
if scale == 0:
|
| 49 |
+
scale = 1.0
|
| 50 |
+
v = v / scale
|
| 51 |
+
return v.astype(np.float32)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def ensure_vertex_colors(mesh: trimesh.Trimesh, colors: np.ndarray) -> trimesh.Trimesh:
|
| 56 |
+
out = mesh.copy()
|
| 57 |
+
if colors.shape[1] == 3:
|
| 58 |
+
rgba = np.concatenate([colors, 255*np.ones((colors.shape[0],1), dtype=np.uint8)], axis=1)
|
| 59 |
+
else:
|
| 60 |
+
rgba = colors
|
| 61 |
+
out.visual.vertex_colors = rgba
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def export_for_view(surf: Surface, colors: np.ndarray, basename: str, outdir: str) -> Tuple[str, str]:
|
| 66 |
+
"""Export to PLY (with vertex colors) and GLB for Model3D preview."""
|
| 67 |
+
glb_path = os.path.join(outdir, f"{basename}.glb")
|
| 68 |
+
mesh = trimesh.Trimesh(surf.vertices, surf.faces, process=False)
|
| 69 |
+
colored_mesh = ensure_vertex_colors(mesh, colors)
|
| 70 |
+
colored_mesh.export(glb_path)
|
| 71 |
+
return glb_path
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# -----------------------------
|
| 75 |
+
# Algorithm placeholders (replace with your real pipeline)
|
| 76 |
+
# -----------------------------
|
| 77 |
+
DEFAULT_SETTINGS = {
|
| 78 |
+
"deepfeat_conf.fmap.lambda_": 1,
|
| 79 |
+
"sds_conf.zoomout": 40.0,
|
| 80 |
+
"diffusion.time": 1.0,
|
| 81 |
+
"opt.n_loop": 300,
|
| 82 |
+
"loss.sds": 1.0,
|
| 83 |
+
"loss.proper": 1.0,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
FLOAT_SLIDERS = {
|
| 87 |
+
# name: (min, max, step)
|
| 88 |
+
"deepfeat_conf.fmap.lambda_": (1e-3, 10.0, 1e-3),
|
| 89 |
+
"sds_conf.zoomout": (1e-3, 10.0, 1e-3),
|
| 90 |
+
"diffusion.time": (1e-3, 10.0, 1e-3),
|
| 91 |
+
"loss.sds": (1e-3, 10.0, 1e-3),
|
| 92 |
+
"loss.proper": (1e-3, 10.0, 1e-3),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
INT_SLIDERS = {
|
| 96 |
+
"opt.n_loop": (1, 5000, 1),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def flatten_yaml_floats(d: Dict, prefix: str = "") -> Dict[str, float]:
|
| 101 |
+
flat = {}
|
| 102 |
+
for k, v in d.items():
|
| 103 |
+
key = f"{prefix}.{k}" if prefix else str(k)
|
| 104 |
+
if isinstance(v, dict):
|
| 105 |
+
flat.update(flatten_yaml_floats(v, key))
|
| 106 |
+
elif isinstance(v, (int, float)):
|
| 107 |
+
flat[key] = float(v)
|
| 108 |
+
return flat
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def read_yaml_defaults(yaml_path: Optional[str]) -> Dict[str, float]:
|
| 112 |
+
if yaml_path and os.path.exists(yaml_path):
|
| 113 |
+
with open(yaml_path, "r") as f:
|
| 114 |
+
data = yaml.safe_load(f)
|
| 115 |
+
flat = flatten_yaml_floats(data)
|
| 116 |
+
# Only keep known keys we expose as controls
|
| 117 |
+
defaults = DEFAULT_SETTINGS.copy()
|
| 118 |
+
for k in list(DEFAULT_SETTINGS.keys()):
|
| 119 |
+
if k in flat:
|
| 120 |
+
defaults[k] = float(flat[k])
|
| 121 |
+
return defaults
|
| 122 |
+
return DEFAULT_SETTINGS.copy()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Datadicts:
|
| 128 |
+
def __init__(self, shape_path, target_path):
|
| 129 |
+
self.shape_path = shape_path
|
| 130 |
+
basename_1 = os.path.basename(shape_path)
|
| 131 |
+
self.shape_dict, _ = helper.load_data(shape_path, "tmp/" + os.path.splitext(basename_1)[0]+".npz", "source", make_cache=True)
|
| 132 |
+
self.shape_surf = Surface(filename=shape_path)
|
| 133 |
+
self.target_path = target_path
|
| 134 |
+
basename_2 = os.path.basename(target_path)
|
| 135 |
+
self.target_dict, _ = helper.load_data(target_path, "tmp/" + os.path.splitext(basename_2)[0]+".npz", "target", make_cache=True)
|
| 136 |
+
self.target_surf = Surface(filename=target_path)
|
| 137 |
+
self.cmap1 = visu_pts(self.shape_surf)
|
| 138 |
+
|
| 139 |
+
# -----------------------------
|
| 140 |
+
# Gradio UI
|
| 141 |
+
# -----------------------------
|
| 142 |
+
TMP_ROOT = tempfile.mkdtemp(prefix="meshapp_")
|
| 143 |
+
|
| 144 |
+
def save_array_txt(arr):
|
| 145 |
+
# Create a temporary file with .txt suffix
|
| 146 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
|
| 147 |
+
np.savetxt(f, arr.astype(int), fmt="%d") # save as text
|
| 148 |
+
return f.name
|
| 149 |
+
|
| 150 |
+
def build_outputs(surf_a: Surface, surf_b: Surface, cmap_a: np.ndarray, p2p: np.ndarray, tag: str) -> Tuple[str, str, str, str]:
|
| 151 |
+
outdir = os.path.join(TMP_ROOT, tag)
|
| 152 |
+
os.makedirs(outdir, exist_ok=True)
|
| 153 |
+
glb_a = export_for_view(surf_a, cmap_a, f"A_{tag}", outdir)
|
| 154 |
+
cmap_b = cmap_a[p2p]
|
| 155 |
+
glb_b = export_for_view(surf_b, cmap_b, f"B_{tag}", outdir)
|
| 156 |
+
out_file = save_array_txt(p2p)
|
| 157 |
+
return glb_a, glb_b, out_file
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def init_clicked(mesh1_path, mesh2_path,
|
| 161 |
+
lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val):
|
| 162 |
+
cfg.deepfeat_conf.fmap.lambda_ = lambda_val
|
| 163 |
+
cfg.sds_conf.zoomout = zoomout_val
|
| 164 |
+
cfg.deepfeat_conf.fmap.diffusion.time = time_val
|
| 165 |
+
cfg.opt.n_loop = nloop_val
|
| 166 |
+
cfg.loss.sds = sds_val
|
| 167 |
+
cfg.loss.proper = proper_val
|
| 168 |
+
matcher.reconf(cfg)
|
| 169 |
+
if not mesh1_path or not mesh2_path:
|
| 170 |
+
raise gr.Error("Please upload both meshes.")
|
| 171 |
+
matcher._init()
|
| 172 |
+
global datadicts
|
| 173 |
+
datadicts = Datadicts(mesh1_path, mesh2_path)
|
| 174 |
+
|
| 175 |
+
C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = matcher.fmap_model({"shape1": datadicts.shape_dict, "shape2": datadicts.target_dict}, diff_model=matcher.diffusion_model, scale=matcher.fmap_cfg.diffusion.time)
|
| 176 |
+
C12_pred, C12_obj, mask_12 = C12_pred_init
|
| 177 |
+
p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
|
| 178 |
+
return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def run_clicked(mesh1_path, mesh2_path, yaml_path, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val, progress=gr.Progress(track_tqdm=True)):
|
| 182 |
+
if not mesh1_path or not mesh2_path:
|
| 183 |
+
raise gr.Error("Please upload both meshes.")
|
| 184 |
+
|
| 185 |
+
cfg.deepfeat_conf.fmap.lambda_ = lambda_val
|
| 186 |
+
cfg.sds_conf.zoomout = zoomout_val
|
| 187 |
+
cfg.deepfeat_conf.fmap.diffusion.time = time_val
|
| 188 |
+
cfg.opt.n_loop = nloop_val
|
| 189 |
+
cfg.loss.sds = sds_val
|
| 190 |
+
cfg.loss.proper = proper_val
|
| 191 |
+
matcher.reconf(cfg)
|
| 192 |
+
if not mesh1_path or not mesh2_path:
|
| 193 |
+
raise gr.Error("Please upload both meshes.")
|
| 194 |
+
matcher._init()
|
| 195 |
+
global datadicts
|
| 196 |
+
if datadicts is None:
|
| 197 |
+
datadicts = Datadicts(mesh1_path, mesh2_path)
|
| 198 |
+
elif datadicts is not None:
|
| 199 |
+
if not (datadicts.shape_path == mesh1_path and datadicts.target_path == mesh2_path):
|
| 200 |
+
datadicts = Datadicts(mesh1_path, mesh2_path)
|
| 201 |
+
|
| 202 |
+
target_normals = torch.from_numpy(datadicts.target_surf.surfel/np.linalg.norm(datadicts.target_surf.surfel, axis=-1, keepdims=True)).float().to("cuda")
|
| 203 |
+
|
| 204 |
+
C12_new, p2p, p2p_init, _, loss_save = matcher.optimize(datadicts.shape_dict, datadicts.target_dict, target_normals)
|
| 205 |
+
evecs1, evecs2 = datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"]
|
| 206 |
+
evecs_2trans = evecs2.t() @ torch.diag(datadicts.target_dict["mass"])
|
| 207 |
+
C12_end_zo = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze()[:15, :15], 150)# matcher.cfg.sds_conf.zoomout)
|
| 208 |
+
p2p_zo, _ = extract_p2p_torch_fmap(C12_end_zo, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
|
| 209 |
+
return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_zo, tag="run")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
with gr.Blocks(title="DiffuMatch demo") as demo:
|
| 213 |
+
text_in = "Upload two meshes and try our ICCV zero-shot method DiffuMatch! \n"
|
| 214 |
+
text_in += "*Init* will give you a rough correspondence, and you can click on *Run* to see if our method is able to match the two shapes! \n"
|
| 215 |
+
text_in += "*Recommended*: The method requires that the meshes are aligned (rotation-wise) to work well. Also might not work with scans (but try it out!)."
|
| 216 |
+
gr.Markdown(text_in)
|
| 217 |
+
with gr.Row():
|
| 218 |
+
mesh1 = gr.File(label="Mesh A (.ply/.obj/.off)")
|
| 219 |
+
mesh2 = gr.File(label="Mesh B (.ply/.obj/.off)")
|
| 220 |
+
yaml_file = gr.File(label="Optional YAML config", file_types=[".yaml", ".yml"], visible=True)
|
| 221 |
+
# except Exception:
|
| 222 |
+
with gr.Accordion("Settings", open=True):
|
| 223 |
+
with gr.Row():
|
| 224 |
+
lambda_val = gr.Slider(minimum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][0], maximum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][1], step=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][2], value=1, label="deepfeat_conf.fmap.lambda_")
|
| 225 |
+
zoomout_val = gr.Slider(minimum=FLOAT_SLIDERS["sds_conf.zoomout"][0], maximum=FLOAT_SLIDERS["sds_conf.zoomout"][1], step=FLOAT_SLIDERS["sds_conf.zoomout"][2], value=40, label="sds_conf.zoomout")
|
| 226 |
+
time_val = gr.Slider(minimum=FLOAT_SLIDERS["diffusion.time"][0], maximum=FLOAT_SLIDERS["diffusion.time"][1], step=FLOAT_SLIDERS["diffusion.time"][2], value=1, label="diffusion.time")
|
| 227 |
+
with gr.Row():
|
| 228 |
+
nloop_val = gr.Slider(minimum=INT_SLIDERS["opt.n_loop"][0], maximum=INT_SLIDERS["opt.n_loop"][1], step=INT_SLIDERS["opt.n_loop"][2], value=300, label="opt.n_loop")
|
| 229 |
+
sds_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.sds"][0], maximum=FLOAT_SLIDERS["loss.sds"][1], step=FLOAT_SLIDERS["loss.sds"][2], value=1, label="loss.sds")
|
| 230 |
+
proper_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.proper"][0], maximum=FLOAT_SLIDERS["loss.proper"][1], step=FLOAT_SLIDERS["loss.proper"][2], value=1, label="loss.proper")
|
| 231 |
+
|
| 232 |
+
with gr.Row():
|
| 233 |
+
init_btn = gr.Button("Init", variant="primary")
|
| 234 |
+
run_btn = gr.Button("Run", variant="secondary")
|
| 235 |
+
|
| 236 |
+
gr.Markdown("### Outputs\nEach stage exports both **GLB** (preview below) and **PLY** (download links) with per‑vertex colors.")
|
| 237 |
+
with gr.Tab("Init"):
|
| 238 |
+
with gr.Row():
|
| 239 |
+
init_view_a = gr.Model3D(label="Shape")
|
| 240 |
+
init_view_b = gr.Model3D(label="Target correspondence (init)")
|
| 241 |
+
with gr.Row():
|
| 242 |
+
out_file_init = gr.File(label="Download correspondences TXT")
|
| 243 |
+
with gr.Tab("Run"):
|
| 244 |
+
with gr.Row():
|
| 245 |
+
run_view_a = gr.Model3D(label="Shape")
|
| 246 |
+
run_view_b = gr.Model3D(label="Target correspondence (run)")
|
| 247 |
+
with gr.Row():
|
| 248 |
+
out_file = gr.File(label="Download correspondences TXT")
|
| 249 |
+
|
| 250 |
+
init_btn.click(
|
| 251 |
+
fn=init_clicked,
|
| 252 |
+
inputs=[mesh1, mesh2, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
|
| 253 |
+
outputs=[init_view_a, init_view_b, out_file_init],
|
| 254 |
+
api_name="init",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
run_btn.click(
|
| 258 |
+
fn=run_clicked,
|
| 259 |
+
inputs=[mesh1, mesh2, yaml_file, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
|
| 260 |
+
outputs=[run_view_a, run_view_b, out_file],
|
| 261 |
+
api_name="run",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
parser = argparse.ArgumentParser(description="Launch the gradio demo")
|
| 266 |
+
parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
|
| 267 |
+
parser.add_argument('--share', action="store_true")
|
| 268 |
+
args = parser.parse_args()
|
| 269 |
+
cfg = OmegaConf.load(args.config)
|
| 270 |
+
matcher = zero_shot.Matcher(cfg)
|
| 271 |
+
datadicts = None
|
| 272 |
+
demo.launch(share=args.share)
|
notebook_helpers.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.mesh import load_mesh
|
| 2 |
+
from utils.geometry import get_operators, load_operators
|
| 3 |
+
import os
|
| 4 |
+
from utils.utils_func import convert_dict
|
| 5 |
+
from utils.surfaces import Surface
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
device = "cuda:0"
|
| 10 |
+
|
| 11 |
+
def load_data(file, cache_path, name, num_evecs=128, make_cache=False, factor=None):
|
| 12 |
+
if factor is None:
|
| 13 |
+
verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
|
| 14 |
+
else:
|
| 15 |
+
verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True, scale=False)
|
| 16 |
+
verts_shape = verts_shape/factor
|
| 17 |
+
area_shape /= factor**2
|
| 18 |
+
# print("Cache is: ", cache_path)
|
| 19 |
+
if not os.path.exists(cache_path) or make_cache:
|
| 20 |
+
print("Computing operators ...")
|
| 21 |
+
get_operators(verts_shape, faces, num_evecs, cache_path, vnormals)
|
| 22 |
+
data_dict = load_operators(cache_path)
|
| 23 |
+
data_dict['name'] = name
|
| 24 |
+
data_dict['normals'] = vnormals
|
| 25 |
+
data_dict['vertices'] = verts_shape
|
| 26 |
+
data_dict_torch = convert_dict(data_dict, device)
|
| 27 |
+
#batchify_dict(data_dict_torch)
|
| 28 |
+
return data_dict_torch, area_shape
|
| 29 |
+
|
| 30 |
+
def get_map_info(file_1, file_2, dict_1, dict_2, dataset):
|
| 31 |
+
shape_dict, target_dict = dict_1, dict_2
|
| 32 |
+
name_1, name_2 = shape_dict["name"], target_dict["name"]
|
| 33 |
+
if dataset is None:
|
| 34 |
+
vts_1, vts_2 = np.arange(shape_dict['vertices'].shape[0]), np.arange(target_dict['vertices'].shape[0])
|
| 35 |
+
map_info = (file_1, file_2, vts_1, vts_2)
|
| 36 |
+
file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
|
| 37 |
+
vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
|
| 38 |
+
|
| 39 |
+
file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
|
| 40 |
+
vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
|
| 41 |
+
if "DT4D" in dataset:
|
| 42 |
+
file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
|
| 43 |
+
vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
|
| 44 |
+
|
| 45 |
+
file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
|
| 46 |
+
vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
|
| 47 |
+
if name_1 == name_2:
|
| 48 |
+
map_info = (file_1, file_2, vts_1, vts_2)
|
| 49 |
+
elif ("crypto" in name_1) or ("crypto" in name_2):
|
| 50 |
+
name_cat_1, name_cat_2 = name_1.split(os.sep)[0], name_2.split(os.sep)[0]
|
| 51 |
+
data_path = os.path.dirname(os.path.dirname(os.path.dirname(file_1)))
|
| 52 |
+
map_file = os.path.join(data_path, "correspondences/cross_category_corres", f"{name_cat_1}_{name_cat_2}.vts")
|
| 53 |
+
if os.path.exists(map_file):
|
| 54 |
+
map_idx = np.loadtxt(map_file).astype(np.int32) - 1
|
| 55 |
+
map_info = (file_1, file_2, vts_1, vts_2[map_idx])
|
| 56 |
+
else:
|
| 57 |
+
print("NO GROUND TRUTH PAIRS")
|
| 58 |
+
else:
|
| 59 |
+
map_info = (file_1, file_2, vts_1, vts_2)
|
| 60 |
+
return map_info
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_pair(cache, id_1, id_2, name_1, name_2, dataset):
|
| 64 |
+
if "SCAPE" in dataset:
|
| 65 |
+
os.makedirs(cache, exist_ok=True)
|
| 66 |
+
cache_file = os.path.join(cache, f"mesh{id_1:03d}_mesh_256k_0n.npz")
|
| 67 |
+
shape_file = f"data/{dataset}/shapes/mesh{id_1:03d}.ply"
|
| 68 |
+
shape_surf = Surface(filename=shape_file)
|
| 69 |
+
shape_dict, _ = load_data(shape_file, cache_file, str(id_1))
|
| 70 |
+
|
| 71 |
+
cache_file = os.path.join(cache, f"mesh{id_2:03d}_mesh_256k_0n.npz")
|
| 72 |
+
target_file = f"data/{dataset}/shapes/mesh{id_2:03d}.ply"
|
| 73 |
+
target_surf = Surface(filename=target_file)
|
| 74 |
+
target_dict, _ = load_data(target_file, cache_file, str(id_2))
|
| 75 |
+
map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "SCAPE")
|
| 76 |
+
elif "DT4D" in dataset:
|
| 77 |
+
cache_file = os.path.join(cache, f"{name_1}_mesh_256k_0n.npz")
|
| 78 |
+
shape_file = f"data/DT4D_r_ori/shapes/{name_1}.ply"
|
| 79 |
+
shape_surf = Surface(filename=shape_file)
|
| 80 |
+
shape_dict, _ = load_data(shape_file, cache_file, name_1)
|
| 81 |
+
cache_file = os.path.join(cache, f"{name_2}_mesh_256k_0n.npz")
|
| 82 |
+
# cache_file = f"../nonrigiddiff/cache/snk/{name_2}.npz"
|
| 83 |
+
cache_file = f"../nonrigiddiff/cache/attentive/{name_2}_mesh_256k_0n.npz"
|
| 84 |
+
target_file = f"data/DT4D_r_ori/shapes/{name_2}.ply"
|
| 85 |
+
target_surf = Surface(filename=target_file)
|
| 86 |
+
target_dict, _ = load_data(target_file, cache_file, name_2)
|
| 87 |
+
map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "DT4D")
|
| 88 |
+
return shape_surf, target_surf, shape_dict, target_dict, map_info
|
requirements.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
comm==0.2.1
|
| 2 |
+
configer
|
| 3 |
+
configparser==6.0.0
|
| 4 |
+
contourpy==1.1.1
|
| 5 |
+
cycler==0.12.1
|
| 6 |
+
fonttools==4.47.2
|
| 7 |
+
freetype-py==2.4.0
|
| 8 |
+
imageio==2.33.1
|
| 9 |
+
ipygany==0.5.0
|
| 10 |
+
ipywidgets==8.1.7
|
| 11 |
+
nbformat==5.10.4
|
| 12 |
+
jupyter
|
| 13 |
+
kiwisolver==1.4.5
|
| 14 |
+
lxml==5.1.0
|
| 15 |
+
matplotlib==3.7.4
|
| 16 |
+
networkx==3.1
|
| 17 |
+
open3d>=0.18.0
|
| 18 |
+
pyglet==2.0.10
|
| 19 |
+
pyopengl==3.1.0
|
| 20 |
+
pyparsing==3.1.1
|
| 21 |
+
pyrender==0.1.45
|
| 22 |
+
scipy>=1.10.1
|
| 23 |
+
shapely==2.0.2
|
| 24 |
+
smplx[all]
|
| 25 |
+
tqdm==4.66.1
|
| 26 |
+
trimesh==4.0.10
|
| 27 |
+
vtk==9.3.0
|
| 28 |
+
timm==1.0.11
|
| 29 |
+
potpourri3d
|
| 30 |
+
pythreejs==2.4.2
|
| 31 |
+
lapy
|
| 32 |
+
roma
|
| 33 |
+
webdataset
|
| 34 |
+
lmdb
|
| 35 |
+
robust_laplacian
|
| 36 |
+
termcolor
|
| 37 |
+
omegaconf
|
| 38 |
+
pykeops==2.2.3
|
| 39 |
+
scikit-image
|
| 40 |
+
pyfmaps
|
| 41 |
+
wandb
|
| 42 |
+
gradio==4.44.1
|
| 43 |
+
pydantic==2.10.6
|
| 44 |
+
traitlets==5.7.1
|
| 45 |
+
https://github.com/skoch9/meshplot/archive/0.4.0.tar.gz
|
shape_models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import *
|
| 2 |
+
from .geometry import *
|
| 3 |
+
from .layers import *
|
shape_models/encoder.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .layers import DiffusionNet
|
| 4 |
+
|
| 5 |
+
class Encoder(nn.Module):
|
| 6 |
+
def __init__(self, with_grad=True, key_verts="vertices"):
|
| 7 |
+
super(Encoder, self).__init__()
|
| 8 |
+
self.diff_net = DiffusionNet(
|
| 9 |
+
C_in=3,
|
| 10 |
+
C_out=512,
|
| 11 |
+
C_width=128,
|
| 12 |
+
N_block=4,
|
| 13 |
+
dropout=True,
|
| 14 |
+
with_gradient_features=with_grad,
|
| 15 |
+
with_gradient_rotations=with_grad,
|
| 16 |
+
)
|
| 17 |
+
self.key_verts = key_verts
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def forward(self, shape_dict):
|
| 21 |
+
feats = self.diff_net(shape_dict[self.key_verts], shape_dict["mass"], shape_dict["L"], evals=shape_dict["evals"],
|
| 22 |
+
evecs=shape_dict["evecs"], gradX=shape_dict["gradX"], gradY=shape_dict["gradY"], faces=shape_dict["faces"])
|
| 23 |
+
x_out = torch.max(feats, dim=0).values
|
| 24 |
+
return x_out
|
shape_models/fmap.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# feature extractor
|
| 7 |
+
from .layers import DiffusionNet
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_mask(evals1, evals2, gamma=0.5, device="cpu"):
|
| 12 |
+
scaling_factor = max(torch.max(evals1), torch.max(evals2))
|
| 13 |
+
evals1, evals2 = evals1.to(device) / scaling_factor, evals2.to(device) / scaling_factor
|
| 14 |
+
evals_gamma1, evals_gamma2 = (evals1 ** gamma)[None, :], (evals2 ** gamma)[:, None]
|
| 15 |
+
|
| 16 |
+
M_re = evals_gamma2 / (evals_gamma2.square() + 1) - evals_gamma1 / (evals_gamma1.square() + 1)
|
| 17 |
+
M_im = 1 / (evals_gamma2.square() + 1) - 1 / (evals_gamma1.square() + 1)
|
| 18 |
+
return M_re.square() + M_im.square()
|
| 19 |
+
|
| 20 |
+
def get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y):
|
| 21 |
+
# compute linear operator matrix representation C1 and C2
|
| 22 |
+
|
| 23 |
+
F_hat = torch.bmm(evecs_trans_x, feat_x)
|
| 24 |
+
G_hat = torch.bmm(evecs_trans_y, feat_y)
|
| 25 |
+
A, B = F_hat, G_hat
|
| 26 |
+
|
| 27 |
+
A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
|
| 28 |
+
A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
|
| 29 |
+
B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
|
| 30 |
+
|
| 31 |
+
C12 = torch.bmm(B_A_t, torch.inverse(A_A_t))
|
| 32 |
+
|
| 33 |
+
C21 = torch.bmm(A_B_t, torch.inverse(B_B_t))
|
| 34 |
+
return [C12, C21]
|
| 35 |
+
|
| 36 |
+
def get_mask_noise(C, diff_model, scale=1, N_est=200, device="cuda", normalize=True, absolute=False):
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
sig = torch.ones([1, 1, 1, 1], device=device) * scale
|
| 39 |
+
|
| 40 |
+
noise = torch.randn((N_est, 1, 30, 30), device=device)
|
| 41 |
+
if absolute:
|
| 42 |
+
#noisy_new = torch.abs(C[None, :, :] + noise *scale)
|
| 43 |
+
noisy_new = torch.abs(torch.abs(C[None, :, :]) + noise * scale)
|
| 44 |
+
else:
|
| 45 |
+
noisy_new = C[None, :, :] + noise *scale
|
| 46 |
+
denoised = diff_model.net(noisy_new, sig)
|
| 47 |
+
mask_squared = torch.mean(torch.abs(noisy_new - denoised)/(2*scale), dim=0)/torch.mean(torch.abs(noisy_new), dim=0)
|
| 48 |
+
mask_median = torch.median(mask_squared).item()
|
| 49 |
+
if normalize:
|
| 50 |
+
mask_median = torch.median(mask_squared).item()
|
| 51 |
+
M_denoised = torch.sqrt(torch.clamp(mask_squared-mask_median/2, 0, mask_median*2))
|
| 52 |
+
else:
|
| 53 |
+
M_denoised = torch.sqrt(mask_squared)
|
| 54 |
+
return M_denoised-M_denoised.min()#, mask_median*2)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RegularizedFMNet(nn.Module):
|
| 58 |
+
"""Compute the functional map matrix representation."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, lambda_=1e-3, resolvant_gamma=0.5, use_resolvent=False):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.lambda_ = lambda_
|
| 63 |
+
self.use_resolvent = use_resolvent
|
| 64 |
+
self.resolvant_gamma = resolvant_gamma
|
| 65 |
+
|
| 66 |
+
def forward(self, feat_x, feat_y, evals_x, evals_y, evecs_trans_x, evecs_trans_y, diff_conf=None):
|
| 67 |
+
# compute linear operator matrix representation C1 and C2
|
| 68 |
+
evecs_trans_x, evecs_trans_y = evecs_trans_x.unsqueeze(0), evecs_trans_y.unsqueeze(0)
|
| 69 |
+
evals_x, evals_y = evals_x.unsqueeze(0), evals_y.unsqueeze(0)
|
| 70 |
+
|
| 71 |
+
F_hat = torch.bmm(evecs_trans_x, feat_x)
|
| 72 |
+
G_hat = torch.bmm(evecs_trans_y, feat_y)
|
| 73 |
+
A, B = F_hat, G_hat
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if diff_conf is not None:
|
| 78 |
+
diff_model, scale, normalize, absolute, N_est = diff_conf
|
| 79 |
+
C12_raw, C21_raw = get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y)
|
| 80 |
+
D12 = get_mask_noise(C12_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
|
| 81 |
+
D21 = get_mask_noise(C21_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
|
| 82 |
+
elif self.use_resolvent:
|
| 83 |
+
D12 = get_mask(evals_x.flatten(), evals_y.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
|
| 84 |
+
D21 = get_mask(evals_y.flatten(), evals_x.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
|
| 85 |
+
else:
|
| 86 |
+
D12 = (torch.unsqueeze(evals_y, 2) - torch.unsqueeze(evals_x, 1))**2
|
| 87 |
+
D21 = (torch.unsqueeze(evals_x, 2) - torch.unsqueeze(evals_y, 1))**2
|
| 88 |
+
|
| 89 |
+
A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
|
| 90 |
+
A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
|
| 91 |
+
B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
|
| 92 |
+
|
| 93 |
+
C12_i = []
|
| 94 |
+
for i in range(evals_x.size(1)):
|
| 95 |
+
D12_i = torch.cat([torch.diag(D12[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0)
|
| 96 |
+
C12 = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D12_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2))
|
| 97 |
+
C12_i.append(C12.transpose(1, 2))
|
| 98 |
+
C12 = torch.cat(C12_i, dim=1)
|
| 99 |
+
|
| 100 |
+
C21_i = []
|
| 101 |
+
for i in range(evals_y.size(1)):
|
| 102 |
+
D21_i = torch.cat([torch.diag(D21[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_y.size(0))], dim=0)
|
| 103 |
+
C21 = torch.bmm(torch.inverse(B_B_t + self.lambda_ * D21_i), A_B_t[:, i, :].unsqueeze(1).transpose(1, 2))
|
| 104 |
+
C21_i.append(C21.transpose(1, 2))
|
| 105 |
+
C21 = torch.cat(C21_i, dim=1)
|
| 106 |
+
if diff_conf is not None:
|
| 107 |
+
return [C12_raw, C12, D12], [C21_raw, C21, D21]
|
| 108 |
+
else:
|
| 109 |
+
return [C12, C21]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DFMNet(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
Compilation of the global model :
|
| 115 |
+
- diffusion net as feature extractor
|
| 116 |
+
- fmap + q-fmap
|
| 117 |
+
- unsupervised loss
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, cfg):
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
# feature extractor #
|
| 124 |
+
with_grad=True
|
| 125 |
+
self.feat = cfg["feat"]
|
| 126 |
+
|
| 127 |
+
self.feature_extractor = DiffusionNet(
|
| 128 |
+
C_in=cfg["C_in"],
|
| 129 |
+
C_out=cfg["n_feat"],
|
| 130 |
+
C_width=128,
|
| 131 |
+
N_block=4,
|
| 132 |
+
dropout=True,
|
| 133 |
+
with_gradient_features=with_grad,
|
| 134 |
+
with_gradient_rotations=with_grad,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# regularized fmap
|
| 138 |
+
self.fmreg_net = RegularizedFMNet(lambda_=cfg["lambda_"],
|
| 139 |
+
resolvant_gamma=cfg.get("resolvent_gamma", 0.5), use_resolvent=cfg.get("use_resolvent", False))
|
| 140 |
+
# parameters
|
| 141 |
+
self.n_fmap = cfg["n_fmap"]
|
| 142 |
+
if cfg.get("diffusion", None) is not None:
|
| 143 |
+
self.normalize = cfg["diffusion"]["normalize"]
|
| 144 |
+
self.abs = cfg["diffusion"]["abs"]
|
| 145 |
+
self.N_est = cfg.diffusion.get("batch_mask", 100)
|
| 146 |
+
|
| 147 |
+
def forward(self, batch, diff_model=None, scale=1):
|
| 148 |
+
if self.feat == "xyz":
|
| 149 |
+
feat_1, feat_2 = batch["shape1"]["vertices"], batch["shape2"]["vertices"]
|
| 150 |
+
elif self.feat == "wks":
|
| 151 |
+
feat_1, feat_2 = batch["shape1"]["wks"], batch["shape2"]["wks"]
|
| 152 |
+
elif self.feat == "hks":
|
| 153 |
+
feat_1, feat_2 = batch["shape1"]["hks"], batch["shape2"]["hks"]
|
| 154 |
+
else:
|
| 155 |
+
raise Exception("Unknow Feature")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
verts1, faces1, mass1, L1, evals1, evecs1, gradX1, gradY1 = (feat_1, batch["shape1"]["faces"],
|
| 159 |
+
batch["shape1"]["mass"], batch["shape1"]["L"],
|
| 160 |
+
batch["shape1"]["evals"], batch["shape1"]["evecs"],
|
| 161 |
+
batch["shape1"]["gradX"], batch["shape1"]["gradY"])
|
| 162 |
+
verts2, faces2, mass2, L2, evals2, evecs2, gradX2, gradY2 = (feat_2, batch["shape2"]["faces"],
|
| 163 |
+
batch["shape2"]["mass"], batch["shape2"]["L"],
|
| 164 |
+
batch["shape2"]["evals"], batch["shape2"]["evecs"],
|
| 165 |
+
batch["shape2"]["gradX"], batch["shape2"]["gradY"])
|
| 166 |
+
|
| 167 |
+
# set features to vertices
|
| 168 |
+
features1, features2 = verts1, verts2
|
| 169 |
+
# print(features1.shape, features2.shape)
|
| 170 |
+
|
| 171 |
+
feat1 = self.feature_extractor(features1, mass1, L=L1, evals=evals1, evecs=evecs1,
|
| 172 |
+
gradX=gradX1, gradY=gradY1, faces=faces1).unsqueeze(0)
|
| 173 |
+
feat2 = self.feature_extractor(features2, mass2, L=L2, evals=evals2, evecs=evecs2,
|
| 174 |
+
gradX=gradX2, gradY=gradY2, faces=faces2).unsqueeze(0)
|
| 175 |
+
|
| 176 |
+
evecs_trans1, evecs_trans2 = evecs1.t()[:self.n_fmap] @ torch.diag(mass1.squeeze()), evecs2.squeeze().t()[:self.n_fmap].squeeze() @ torch.diag(mass2.squeeze())
|
| 177 |
+
evals1, evals2 = evals1[:self.n_fmap], evals2[:self.n_fmap]
|
| 178 |
+
|
| 179 |
+
#
|
| 180 |
+
if diff_model is not None:
|
| 181 |
+
C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2,
|
| 182 |
+
diff_conf=[diff_model, scale, self.normalize, self.abs, self.N_est])
|
| 183 |
+
else:
|
| 184 |
+
C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2)
|
| 185 |
+
|
| 186 |
+
return C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2
|
shape_models/geometry.py
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scipy
|
| 2 |
+
import scipy.sparse.linalg as sla
|
| 3 |
+
# ^^^ we NEED to import scipy before torch, or it crashes :(
|
| 4 |
+
# (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
|
| 5 |
+
|
| 6 |
+
import os.path
|
| 7 |
+
import sys
|
| 8 |
+
import random
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import scipy.spatial
|
| 13 |
+
import torch
|
| 14 |
+
import sklearn.neighbors
|
| 15 |
+
|
| 16 |
+
import robust_laplacian
|
| 17 |
+
import potpourri3d as pp3d
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def norm(x, highdim=False):
|
| 21 |
+
"""
|
| 22 |
+
Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 23 |
+
"""
|
| 24 |
+
return torch.norm(x, dim=len(x.shape) - 1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def norm2(x, highdim=False):
|
| 28 |
+
"""
|
| 29 |
+
Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 30 |
+
"""
|
| 31 |
+
return dot(x, x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def normalize(x, divide_eps=1e-6, highdim=False):
|
| 35 |
+
"""
|
| 36 |
+
Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 37 |
+
"""
|
| 38 |
+
if (len(x.shape) == 1):
|
| 39 |
+
raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
|
| 40 |
+
if (not highdim and x.shape[-1] > 4):
|
| 41 |
+
raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
|
| 42 |
+
return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def face_coords(verts, faces):
|
| 46 |
+
coords = verts[faces]
|
| 47 |
+
return coords
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def cross(vec_A, vec_B):
|
| 51 |
+
return torch.cross(vec_A, vec_B, dim=-1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def dot(vec_A, vec_B):
|
| 55 |
+
return torch.sum(vec_A * vec_B, dim=-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Given (..., 3) vectors and normals, projects out any components of vecs
|
| 59 |
+
# which lies in the direction of normals. Normals are assumed to be unit.
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def project_to_tangent(vecs, unit_normals):
|
| 63 |
+
dots = dot(vecs, unit_normals)
|
| 64 |
+
return vecs - unit_normals * dots.unsqueeze(-1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def face_area(verts, faces):
|
| 68 |
+
coords = face_coords(verts, faces)
|
| 69 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 70 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 71 |
+
|
| 72 |
+
raw_normal = cross(vec_A, vec_B)
|
| 73 |
+
return 0.5 * norm(raw_normal)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def face_normals(verts, faces, normalized=True):
|
| 77 |
+
coords = face_coords(verts, faces)
|
| 78 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 79 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 80 |
+
|
| 81 |
+
raw_normal = cross(vec_A, vec_B)
|
| 82 |
+
|
| 83 |
+
if normalized:
|
| 84 |
+
return normalize(raw_normal)
|
| 85 |
+
|
| 86 |
+
return raw_normal
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def neighborhood_normal(points):
|
| 90 |
+
# points: (N, K, 3) array of neighborhood psoitions
|
| 91 |
+
# points should be centered at origin
|
| 92 |
+
# out: (N,3) array of normals
|
| 93 |
+
# numpy in, numpy out
|
| 94 |
+
(u, s, vh) = np.linalg.svd(points, full_matrices=False)
|
| 95 |
+
normal = vh[:, 2, :]
|
| 96 |
+
return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def mesh_vertex_normals(verts, faces):
|
| 100 |
+
# numpy in / out
|
| 101 |
+
face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces))) # ugly torch <---> numpy
|
| 102 |
+
|
| 103 |
+
vertex_normals = np.zeros(verts.shape)
|
| 104 |
+
for i in range(3):
|
| 105 |
+
np.add.at(vertex_normals, faces[:, i], face_n)
|
| 106 |
+
|
| 107 |
+
vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=-1, keepdims=True)
|
| 108 |
+
|
| 109 |
+
return vertex_normals
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def vertex_normals(verts, faces, n_neighbors_cloud=30):
|
| 113 |
+
verts_np = toNP(verts)
|
| 114 |
+
|
| 115 |
+
if faces.numel() == 0: # point cloud
|
| 116 |
+
|
| 117 |
+
_, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
|
| 118 |
+
neigh_points = verts_np[neigh_inds, :]
|
| 119 |
+
neigh_points = neigh_points - verts_np[:, np.newaxis, :]
|
| 120 |
+
normals = neighborhood_normal(neigh_points)
|
| 121 |
+
|
| 122 |
+
else: # mesh
|
| 123 |
+
|
| 124 |
+
normals = mesh_vertex_normals(verts_np, toNP(faces))
|
| 125 |
+
|
| 126 |
+
# if any are NaN, wiggle slightly and recompute
|
| 127 |
+
bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
|
| 128 |
+
if bad_normals_mask.any():
|
| 129 |
+
bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
|
| 130 |
+
scale = np.linalg.norm(bbox) * 1e-4
|
| 131 |
+
wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
|
| 132 |
+
wiggle_verts = verts_np + bad_normals_mask * wiggle
|
| 133 |
+
normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
|
| 134 |
+
|
| 135 |
+
# if still NaN assign random normals (probably means unreferenced verts in mesh)
|
| 136 |
+
bad_normals_mask = np.isnan(normals).any(axis=1)
|
| 137 |
+
if bad_normals_mask.any():
|
| 138 |
+
normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
|
| 139 |
+
normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
|
| 140 |
+
|
| 141 |
+
normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
|
| 142 |
+
|
| 143 |
+
if torch.any(torch.isnan(normals)):
|
| 144 |
+
raise ValueError("NaN normals :(")
|
| 145 |
+
|
| 146 |
+
return normals
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def build_tangent_frames(verts, faces, normals=None):
|
| 150 |
+
|
| 151 |
+
V = verts.shape[0]
|
| 152 |
+
dtype = verts.dtype
|
| 153 |
+
device = verts.device
|
| 154 |
+
|
| 155 |
+
if normals == None:
|
| 156 |
+
vert_normals = vertex_normals(verts, faces) # (V,3)
|
| 157 |
+
else:
|
| 158 |
+
vert_normals = normals
|
| 159 |
+
|
| 160 |
+
# = find an orthogonal basis
|
| 161 |
+
|
| 162 |
+
basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
|
| 163 |
+
basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
|
| 164 |
+
|
| 165 |
+
basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
|
| 166 |
+
basisX = project_to_tangent(basisX, vert_normals)
|
| 167 |
+
basisX = normalize(basisX)
|
| 168 |
+
basisY = cross(vert_normals, basisX)
|
| 169 |
+
frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
|
| 170 |
+
|
| 171 |
+
if torch.any(torch.isnan(frames)):
|
| 172 |
+
raise ValueError("NaN coordinate frame! Must be very degenerate")
|
| 173 |
+
|
| 174 |
+
return frames
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
|
| 178 |
+
|
| 179 |
+
verts_np = toNP(verts)
|
| 180 |
+
frames_np = toNP(frames)
|
| 181 |
+
|
| 182 |
+
_, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
|
| 183 |
+
neigh_points = verts_np[neigh_inds, :]
|
| 184 |
+
neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
|
| 185 |
+
|
| 186 |
+
# TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
|
| 187 |
+
|
| 188 |
+
edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
|
| 189 |
+
edges = np.stack((edge_inds_from, neigh_inds.flatten()))
|
| 190 |
+
edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
|
| 191 |
+
|
| 192 |
+
return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def edge_tangent_vectors(verts, frames, edges):
|
| 196 |
+
edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
|
| 197 |
+
basisX = frames[edges[0, :], 0, :]
|
| 198 |
+
basisY = frames[edges[0, :], 1, :]
|
| 199 |
+
|
| 200 |
+
compX = dot(edge_vecs, basisX)
|
| 201 |
+
compY = dot(edge_vecs, basisY)
|
| 202 |
+
edge_tangent = torch.stack((compX, compY), dim=-1)
|
| 203 |
+
|
| 204 |
+
return edge_tangent
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def build_grad(verts, edges, edge_tangent_vectors):
|
| 208 |
+
"""
|
| 209 |
+
Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
|
| 210 |
+
- edges: (2, E)
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
edges_np = toNP(edges)
|
| 214 |
+
edge_tangent_vectors_np = toNP(edge_tangent_vectors)
|
| 215 |
+
|
| 216 |
+
# TODO find a way to do this in pure numpy?
|
| 217 |
+
|
| 218 |
+
# Build outgoing neighbor lists
|
| 219 |
+
N = verts.shape[0]
|
| 220 |
+
vert_edge_outgoing = [[] for i in range(N)]
|
| 221 |
+
for iE in range(edges_np.shape[1]):
|
| 222 |
+
tail_ind = edges_np[0, iE]
|
| 223 |
+
tip_ind = edges_np[1, iE]
|
| 224 |
+
if tip_ind != tail_ind:
|
| 225 |
+
vert_edge_outgoing[tail_ind].append(iE)
|
| 226 |
+
|
| 227 |
+
# Build local inversion matrix for each vertex
|
| 228 |
+
row_inds = []
|
| 229 |
+
col_inds = []
|
| 230 |
+
data_vals = []
|
| 231 |
+
eps_reg = 1e-5
|
| 232 |
+
for iV in range(N):
|
| 233 |
+
n_neigh = len(vert_edge_outgoing[iV])
|
| 234 |
+
|
| 235 |
+
lhs_mat = np.zeros((n_neigh, 2))
|
| 236 |
+
rhs_mat = np.zeros((n_neigh, n_neigh + 1))
|
| 237 |
+
ind_lookup = [iV]
|
| 238 |
+
for i_neigh in range(n_neigh):
|
| 239 |
+
iE = vert_edge_outgoing[iV][i_neigh]
|
| 240 |
+
jV = edges_np[1, iE]
|
| 241 |
+
ind_lookup.append(jV)
|
| 242 |
+
|
| 243 |
+
edge_vec = edge_tangent_vectors[iE][:]
|
| 244 |
+
w_e = 1.
|
| 245 |
+
|
| 246 |
+
lhs_mat[i_neigh][:] = w_e * edge_vec
|
| 247 |
+
rhs_mat[i_neigh][0] = w_e * (-1)
|
| 248 |
+
rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
|
| 249 |
+
|
| 250 |
+
lhs_T = lhs_mat.T
|
| 251 |
+
lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
|
| 252 |
+
|
| 253 |
+
sol_mat = lhs_inv @ rhs_mat
|
| 254 |
+
sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
|
| 255 |
+
|
| 256 |
+
for i_neigh in range(n_neigh + 1):
|
| 257 |
+
i_glob = ind_lookup[i_neigh]
|
| 258 |
+
|
| 259 |
+
row_inds.append(iV)
|
| 260 |
+
col_inds.append(i_glob)
|
| 261 |
+
data_vals.append(sol_coefs[i_neigh])
|
| 262 |
+
|
| 263 |
+
# build the sparse matrix
|
| 264 |
+
row_inds = np.array(row_inds)
|
| 265 |
+
col_inds = np.array(col_inds)
|
| 266 |
+
data_vals = np.array(data_vals)
|
| 267 |
+
mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
|
| 268 |
+
|
| 269 |
+
return mat
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def compute_operators(verts, faces, k_eig, normals=None):
|
| 273 |
+
"""
|
| 274 |
+
Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
|
| 275 |
+
|
| 276 |
+
See get_operators() for a similar routine that wraps this one with a layer of caching.
|
| 277 |
+
|
| 278 |
+
Torch in / torch out.
|
| 279 |
+
|
| 280 |
+
Arguments:
|
| 281 |
+
- vertices: (V,3) vertex positions
|
| 282 |
+
- faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
|
| 283 |
+
- k_eig: number of eigenvectors to use
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
- frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
|
| 287 |
+
- massvec: (V) real diagonal of lumped mass matrix
|
| 288 |
+
- L: (VxV) real sparse matrix of (weak) Laplacian
|
| 289 |
+
- evals: (k) list of eigenvalues of the Laplacian
|
| 290 |
+
- evecs: (V,k) list of eigenvectors of the Laplacian
|
| 291 |
+
- gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
|
| 292 |
+
- gradY: same as gradX but for Y-component of gradient
|
| 293 |
+
|
| 294 |
+
PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
|
| 295 |
+
|
| 296 |
+
Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
device = verts.device
|
| 300 |
+
dtype = verts.dtype
|
| 301 |
+
V = verts.shape[0]
|
| 302 |
+
is_cloud = faces.numel() == 0
|
| 303 |
+
|
| 304 |
+
eps = 1e-8
|
| 305 |
+
|
| 306 |
+
verts_np = toNP(verts).astype(np.float64)
|
| 307 |
+
faces_np = toNP(faces)
|
| 308 |
+
frames = build_tangent_frames(verts, faces, normals=normals)
|
| 309 |
+
frames_np = toNP(frames)
|
| 310 |
+
|
| 311 |
+
# Build the scalar Laplacian
|
| 312 |
+
if is_cloud:
|
| 313 |
+
L, M = robust_laplacian.point_cloud_laplacian(verts_np)
|
| 314 |
+
massvec_np = M.diagonal()
|
| 315 |
+
else:
|
| 316 |
+
# L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
|
| 317 |
+
# massvec_np = M.diagonal()
|
| 318 |
+
L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
|
| 319 |
+
massvec_np = pp3d.vertex_areas(verts_np, faces_np)
|
| 320 |
+
massvec_np += eps * np.mean(massvec_np)
|
| 321 |
+
|
| 322 |
+
if (np.isnan(L.data).any()):
|
| 323 |
+
raise RuntimeError("NaN Laplace matrix")
|
| 324 |
+
if (np.isnan(massvec_np).any()):
|
| 325 |
+
raise RuntimeError("NaN mass matrix")
|
| 326 |
+
|
| 327 |
+
# Read off neighbors & rotations from the Laplacian
|
| 328 |
+
L_coo = L.tocoo()
|
| 329 |
+
inds_row = L_coo.row
|
| 330 |
+
inds_col = L_coo.col
|
| 331 |
+
|
| 332 |
+
# === Compute the eigenbasis
|
| 333 |
+
if k_eig > 0:
|
| 334 |
+
|
| 335 |
+
# Prepare matrices
|
| 336 |
+
L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
|
| 337 |
+
massvec_eigsh = massvec_np
|
| 338 |
+
Mmat = scipy.sparse.diags(massvec_eigsh)
|
| 339 |
+
eigs_sigma = eps
|
| 340 |
+
|
| 341 |
+
failcount = 0
|
| 342 |
+
while True:
|
| 343 |
+
try:
|
| 344 |
+
# We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
|
| 345 |
+
evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
|
| 346 |
+
|
| 347 |
+
# Clip off any eigenvalues that end up slightly negative due to numerical weirdness
|
| 348 |
+
evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
|
| 349 |
+
|
| 350 |
+
break
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(e)
|
| 353 |
+
if (failcount > 3):
|
| 354 |
+
raise ValueError("failed to compute eigendecomp")
|
| 355 |
+
failcount += 1
|
| 356 |
+
print("--- decomp failed; adding eps ===> count: " + str(failcount))
|
| 357 |
+
L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
|
| 358 |
+
|
| 359 |
+
else: #k_eig == 0
|
| 360 |
+
evals_np = np.zeros((0))
|
| 361 |
+
evecs_np = np.zeros((verts.shape[0], 0))
|
| 362 |
+
|
| 363 |
+
# == Build gradient matrices
|
| 364 |
+
|
| 365 |
+
# For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
|
| 366 |
+
if is_cloud:
|
| 367 |
+
grad_mat_np = build_grad_point_cloud(verts, frames)
|
| 368 |
+
else:
|
| 369 |
+
edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
|
| 370 |
+
edge_vecs = edge_tangent_vectors(verts, frames, edges)
|
| 371 |
+
grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
|
| 372 |
+
|
| 373 |
+
# Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
|
| 374 |
+
gradX_np = np.real(grad_mat_np)
|
| 375 |
+
gradY_np = np.imag(grad_mat_np)
|
| 376 |
+
|
| 377 |
+
# === Convert back to torch
|
| 378 |
+
massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
|
| 379 |
+
L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 380 |
+
evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
|
| 381 |
+
evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
|
| 382 |
+
gradX = utils.sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
|
| 383 |
+
gradY = utils.sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
|
| 384 |
+
|
| 385 |
+
return frames, massvec, L, evals, evecs, gradX, gradY
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
|
| 389 |
+
N = len(verts_list)
|
| 390 |
+
|
| 391 |
+
frames = [None] * N
|
| 392 |
+
massvec = [None] * N
|
| 393 |
+
L = [None] * N
|
| 394 |
+
evals = [None] * N
|
| 395 |
+
evecs = [None] * N
|
| 396 |
+
gradX = [None] * N
|
| 397 |
+
gradY = [None] * N
|
| 398 |
+
|
| 399 |
+
inds = [i for i in range(N)]
|
| 400 |
+
# process in random order
|
| 401 |
+
# random.shuffle(inds)
|
| 402 |
+
|
| 403 |
+
for num, i in enumerate(inds):
|
| 404 |
+
print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
|
| 405 |
+
if normals is None:
|
| 406 |
+
outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
|
| 407 |
+
else:
|
| 408 |
+
outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
|
| 409 |
+
frames[i] = outputs[0]
|
| 410 |
+
massvec[i] = outputs[1]
|
| 411 |
+
L[i] = outputs[2]
|
| 412 |
+
evals[i] = outputs[3]
|
| 413 |
+
evecs[i] = outputs[4]
|
| 414 |
+
gradX[i] = outputs[5]
|
| 415 |
+
gradY[i] = outputs[6]
|
| 416 |
+
|
| 417 |
+
return frames, massvec, L, evals, evecs, gradX, gradY
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def get_operators(verts, faces, k_eig=128, op_cache_dir=None, normals=None, overwrite_cache=False):
|
| 421 |
+
"""
|
| 422 |
+
See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
|
| 423 |
+
All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
device = verts.device
|
| 427 |
+
dtype = verts.dtype
|
| 428 |
+
verts_np = toNP(verts)
|
| 429 |
+
faces_np = toNP(faces)
|
| 430 |
+
is_cloud = faces.numel() == 0
|
| 431 |
+
|
| 432 |
+
if (np.isnan(verts_np).any()):
|
| 433 |
+
raise RuntimeError("tried to construct operators from NaN verts")
|
| 434 |
+
|
| 435 |
+
# Check the cache directory
|
| 436 |
+
# Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
|
| 437 |
+
# but for good measure we check values nonetheless.
|
| 438 |
+
# Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
|
| 439 |
+
# entries in this cache. The good news is that that is totally fine, and at most slightly
|
| 440 |
+
# slows performance with rare extra cache misses.
|
| 441 |
+
found = False
|
| 442 |
+
if op_cache_dir is not None:
|
| 443 |
+
utils.ensure_dir_exists(op_cache_dir)
|
| 444 |
+
hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
|
| 445 |
+
# print("Building operators for input with hash: " + hash_key_str)
|
| 446 |
+
|
| 447 |
+
# Search through buckets with matching hashes. When the loop exits, this
|
| 448 |
+
# is the bucket index of the file we should write to.
|
| 449 |
+
i_cache_search = 0
|
| 450 |
+
while True:
|
| 451 |
+
|
| 452 |
+
# Form the name of the file to check
|
| 453 |
+
search_path = os.path.join(op_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
|
| 454 |
+
|
| 455 |
+
try:
|
| 456 |
+
# print('loading path: ' + str(search_path))
|
| 457 |
+
npzfile = np.load(search_path, allow_pickle=True)
|
| 458 |
+
cache_verts = npzfile["verts"]
|
| 459 |
+
cache_faces = npzfile["faces"]
|
| 460 |
+
cache_k_eig = npzfile["k_eig"].item()
|
| 461 |
+
|
| 462 |
+
# If the cache doesn't match, keep looking
|
| 463 |
+
if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
|
| 464 |
+
i_cache_search += 1
|
| 465 |
+
print("hash collision! searching next.")
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
# print(" cache hit!")
|
| 469 |
+
|
| 470 |
+
# If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
|
| 471 |
+
# entry below more eigenvalues
|
| 472 |
+
if overwrite_cache:
|
| 473 |
+
print(" overwriting cache by request")
|
| 474 |
+
os.remove(search_path)
|
| 475 |
+
break
|
| 476 |
+
|
| 477 |
+
if cache_k_eig < k_eig:
|
| 478 |
+
print(" overwriting cache --- not enough eigenvalues")
|
| 479 |
+
os.remove(search_path)
|
| 480 |
+
break
|
| 481 |
+
|
| 482 |
+
if "L_data" not in npzfile:
|
| 483 |
+
print(" overwriting cache --- entries are absent")
|
| 484 |
+
os.remove(search_path)
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
def read_sp_mat(prefix):
|
| 488 |
+
data = npzfile[prefix + "_data"]
|
| 489 |
+
indices = npzfile[prefix + "_indices"]
|
| 490 |
+
indptr = npzfile[prefix + "_indptr"]
|
| 491 |
+
shape = npzfile[prefix + "_shape"]
|
| 492 |
+
mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
|
| 493 |
+
return mat
|
| 494 |
+
|
| 495 |
+
# This entry matches! Return it.
|
| 496 |
+
frames = npzfile["frames"]
|
| 497 |
+
mass = npzfile["mass"]
|
| 498 |
+
L = read_sp_mat("L")
|
| 499 |
+
evals = npzfile["evals"][:k_eig]
|
| 500 |
+
evecs = npzfile["evecs"][:, :k_eig]
|
| 501 |
+
gradX = read_sp_mat("gradX")
|
| 502 |
+
gradY = read_sp_mat("gradY")
|
| 503 |
+
|
| 504 |
+
frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
|
| 505 |
+
mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
|
| 506 |
+
L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 507 |
+
evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
|
| 508 |
+
evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
|
| 509 |
+
gradX = utils.sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
|
| 510 |
+
gradY = utils.sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
|
| 511 |
+
|
| 512 |
+
found = True
|
| 513 |
+
|
| 514 |
+
break
|
| 515 |
+
|
| 516 |
+
except FileNotFoundError:
|
| 517 |
+
print(" cache miss -- constructing operators")
|
| 518 |
+
break
|
| 519 |
+
|
| 520 |
+
except Exception as E:
|
| 521 |
+
print("unexpected error loading file: " + str(E))
|
| 522 |
+
print("-- constructing operators")
|
| 523 |
+
break
|
| 524 |
+
|
| 525 |
+
if not found:
|
| 526 |
+
|
| 527 |
+
# No matching entry found; recompute.
|
| 528 |
+
frames, mass, L, evals, evecs, gradX, gradY = compute_operators(verts, faces, k_eig, normals=normals)
|
| 529 |
+
|
| 530 |
+
dtype_np = np.float32
|
| 531 |
+
|
| 532 |
+
# Store it in the cache
|
| 533 |
+
if op_cache_dir is not None:
|
| 534 |
+
|
| 535 |
+
L_np = utils.sparse_torch_to_np(L).astype(dtype_np)
|
| 536 |
+
gradX_np = utils.sparse_torch_to_np(gradX).astype(dtype_np)
|
| 537 |
+
gradY_np = utils.sparse_torch_to_np(gradY).astype(dtype_np)
|
| 538 |
+
|
| 539 |
+
np.savez(
|
| 540 |
+
search_path,
|
| 541 |
+
verts=verts_np.astype(dtype_np),
|
| 542 |
+
frames=toNP(frames).astype(dtype_np),
|
| 543 |
+
faces=faces_np,
|
| 544 |
+
k_eig=k_eig,
|
| 545 |
+
mass=toNP(mass).astype(dtype_np),
|
| 546 |
+
L_data=L_np.data.astype(dtype_np),
|
| 547 |
+
L_indices=L_np.indices,
|
| 548 |
+
L_indptr=L_np.indptr,
|
| 549 |
+
L_shape=L_np.shape,
|
| 550 |
+
evals=toNP(evals).astype(dtype_np),
|
| 551 |
+
evecs=toNP(evecs).astype(dtype_np),
|
| 552 |
+
gradX_data=gradX_np.data.astype(dtype_np),
|
| 553 |
+
gradX_indices=gradX_np.indices,
|
| 554 |
+
gradX_indptr=gradX_np.indptr,
|
| 555 |
+
gradX_shape=gradX_np.shape,
|
| 556 |
+
gradY_data=gradY_np.data.astype(dtype_np),
|
| 557 |
+
gradY_indices=gradY_np.indices,
|
| 558 |
+
gradY_indptr=gradY_np.indptr,
|
| 559 |
+
gradY_shape=gradY_np.shape,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
return frames, mass, L, evals, evecs, gradX, gradY
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def to_basis(values, basis, massvec):
|
| 566 |
+
"""
|
| 567 |
+
Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
|
| 568 |
+
Inputs:
|
| 569 |
+
- values: (B,V,D)
|
| 570 |
+
- basis: (B,V,K)
|
| 571 |
+
- massvec: (B,V)
|
| 572 |
+
Outputs:
|
| 573 |
+
- (B,K,D) transformed values
|
| 574 |
+
"""
|
| 575 |
+
basisT = basis.transpose(-2, -1)
|
| 576 |
+
return torch.matmul(basisT, values * massvec.unsqueeze(-1))
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def from_basis(values, basis):
|
| 580 |
+
"""
|
| 581 |
+
Transform data out of an orthonormal basis
|
| 582 |
+
Inputs:
|
| 583 |
+
- values: (K,D)
|
| 584 |
+
- basis: (V,K)
|
| 585 |
+
Outputs:
|
| 586 |
+
- (V,D) reconstructed values
|
| 587 |
+
"""
|
| 588 |
+
if values.is_complex() or basis.is_complex():
|
| 589 |
+
return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
|
| 590 |
+
else:
|
| 591 |
+
return torch.matmul(basis, values)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def compute_hks(evals, evecs, scales):
|
| 595 |
+
"""
|
| 596 |
+
Inputs:
|
| 597 |
+
- evals: (K) eigenvalues
|
| 598 |
+
- evecs: (V,K) values
|
| 599 |
+
- scales: (S) times
|
| 600 |
+
Outputs:
|
| 601 |
+
- (V,S) hks values
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
# expand batch
|
| 605 |
+
if len(evals.shape) == 1:
|
| 606 |
+
expand_batch = True
|
| 607 |
+
evals = evals.unsqueeze(0)
|
| 608 |
+
evecs = evecs.unsqueeze(0)
|
| 609 |
+
scales = scales.unsqueeze(0)
|
| 610 |
+
else:
|
| 611 |
+
expand_batch = False
|
| 612 |
+
|
| 613 |
+
# TODO could be a matmul
|
| 614 |
+
power_coefs = torch.exp(-evals.unsqueeze(1) * scales.unsqueeze(-1)).unsqueeze(1) # (B,1,S,K)
|
| 615 |
+
terms = power_coefs * (evecs * evecs).unsqueeze(2) # (B,V,S,K)
|
| 616 |
+
|
| 617 |
+
out = torch.sum(terms, dim=-1) # (B,V,S)
|
| 618 |
+
|
| 619 |
+
if expand_batch:
|
| 620 |
+
return out.squeeze(0)
|
| 621 |
+
else:
|
| 622 |
+
return out
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def compute_hks_autoscale(evals, evecs, count):
|
| 626 |
+
# these scales roughly approximate those suggested in the hks paper
|
| 627 |
+
scales = torch.logspace(-2, 0., steps=count, device=evals.device, dtype=evals.dtype)
|
| 628 |
+
return compute_hks(evals, evecs, scales)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
|
| 632 |
+
# center and unit-scale positions
|
| 633 |
+
|
| 634 |
+
if method == 'mean':
|
| 635 |
+
# center using the average point position
|
| 636 |
+
pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
|
| 637 |
+
elif method == 'bbox':
|
| 638 |
+
# center via the middle of the axis-aligned bounding box
|
| 639 |
+
bbox_min = torch.min(pos, dim=-2).values
|
| 640 |
+
bbox_max = torch.max(pos, dim=-2).values
|
| 641 |
+
center = (bbox_max + bbox_min) / 2.
|
| 642 |
+
pos -= center.unsqueeze(-2)
|
| 643 |
+
else:
|
| 644 |
+
raise ValueError("unrecognized method")
|
| 645 |
+
|
| 646 |
+
if scale_method == 'max_rad':
|
| 647 |
+
scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
|
| 648 |
+
pos = pos / scale
|
| 649 |
+
elif scale_method == 'area':
|
| 650 |
+
if faces is None:
|
| 651 |
+
raise ValueError("must pass faces for area normalization")
|
| 652 |
+
coords = pos[faces]
|
| 653 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 654 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 655 |
+
face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
|
| 656 |
+
total_area = torch.sum(face_areas)
|
| 657 |
+
scale = (1. / torch.sqrt(total_area))
|
| 658 |
+
pos = pos * scale
|
| 659 |
+
else:
|
| 660 |
+
raise ValueError("unrecognized scale method")
|
| 661 |
+
return pos
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# Finds the k nearest neighbors of source on target.
|
| 665 |
+
# Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
|
| 666 |
+
def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
|
| 667 |
+
|
| 668 |
+
if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
|
| 669 |
+
raise ValueError("omit_diagonal can only be used when source and target are same shape")
|
| 670 |
+
|
| 671 |
+
if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
|
| 672 |
+
method = 'cpu_kd'
|
| 673 |
+
print("switching to cpu_kd knn")
|
| 674 |
+
|
| 675 |
+
if method == 'brute':
|
| 676 |
+
|
| 677 |
+
# Expand so both are NxMx3 tensor
|
| 678 |
+
points_source_expand = points_source.unsqueeze(1)
|
| 679 |
+
points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
|
| 680 |
+
points_target_expand = points_target.unsqueeze(0)
|
| 681 |
+
points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
|
| 682 |
+
|
| 683 |
+
diff_mat = points_source_expand - points_target_expand
|
| 684 |
+
dist_mat = norm(diff_mat)
|
| 685 |
+
|
| 686 |
+
if omit_diagonal:
|
| 687 |
+
torch.diagonal(dist_mat)[:] = float('inf')
|
| 688 |
+
|
| 689 |
+
result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
|
| 690 |
+
return result
|
| 691 |
+
|
| 692 |
+
elif method == 'cpu_kd':
|
| 693 |
+
|
| 694 |
+
if largest:
|
| 695 |
+
raise ValueError("can't do largest with cpu_kd")
|
| 696 |
+
|
| 697 |
+
points_source_np = toNP(points_source)
|
| 698 |
+
points_target_np = toNP(points_target)
|
| 699 |
+
|
| 700 |
+
# Build the tree
|
| 701 |
+
kd_tree = sklearn.neighbors.KDTree(points_target_np)
|
| 702 |
+
|
| 703 |
+
k_search = k + 1 if omit_diagonal else k
|
| 704 |
+
_, neighbors = kd_tree.query(points_source_np, k=k_search)
|
| 705 |
+
|
| 706 |
+
if omit_diagonal:
|
| 707 |
+
# Mask out self element
|
| 708 |
+
mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
|
| 709 |
+
|
| 710 |
+
# make sure we mask out exactly one element in each row, in rare case of many duplicate points
|
| 711 |
+
mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
|
| 712 |
+
|
| 713 |
+
neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
|
| 714 |
+
|
| 715 |
+
inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
|
| 716 |
+
dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
|
| 717 |
+
|
| 718 |
+
return dists, inds
|
| 719 |
+
|
| 720 |
+
else:
|
| 721 |
+
raise ValueError("unrecognized method")
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def farthest_point_sampling(points, n_sample):
|
| 725 |
+
# Torch in, torch out. Returns a |V| mask with n_sample elements set to true.
|
| 726 |
+
|
| 727 |
+
N = points.shape[0]
|
| 728 |
+
if (n_sample > N):
|
| 729 |
+
raise ValueError("not enough points to sample")
|
| 730 |
+
|
| 731 |
+
chosen_mask = torch.zeros(N, dtype=torch.bool, device=points.device)
|
| 732 |
+
min_dists = torch.ones(N, dtype=points.dtype, device=points.device) * float('inf')
|
| 733 |
+
|
| 734 |
+
# pick the centermost first point
|
| 735 |
+
points = normalize_positions(points)
|
| 736 |
+
i = torch.min(norm2(points), dim=0).indices
|
| 737 |
+
chosen_mask[i] = True
|
| 738 |
+
|
| 739 |
+
for _ in range(n_sample - 1):
|
| 740 |
+
|
| 741 |
+
# update distance
|
| 742 |
+
dists = norm2(points[i, :].unsqueeze(0) - points)
|
| 743 |
+
min_dists = torch.minimum(dists, min_dists)
|
| 744 |
+
|
| 745 |
+
# take the farthest
|
| 746 |
+
i = torch.max(min_dists, dim=0).indices.item()
|
| 747 |
+
chosen_mask[i] = True
|
| 748 |
+
|
| 749 |
+
return chosen_mask
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def geodesic_label_errors(target_verts,
|
| 753 |
+
target_faces,
|
| 754 |
+
pred_labels,
|
| 755 |
+
gt_labels,
|
| 756 |
+
normalization='diameter',
|
| 757 |
+
geodesic_cache_dir=None):
|
| 758 |
+
"""
|
| 759 |
+
Return a vector of distances between predicted and ground-truth lables (normalized by geodesic diameter or area)
|
| 760 |
+
|
| 761 |
+
This method is SLOW when it needs to recompute geodesic distances.
|
| 762 |
+
"""
|
| 763 |
+
|
| 764 |
+
# move all to numpy cpu
|
| 765 |
+
target_verts = toNP(target_verts)
|
| 766 |
+
target_faces = toNP(target_faces)
|
| 767 |
+
|
| 768 |
+
pred_labels = toNP(pred_labels)
|
| 769 |
+
gt_labels = toNP(gt_labels)
|
| 770 |
+
|
| 771 |
+
dists = get_all_pairs_geodesic_distance(target_verts, target_faces, geodesic_cache_dir)
|
| 772 |
+
|
| 773 |
+
result_dists = dists[pred_labels, gt_labels]
|
| 774 |
+
|
| 775 |
+
if normalization == 'diameter':
|
| 776 |
+
geodesic_diameter = np.max(dists)
|
| 777 |
+
normalized_result_dists = result_dists / geodesic_diameter
|
| 778 |
+
elif normalization == 'area':
|
| 779 |
+
total_area = torch.sum(face_area(torch.tensor(target_verts), torch.tensor(target_faces)))
|
| 780 |
+
normalized_result_dists = result_dists / torch.sqrt(total_area)
|
| 781 |
+
else:
|
| 782 |
+
raise ValueError('unrecognized normalization')
|
| 783 |
+
|
| 784 |
+
return normalized_result_dists
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# This function and the helper class below are to support parallel computation of all-pairs geodesic distance
|
| 788 |
+
def all_pairs_geodesic_worker(verts, faces, i):
|
| 789 |
+
import igl
|
| 790 |
+
|
| 791 |
+
N = verts.shape[0]
|
| 792 |
+
|
| 793 |
+
# TODO: this re-does a ton of work, since it is called independently each time. Some custom C++ code could surely make it faster.
|
| 794 |
+
sources = np.array([i])[:, np.newaxis]
|
| 795 |
+
targets = np.arange(N)[:, np.newaxis]
|
| 796 |
+
dist_vec = igl.exact_geodesic(verts, faces, sources, targets)
|
| 797 |
+
|
| 798 |
+
return dist_vec
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class AllPairsGeodesicEngine(object):
|
| 802 |
+
|
| 803 |
+
def __init__(self, verts, faces):
|
| 804 |
+
self.verts = verts
|
| 805 |
+
self.faces = faces
|
| 806 |
+
|
| 807 |
+
def __call__(self, i):
|
| 808 |
+
return all_pairs_geodesic_worker(self.verts, self.faces, i)
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def get_all_pairs_geodesic_distance(verts_np, faces_np, geodesic_cache_dir=None):
|
| 812 |
+
"""
|
| 813 |
+
Return a gigantic VxV dense matrix containing the all-pairs geodesic distance matrix. Internally caches, recomputing only if necessary.
|
| 814 |
+
|
| 815 |
+
(numpy in, numpy out)
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
# need libigl for geodesic call
|
| 819 |
+
try:
|
| 820 |
+
import igl
|
| 821 |
+
except ImportError as e:
|
| 822 |
+
raise ImportError("Must have python libigl installed for all-pairs geodesics. `conda install -c conda-forge igl`")
|
| 823 |
+
|
| 824 |
+
# Check the cache
|
| 825 |
+
found = False
|
| 826 |
+
if geodesic_cache_dir is not None:
|
| 827 |
+
utils.ensure_dir_exists(geodesic_cache_dir)
|
| 828 |
+
hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
|
| 829 |
+
# print("Building operators for input with hash: " + hash_key_str)
|
| 830 |
+
|
| 831 |
+
# Search through buckets with matching hashes. When the loop exits, this
|
| 832 |
+
# is the bucket index of the file we should write to.
|
| 833 |
+
i_cache_search = 0
|
| 834 |
+
while True:
|
| 835 |
+
|
| 836 |
+
# Form the name of the file to check
|
| 837 |
+
search_path = os.path.join(geodesic_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
|
| 838 |
+
|
| 839 |
+
try:
|
| 840 |
+
npzfile = np.load(search_path, allow_pickle=True)
|
| 841 |
+
cache_verts = npzfile["verts"]
|
| 842 |
+
cache_faces = npzfile["faces"]
|
| 843 |
+
|
| 844 |
+
# If the cache doesn't match, keep looking
|
| 845 |
+
if (not np.array_equal(verts_np, cache_verts)) or (not np.array_equal(faces_np, cache_faces)):
|
| 846 |
+
i_cache_search += 1
|
| 847 |
+
continue
|
| 848 |
+
|
| 849 |
+
# This entry matches! Return it.
|
| 850 |
+
found = True
|
| 851 |
+
result_dists = npzfile["dist"]
|
| 852 |
+
break
|
| 853 |
+
|
| 854 |
+
except FileNotFoundError:
|
| 855 |
+
break
|
| 856 |
+
|
| 857 |
+
if not found:
|
| 858 |
+
|
| 859 |
+
print("Computing all-pairs geodesic distance (warning: SLOW!)")
|
| 860 |
+
|
| 861 |
+
# Not found, compute from scratch
|
| 862 |
+
# warning: slowwwwwww
|
| 863 |
+
|
| 864 |
+
N = verts_np.shape[0]
|
| 865 |
+
|
| 866 |
+
try:
|
| 867 |
+
pool = Pool(None) # on 8 processors
|
| 868 |
+
engine = AllPairsGeodesicEngine(verts_np, faces_np)
|
| 869 |
+
outputs = pool.map(engine, range(N))
|
| 870 |
+
finally: # To make sure processes are closed in the end, even if errors happen
|
| 871 |
+
pool.close()
|
| 872 |
+
pool.join()
|
| 873 |
+
|
| 874 |
+
result_dists = np.array(outputs)
|
| 875 |
+
|
| 876 |
+
# replace any failed values with nan
|
| 877 |
+
result_dists = np.nan_to_num(result_dists, nan=np.nan, posinf=np.nan, neginf=np.nan)
|
| 878 |
+
|
| 879 |
+
# we expect that this should be a symmetric matrix, but it might not be. Take the min of the symmetric values to make it symmetric
|
| 880 |
+
result_dists = np.fmin(result_dists, np.transpose(result_dists))
|
| 881 |
+
|
| 882 |
+
# on rare occaisions MMP fails, yielding nan/inf; set it to the largest non-failed value if so
|
| 883 |
+
max_dist = np.nanmax(result_dists)
|
| 884 |
+
result_dists = np.nan_to_num(result_dists, nan=max_dist, posinf=max_dist, neginf=max_dist)
|
| 885 |
+
|
| 886 |
+
print("...finished computing all-pairs geodesic distance")
|
| 887 |
+
|
| 888 |
+
# put it in the cache if possible
|
| 889 |
+
if geodesic_cache_dir is not None:
|
| 890 |
+
|
| 891 |
+
print("saving geodesic distances to cache: " + str(geodesic_cache_dir))
|
| 892 |
+
|
| 893 |
+
# TODO we're potentially saving a double precision but only using a single
|
| 894 |
+
# precision here; could save storage by always saving as floats
|
| 895 |
+
np.savez(search_path, verts=verts_np, faces=faces_np, dist=result_dists)
|
| 896 |
+
|
| 897 |
+
return result_dists
|
shape_models/layers.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import scipy
|
| 7 |
+
import scipy.sparse.linalg as sla
|
| 8 |
+
# ^^^ we NEED to import scipy before torch, or it crashes :(
|
| 9 |
+
# (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def to_basis(values, basis, massvec):
|
| 18 |
+
"""
|
| 19 |
+
Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
|
| 20 |
+
Inputs:
|
| 21 |
+
- values: (B,V,D)
|
| 22 |
+
- basis: (B,V,K)
|
| 23 |
+
- massvec: (B,V)
|
| 24 |
+
Outputs:
|
| 25 |
+
- (B,K,D) transformed values
|
| 26 |
+
"""
|
| 27 |
+
basisT = basis.transpose(-2, -1)
|
| 28 |
+
return torch.matmul(basisT, values * massvec.unsqueeze(-1))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def from_basis(values, basis):
|
| 32 |
+
"""
|
| 33 |
+
Transform data out of an orthonormal basis
|
| 34 |
+
Inputs:
|
| 35 |
+
- values: (K,D)
|
| 36 |
+
- basis: (V,K)
|
| 37 |
+
Outputs:
|
| 38 |
+
- (V,D) reconstructed values
|
| 39 |
+
"""
|
| 40 |
+
if values.is_complex() or basis.is_complex():
|
| 41 |
+
return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
|
| 42 |
+
else:
|
| 43 |
+
return torch.matmul(basis, values)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LearnedTimeDiffusion(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Applies diffusion with learned per-channel t.
|
| 49 |
+
|
| 50 |
+
In the spectral domain this becomes
|
| 51 |
+
f_out = e ^ (lambda_i t) f_in
|
| 52 |
+
|
| 53 |
+
Inputs:
|
| 54 |
+
- values: (V,C) in the spectral domain
|
| 55 |
+
- L: (V,V) sparse laplacian
|
| 56 |
+
- evals: (K) eigenvalues
|
| 57 |
+
- mass: (V) mass matrix diagonal
|
| 58 |
+
|
| 59 |
+
(note: L/evals may be omitted as None depending on method)
|
| 60 |
+
Outputs:
|
| 61 |
+
- (V,C) diffused values
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, C_inout, method='spectral'):
|
| 65 |
+
super(LearnedTimeDiffusion, self).__init__()
|
| 66 |
+
self.C_inout = C_inout
|
| 67 |
+
self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C)
|
| 68 |
+
self.method = method # one of ['spectral', 'implicit_dense']
|
| 69 |
+
|
| 70 |
+
nn.init.constant_(self.diffusion_time, 0.0)
|
| 71 |
+
|
| 72 |
+
def forward(self, x, L, mass, evals, evecs):
|
| 73 |
+
|
| 74 |
+
# project times to the positive halfspace
|
| 75 |
+
# (and away from 0 in the incredibly rare chance that they get stuck)
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
|
| 78 |
+
|
| 79 |
+
if x.shape[-1] != self.C_inout:
|
| 80 |
+
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
|
| 81 |
+
x.shape, self.C_inout))
|
| 82 |
+
|
| 83 |
+
if self.method == 'spectral':
|
| 84 |
+
|
| 85 |
+
# Transform to spectral
|
| 86 |
+
x_spec = to_basis(x, evecs, mass)
|
| 87 |
+
|
| 88 |
+
# Diffuse
|
| 89 |
+
time = self.diffusion_time
|
| 90 |
+
diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
|
| 91 |
+
x_diffuse_spec = diffusion_coefs * x_spec
|
| 92 |
+
|
| 93 |
+
# Transform back to per-vertex
|
| 94 |
+
x_diffuse = from_basis(x_diffuse_spec, evecs)
|
| 95 |
+
|
| 96 |
+
elif self.method == 'implicit_dense':
|
| 97 |
+
V = x.shape[-2]
|
| 98 |
+
|
| 99 |
+
# Form the dense matrices (M + tL) with dims (B,C,V,V)
|
| 100 |
+
mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
|
| 101 |
+
mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 102 |
+
mat_dense += torch.diag_embed(mass).unsqueeze(1)
|
| 103 |
+
|
| 104 |
+
# Factor the system
|
| 105 |
+
cholesky_factors = torch.linalg.cholesky(mat_dense)
|
| 106 |
+
|
| 107 |
+
# Solve the system
|
| 108 |
+
rhs = x * mass.unsqueeze(-1)
|
| 109 |
+
rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
|
| 110 |
+
sols = torch.cholesky_solve(rhsT, cholesky_factors)
|
| 111 |
+
x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("unrecognized method")
|
| 115 |
+
|
| 116 |
+
return x_diffuse
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class SpatialGradientFeatures(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
|
| 122 |
+
|
| 123 |
+
Input:
|
| 124 |
+
- vectors: (V,C,2)
|
| 125 |
+
Output:
|
| 126 |
+
- dots: (V,C) dots
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, C_inout, with_gradient_rotations=True):
|
| 130 |
+
super(SpatialGradientFeatures, self).__init__()
|
| 131 |
+
|
| 132 |
+
self.C_inout = C_inout
|
| 133 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 134 |
+
|
| 135 |
+
if (self.with_gradient_rotations):
|
| 136 |
+
self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 137 |
+
self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 138 |
+
else:
|
| 139 |
+
self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 140 |
+
|
| 141 |
+
# self.norm = nn.InstanceNorm1d(C_inout)
|
| 142 |
+
|
| 143 |
+
def forward(self, vectors):
|
| 144 |
+
|
| 145 |
+
vectorsA = vectors # (V,C)
|
| 146 |
+
|
| 147 |
+
if self.with_gradient_rotations:
|
| 148 |
+
vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
|
| 149 |
+
vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
|
| 150 |
+
else:
|
| 151 |
+
vectorsBreal = self.A(vectors[..., 0])
|
| 152 |
+
vectorsBimag = self.A(vectors[..., 1])
|
| 153 |
+
|
| 154 |
+
dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
|
| 155 |
+
|
| 156 |
+
return torch.tanh(dots)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class MiniMLP(nn.Sequential):
|
| 160 |
+
'''
|
| 161 |
+
A simple MLP with configurable hidden layer sizes.
|
| 162 |
+
'''
|
| 163 |
+
|
| 164 |
+
def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
|
| 165 |
+
super(MiniMLP, self).__init__()
|
| 166 |
+
|
| 167 |
+
for i in range(len(layer_sizes) - 1):
|
| 168 |
+
is_last = (i + 2 == len(layer_sizes))
|
| 169 |
+
|
| 170 |
+
if dropout and i > 0:
|
| 171 |
+
self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
|
| 172 |
+
|
| 173 |
+
# Affine map
|
| 174 |
+
self.add_module(
|
| 175 |
+
name + "_mlp_layer_{:03d}".format(i),
|
| 176 |
+
nn.Linear(
|
| 177 |
+
layer_sizes[i],
|
| 178 |
+
layer_sizes[i + 1],
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Nonlinearity
|
| 183 |
+
# (but not on the last layer)
|
| 184 |
+
if not is_last:
|
| 185 |
+
self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class DiffusionNetBlock(nn.Module):
|
| 189 |
+
"""
|
| 190 |
+
Inputs and outputs are defined at vertices
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self,
|
| 194 |
+
C_width,
|
| 195 |
+
mlp_hidden_dims,
|
| 196 |
+
dropout=True,
|
| 197 |
+
diffusion_method='spectral',
|
| 198 |
+
with_gradient_features=True,
|
| 199 |
+
with_gradient_rotations=True):
|
| 200 |
+
super(DiffusionNetBlock, self).__init__()
|
| 201 |
+
|
| 202 |
+
# Specified dimensions
|
| 203 |
+
self.C_width = C_width
|
| 204 |
+
self.mlp_hidden_dims = mlp_hidden_dims
|
| 205 |
+
|
| 206 |
+
self.dropout = dropout
|
| 207 |
+
self.with_gradient_features = with_gradient_features
|
| 208 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 209 |
+
|
| 210 |
+
# Diffusion block
|
| 211 |
+
self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
|
| 212 |
+
|
| 213 |
+
self.MLP_C = 2 * self.C_width
|
| 214 |
+
|
| 215 |
+
if self.with_gradient_features:
|
| 216 |
+
self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
|
| 217 |
+
self.MLP_C += self.C_width
|
| 218 |
+
|
| 219 |
+
# MLPs
|
| 220 |
+
self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
|
| 221 |
+
|
| 222 |
+
def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
|
| 223 |
+
|
| 224 |
+
# Manage dimensions
|
| 225 |
+
B = x_in.shape[0] # batch dimension
|
| 226 |
+
if x_in.shape[-1] != self.C_width:
|
| 227 |
+
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
|
| 228 |
+
x_in.shape, self.C_width))
|
| 229 |
+
|
| 230 |
+
# Diffusion block
|
| 231 |
+
x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
|
| 232 |
+
|
| 233 |
+
# Compute gradient features, if using
|
| 234 |
+
if self.with_gradient_features:
|
| 235 |
+
|
| 236 |
+
# Compute gradients
|
| 237 |
+
x_grads = [
|
| 238 |
+
] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
|
| 239 |
+
for b in range(B):
|
| 240 |
+
# gradient after diffusion
|
| 241 |
+
x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
|
| 242 |
+
x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
|
| 243 |
+
|
| 244 |
+
x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
|
| 245 |
+
x_grad = torch.stack(x_grads, dim=0)
|
| 246 |
+
|
| 247 |
+
# Evaluate gradient features
|
| 248 |
+
x_grad_features = self.gradient_features(x_grad)
|
| 249 |
+
|
| 250 |
+
# Stack inputs to mlp
|
| 251 |
+
feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
|
| 252 |
+
else:
|
| 253 |
+
# Stack inputs to mlp
|
| 254 |
+
feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
|
| 255 |
+
|
| 256 |
+
# Apply the mlp
|
| 257 |
+
x0_out = self.mlp(feature_combined)
|
| 258 |
+
|
| 259 |
+
# Skip connection
|
| 260 |
+
x0_out = x0_out + x_in
|
| 261 |
+
|
| 262 |
+
return x0_out
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class DiffusionNet(nn.Module):
|
| 266 |
+
|
| 267 |
+
def __init__(self,
|
| 268 |
+
C_in,
|
| 269 |
+
C_out,
|
| 270 |
+
C_width=128,
|
| 271 |
+
N_block=4,
|
| 272 |
+
last_activation=None,
|
| 273 |
+
outputs_at='vertices',
|
| 274 |
+
mlp_hidden_dims=None,
|
| 275 |
+
dropout=True,
|
| 276 |
+
with_gradient_features=True,
|
| 277 |
+
with_gradient_rotations=True,
|
| 278 |
+
diffusion_method='spectral',
|
| 279 |
+
num_eigenbasis=128):
|
| 280 |
+
"""
|
| 281 |
+
Construct a DiffusionNet.
|
| 282 |
+
|
| 283 |
+
Parameters:
|
| 284 |
+
C_in (int): input dimension
|
| 285 |
+
C_out (int): output dimension
|
| 286 |
+
last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
|
| 287 |
+
outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
|
| 288 |
+
C_width (int): dimension of internal DiffusionNet blocks (default: 128)
|
| 289 |
+
N_block (int): number of DiffusionNet blocks (default: 4)
|
| 290 |
+
mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width])
|
| 291 |
+
dropout (bool): if True, internal MLPs use dropout (default: True)
|
| 292 |
+
diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
|
| 293 |
+
with_gradient_features (bool): if True, use gradient features (default: True)
|
| 294 |
+
with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
|
| 295 |
+
num_eigenbasis (int): for trunking the eigenvalues eigenvectors
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
super(DiffusionNet, self).__init__()
|
| 299 |
+
|
| 300 |
+
## Store parameters
|
| 301 |
+
|
| 302 |
+
# Basic parameters
|
| 303 |
+
self.C_in = C_in
|
| 304 |
+
self.C_out = C_out
|
| 305 |
+
self.C_width = C_width
|
| 306 |
+
self.N_block = N_block
|
| 307 |
+
|
| 308 |
+
# Outputs
|
| 309 |
+
self.last_activation = last_activation
|
| 310 |
+
self.outputs_at = outputs_at
|
| 311 |
+
if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
|
| 312 |
+
raise ValueError("invalid setting for outputs_at")
|
| 313 |
+
|
| 314 |
+
# MLP options
|
| 315 |
+
if mlp_hidden_dims == None:
|
| 316 |
+
mlp_hidden_dims = [C_width, C_width]
|
| 317 |
+
self.mlp_hidden_dims = mlp_hidden_dims
|
| 318 |
+
self.dropout = dropout
|
| 319 |
+
|
| 320 |
+
# Diffusion
|
| 321 |
+
self.diffusion_method = diffusion_method
|
| 322 |
+
if diffusion_method not in ['spectral', 'implicit_dense']:
|
| 323 |
+
raise ValueError("invalid setting for diffusion_method")
|
| 324 |
+
self.num_eigenbasis = num_eigenbasis
|
| 325 |
+
|
| 326 |
+
# Gradient features
|
| 327 |
+
self.with_gradient_features = with_gradient_features
|
| 328 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 329 |
+
|
| 330 |
+
## Set up the network
|
| 331 |
+
|
| 332 |
+
# First and last affine layers
|
| 333 |
+
self.first_lin = nn.Linear(C_in, C_width)
|
| 334 |
+
self.last_lin = nn.Linear(C_width, C_out)
|
| 335 |
+
|
| 336 |
+
# DiffusionNet blocks
|
| 337 |
+
self.blocks = []
|
| 338 |
+
for i_block in range(self.N_block):
|
| 339 |
+
block = DiffusionNetBlock(C_width=C_width,
|
| 340 |
+
mlp_hidden_dims=mlp_hidden_dims,
|
| 341 |
+
dropout=dropout,
|
| 342 |
+
diffusion_method=diffusion_method,
|
| 343 |
+
with_gradient_features=with_gradient_features,
|
| 344 |
+
with_gradient_rotations=with_gradient_rotations)
|
| 345 |
+
|
| 346 |
+
self.blocks.append(block)
|
| 347 |
+
self.add_module("block_" + str(i_block), self.blocks[-1])
|
| 348 |
+
|
| 349 |
+
def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
|
| 350 |
+
"""
|
| 351 |
+
A forward pass on the DiffusionNet.
|
| 352 |
+
|
| 353 |
+
In the notation below, dimension are:
|
| 354 |
+
- C is the input channel dimension (C_in on construction)
|
| 355 |
+
- C_OUT is the output channel dimension (C_out on construction)
|
| 356 |
+
- N is the number of vertices/points, which CAN be different for each forward pass
|
| 357 |
+
- B is an OPTIONAL batch dimension
|
| 358 |
+
- K_EIG is the number of eigenvalues used for spectral acceleration
|
| 359 |
+
Generally, our data layout it is [N,C] or [B,N,C].
|
| 360 |
+
|
| 361 |
+
Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
|
| 362 |
+
|
| 363 |
+
Parameters:
|
| 364 |
+
x_in (tensor): Input features, dimension [N,C] or [B,N,C]
|
| 365 |
+
mass (tensor): Mass vector, dimension [N] or [B,N]
|
| 366 |
+
L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
|
| 367 |
+
evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
|
| 368 |
+
evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
|
| 369 |
+
gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
|
| 370 |
+
gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out]
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
## Check dimensions, and append batch dimension if not given
|
| 377 |
+
if x_in.shape[-1] != self.C_in:
|
| 378 |
+
raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
|
| 379 |
+
self.C_in, x_in.shape[-1]))
|
| 380 |
+
N = x_in.shape[-2]
|
| 381 |
+
if len(x_in.shape) == 2:
|
| 382 |
+
appended_batch_dim = True
|
| 383 |
+
|
| 384 |
+
# add a batch dim to all inputs
|
| 385 |
+
x_in = x_in.unsqueeze(0)
|
| 386 |
+
mass = mass.unsqueeze(0)
|
| 387 |
+
if L != None:
|
| 388 |
+
L = L.unsqueeze(0)
|
| 389 |
+
if evals != None:
|
| 390 |
+
evals = evals.unsqueeze(0)
|
| 391 |
+
if evecs != None:
|
| 392 |
+
evecs = evecs.unsqueeze(0)
|
| 393 |
+
if gradX != None:
|
| 394 |
+
gradX = gradX.unsqueeze(0)
|
| 395 |
+
if gradY != None:
|
| 396 |
+
gradY = gradY.unsqueeze(0)
|
| 397 |
+
if edges != None:
|
| 398 |
+
edges = edges.unsqueeze(0)
|
| 399 |
+
if faces != None:
|
| 400 |
+
faces = faces.unsqueeze(0)
|
| 401 |
+
|
| 402 |
+
elif len(x_in.shape) == 3:
|
| 403 |
+
appended_batch_dim = False
|
| 404 |
+
|
| 405 |
+
else:
|
| 406 |
+
raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
|
| 407 |
+
|
| 408 |
+
evals = evals[..., :self.num_eigenbasis]
|
| 409 |
+
evecs = evecs[..., :self.num_eigenbasis]
|
| 410 |
+
|
| 411 |
+
# Apply the first linear layer
|
| 412 |
+
x = self.first_lin(x_in)
|
| 413 |
+
|
| 414 |
+
# Apply each of the blocks
|
| 415 |
+
for b in self.blocks:
|
| 416 |
+
x = b(x, mass, L, evals, evecs, gradX, gradY)
|
| 417 |
+
|
| 418 |
+
# Apply the last linear layer
|
| 419 |
+
x = self.last_lin(x)
|
| 420 |
+
|
| 421 |
+
# Remap output to faces/edges if requested
|
| 422 |
+
if self.outputs_at == 'vertices':
|
| 423 |
+
x_out = x
|
| 424 |
+
|
| 425 |
+
elif self.outputs_at == 'edges':
|
| 426 |
+
# Remap to edges
|
| 427 |
+
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
|
| 428 |
+
edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
|
| 429 |
+
xe = torch.gather(x_gather, 1, edges_gather)
|
| 430 |
+
x_out = torch.mean(xe, dim=-1)
|
| 431 |
+
|
| 432 |
+
elif self.outputs_at == 'faces':
|
| 433 |
+
# Remap to faces
|
| 434 |
+
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
|
| 435 |
+
faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
|
| 436 |
+
xf = torch.gather(x_gather, 1, faces_gather)
|
| 437 |
+
x_out = torch.mean(xf, dim=-1)
|
| 438 |
+
|
| 439 |
+
elif self.outputs_at == 'global_mean':
|
| 440 |
+
# Produce a single global mean ouput.
|
| 441 |
+
# Using a weighted mean according to the point mass/area is discretization-invariant.
|
| 442 |
+
# (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
|
| 443 |
+
x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
|
| 444 |
+
|
| 445 |
+
# Apply last nonlinearity if specified
|
| 446 |
+
if self.last_activation != None:
|
| 447 |
+
x_out = self.last_activation(x_out)
|
| 448 |
+
|
| 449 |
+
# Remove batch dim if we added it
|
| 450 |
+
if appended_batch_dim:
|
| 451 |
+
x_out = x_out.squeeze(0)
|
| 452 |
+
|
| 453 |
+
return x_out
|
shape_models/utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import hashlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import scipy
|
| 9 |
+
|
| 10 |
+
# == Pytorch things
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def toNP(x):
|
| 14 |
+
"""
|
| 15 |
+
Really, definitely convert a torch tensor to a numpy array
|
| 16 |
+
"""
|
| 17 |
+
return x.detach().to(torch.device('cpu')).numpy()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def label_smoothing_log_loss(pred, labels, smoothing=0.0):
|
| 21 |
+
n_class = pred.shape[-1]
|
| 22 |
+
one_hot = torch.zeros_like(pred)
|
| 23 |
+
one_hot[labels] = 1.
|
| 24 |
+
one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
|
| 25 |
+
loss = -(one_hot * pred).sum(dim=-1).mean()
|
| 26 |
+
return loss
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Randomly rotate points.
|
| 30 |
+
# Torch in, torch out
|
| 31 |
+
# Note fornow, builds rotation matrix on CPU.
|
| 32 |
+
def random_rotate_points(pts, randgen=None):
|
| 33 |
+
R = random_rotation_matrix(randgen)
|
| 34 |
+
R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
|
| 35 |
+
return torch.matmul(pts, R)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def random_rotate_points_y(pts):
|
| 39 |
+
angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
|
| 40 |
+
rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
|
| 41 |
+
rot_mats[0, 0] = torch.cos(angles)
|
| 42 |
+
rot_mats[0, 2] = torch.sin(angles)
|
| 43 |
+
rot_mats[2, 0] = -torch.sin(angles)
|
| 44 |
+
rot_mats[2, 2] = torch.cos(angles)
|
| 45 |
+
rot_mats[1, 1] = 1.
|
| 46 |
+
|
| 47 |
+
pts = torch.matmul(pts, rot_mats)
|
| 48 |
+
return pts
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Numpy things
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Numpy sparse matrix to pytorch
|
| 55 |
+
def sparse_np_to_torch(A):
|
| 56 |
+
Acoo = A.tocoo()
|
| 57 |
+
values = Acoo.data
|
| 58 |
+
indices = np.vstack((Acoo.row, Acoo.col))
|
| 59 |
+
shape = Acoo.shape
|
| 60 |
+
return torch.sparse.FloatTensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape)).coalesce()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Pytorch sparse to numpy csc matrix
|
| 64 |
+
def sparse_torch_to_np(A):
|
| 65 |
+
if len(A.shape) != 2:
|
| 66 |
+
raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
|
| 67 |
+
|
| 68 |
+
indices = toNP(A.indices())
|
| 69 |
+
values = toNP(A.values())
|
| 70 |
+
|
| 71 |
+
mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
|
| 72 |
+
|
| 73 |
+
return mat
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Hash a list of numpy arrays
|
| 77 |
+
def hash_arrays(arrs):
|
| 78 |
+
running_hash = hashlib.sha1()
|
| 79 |
+
for arr in arrs:
|
| 80 |
+
binarr = arr.view(np.uint8)
|
| 81 |
+
running_hash.update(binarr)
|
| 82 |
+
return running_hash.hexdigest()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def random_rotation_matrix(randgen=None):
|
| 86 |
+
"""
|
| 87 |
+
Creates a random rotation matrix.
|
| 88 |
+
randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
|
| 89 |
+
"""
|
| 90 |
+
# adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
|
| 91 |
+
|
| 92 |
+
if randgen is None:
|
| 93 |
+
randgen = np.random.RandomState()
|
| 94 |
+
|
| 95 |
+
theta, phi, z = tuple(randgen.rand(3).tolist())
|
| 96 |
+
|
| 97 |
+
theta = theta * 2.0 * np.pi # Rotation about the pole (Z).
|
| 98 |
+
phi = phi * 2.0 * np.pi # For direction of pole deflection.
|
| 99 |
+
z = z * 2.0 # For magnitude of pole deflection.
|
| 100 |
+
|
| 101 |
+
# Compute a vector V used for distributing points over the sphere
|
| 102 |
+
# via the reflection I - V Transpose(V). This formulation of V
|
| 103 |
+
# will guarantee that if x[1] and x[2] are uniformly distributed,
|
| 104 |
+
# the reflected points will be uniform on the sphere. Note that V
|
| 105 |
+
# has length sqrt(2) to eliminate the 2 in the Householder matrix.
|
| 106 |
+
|
| 107 |
+
r = np.sqrt(z)
|
| 108 |
+
Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
|
| 109 |
+
|
| 110 |
+
st = np.sin(theta)
|
| 111 |
+
ct = np.cos(theta)
|
| 112 |
+
|
| 113 |
+
R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
|
| 114 |
+
# Construct the rotation matrix ( V Transpose(V) - I ) R.
|
| 115 |
+
|
| 116 |
+
M = (np.outer(V, V) - np.eye(3)).dot(R)
|
| 117 |
+
return M
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Python string/file utilities
|
| 121 |
+
def ensure_dir_exists(d):
|
| 122 |
+
if not os.path.exists(d):
|
| 123 |
+
os.makedirs(d)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .utils_legacy import *
|
| 2 |
+
# from .geometry import *
|
| 3 |
+
# from .layers import *
|
utils/descriptors.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
# https://github.com/RobinMagnet/SimplifiedFmapsLearning/blob/main/learn_zo/data/utils.py
|
| 4 |
+
# https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/HKS_functions.py
|
| 5 |
+
def HKS(evals, evects, time_list,scaled=False):
|
| 6 |
+
"""
|
| 7 |
+
Returns the Heat Kernel Signature for num_T different values.
|
| 8 |
+
The values of the time are interpolated in logscale between the limits
|
| 9 |
+
given in the HKS paper. These limits only depends on the eigenvalues.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
------------------------
|
| 13 |
+
evals : (K,) array of the K eigenvalues
|
| 14 |
+
evecs : (N,K) array with the K eigenvectors
|
| 15 |
+
time_list : (num_T,) Time values to use
|
| 16 |
+
scaled : (bool) whether to scale for each time value
|
| 17 |
+
|
| 18 |
+
Output
|
| 19 |
+
------------------------
|
| 20 |
+
HKS : (N,num_T) array where each line is the HKS for a given t
|
| 21 |
+
"""
|
| 22 |
+
evals_s = np.asarray(evals).flatten()
|
| 23 |
+
t_list = np.asarray(time_list).flatten()
|
| 24 |
+
|
| 25 |
+
coefs = np.exp(-np.outer(t_list, evals_s)) # (num_T,K)
|
| 26 |
+
# weighted_evects = evects[None, :, :] * coefs[:, None,:] # (num_T,N,K)
|
| 27 |
+
# natural_HKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
|
| 28 |
+
natural_HKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
|
| 29 |
+
|
| 30 |
+
if scaled:
|
| 31 |
+
inv_scaling = coefs.sum(1) # (num_T)
|
| 32 |
+
return (1/inv_scaling)[None,:] * natural_HKS
|
| 33 |
+
|
| 34 |
+
return natural_HKS
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def lm_HKS(evals, evects, landmarks, time_list, scaled=False):
|
| 38 |
+
"""
|
| 39 |
+
Returns the Heat Kernel Signature for some landmarks and time values.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
------------------------
|
| 44 |
+
evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
|
| 45 |
+
evals : (K,) array of the K corresponding eigenvalues
|
| 46 |
+
landmarks : (p,) indices of landmarks to compute
|
| 47 |
+
time_list : (num_T,) values of t to use
|
| 48 |
+
|
| 49 |
+
Output
|
| 50 |
+
------------------------
|
| 51 |
+
landmarks_HKS : (N,num_E*p) array where each column is the HKS for a given t for some landmark
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
evals_s = np.asarray(evals).flatten()
|
| 55 |
+
t_list = np.asarray(time_list).flatten()
|
| 56 |
+
|
| 57 |
+
coefs = np.exp(-np.outer(t_list, evals_s)) # (num_T,K)
|
| 58 |
+
weighted_evects = evects[None, landmarks, :] * coefs[:,None,:] # (num_T,p,K)
|
| 59 |
+
|
| 60 |
+
landmarks_HKS = np.einsum('tpk,nk->ptn', weighted_evects, evects) # (p,num_T,N)
|
| 61 |
+
landmarks_HKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects) # (p,num_T,N)
|
| 62 |
+
|
| 63 |
+
if scaled:
|
| 64 |
+
inv_scaling = coefs.sum(1) # (num_T,)
|
| 65 |
+
landmarks_HKS = (1/inv_scaling)[None,:,None] * landmarks_HKS # (p,num_T,N)
|
| 66 |
+
|
| 67 |
+
return rearrange(landmarks_HKS, 'p T N -> N (p T)')
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def auto_HKS(evals, evects, num_T, landmarks=None, scaled=True):
|
| 71 |
+
"""
|
| 72 |
+
Compute HKS with an automatic choice of tile values
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
------------------------
|
| 76 |
+
evals : (K,) array of K eigenvalues
|
| 77 |
+
evects : (N,K) array with K eigenvectors
|
| 78 |
+
landmarks : (p,) if not None, indices of landmarks to compute.
|
| 79 |
+
num_T : (int) number values of t to use
|
| 80 |
+
Output
|
| 81 |
+
------------------------
|
| 82 |
+
HKS or lm_HKS : (N,num_E) or (N,p*num_E) array where each column is the WKS for a given e
|
| 83 |
+
for some landmark
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
abs_ev = sorted(np.abs(evals))
|
| 87 |
+
t_list = np.geomspace(4*np.log(10)/abs_ev[-1], 4*np.log(10)/abs_ev[1], num_T)
|
| 88 |
+
|
| 89 |
+
if landmarks is None:
|
| 90 |
+
return HKS(abs_ev, evects, t_list, scaled=scaled)
|
| 91 |
+
else:
|
| 92 |
+
return lm_HKS(abs_ev, evects, landmarks, t_list, scaled=scaled)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/WKS_functions.py
|
| 96 |
+
def WKS(evals, evects, energy_list, sigma, scaled=False):
|
| 97 |
+
"""
|
| 98 |
+
Returns the Wave Kernel Signature for some energy values.
|
| 99 |
+
|
| 100 |
+
Parameters
|
| 101 |
+
------------------------
|
| 102 |
+
evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
|
| 103 |
+
evals : (K,) array of the K corresponding eigenvalues
|
| 104 |
+
energy_list : (num_E,) values of e to use
|
| 105 |
+
sigma : (float) [positive] standard deviation to use
|
| 106 |
+
scaled : (bool) Whether to scale each energy level
|
| 107 |
+
|
| 108 |
+
Output
|
| 109 |
+
------------------------
|
| 110 |
+
WKS : (N,num_E) array where each column is the WKS for a given e
|
| 111 |
+
"""
|
| 112 |
+
assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
|
| 113 |
+
|
| 114 |
+
evals = np.asarray(evals).flatten()
|
| 115 |
+
indices = np.where(evals > 1e-5)[0].flatten()
|
| 116 |
+
evals = evals[indices]
|
| 117 |
+
evects = evects[:, indices]
|
| 118 |
+
|
| 119 |
+
e_list = np.asarray(energy_list)
|
| 120 |
+
coefs = np.exp(-np.square(e_list[:,None] - np.log(np.abs(evals))[None,:])/(2*sigma**2)) # (num_E,K)
|
| 121 |
+
|
| 122 |
+
# weighted_evects = evects[None, :, :] * coefs[:,None, :] # (num_E,N,K)
|
| 123 |
+
|
| 124 |
+
# natural_WKS = np.einsum('tnk,nk->nt', weighted_evects, evects) # (N,num_E)
|
| 125 |
+
natural_WKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
|
| 126 |
+
|
| 127 |
+
if scaled:
|
| 128 |
+
inv_scaling = coefs.sum(1) # (num_E)
|
| 129 |
+
return (1/inv_scaling)[None,:] * natural_WKS
|
| 130 |
+
|
| 131 |
+
return natural_WKS
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def lm_WKS(evals, evects, landmarks, energy_list, sigma, scaled=False):
|
| 135 |
+
"""
|
| 136 |
+
Returns the Wave Kernel Signature for some landmarks and energy values.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
------------------------
|
| 141 |
+
evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
|
| 142 |
+
evals : (K,) array of the K corresponding eigenvalues
|
| 143 |
+
landmarks : (p,) indices of landmarks to compute
|
| 144 |
+
energy_list : (num_E,) values of e to use
|
| 145 |
+
sigma : int - standard deviation
|
| 146 |
+
|
| 147 |
+
Output
|
| 148 |
+
------------------------
|
| 149 |
+
landmarks_WKS : (N,num_E*p) array where each column is the WKS for a given e for some landmark
|
| 150 |
+
"""
|
| 151 |
+
assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
|
| 152 |
+
|
| 153 |
+
evals = np.asarray(evals).flatten()
|
| 154 |
+
indices = np.where(evals > 1e-2)[0].flatten()
|
| 155 |
+
evals = evals[indices]
|
| 156 |
+
evects = evects[:,indices]
|
| 157 |
+
|
| 158 |
+
e_list = np.asarray(energy_list)
|
| 159 |
+
coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2*sigma**2)) # (num_E,K)
|
| 160 |
+
# weighted_evects = evects[None, landmarks, :] * coefs[:,None,:] # (num_E,p,K)
|
| 161 |
+
|
| 162 |
+
# landmarks_WKS = np.einsum('tpk,nk->ptn', weighted_evects, evects) # (p,num_E,N)
|
| 163 |
+
landmarks_WKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects) # (p,num_E,N)
|
| 164 |
+
|
| 165 |
+
if scaled:
|
| 166 |
+
inv_scaling = coefs.sum(1) # (num_E,)
|
| 167 |
+
landmarks_WKS = (1/inv_scaling)[None,:,None] * landmarks_WKS
|
| 168 |
+
|
| 169 |
+
# return landmarks_WKS.reshape(-1,evects.shape[0]).T # (N,p*num_E)
|
| 170 |
+
return rearrange(landmarks_WKS, 'p T N -> N (p T)')
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def auto_WKS(evals, evects, num_E, landmarks=None, scaled=True):
|
| 174 |
+
"""
|
| 175 |
+
Compute WKS with an automatic choice of scale and energy
|
| 176 |
+
|
| 177 |
+
Parameters
|
| 178 |
+
------------------------
|
| 179 |
+
evals : (K,) array of K eigenvalues
|
| 180 |
+
evects : (N,K) array with K eigenvectors
|
| 181 |
+
landmarks : (p,) If not None, indices of landmarks to compute.
|
| 182 |
+
num_E : (int) number values of e to use
|
| 183 |
+
Output
|
| 184 |
+
------------------------
|
| 185 |
+
WKS or lm_WKS : (N,num_E) or (N,p*num_E) array where each column is the WKS for a given e
|
| 186 |
+
and possibly for some landmarks
|
| 187 |
+
"""
|
| 188 |
+
abs_ev = sorted(np.abs(evals))
|
| 189 |
+
|
| 190 |
+
e_min,e_max = np.log(abs_ev[1]),np.log(abs_ev[-1])
|
| 191 |
+
sigma = 7*(e_max-e_min)/num_E
|
| 192 |
+
|
| 193 |
+
e_min += 2*sigma
|
| 194 |
+
e_max -= 2*sigma
|
| 195 |
+
|
| 196 |
+
energy_list = np.linspace(e_min,e_max,num_E)
|
| 197 |
+
|
| 198 |
+
if landmarks is None:
|
| 199 |
+
return WKS(abs_ev, evects, energy_list, sigma, scaled=scaled)
|
| 200 |
+
else:
|
| 201 |
+
return lm_WKS(abs_ev, evects, landmarks, energy_list, sigma, scaled=scaled)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def compute_hks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
|
| 205 |
+
"""
|
| 206 |
+
Compute Heat Kernel Signature (HKS) descriptors.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
|
| 210 |
+
evals: (K,) eigenvalues of the Laplace-Beltrami operator
|
| 211 |
+
mass: (N,) vertex masses
|
| 212 |
+
n_descr: (int) number of descriptors
|
| 213 |
+
subsample_step: (int) subsampling step
|
| 214 |
+
n_eig: (int) number of eigenvectors to use
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
feats: (N, n_descr) HKS descriptors
|
| 218 |
+
"""
|
| 219 |
+
feats = auto_HKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
|
| 220 |
+
feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
|
| 221 |
+
if normalize:
|
| 222 |
+
feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
|
| 223 |
+
feats /= np.sqrt(feats_norm2)[None, :]
|
| 224 |
+
return feats.astype(np.float32)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def compute_wks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
|
| 228 |
+
"""
|
| 229 |
+
Compute Wave Kernel Signature (WKS) descriptors.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
|
| 233 |
+
evals: (K,) eigenvalues of the Laplace-Beltrami operator
|
| 234 |
+
mass: (N,) vertex masses
|
| 235 |
+
n_descr: (int) number of descriptors
|
| 236 |
+
subsample_step: (int) subsampling step
|
| 237 |
+
n_eig: (int) number of eigenvectors to use
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
feats: (N, n_descr) WKS descriptors
|
| 241 |
+
"""
|
| 242 |
+
feats = auto_WKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
|
| 243 |
+
feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
|
| 244 |
+
# print("wks_shape",feats.shape, mass.shape)
|
| 245 |
+
if normalize:
|
| 246 |
+
feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
|
| 247 |
+
feats /= np.sqrt(feats_norm2)[None, :]
|
| 248 |
+
# feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
|
| 249 |
+
# feats /= np.sqrt(feats_norm2)[None, :]
|
| 250 |
+
return feats.astype(np.float32)
|
utils/eval.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def accuracy(p2p, gt_p2p, D1_geod, return_all=False, sqrt_area=None):
|
| 2 |
+
"""
|
| 3 |
+
Computes the geodesic accuracy of a vertex to vertex map. The map goes from
|
| 4 |
+
the target shape to the source shape.
|
| 5 |
+
Borrowed from Robin.
|
| 6 |
+
Parameters
|
| 7 |
+
----------------------
|
| 8 |
+
p2p : (n2,) - vertex to vertex map giving the index of the matched vertex on the source shape
|
| 9 |
+
for each vertex on the target shape (from a functional map point of view)
|
| 10 |
+
gt_p2p : (n2,) - ground truth mapping between the pairs
|
| 11 |
+
D1_geod : (n1,n1) - geodesic distance between pairs of vertices on the source mesh
|
| 12 |
+
return_all : bool - whether to return all the distances or only the average geodesic distance
|
| 13 |
+
|
| 14 |
+
Output
|
| 15 |
+
-----------------------
|
| 16 |
+
acc : float - average accuracy of the vertex to vertex map
|
| 17 |
+
dists : (n2,) - if return_all is True, returns all the pairwise distances
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
dists = D1_geod[(p2p, gt_p2p)]
|
| 21 |
+
if sqrt_area is not None:
|
| 22 |
+
dists /= sqrt_area
|
| 23 |
+
if return_all:
|
| 24 |
+
return dists.mean(), dists
|
| 25 |
+
return dists.mean()
|
utils/fmap.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import scipy.linalg
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 9 |
+
if ROOT_DIR not in sys.path:
|
| 10 |
+
sys.path.append(ROOT_DIR)
|
| 11 |
+
|
| 12 |
+
from utils_fmaps.misc import KNNSearch
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import pynndescent
|
| 16 |
+
index = pynndescent.NNDescent(np.random.random((100, 3)), n_jobs=2)
|
| 17 |
+
del index
|
| 18 |
+
ANN = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
ANN = False
|
| 21 |
+
|
| 22 |
+
# https://github.com/RobinMagnet/pyFM
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=False):
|
| 26 |
+
if use_ANN and not ANN:
|
| 27 |
+
raise ValueError('Please install pydescent to achieve Approximate Nearest Neighbor')
|
| 28 |
+
|
| 29 |
+
k2, k1 = FM.shape
|
| 30 |
+
|
| 31 |
+
assert k1 <= eigvects1.shape[1], \
|
| 32 |
+
f'At least {k1} should be provided, here only {eigvects1.shape[1]} are given'
|
| 33 |
+
assert k2 <= eigvects2.shape[1], \
|
| 34 |
+
f'At least {k2} should be provided, here only {eigvects2.shape[1]} are given'
|
| 35 |
+
|
| 36 |
+
if use_ANN:
|
| 37 |
+
index = pynndescent.NNDescent(eigvects1[:, :k1] @ FM.T, n_jobs=8)
|
| 38 |
+
matches, _ = index.query(eigvects2[:, :k2], k=1)
|
| 39 |
+
matches = matches.flatten()
|
| 40 |
+
else:
|
| 41 |
+
tree = KNNSearch(eigvects1[:, :k1] @ FM.T)
|
| 42 |
+
matches = tree.query(eigvects2[:, :k2], k=1).flatten()
|
| 43 |
+
|
| 44 |
+
return matches
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def p2p_to_FM(p2p, eigvects1, eigvects2, A2=None):
|
| 48 |
+
if A2 is not None:
|
| 49 |
+
if A2.shape[0] != eigvects2.shape[0]:
|
| 50 |
+
raise ValueError("Can't compute pseudo inverse with subsampled eigenvectors")
|
| 51 |
+
|
| 52 |
+
if len(A2.shape) == 1:
|
| 53 |
+
return eigvects2.T @ (A2[:, None] * eigvects1[p2p, :])
|
| 54 |
+
|
| 55 |
+
return eigvects2.T @ A2 @ eigvects1[p2p, :]
|
| 56 |
+
|
| 57 |
+
return scipy.linalg.lstsq(eigvects2, eigvects1[p2p, :])[0]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def zoomout_iteration(eigvects1, eigvects2, FM, step=1, A2=None, use_ANN=False):
|
| 61 |
+
k2, k1 = FM.shape
|
| 62 |
+
try:
|
| 63 |
+
step1, step2 = step
|
| 64 |
+
except TypeError:
|
| 65 |
+
step1 = step
|
| 66 |
+
step2 = step
|
| 67 |
+
new_k1, new_k2 = k1 + step1, k2 + step2
|
| 68 |
+
|
| 69 |
+
p2p = FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=use_ANN)
|
| 70 |
+
FM_zo = p2p_to_FM(p2p, eigvects1[:, :new_k1], eigvects2[:, :new_k2], A2=A2)
|
| 71 |
+
|
| 72 |
+
return FM_zo
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def zoomout_refine(eigvects1,
|
| 76 |
+
eigvects2,
|
| 77 |
+
FM,
|
| 78 |
+
nit=10,
|
| 79 |
+
step=1,
|
| 80 |
+
A2=None,
|
| 81 |
+
subsample=None,
|
| 82 |
+
use_ANN=False,
|
| 83 |
+
return_p2p=False,
|
| 84 |
+
verbose=False):
|
| 85 |
+
k2_0, k1_0 = FM.shape
|
| 86 |
+
try:
|
| 87 |
+
step1, step2 = step
|
| 88 |
+
except TypeError:
|
| 89 |
+
step1 = step
|
| 90 |
+
step2 = step
|
| 91 |
+
|
| 92 |
+
assert k1_0 + nit*step1 <= eigvects1.shape[1], \
|
| 93 |
+
f"Not enough eigenvectors on source : \
|
| 94 |
+
{k1_0 + nit*step1} are needed when {eigvects1.shape[1]} are provided"
|
| 95 |
+
assert k2_0 + nit*step2 <= eigvects2.shape[1], \
|
| 96 |
+
f"Not enough eigenvectors on target : \
|
| 97 |
+
{k2_0 + nit*step2} are needed when {eigvects2.shape[1]} are provided"
|
| 98 |
+
|
| 99 |
+
use_subsample = False
|
| 100 |
+
if subsample is not None:
|
| 101 |
+
use_subsample = True
|
| 102 |
+
sub1, sub2 = subsample
|
| 103 |
+
|
| 104 |
+
FM_zo = FM.copy()
|
| 105 |
+
|
| 106 |
+
ANN_adventage = False
|
| 107 |
+
iterable = range(nit) if not verbose else tqdm(range(nit))
|
| 108 |
+
for it in iterable:
|
| 109 |
+
ANN_adventage = use_ANN and (FM_zo.shape[0] > 90) and (FM_zo.shape[1] > 90)
|
| 110 |
+
|
| 111 |
+
if use_subsample:
|
| 112 |
+
FM_zo = zoomout_iteration(eigvects1[sub1], eigvects2[sub2], FM_zo, A2=None, step=step, use_ANN=ANN_adventage)
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
FM_zo = zoomout_iteration(eigvects1, eigvects2, FM_zo, A2=A2, step=step, use_ANN=ANN_adventage)
|
| 116 |
+
|
| 117 |
+
if return_p2p:
|
| 118 |
+
p2p_zo = FM_to_p2p(FM_zo, eigvects1, eigvects2, use_ANN=False)
|
| 119 |
+
return FM_zo, p2p_zo
|
| 120 |
+
|
| 121 |
+
return FM_zo
|
utils/geometry.py
ADDED
|
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scipy
|
| 2 |
+
import scipy.sparse.linalg as sla
|
| 3 |
+
# ^^^ we NEED to import scipy before torch, or it crashes :(
|
| 4 |
+
# (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
|
| 5 |
+
|
| 6 |
+
import os.path
|
| 7 |
+
import sys
|
| 8 |
+
import random
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import scipy.spatial
|
| 13 |
+
import torch
|
| 14 |
+
import sklearn.neighbors
|
| 15 |
+
|
| 16 |
+
import robust_laplacian
|
| 17 |
+
import potpourri3d as pp3d
|
| 18 |
+
|
| 19 |
+
from utils.utils_legacy import toNP, ensure_dir_exists, sparse_np_to_torch, sparse_torch_to_np
|
| 20 |
+
from utils.descriptors import compute_hks, compute_wks
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def norm(x, highdim=False):
|
| 24 |
+
"""
|
| 25 |
+
Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 26 |
+
"""
|
| 27 |
+
return torch.norm(x, dim=len(x.shape) - 1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def norm2(x, highdim=False):
|
| 31 |
+
"""
|
| 32 |
+
Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 33 |
+
"""
|
| 34 |
+
return dot(x, x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def normalize(x, divide_eps=1e-6, highdim=False):
|
| 38 |
+
"""
|
| 39 |
+
Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
|
| 40 |
+
"""
|
| 41 |
+
if (len(x.shape) == 1):
|
| 42 |
+
raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
|
| 43 |
+
if (not highdim and x.shape[-1] > 4):
|
| 44 |
+
raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
|
| 45 |
+
return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def face_coords(verts, faces):
|
| 49 |
+
coords = verts[faces]
|
| 50 |
+
return coords
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cross(vec_A, vec_B):
|
| 54 |
+
return torch.cross(vec_A, vec_B, dim=-1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def dot(vec_A, vec_B):
|
| 58 |
+
return torch.sum(vec_A * vec_B, dim=-1)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Given (..., 3) vectors and normals, projects out any components of vecs
|
| 62 |
+
# which lies in the direction of normals. Normals are assumed to be unit.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def project_to_tangent(vecs, unit_normals):
|
| 66 |
+
dots = dot(vecs, unit_normals)
|
| 67 |
+
return vecs - unit_normals * dots.unsqueeze(-1)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def face_area(verts, faces):
|
| 71 |
+
coords = face_coords(verts, faces)
|
| 72 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 73 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 74 |
+
|
| 75 |
+
raw_normal = cross(vec_A, vec_B)
|
| 76 |
+
return 0.5 * norm(raw_normal)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def face_normals(verts, faces, normalized=True):
|
| 80 |
+
coords = face_coords(verts, faces)
|
| 81 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 82 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 83 |
+
|
| 84 |
+
raw_normal = cross(vec_A, vec_B)
|
| 85 |
+
|
| 86 |
+
if normalized:
|
| 87 |
+
return normalize(raw_normal)
|
| 88 |
+
|
| 89 |
+
return raw_normal
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def neighborhood_normal(points):
|
| 93 |
+
# points: (N, K, 3) array of neighborhood psoitions
|
| 94 |
+
# points should be centered at origin
|
| 95 |
+
# out: (N,3) array of normals
|
| 96 |
+
# numpy in, numpy out
|
| 97 |
+
(u, s, vh) = np.linalg.svd(points, full_matrices=False)
|
| 98 |
+
normal = vh[:, 2, :]
|
| 99 |
+
return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def mesh_vertex_normals(verts, faces):
|
| 103 |
+
# numpy in / out
|
| 104 |
+
face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces))) # ugly torch <---> numpy
|
| 105 |
+
eps = 1e-3
|
| 106 |
+
vertex_normals = np.zeros(verts.shape)
|
| 107 |
+
for i in range(3):
|
| 108 |
+
np.add.at(vertex_normals, faces[:, i], face_n)
|
| 109 |
+
|
| 110 |
+
vertex_normals = vertex_normals / (eps + np.linalg.norm(vertex_normals, axis=-1, keepdims=True))
|
| 111 |
+
|
| 112 |
+
return vertex_normals
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def vertex_normals(verts, faces, n_neighbors_cloud=30):
|
| 116 |
+
verts_np = toNP(verts)
|
| 117 |
+
|
| 118 |
+
if faces.numel() == 0: # point cloud
|
| 119 |
+
|
| 120 |
+
_, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
|
| 121 |
+
neigh_points = verts_np[neigh_inds, :]
|
| 122 |
+
neigh_points = neigh_points - verts_np[:, np.newaxis, :]
|
| 123 |
+
normals = neighborhood_normal(neigh_points)
|
| 124 |
+
|
| 125 |
+
else: # mesh
|
| 126 |
+
|
| 127 |
+
normals = mesh_vertex_normals(verts_np, toNP(faces))
|
| 128 |
+
|
| 129 |
+
# if any are NaN, wiggle slightly and recompute
|
| 130 |
+
bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
|
| 131 |
+
if bad_normals_mask.any():
|
| 132 |
+
bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
|
| 133 |
+
scale = np.linalg.norm(bbox) * 1e-4
|
| 134 |
+
wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
|
| 135 |
+
wiggle_verts = verts_np + bad_normals_mask * wiggle
|
| 136 |
+
normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
|
| 137 |
+
|
| 138 |
+
# if still NaN assign random normals (probably means unreferenced verts in mesh)
|
| 139 |
+
bad_normals_mask = np.isnan(normals).any(axis=1)
|
| 140 |
+
if bad_normals_mask.any():
|
| 141 |
+
normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
|
| 142 |
+
normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
|
| 143 |
+
|
| 144 |
+
normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
|
| 145 |
+
|
| 146 |
+
if torch.any(torch.isnan(normals)):
|
| 147 |
+
raise ValueError("NaN normals :(")
|
| 148 |
+
|
| 149 |
+
return normals
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def build_tangent_frames(verts, faces, normals=None):
|
| 153 |
+
|
| 154 |
+
V = verts.shape[0]
|
| 155 |
+
dtype = verts.dtype
|
| 156 |
+
device = verts.device
|
| 157 |
+
|
| 158 |
+
if normals is None:
|
| 159 |
+
vert_normals = vertex_normals(verts, faces) # (V,3)
|
| 160 |
+
elif isinstance(normals, np.ndarray):
|
| 161 |
+
vert_normals = torch.from_numpy(normals).to(dtype=dtype, device=device)
|
| 162 |
+
else:
|
| 163 |
+
vert_normals = normals
|
| 164 |
+
|
| 165 |
+
# = find an orthogonal basis
|
| 166 |
+
|
| 167 |
+
basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
|
| 168 |
+
basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
|
| 169 |
+
|
| 170 |
+
basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
|
| 171 |
+
basisX = project_to_tangent(basisX, vert_normals)
|
| 172 |
+
basisX = normalize(basisX)
|
| 173 |
+
basisY = cross(vert_normals, basisX)
|
| 174 |
+
frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
|
| 175 |
+
|
| 176 |
+
if torch.any(torch.isnan(frames)):
|
| 177 |
+
raise ValueError("NaN coordinate frame! Must be very degenerate")
|
| 178 |
+
|
| 179 |
+
return frames
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
|
| 183 |
+
|
| 184 |
+
verts_np = toNP(verts)
|
| 185 |
+
frames_np = toNP(frames)
|
| 186 |
+
|
| 187 |
+
_, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
|
| 188 |
+
neigh_points = verts_np[neigh_inds, :]
|
| 189 |
+
neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
|
| 190 |
+
|
| 191 |
+
# TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
|
| 192 |
+
|
| 193 |
+
edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
|
| 194 |
+
edges = np.stack((edge_inds_from, neigh_inds.flatten()))
|
| 195 |
+
edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
|
| 196 |
+
|
| 197 |
+
return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def edge_tangent_vectors(verts, frames, edges):
|
| 201 |
+
edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
|
| 202 |
+
basisX = frames[edges[0, :], 0, :]
|
| 203 |
+
basisY = frames[edges[0, :], 1, :]
|
| 204 |
+
|
| 205 |
+
compX = dot(edge_vecs, basisX)
|
| 206 |
+
compY = dot(edge_vecs, basisY)
|
| 207 |
+
edge_tangent = torch.stack((compX, compY), dim=-1)
|
| 208 |
+
|
| 209 |
+
return edge_tangent
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def build_grad(verts, edges, edge_tangent_vectors):
|
| 213 |
+
"""
|
| 214 |
+
Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
|
| 215 |
+
- edges: (2, E)
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
edges_np = toNP(edges)
|
| 219 |
+
edge_tangent_vectors_np = toNP(edge_tangent_vectors)
|
| 220 |
+
|
| 221 |
+
# TODO find a way to do this in pure numpy?
|
| 222 |
+
|
| 223 |
+
# Build outgoing neighbor lists
|
| 224 |
+
N = verts.shape[0]
|
| 225 |
+
vert_edge_outgoing = [[] for i in range(N)]
|
| 226 |
+
for iE in range(edges_np.shape[1]):
|
| 227 |
+
tail_ind = edges_np[0, iE]
|
| 228 |
+
tip_ind = edges_np[1, iE]
|
| 229 |
+
if tip_ind != tail_ind:
|
| 230 |
+
vert_edge_outgoing[tail_ind].append(iE)
|
| 231 |
+
|
| 232 |
+
# Build local inversion matrix for each vertex
|
| 233 |
+
row_inds = []
|
| 234 |
+
col_inds = []
|
| 235 |
+
data_vals = []
|
| 236 |
+
eps_reg = 1e-5
|
| 237 |
+
for iV in range(N):
|
| 238 |
+
n_neigh = len(vert_edge_outgoing[iV])
|
| 239 |
+
|
| 240 |
+
lhs_mat = np.zeros((n_neigh, 2))
|
| 241 |
+
rhs_mat = np.zeros((n_neigh, n_neigh + 1))
|
| 242 |
+
ind_lookup = [iV]
|
| 243 |
+
for i_neigh in range(n_neigh):
|
| 244 |
+
iE = vert_edge_outgoing[iV][i_neigh]
|
| 245 |
+
jV = edges_np[1, iE]
|
| 246 |
+
ind_lookup.append(jV)
|
| 247 |
+
|
| 248 |
+
edge_vec = edge_tangent_vectors[iE][:]
|
| 249 |
+
w_e = 1.
|
| 250 |
+
|
| 251 |
+
lhs_mat[i_neigh][:] = w_e * edge_vec
|
| 252 |
+
rhs_mat[i_neigh][0] = w_e * (-1)
|
| 253 |
+
rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
|
| 254 |
+
|
| 255 |
+
lhs_T = lhs_mat.T
|
| 256 |
+
lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
|
| 257 |
+
|
| 258 |
+
sol_mat = lhs_inv @ rhs_mat
|
| 259 |
+
sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
|
| 260 |
+
|
| 261 |
+
for i_neigh in range(n_neigh + 1):
|
| 262 |
+
i_glob = ind_lookup[i_neigh]
|
| 263 |
+
|
| 264 |
+
row_inds.append(iV)
|
| 265 |
+
col_inds.append(i_glob)
|
| 266 |
+
data_vals.append(sol_coefs[i_neigh])
|
| 267 |
+
|
| 268 |
+
# build the sparse matrix
|
| 269 |
+
row_inds = np.array(row_inds)
|
| 270 |
+
col_inds = np.array(col_inds)
|
| 271 |
+
data_vals = np.array(data_vals)
|
| 272 |
+
mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
|
| 273 |
+
|
| 274 |
+
return mat
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_operators(verts, faces, k_eig, normals=None, normalize_desc=True):
|
| 278 |
+
"""
|
| 279 |
+
Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
|
| 280 |
+
|
| 281 |
+
See get_operators() for a similar routine that wraps this one with a layer of caching.
|
| 282 |
+
|
| 283 |
+
Torch in / torch out.
|
| 284 |
+
|
| 285 |
+
Arguments:
|
| 286 |
+
- vertices: (V,3) vertex positions
|
| 287 |
+
- faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
|
| 288 |
+
- k_eig: number of eigenvectors to use
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
- frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
|
| 292 |
+
- massvec: (V) real diagonal of lumped mass matrix
|
| 293 |
+
- L: (VxV) real sparse matrix of (weak) Laplacian
|
| 294 |
+
- evals: (k) list of eigenvalues of the Laplacian
|
| 295 |
+
- evecs: (V,k) list of eigenvectors of the Laplacian
|
| 296 |
+
- gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
|
| 297 |
+
- gradY: same as gradX but for Y-component of gradient
|
| 298 |
+
|
| 299 |
+
PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
|
| 300 |
+
|
| 301 |
+
Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
|
| 302 |
+
"""
|
| 303 |
+
device = verts.device
|
| 304 |
+
dtype = verts.dtype
|
| 305 |
+
V = verts.shape[0]
|
| 306 |
+
is_cloud = faces.numel() == 0
|
| 307 |
+
|
| 308 |
+
eps = 1e-8
|
| 309 |
+
|
| 310 |
+
verts_np = toNP(verts).astype(np.float64)
|
| 311 |
+
faces_np = toNP(faces)
|
| 312 |
+
frames = build_tangent_frames(verts, faces, normals=normals)
|
| 313 |
+
frames_np = toNP(frames)
|
| 314 |
+
|
| 315 |
+
# Build the scalar Laplacian
|
| 316 |
+
if is_cloud:
|
| 317 |
+
L, M = robust_laplacian.point_cloud_laplacian(verts_np)
|
| 318 |
+
massvec_np = M.diagonal()
|
| 319 |
+
else:
|
| 320 |
+
# L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
|
| 321 |
+
# massvec_np = M.diagonal()
|
| 322 |
+
L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
|
| 323 |
+
massvec_np = pp3d.vertex_areas(verts_np, faces_np)
|
| 324 |
+
massvec_np += eps * np.mean(massvec_np)
|
| 325 |
+
|
| 326 |
+
if (np.isnan(L.data).any()):
|
| 327 |
+
raise RuntimeError("NaN Laplace matrix")
|
| 328 |
+
if (np.isnan(massvec_np).any()):
|
| 329 |
+
raise RuntimeError("NaN mass matrix")
|
| 330 |
+
|
| 331 |
+
# Read off neighbors & rotations from the Laplacian
|
| 332 |
+
L_coo = L.tocoo()
|
| 333 |
+
inds_row = L_coo.row
|
| 334 |
+
inds_col = L_coo.col
|
| 335 |
+
|
| 336 |
+
# === Compute the eigenbasis
|
| 337 |
+
if k_eig > 0:
|
| 338 |
+
|
| 339 |
+
# Prepare matrices
|
| 340 |
+
L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
|
| 341 |
+
massvec_eigsh = massvec_np
|
| 342 |
+
Mmat = scipy.sparse.diags(massvec_eigsh)
|
| 343 |
+
eigs_sigma = eps
|
| 344 |
+
|
| 345 |
+
failcount = 0
|
| 346 |
+
while True:
|
| 347 |
+
try:
|
| 348 |
+
# We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
|
| 349 |
+
evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
|
| 350 |
+
|
| 351 |
+
# Clip off any eigenvalues that end up slightly negative due to numerical weirdness
|
| 352 |
+
evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
|
| 353 |
+
|
| 354 |
+
break
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(e)
|
| 357 |
+
if (failcount > 3):
|
| 358 |
+
raise ValueError("failed to compute eigendecomp")
|
| 359 |
+
failcount += 1
|
| 360 |
+
print("--- decomp failed; adding eps ===> count: " + str(failcount))
|
| 361 |
+
L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
|
| 362 |
+
|
| 363 |
+
else: #k_eig == 0
|
| 364 |
+
evals_np = np.zeros((0))
|
| 365 |
+
evecs_np = np.zeros((verts.shape[0], 0))
|
| 366 |
+
|
| 367 |
+
# == Build gradient matrices
|
| 368 |
+
|
| 369 |
+
# For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
|
| 370 |
+
if is_cloud:
|
| 371 |
+
grad_mat_np = build_grad_point_cloud(verts, frames)
|
| 372 |
+
else:
|
| 373 |
+
edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
|
| 374 |
+
edge_vecs = edge_tangent_vectors(verts, frames, edges)
|
| 375 |
+
grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
|
| 376 |
+
|
| 377 |
+
# Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
|
| 378 |
+
gradX_np = np.real(grad_mat_np)
|
| 379 |
+
gradY_np = np.imag(grad_mat_np)
|
| 380 |
+
|
| 381 |
+
# === Convert back to torch
|
| 382 |
+
massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
|
| 383 |
+
L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 384 |
+
evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
|
| 385 |
+
evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
|
| 386 |
+
gradX = sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
|
| 387 |
+
gradY = sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
|
| 388 |
+
|
| 389 |
+
hks = torch.from_numpy(compute_hks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
|
| 390 |
+
wks = torch.from_numpy(compute_wks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
return frames, massvec, L, evals, evecs, gradX, gradY, hks, wks
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
|
| 397 |
+
N = len(verts_list)
|
| 398 |
+
|
| 399 |
+
frames = [None] * N
|
| 400 |
+
massvec = [None] * N
|
| 401 |
+
L = [None] * N
|
| 402 |
+
evals = [None] * N
|
| 403 |
+
evecs = [None] * N
|
| 404 |
+
gradX = [None] * N
|
| 405 |
+
gradY = [None] * N
|
| 406 |
+
|
| 407 |
+
inds = [i for i in range(N)]
|
| 408 |
+
# process in random order
|
| 409 |
+
# random.shuffle(inds)
|
| 410 |
+
|
| 411 |
+
for num, i in enumerate(inds):
|
| 412 |
+
print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
|
| 413 |
+
if normals is None:
|
| 414 |
+
outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
|
| 415 |
+
else:
|
| 416 |
+
outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
|
| 417 |
+
frames[i] = outputs[0]
|
| 418 |
+
massvec[i] = outputs[1]
|
| 419 |
+
L[i] = outputs[2]
|
| 420 |
+
evals[i] = outputs[3]
|
| 421 |
+
evecs[i] = outputs[4]
|
| 422 |
+
gradX[i] = outputs[5]
|
| 423 |
+
gradY[i] = outputs[6]
|
| 424 |
+
|
| 425 |
+
return frames, massvec, L, evals, evecs, gradX, gradY
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def get_operators(verts, faces, k_eig=128, cache_path=None, normals=None, overwrite_cache=False):
|
| 429 |
+
"""
|
| 430 |
+
See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
|
| 431 |
+
All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
|
| 432 |
+
"""
|
| 433 |
+
if type(verts) == torch.Tensor:
|
| 434 |
+
device = verts.device
|
| 435 |
+
dtype = verts.dtype
|
| 436 |
+
verts_np = toNP(verts)
|
| 437 |
+
else:
|
| 438 |
+
device = "cpu"
|
| 439 |
+
dtype = torch.float32
|
| 440 |
+
verts_np = verts.copy()
|
| 441 |
+
verts = torch.from_numpy(verts).float()
|
| 442 |
+
if type(faces) == torch.Tensor:
|
| 443 |
+
faces_np = toNP(faces)
|
| 444 |
+
else:
|
| 445 |
+
faces_np = faces.copy()
|
| 446 |
+
faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
|
| 447 |
+
is_cloud = faces.numel() == 0
|
| 448 |
+
|
| 449 |
+
if (np.isnan(verts_np).any()):
|
| 450 |
+
raise RuntimeError("tried to construct operators from NaN verts")
|
| 451 |
+
|
| 452 |
+
# Check the cache directory
|
| 453 |
+
# Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
|
| 454 |
+
# but for good measure we check values nonetheless.
|
| 455 |
+
# Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
|
| 456 |
+
# entries in this cache. The good news is that that is totally fine, and at most slightly
|
| 457 |
+
# slows performance with rare extra cache misses.
|
| 458 |
+
found = False
|
| 459 |
+
|
| 460 |
+
if cache_path is not None:
|
| 461 |
+
op_cache_dir = os.path.dirname(cache_path)
|
| 462 |
+
ensure_dir_exists(op_cache_dir)
|
| 463 |
+
# print("Building operators for input with hash: " + hash_key_str)
|
| 464 |
+
|
| 465 |
+
# Search through buckets with matching hashes. When the loop exits, this
|
| 466 |
+
# is the bucket index of the file we should write to.
|
| 467 |
+
i_cache_search = 0
|
| 468 |
+
while True:
|
| 469 |
+
try:
|
| 470 |
+
# print('loading path: ' + str(search_path))
|
| 471 |
+
npzfile = np.load(cache_path, allow_pickle=True)
|
| 472 |
+
cache_verts = npzfile["verts"]
|
| 473 |
+
cache_faces = npzfile["faces"]
|
| 474 |
+
cache_k_eig = npzfile["k_eig"].item()
|
| 475 |
+
|
| 476 |
+
# If the cache doesn't match, keep looking
|
| 477 |
+
if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
|
| 478 |
+
i_cache_search += 1
|
| 479 |
+
print("hash collision! searching next.")
|
| 480 |
+
continue
|
| 481 |
+
|
| 482 |
+
# print(" cache hit!")
|
| 483 |
+
|
| 484 |
+
# If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
|
| 485 |
+
# entry below more eigenvalues
|
| 486 |
+
if overwrite_cache:
|
| 487 |
+
print(" overwriting cache by request")
|
| 488 |
+
os.remove(cache_path)
|
| 489 |
+
break
|
| 490 |
+
|
| 491 |
+
if cache_k_eig < k_eig:
|
| 492 |
+
print(" overwriting cache --- not enough eigenvalues")
|
| 493 |
+
os.remove(cache_path)
|
| 494 |
+
break
|
| 495 |
+
|
| 496 |
+
if "L_data" not in npzfile:
|
| 497 |
+
print(" overwriting cache --- entries are absent")
|
| 498 |
+
os.remove(cache_path)
|
| 499 |
+
break
|
| 500 |
+
|
| 501 |
+
def read_sp_mat(prefix):
|
| 502 |
+
data = npzfile[prefix + "_data"]
|
| 503 |
+
indices = npzfile[prefix + "_indices"]
|
| 504 |
+
indptr = npzfile[prefix + "_indptr"]
|
| 505 |
+
shape = npzfile[prefix + "_shape"]
|
| 506 |
+
mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
|
| 507 |
+
return mat
|
| 508 |
+
|
| 509 |
+
# This entry matches! Return it.
|
| 510 |
+
frames = npzfile["frames"]
|
| 511 |
+
mass = npzfile["mass"]
|
| 512 |
+
L = read_sp_mat("L")
|
| 513 |
+
evals = npzfile["evals"][:k_eig]
|
| 514 |
+
evecs = npzfile["evecs"][:, :k_eig]
|
| 515 |
+
gradX = read_sp_mat("gradX")
|
| 516 |
+
gradY = read_sp_mat("gradY")
|
| 517 |
+
frames = npzfile["hks"]
|
| 518 |
+
mass = npzfile["wks"]
|
| 519 |
+
|
| 520 |
+
frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
|
| 521 |
+
mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
|
| 522 |
+
L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 523 |
+
evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
|
| 524 |
+
evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
|
| 525 |
+
gradX = sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
|
| 526 |
+
gradY = sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
|
| 527 |
+
hks = torch.from_numpy(hks).to(device=device, dtype=dtype)
|
| 528 |
+
wks = torch.from_numpy(wks).to(device=device, dtype=dtype)
|
| 529 |
+
|
| 530 |
+
found = True
|
| 531 |
+
|
| 532 |
+
break
|
| 533 |
+
|
| 534 |
+
except FileNotFoundError:
|
| 535 |
+
print(" cache miss -- constructing operators")
|
| 536 |
+
break
|
| 537 |
+
|
| 538 |
+
except Exception as E:
|
| 539 |
+
print("unexpected error loading file: " + str(E))
|
| 540 |
+
print("-- constructing operators")
|
| 541 |
+
break
|
| 542 |
+
|
| 543 |
+
if not found:
|
| 544 |
+
|
| 545 |
+
# No matching entry found; recompute.
|
| 546 |
+
frames, mass, L, evals, evecs, gradX, gradY, hks, wks = compute_operators(verts, faces, k_eig, normals=normals)
|
| 547 |
+
|
| 548 |
+
dtype_np = np.float32
|
| 549 |
+
|
| 550 |
+
# Store it in the cache
|
| 551 |
+
if op_cache_dir is not None:
|
| 552 |
+
|
| 553 |
+
L_np = sparse_torch_to_np(L).astype(dtype_np)
|
| 554 |
+
gradX_np = sparse_torch_to_np(gradX).astype(dtype_np)
|
| 555 |
+
gradY_np = sparse_torch_to_np(gradY).astype(dtype_np)
|
| 556 |
+
|
| 557 |
+
np.savez(
|
| 558 |
+
cache_path,
|
| 559 |
+
verts=verts_np.astype(dtype_np),
|
| 560 |
+
frames=toNP(frames).astype(dtype_np),
|
| 561 |
+
faces=faces_np,
|
| 562 |
+
k_eig=k_eig,
|
| 563 |
+
mass=toNP(mass).astype(dtype_np),
|
| 564 |
+
L_data=L_np.data.astype(dtype_np),
|
| 565 |
+
L_indices=L_np.indices,
|
| 566 |
+
L_indptr=L_np.indptr,
|
| 567 |
+
L_shape=L_np.shape,
|
| 568 |
+
evals=toNP(evals).astype(dtype_np),
|
| 569 |
+
evecs=toNP(evecs).astype(dtype_np),
|
| 570 |
+
gradX_data=gradX_np.data.astype(dtype_np),
|
| 571 |
+
gradX_indices=gradX_np.indices,
|
| 572 |
+
gradX_indptr=gradX_np.indptr,
|
| 573 |
+
gradX_shape=gradX_np.shape,
|
| 574 |
+
gradY_data=gradY_np.data.astype(dtype_np),
|
| 575 |
+
gradY_indices=gradY_np.indices,
|
| 576 |
+
gradY_indptr=gradY_np.indptr,
|
| 577 |
+
gradY_shape=gradY_np.shape,
|
| 578 |
+
hks=hks,
|
| 579 |
+
wks=wks
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
return frames, mass, L, evals, evecs, gradX, gradY, hks, wks
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def load_operators(filepath):
|
| 586 |
+
npzfile = np.load(filepath, allow_pickle=False)
|
| 587 |
+
|
| 588 |
+
def read_sp_mat(prefix):
|
| 589 |
+
data = npzfile[prefix + '_data']
|
| 590 |
+
indices = npzfile[prefix + '_indices']
|
| 591 |
+
indptr = npzfile[prefix + '_indptr']
|
| 592 |
+
shape = npzfile[prefix + '_shape']
|
| 593 |
+
mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
|
| 594 |
+
return mat
|
| 595 |
+
if 'verts' in npzfile:
|
| 596 |
+
keyverts = 'verts'
|
| 597 |
+
else:
|
| 598 |
+
keyverts = 'vertices'
|
| 599 |
+
return dict(
|
| 600 |
+
vertices=npzfile[keyverts],
|
| 601 |
+
faces=npzfile['faces'],
|
| 602 |
+
frames=npzfile['frames'],
|
| 603 |
+
mass=npzfile['mass'],
|
| 604 |
+
L=read_sp_mat('L'),
|
| 605 |
+
evals=npzfile['evals'],
|
| 606 |
+
evecs=npzfile['evecs'],
|
| 607 |
+
gradX=read_sp_mat('gradX'),
|
| 608 |
+
gradY=read_sp_mat('gradY'),
|
| 609 |
+
hks=npzfile['hks'],
|
| 610 |
+
wks=npzfile['wks'],
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def compute_operators_small(verts, faces, k_eig):
|
| 615 |
+
"""
|
| 616 |
+
Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
|
| 617 |
+
|
| 618 |
+
See get_operators() for a similar routine that wraps this one with a layer of caching.
|
| 619 |
+
|
| 620 |
+
Torch in / torch out.
|
| 621 |
+
|
| 622 |
+
Arguments:
|
| 623 |
+
- vertices: (V,3) vertex positions
|
| 624 |
+
- faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
|
| 625 |
+
- k_eig: number of eigenvectors to use
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
- massvec: (V) real diagonal of lumped mass matrix
|
| 629 |
+
- L: (VxV) real sparse matrix of (weak) Laplacian
|
| 630 |
+
- evals: (k) list of eigenvalues of the Laplacian
|
| 631 |
+
- evecs: (V,k) list of eigenvectors of the Laplacian
|
| 632 |
+
- gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
|
| 633 |
+
- gradY: same as gradX but for Y-component of gradient
|
| 634 |
+
|
| 635 |
+
PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
|
| 636 |
+
|
| 637 |
+
Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
device = verts.device
|
| 641 |
+
dtype = verts.dtype
|
| 642 |
+
V = verts.shape[0]
|
| 643 |
+
is_cloud = faces.numel() == 0
|
| 644 |
+
|
| 645 |
+
eps = 1e-8
|
| 646 |
+
|
| 647 |
+
verts_np = toNP(verts).astype(np.float64)
|
| 648 |
+
faces_np = toNP(faces)
|
| 649 |
+
# Build the scalar Laplacian
|
| 650 |
+
if is_cloud:
|
| 651 |
+
L, M = robust_laplacian.point_cloud_laplacian(verts_np)
|
| 652 |
+
massvec_np = M.diagonal()
|
| 653 |
+
else:
|
| 654 |
+
# L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
|
| 655 |
+
# massvec_np = M.diagonal()
|
| 656 |
+
L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
|
| 657 |
+
massvec_np = pp3d.vertex_areas(verts_np, faces_np)
|
| 658 |
+
massvec_np += eps * np.mean(massvec_np)
|
| 659 |
+
|
| 660 |
+
if (np.isnan(L.data).any()):
|
| 661 |
+
raise RuntimeError("NaN Laplace matrix")
|
| 662 |
+
if (np.isnan(massvec_np).any()):
|
| 663 |
+
raise RuntimeError("NaN mass matrix")
|
| 664 |
+
|
| 665 |
+
# Read off neighbors & rotations from the Laplacian
|
| 666 |
+
L_coo = L.tocoo()
|
| 667 |
+
inds_row = L_coo.row
|
| 668 |
+
inds_col = L_coo.col
|
| 669 |
+
|
| 670 |
+
# === Compute the eigenbasis
|
| 671 |
+
if k_eig > 0:
|
| 672 |
+
|
| 673 |
+
# Prepare matrices
|
| 674 |
+
L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
|
| 675 |
+
massvec_eigsh = massvec_np
|
| 676 |
+
Mmat = scipy.sparse.diags(massvec_eigsh)
|
| 677 |
+
eigs_sigma = eps
|
| 678 |
+
|
| 679 |
+
failcount = 0
|
| 680 |
+
while True:
|
| 681 |
+
try:
|
| 682 |
+
# We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
|
| 683 |
+
evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
|
| 684 |
+
|
| 685 |
+
# Clip off any eigenvalues that end up slightly negative due to numerical weirdness
|
| 686 |
+
evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
|
| 687 |
+
|
| 688 |
+
break
|
| 689 |
+
except Exception as e:
|
| 690 |
+
print(e)
|
| 691 |
+
if (failcount > 3):
|
| 692 |
+
raise ValueError("failed to compute eigendecomp")
|
| 693 |
+
failcount += 1
|
| 694 |
+
print("--- decomp failed; adding eps ===> count: " + str(failcount))
|
| 695 |
+
L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
|
| 696 |
+
|
| 697 |
+
else: #k_eig == 0
|
| 698 |
+
evals_np = np.zeros((0))
|
| 699 |
+
evecs_np = np.zeros((verts.shape[0], 0))
|
| 700 |
+
# Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
|
| 701 |
+
|
| 702 |
+
# === Convert back to torch
|
| 703 |
+
massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
|
| 704 |
+
L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 705 |
+
evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
|
| 706 |
+
evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
|
| 707 |
+
|
| 708 |
+
return massvec, L, evals, evecs
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def get_operators_small(verts, faces, k_eig=128, cache_path=None, overwrite_cache=False):
|
| 712 |
+
"""
|
| 713 |
+
See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
|
| 714 |
+
All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
|
| 715 |
+
"""
|
| 716 |
+
if type(verts) == torch.Tensor:
|
| 717 |
+
device = verts.device
|
| 718 |
+
dtype = verts.dtype
|
| 719 |
+
verts_np = toNP(verts)
|
| 720 |
+
else:
|
| 721 |
+
device = "cpu"
|
| 722 |
+
dtype = torch.float32
|
| 723 |
+
verts_np = verts.copy()
|
| 724 |
+
verts = torch.from_numpy(verts).float()
|
| 725 |
+
if type(faces) == torch.Tensor:
|
| 726 |
+
faces_np = toNP(faces)
|
| 727 |
+
else:
|
| 728 |
+
faces_np = faces.copy()
|
| 729 |
+
faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
|
| 730 |
+
is_cloud = faces.numel() == 0
|
| 731 |
+
|
| 732 |
+
if (np.isnan(verts_np).any()):
|
| 733 |
+
raise RuntimeError("tried to construct operators from NaN verts")
|
| 734 |
+
|
| 735 |
+
# Check the cache directory
|
| 736 |
+
# Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
|
| 737 |
+
# but for good measure we check values nonetheless.
|
| 738 |
+
# Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
|
| 739 |
+
# entries in this cache. The good news is that that is totally fine, and at most slightly
|
| 740 |
+
# slows performance with rare extra cache misses.
|
| 741 |
+
found = False
|
| 742 |
+
|
| 743 |
+
if cache_path is not None:
|
| 744 |
+
op_cache_dir = os.path.dirname(cache_path)
|
| 745 |
+
ensure_dir_exists(op_cache_dir)
|
| 746 |
+
# print("Building operators for input with hash: " + hash_key_str)
|
| 747 |
+
|
| 748 |
+
# Search through buckets with matching hashes. When the loop exits, this
|
| 749 |
+
# is the bucket index of the file we should write to.
|
| 750 |
+
i_cache_search = 0
|
| 751 |
+
while True:
|
| 752 |
+
try:
|
| 753 |
+
# print('loading path: ' + str(search_path))
|
| 754 |
+
npzfile = np.load(cache_path, allow_pickle=True)
|
| 755 |
+
cache_verts = npzfile["verts"]
|
| 756 |
+
cache_faces = npzfile["faces"]
|
| 757 |
+
cache_k_eig = npzfile["k_eig"].item()
|
| 758 |
+
|
| 759 |
+
# If the cache doesn't match, keep looking
|
| 760 |
+
if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
|
| 761 |
+
i_cache_search += 1
|
| 762 |
+
print("hash collision! searching next.")
|
| 763 |
+
continue
|
| 764 |
+
|
| 765 |
+
# print(" cache hit!")
|
| 766 |
+
|
| 767 |
+
# If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
|
| 768 |
+
# entry below more eigenvalues
|
| 769 |
+
if overwrite_cache:
|
| 770 |
+
print(" overwriting cache by request")
|
| 771 |
+
os.remove(cache_path)
|
| 772 |
+
break
|
| 773 |
+
|
| 774 |
+
if cache_k_eig < k_eig:
|
| 775 |
+
print(" overwriting cache --- not enough eigenvalues")
|
| 776 |
+
os.remove(cache_path)
|
| 777 |
+
break
|
| 778 |
+
|
| 779 |
+
if "L_data" not in npzfile:
|
| 780 |
+
print(" overwriting cache --- entries are absent")
|
| 781 |
+
os.remove(cache_path)
|
| 782 |
+
break
|
| 783 |
+
|
| 784 |
+
def read_sp_mat(prefix):
|
| 785 |
+
data = npzfile[prefix + "_data"]
|
| 786 |
+
indices = npzfile[prefix + "_indices"]
|
| 787 |
+
indptr = npzfile[prefix + "_indptr"]
|
| 788 |
+
shape = npzfile[prefix + "_shape"]
|
| 789 |
+
mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
|
| 790 |
+
return mat
|
| 791 |
+
|
| 792 |
+
# This entry matches! Return it.
|
| 793 |
+
mass = npzfile["mass"]
|
| 794 |
+
L = read_sp_mat("L")
|
| 795 |
+
evals = npzfile["evals"][:k_eig]
|
| 796 |
+
evecs = npzfile["evecs"][:, :k_eig]
|
| 797 |
+
mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
|
| 798 |
+
L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
|
| 799 |
+
evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
|
| 800 |
+
evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
|
| 801 |
+
|
| 802 |
+
found = True
|
| 803 |
+
|
| 804 |
+
break
|
| 805 |
+
|
| 806 |
+
except FileNotFoundError:
|
| 807 |
+
print(cache_path)
|
| 808 |
+
print(" cache miss -- constructing operators")
|
| 809 |
+
break
|
| 810 |
+
|
| 811 |
+
except Exception as E:
|
| 812 |
+
print("unexpected error loading file: " + str(E))
|
| 813 |
+
print("-- constructing operators")
|
| 814 |
+
break
|
| 815 |
+
|
| 816 |
+
if not found:
|
| 817 |
+
|
| 818 |
+
# No matching entry found; recompute.
|
| 819 |
+
mass, L, evals, evecs, = compute_operators_small(verts, faces, k_eig)
|
| 820 |
+
|
| 821 |
+
dtype_np = np.float32
|
| 822 |
+
|
| 823 |
+
# Store it in the cache
|
| 824 |
+
if op_cache_dir is not None:
|
| 825 |
+
|
| 826 |
+
L_np = sparse_torch_to_np(L).astype(dtype_np)
|
| 827 |
+
|
| 828 |
+
np.savez(
|
| 829 |
+
cache_path,
|
| 830 |
+
verts=verts_np.astype(dtype_np),
|
| 831 |
+
faces=faces_np,
|
| 832 |
+
k_eig=k_eig,
|
| 833 |
+
mass=toNP(mass).astype(dtype_np),
|
| 834 |
+
L_data=L_np.data.astype(dtype_np),
|
| 835 |
+
L_indices=L_np.indices,
|
| 836 |
+
L_indptr=L_np.indptr,
|
| 837 |
+
L_shape=L_np.shape,
|
| 838 |
+
evals=toNP(evals).astype(dtype_np),
|
| 839 |
+
evecs=toNP(evecs).astype(dtype_np),
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
return mass, L, evals, evecs
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def to_basis(values, basis, massvec):
|
| 846 |
+
"""
|
| 847 |
+
Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
|
| 848 |
+
Inputs:
|
| 849 |
+
- values: (B,V,D)
|
| 850 |
+
- basis: (B,V,K)
|
| 851 |
+
- massvec: (B,V)
|
| 852 |
+
Outputs:
|
| 853 |
+
- (B,K,D) transformed values
|
| 854 |
+
"""
|
| 855 |
+
basisT = basis.transpose(-2, -1)
|
| 856 |
+
return torch.matmul(basisT, values * massvec.unsqueeze(-1))
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
|
| 861 |
+
# center and unit-scale positions
|
| 862 |
+
|
| 863 |
+
if method == 'mean':
|
| 864 |
+
# center using the average point position
|
| 865 |
+
pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
|
| 866 |
+
elif method == 'bbox':
|
| 867 |
+
# center via the middle of the axis-aligned bounding box
|
| 868 |
+
bbox_min = torch.min(pos, dim=-2).values
|
| 869 |
+
bbox_max = torch.max(pos, dim=-2).values
|
| 870 |
+
center = (bbox_max + bbox_min) / 2.
|
| 871 |
+
pos -= center.unsqueeze(-2)
|
| 872 |
+
else:
|
| 873 |
+
raise ValueError("unrecognized method")
|
| 874 |
+
|
| 875 |
+
if scale_method == 'max_rad':
|
| 876 |
+
scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
|
| 877 |
+
pos = pos / scale
|
| 878 |
+
elif scale_method == 'area':
|
| 879 |
+
if faces is None:
|
| 880 |
+
raise ValueError("must pass faces for area normalization")
|
| 881 |
+
coords = pos[faces]
|
| 882 |
+
vec_A = coords[:, 1, :] - coords[:, 0, :]
|
| 883 |
+
vec_B = coords[:, 2, :] - coords[:, 0, :]
|
| 884 |
+
face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
|
| 885 |
+
total_area = torch.sum(face_areas)
|
| 886 |
+
scale = (1. / torch.sqrt(total_area))
|
| 887 |
+
pos = pos * scale
|
| 888 |
+
else:
|
| 889 |
+
raise ValueError("unrecognized scale method")
|
| 890 |
+
return pos
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
# Finds the k nearest neighbors of source on target.
|
| 894 |
+
# Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
|
| 895 |
+
def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
|
| 896 |
+
|
| 897 |
+
if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
|
| 898 |
+
raise ValueError("omit_diagonal can only be used when source and target are same shape")
|
| 899 |
+
|
| 900 |
+
if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
|
| 901 |
+
method = 'cpu_kd'
|
| 902 |
+
print("switching to cpu_kd knn")
|
| 903 |
+
|
| 904 |
+
if method == 'brute':
|
| 905 |
+
|
| 906 |
+
# Expand so both are NxMx3 tensor
|
| 907 |
+
points_source_expand = points_source.unsqueeze(1)
|
| 908 |
+
points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
|
| 909 |
+
points_target_expand = points_target.unsqueeze(0)
|
| 910 |
+
points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
|
| 911 |
+
|
| 912 |
+
diff_mat = points_source_expand - points_target_expand
|
| 913 |
+
dist_mat = norm(diff_mat)
|
| 914 |
+
|
| 915 |
+
if omit_diagonal:
|
| 916 |
+
torch.diagonal(dist_mat)[:] = float('inf')
|
| 917 |
+
|
| 918 |
+
result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
|
| 919 |
+
return result
|
| 920 |
+
|
| 921 |
+
elif method == 'cpu_kd':
|
| 922 |
+
|
| 923 |
+
if largest:
|
| 924 |
+
raise ValueError("can't do largest with cpu_kd")
|
| 925 |
+
|
| 926 |
+
points_source_np = toNP(points_source)
|
| 927 |
+
points_target_np = toNP(points_target)
|
| 928 |
+
|
| 929 |
+
# Build the tree
|
| 930 |
+
kd_tree = sklearn.neighbors.KDTree(points_target_np)
|
| 931 |
+
|
| 932 |
+
k_search = k + 1 if omit_diagonal else k
|
| 933 |
+
_, neighbors = kd_tree.query(points_source_np, k=k_search)
|
| 934 |
+
|
| 935 |
+
if omit_diagonal:
|
| 936 |
+
# Mask out self element
|
| 937 |
+
mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
|
| 938 |
+
|
| 939 |
+
# make sure we mask out exactly one element in each row, in rare case of many duplicate points
|
| 940 |
+
mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
|
| 941 |
+
|
| 942 |
+
neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
|
| 943 |
+
|
| 944 |
+
inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
|
| 945 |
+
dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
|
| 946 |
+
|
| 947 |
+
return dists, inds
|
| 948 |
+
|
| 949 |
+
else:
|
| 950 |
+
raise ValueError("unrecognized method")
|
| 951 |
+
|
utils/io.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import re
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def is_number(s):
|
| 10 |
+
try:
|
| 11 |
+
float(s)
|
| 12 |
+
return True
|
| 13 |
+
except ValueError:
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def may_create_folder(folder_path):
|
| 18 |
+
if not osp.exists(folder_path):
|
| 19 |
+
oldmask = os.umask(000)
|
| 20 |
+
os.makedirs(folder_path, mode=0o777)
|
| 21 |
+
os.umask(oldmask)
|
| 22 |
+
return True
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_clean_folder(folder_path):
|
| 27 |
+
success = may_create_folder(folder_path)
|
| 28 |
+
if not success:
|
| 29 |
+
shutil.rmtree(folder_path)
|
| 30 |
+
may_create_folder(folder_path)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sorted_alphanum(file_list_ordered):
|
| 34 |
+
convert = lambda text: int(text) if text.isdigit() else text
|
| 35 |
+
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) if len(c) > 0]
|
| 36 |
+
return sorted(file_list_ordered, key=alphanum_key)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def list_files(folder_path, name_filter, alphanum_sort=False):
|
| 40 |
+
file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
|
| 41 |
+
if alphanum_sort:
|
| 42 |
+
return sorted_alphanum(file_list)
|
| 43 |
+
else:
|
| 44 |
+
return sorted(file_list)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def list_folders(folder_path, name_filter=None, alphanum_sort=False):
|
| 48 |
+
folders = list()
|
| 49 |
+
for subfolder in Path(folder_path).iterdir():
|
| 50 |
+
if subfolder.is_dir() and not subfolder.name.startswith('.'):
|
| 51 |
+
folder_name = subfolder.name
|
| 52 |
+
if name_filter is not None:
|
| 53 |
+
if name_filter in folder_name:
|
| 54 |
+
folders.append(folder_name)
|
| 55 |
+
else:
|
| 56 |
+
folders.append(folder_name)
|
| 57 |
+
if alphanum_sort:
|
| 58 |
+
return sorted_alphanum(folders)
|
| 59 |
+
else:
|
| 60 |
+
return sorted(folders)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def read_lines(file_path):
|
| 64 |
+
with open(file_path, 'r') as fin:
|
| 65 |
+
lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
|
| 66 |
+
return lines
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def read_strings(file_path):
|
| 70 |
+
with open(file_path, 'r') as fin:
|
| 71 |
+
ret = fin.readlines()
|
| 72 |
+
return ''.join(ret)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def read_json(filepath):
|
| 76 |
+
with open(filepath, 'r') as fh:
|
| 77 |
+
ret = json.load(fh)
|
| 78 |
+
return ret
|
utils/layers.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import scipy
|
| 7 |
+
import scipy.sparse.linalg as sla
|
| 8 |
+
# ^^^ we NEED to import scipy before torch, or it crashes :(
|
| 9 |
+
# (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../')
|
| 16 |
+
if ROOT_DIR not in sys.path:
|
| 17 |
+
sys.path.append(ROOT_DIR)
|
| 18 |
+
|
| 19 |
+
from diffusion_net.utils import toNP
|
| 20 |
+
from diffusion_net.geometry import to_basis, from_basis
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LearnedTimeDiffusion(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Applies diffusion with learned per-channel t.
|
| 26 |
+
|
| 27 |
+
In the spectral domain this becomes
|
| 28 |
+
f_out = e ^ (lambda_i t) f_in
|
| 29 |
+
|
| 30 |
+
Inputs:
|
| 31 |
+
- values: (V,C) in the spectral domain
|
| 32 |
+
- L: (V,V) sparse laplacian
|
| 33 |
+
- evals: (K) eigenvalues
|
| 34 |
+
- mass: (V) mass matrix diagonal
|
| 35 |
+
|
| 36 |
+
(note: L/evals may be omitted as None depending on method)
|
| 37 |
+
Outputs:
|
| 38 |
+
- (V,C) diffused values
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, C_inout, method='spectral'):
|
| 42 |
+
super(LearnedTimeDiffusion, self).__init__()
|
| 43 |
+
self.C_inout = C_inout
|
| 44 |
+
self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C)
|
| 45 |
+
self.method = method # one of ['spectral', 'implicit_dense']
|
| 46 |
+
|
| 47 |
+
nn.init.constant_(self.diffusion_time, 0.0)
|
| 48 |
+
|
| 49 |
+
def forward(self, x, L, mass, evals, evecs):
|
| 50 |
+
|
| 51 |
+
# project times to the positive halfspace
|
| 52 |
+
# (and away from 0 in the incredibly rare chance that they get stuck)
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
|
| 55 |
+
|
| 56 |
+
if x.shape[-1] != self.C_inout:
|
| 57 |
+
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
|
| 58 |
+
x.shape, self.C_inout))
|
| 59 |
+
|
| 60 |
+
if self.method == 'spectral':
|
| 61 |
+
|
| 62 |
+
# Transform to spectral
|
| 63 |
+
x_spec = to_basis(x, evecs, mass)
|
| 64 |
+
|
| 65 |
+
# Diffuse
|
| 66 |
+
time = self.diffusion_time
|
| 67 |
+
diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
|
| 68 |
+
x_diffuse_spec = diffusion_coefs * x_spec
|
| 69 |
+
|
| 70 |
+
# Transform back to per-vertex
|
| 71 |
+
x_diffuse = from_basis(x_diffuse_spec, evecs)
|
| 72 |
+
|
| 73 |
+
elif self.method == 'implicit_dense':
|
| 74 |
+
V = x.shape[-2]
|
| 75 |
+
|
| 76 |
+
# Form the dense matrices (M + tL) with dims (B,C,V,V)
|
| 77 |
+
mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
|
| 78 |
+
mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 79 |
+
mat_dense += torch.diag_embed(mass).unsqueeze(1)
|
| 80 |
+
|
| 81 |
+
# Factor the system
|
| 82 |
+
cholesky_factors = torch.linalg.cholesky(mat_dense)
|
| 83 |
+
|
| 84 |
+
# Solve the system
|
| 85 |
+
rhs = x * mass.unsqueeze(-1)
|
| 86 |
+
rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
|
| 87 |
+
sols = torch.cholesky_solve(rhsT, cholesky_factors)
|
| 88 |
+
x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError("unrecognized method")
|
| 92 |
+
|
| 93 |
+
return x_diffuse
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SpatialGradientFeatures(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
|
| 99 |
+
|
| 100 |
+
Input:
|
| 101 |
+
- vectors: (V,C,2)
|
| 102 |
+
Output:
|
| 103 |
+
- dots: (V,C) dots
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, C_inout, with_gradient_rotations=True):
|
| 107 |
+
super(SpatialGradientFeatures, self).__init__()
|
| 108 |
+
|
| 109 |
+
self.C_inout = C_inout
|
| 110 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 111 |
+
|
| 112 |
+
if (self.with_gradient_rotations):
|
| 113 |
+
self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 114 |
+
self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 115 |
+
else:
|
| 116 |
+
self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
|
| 117 |
+
|
| 118 |
+
# self.norm = nn.InstanceNorm1d(C_inout)
|
| 119 |
+
|
| 120 |
+
def forward(self, vectors):
|
| 121 |
+
|
| 122 |
+
vectorsA = vectors # (V,C)
|
| 123 |
+
|
| 124 |
+
if self.with_gradient_rotations:
|
| 125 |
+
vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
|
| 126 |
+
vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
|
| 127 |
+
else:
|
| 128 |
+
vectorsBreal = self.A(vectors[..., 0])
|
| 129 |
+
vectorsBimag = self.A(vectors[..., 1])
|
| 130 |
+
|
| 131 |
+
dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
|
| 132 |
+
|
| 133 |
+
return torch.tanh(dots)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class MiniMLP(nn.Sequential):
|
| 137 |
+
'''
|
| 138 |
+
A simple MLP with configurable hidden layer sizes.
|
| 139 |
+
'''
|
| 140 |
+
|
| 141 |
+
def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
|
| 142 |
+
super(MiniMLP, self).__init__()
|
| 143 |
+
|
| 144 |
+
for i in range(len(layer_sizes) - 1):
|
| 145 |
+
is_last = (i + 2 == len(layer_sizes))
|
| 146 |
+
|
| 147 |
+
if dropout and i > 0:
|
| 148 |
+
self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
|
| 149 |
+
|
| 150 |
+
# Affine map
|
| 151 |
+
self.add_module(
|
| 152 |
+
name + "_mlp_layer_{:03d}".format(i),
|
| 153 |
+
nn.Linear(
|
| 154 |
+
layer_sizes[i],
|
| 155 |
+
layer_sizes[i + 1],
|
| 156 |
+
),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Nonlinearity
|
| 160 |
+
# (but not on the last layer)
|
| 161 |
+
if not is_last:
|
| 162 |
+
self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DiffusionNetBlock(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
Inputs and outputs are defined at vertices
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self,
|
| 171 |
+
C_width,
|
| 172 |
+
mlp_hidden_dims,
|
| 173 |
+
dropout=True,
|
| 174 |
+
diffusion_method='spectral',
|
| 175 |
+
with_gradient_features=True,
|
| 176 |
+
with_gradient_rotations=True):
|
| 177 |
+
super(DiffusionNetBlock, self).__init__()
|
| 178 |
+
|
| 179 |
+
# Specified dimensions
|
| 180 |
+
self.C_width = C_width
|
| 181 |
+
self.mlp_hidden_dims = mlp_hidden_dims
|
| 182 |
+
|
| 183 |
+
self.dropout = dropout
|
| 184 |
+
self.with_gradient_features = with_gradient_features
|
| 185 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 186 |
+
|
| 187 |
+
# Diffusion block
|
| 188 |
+
self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
|
| 189 |
+
|
| 190 |
+
self.MLP_C = 2 * self.C_width
|
| 191 |
+
|
| 192 |
+
if self.with_gradient_features:
|
| 193 |
+
self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
|
| 194 |
+
self.MLP_C += self.C_width
|
| 195 |
+
|
| 196 |
+
# MLPs
|
| 197 |
+
self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
|
| 198 |
+
|
| 199 |
+
def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
|
| 200 |
+
|
| 201 |
+
# Manage dimensions
|
| 202 |
+
B = x_in.shape[0] # batch dimension
|
| 203 |
+
if x_in.shape[-1] != self.C_width:
|
| 204 |
+
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
|
| 205 |
+
x_in.shape, self.C_width))
|
| 206 |
+
|
| 207 |
+
# Diffusion block
|
| 208 |
+
x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
|
| 209 |
+
|
| 210 |
+
# Compute gradient features, if using
|
| 211 |
+
if self.with_gradient_features:
|
| 212 |
+
|
| 213 |
+
# Compute gradients
|
| 214 |
+
x_grads = [
|
| 215 |
+
] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
|
| 216 |
+
for b in range(B):
|
| 217 |
+
# gradient after diffusion
|
| 218 |
+
x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
|
| 219 |
+
x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
|
| 220 |
+
|
| 221 |
+
x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
|
| 222 |
+
x_grad = torch.stack(x_grads, dim=0)
|
| 223 |
+
|
| 224 |
+
# Evaluate gradient features
|
| 225 |
+
x_grad_features = self.gradient_features(x_grad)
|
| 226 |
+
|
| 227 |
+
# Stack inputs to mlp
|
| 228 |
+
feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
|
| 229 |
+
else:
|
| 230 |
+
# Stack inputs to mlp
|
| 231 |
+
feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
|
| 232 |
+
|
| 233 |
+
# Apply the mlp
|
| 234 |
+
x0_out = self.mlp(feature_combined)
|
| 235 |
+
|
| 236 |
+
# Skip connection
|
| 237 |
+
x0_out = x0_out + x_in
|
| 238 |
+
|
| 239 |
+
return x0_out
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class DiffusionNet(nn.Module):
|
| 243 |
+
|
| 244 |
+
def __init__(self,
|
| 245 |
+
C_in,
|
| 246 |
+
C_out,
|
| 247 |
+
C_width=128,
|
| 248 |
+
N_block=4,
|
| 249 |
+
last_activation=None,
|
| 250 |
+
outputs_at='vertices',
|
| 251 |
+
mlp_hidden_dims=None,
|
| 252 |
+
dropout=True,
|
| 253 |
+
with_gradient_features=True,
|
| 254 |
+
with_gradient_rotations=True,
|
| 255 |
+
diffusion_method='spectral',
|
| 256 |
+
num_eigenbasis=128):
|
| 257 |
+
"""
|
| 258 |
+
Construct a DiffusionNet.
|
| 259 |
+
|
| 260 |
+
Parameters:
|
| 261 |
+
C_in (int): input dimension
|
| 262 |
+
C_out (int): output dimension
|
| 263 |
+
last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
|
| 264 |
+
outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
|
| 265 |
+
C_width (int): dimension of internal DiffusionNet blocks (default: 128)
|
| 266 |
+
N_block (int): number of DiffusionNet blocks (default: 4)
|
| 267 |
+
mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width])
|
| 268 |
+
dropout (bool): if True, internal MLPs use dropout (default: True)
|
| 269 |
+
diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
|
| 270 |
+
with_gradient_features (bool): if True, use gradient features (default: True)
|
| 271 |
+
with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
|
| 272 |
+
num_eigenbasis (int): for trunking the eigenvalues eigenvectors
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
super(DiffusionNet, self).__init__()
|
| 276 |
+
|
| 277 |
+
## Store parameters
|
| 278 |
+
|
| 279 |
+
# Basic parameters
|
| 280 |
+
self.C_in = C_in
|
| 281 |
+
self.C_out = C_out
|
| 282 |
+
self.C_width = C_width
|
| 283 |
+
self.N_block = N_block
|
| 284 |
+
|
| 285 |
+
# Outputs
|
| 286 |
+
self.last_activation = last_activation
|
| 287 |
+
self.outputs_at = outputs_at
|
| 288 |
+
if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
|
| 289 |
+
raise ValueError("invalid setting for outputs_at")
|
| 290 |
+
|
| 291 |
+
# MLP options
|
| 292 |
+
if mlp_hidden_dims == None:
|
| 293 |
+
mlp_hidden_dims = [C_width, C_width]
|
| 294 |
+
self.mlp_hidden_dims = mlp_hidden_dims
|
| 295 |
+
self.dropout = dropout
|
| 296 |
+
|
| 297 |
+
# Diffusion
|
| 298 |
+
self.diffusion_method = diffusion_method
|
| 299 |
+
if diffusion_method not in ['spectral', 'implicit_dense']:
|
| 300 |
+
raise ValueError("invalid setting for diffusion_method")
|
| 301 |
+
self.num_eigenbasis = num_eigenbasis
|
| 302 |
+
|
| 303 |
+
# Gradient features
|
| 304 |
+
self.with_gradient_features = with_gradient_features
|
| 305 |
+
self.with_gradient_rotations = with_gradient_rotations
|
| 306 |
+
|
| 307 |
+
## Set up the network
|
| 308 |
+
|
| 309 |
+
# First and last affine layers
|
| 310 |
+
self.first_lin = nn.Linear(C_in, C_width)
|
| 311 |
+
self.last_lin = nn.Linear(C_width, C_out)
|
| 312 |
+
|
| 313 |
+
# DiffusionNet blocks
|
| 314 |
+
self.blocks = []
|
| 315 |
+
for i_block in range(self.N_block):
|
| 316 |
+
block = DiffusionNetBlock(C_width=C_width,
|
| 317 |
+
mlp_hidden_dims=mlp_hidden_dims,
|
| 318 |
+
dropout=dropout,
|
| 319 |
+
diffusion_method=diffusion_method,
|
| 320 |
+
with_gradient_features=with_gradient_features,
|
| 321 |
+
with_gradient_rotations=with_gradient_rotations)
|
| 322 |
+
|
| 323 |
+
self.blocks.append(block)
|
| 324 |
+
self.add_module("block_" + str(i_block), self.blocks[-1])
|
| 325 |
+
|
| 326 |
+
def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
|
| 327 |
+
"""
|
| 328 |
+
A forward pass on the DiffusionNet.
|
| 329 |
+
|
| 330 |
+
In the notation below, dimension are:
|
| 331 |
+
- C is the input channel dimension (C_in on construction)
|
| 332 |
+
- C_OUT is the output channel dimension (C_out on construction)
|
| 333 |
+
- N is the number of vertices/points, which CAN be different for each forward pass
|
| 334 |
+
- B is an OPTIONAL batch dimension
|
| 335 |
+
- K_EIG is the number of eigenvalues used for spectral acceleration
|
| 336 |
+
Generally, our data layout it is [N,C] or [B,N,C].
|
| 337 |
+
|
| 338 |
+
Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
|
| 339 |
+
|
| 340 |
+
Parameters:
|
| 341 |
+
x_in (tensor): Input features, dimension [N,C] or [B,N,C]
|
| 342 |
+
mass (tensor): Mass vector, dimension [N] or [B,N]
|
| 343 |
+
L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
|
| 344 |
+
evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
|
| 345 |
+
evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
|
| 346 |
+
gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
|
| 347 |
+
gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out]
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
## Check dimensions, and append batch dimension if not given
|
| 354 |
+
if x_in.shape[-1] != self.C_in:
|
| 355 |
+
raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
|
| 356 |
+
self.C_in, x_in.shape[-1]))
|
| 357 |
+
N = x_in.shape[-2]
|
| 358 |
+
if len(x_in.shape) == 2:
|
| 359 |
+
appended_batch_dim = True
|
| 360 |
+
|
| 361 |
+
# add a batch dim to all inputs
|
| 362 |
+
x_in = x_in.unsqueeze(0)
|
| 363 |
+
mass = mass.unsqueeze(0)
|
| 364 |
+
if L != None:
|
| 365 |
+
L = L.unsqueeze(0)
|
| 366 |
+
if evals != None:
|
| 367 |
+
evals = evals.unsqueeze(0)
|
| 368 |
+
if evecs != None:
|
| 369 |
+
evecs = evecs.unsqueeze(0)
|
| 370 |
+
if gradX != None:
|
| 371 |
+
gradX = gradX.unsqueeze(0)
|
| 372 |
+
if gradY != None:
|
| 373 |
+
gradY = gradY.unsqueeze(0)
|
| 374 |
+
if edges != None:
|
| 375 |
+
edges = edges.unsqueeze(0)
|
| 376 |
+
if faces != None:
|
| 377 |
+
faces = faces.unsqueeze(0)
|
| 378 |
+
|
| 379 |
+
elif len(x_in.shape) == 3:
|
| 380 |
+
appended_batch_dim = False
|
| 381 |
+
|
| 382 |
+
else:
|
| 383 |
+
raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
|
| 384 |
+
|
| 385 |
+
evals = evals[..., :self.num_eigenbasis]
|
| 386 |
+
evecs = evecs[..., :self.num_eigenbasis]
|
| 387 |
+
|
| 388 |
+
# Apply the first linear layer
|
| 389 |
+
x = self.first_lin(x_in)
|
| 390 |
+
|
| 391 |
+
# Apply each of the blocks
|
| 392 |
+
for b in self.blocks:
|
| 393 |
+
x = b(x, mass, L, evals, evecs, gradX, gradY)
|
| 394 |
+
|
| 395 |
+
# Apply the last linear layer
|
| 396 |
+
x = self.last_lin(x)
|
| 397 |
+
|
| 398 |
+
# Remap output to faces/edges if requested
|
| 399 |
+
if self.outputs_at == 'vertices':
|
| 400 |
+
x_out = x
|
| 401 |
+
|
| 402 |
+
elif self.outputs_at == 'edges':
|
| 403 |
+
# Remap to edges
|
| 404 |
+
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
|
| 405 |
+
edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
|
| 406 |
+
xe = torch.gather(x_gather, 1, edges_gather)
|
| 407 |
+
x_out = torch.mean(xe, dim=-1)
|
| 408 |
+
|
| 409 |
+
elif self.outputs_at == 'faces':
|
| 410 |
+
# Remap to faces
|
| 411 |
+
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
|
| 412 |
+
faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
|
| 413 |
+
xf = torch.gather(x_gather, 1, faces_gather)
|
| 414 |
+
x_out = torch.mean(xf, dim=-1)
|
| 415 |
+
|
| 416 |
+
elif self.outputs_at == 'global_mean':
|
| 417 |
+
# Produce a single global mean ouput.
|
| 418 |
+
# Using a weighted mean according to the point mass/area is discretization-invariant.
|
| 419 |
+
# (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
|
| 420 |
+
x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
|
| 421 |
+
|
| 422 |
+
# Apply last nonlinearity if specified
|
| 423 |
+
if self.last_activation != None:
|
| 424 |
+
x_out = self.last_activation(x_out)
|
| 425 |
+
|
| 426 |
+
# Remove batch dim if we added it
|
| 427 |
+
if appended_batch_dim:
|
| 428 |
+
x_out = x_out.squeeze(0)
|
| 429 |
+
|
| 430 |
+
return x_out
|
utils/mesh.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import potpourri3d as pp3d
|
| 6 |
+
import open3d as o3d
|
| 7 |
+
import scipy.io as sio
|
| 8 |
+
import numpy as np
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 12 |
+
from shape_data import get_data_dirs
|
| 13 |
+
|
| 14 |
+
# List of file extensions to consider as "mesh" files.
|
| 15 |
+
# Kudos to chatgpt!
|
| 16 |
+
# Add or remove extensions here as needed.
|
| 17 |
+
MESH_EXTENSIONS = {".ply", ".obj", ".off", ".stl", ".fbx", ".gltf", ".glb"}
|
| 18 |
+
|
| 19 |
+
def sorted_alphanum(file_list_ordered):
|
| 20 |
+
def convert(text):
|
| 21 |
+
return int(text) if text.isdigit() else text
|
| 22 |
+
def alphanum_key(key):
|
| 23 |
+
return [convert(c) for c in re.split('([0-9]+)', str(key)) if len(c) > 0]
|
| 24 |
+
return sorted(file_list_ordered, key=alphanum_key)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def list_files(folder_path, name_filter, alphanum_sort=False):
|
| 28 |
+
file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
|
| 29 |
+
if alphanum_sort:
|
| 30 |
+
return sorted_alphanum(file_list)
|
| 31 |
+
else:
|
| 32 |
+
return sorted(file_list)
|
| 33 |
+
|
| 34 |
+
def find_mesh_files(directory: Path, extensions: set=MESH_EXTENSIONS, alphanum_sort=False) -> list[Path]:
|
| 35 |
+
"""
|
| 36 |
+
Recursively find all files in 'directory' whose suffix (lowercased) is in 'extensions'.
|
| 37 |
+
Returns a list of Path objects.
|
| 38 |
+
"""
|
| 39 |
+
matches = []
|
| 40 |
+
for path in directory.rglob("*"):
|
| 41 |
+
if path.is_file() and path.suffix.lower() in extensions:
|
| 42 |
+
matches.append(path)
|
| 43 |
+
if alphanum_sort:
|
| 44 |
+
return sorted_alphanum(matches)
|
| 45 |
+
else:
|
| 46 |
+
return sorted(matches)
|
| 47 |
+
|
| 48 |
+
def save_ply(file_name, V, F, Rho=None, color=None):
|
| 49 |
+
"""Save mesh information either as an ASCII ply file.
|
| 50 |
+
https://github.com/emmanuel-hartman/BaRe-ESA/blob/main/utils/input_output.py
|
| 51 |
+
Input:
|
| 52 |
+
- file_name: specified path for saving mesh [string]
|
| 53 |
+
- V: vertices of the triangulated surface [nVx3 numpy ndarray]
|
| 54 |
+
- F: faces of the triangulated surface [nFx3 numpy ndarray]
|
| 55 |
+
- Rho: weights defined on the vertices of the triangulated surface [nVx1 numpy ndarray, default=None]
|
| 56 |
+
- color: colormap [nVx3 numpy ndarray of RGB triples]
|
| 57 |
+
|
| 58 |
+
Output:
|
| 59 |
+
- file_name.mat or file_name.ply file containing mesh information
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Save as .ply file
|
| 63 |
+
nV = V.shape[0]
|
| 64 |
+
nF = F.shape[0]
|
| 65 |
+
if not ".ply" in file_name:
|
| 66 |
+
file_name += ".ply"
|
| 67 |
+
file = open(file_name, "w")
|
| 68 |
+
lines = ("ply","\n","format ascii 1.0","\n", "element vertex {}".format(nV),"\n","property float x","\n","property float y","\n","property float z","\n")
|
| 69 |
+
|
| 70 |
+
if color is not None:
|
| 71 |
+
lines += ("property uchar red","\n","property uchar green","\n","property uchar blue","\n")
|
| 72 |
+
if Rho is not None:
|
| 73 |
+
lines += ("property uchar alpha","\n")
|
| 74 |
+
|
| 75 |
+
lines += ("element face {}".format(nF),"\n","property list uchar int vertex_index","\n","end_header","\n")
|
| 76 |
+
|
| 77 |
+
file.writelines(lines)
|
| 78 |
+
lines = []
|
| 79 |
+
for i in range(0,nV):
|
| 80 |
+
for j in range(0,3):
|
| 81 |
+
lines.append(str(V[i][j]))
|
| 82 |
+
lines.append(" ")
|
| 83 |
+
if color is not None:
|
| 84 |
+
for j in range(0,3):
|
| 85 |
+
lines.append(str(color[i][j]))
|
| 86 |
+
lines.append(" ")
|
| 87 |
+
if Rho is not None:
|
| 88 |
+
lines.append(str(Rho[i]))
|
| 89 |
+
lines.append(" ")
|
| 90 |
+
|
| 91 |
+
lines.append("\n")
|
| 92 |
+
for i in range(0,nF):
|
| 93 |
+
l = len(F[i,:])
|
| 94 |
+
lines.append(str(l))
|
| 95 |
+
lines.append(" ")
|
| 96 |
+
|
| 97 |
+
for j in range(0,l):
|
| 98 |
+
lines.append(str(F[i,j]))
|
| 99 |
+
lines.append(" ")
|
| 100 |
+
lines.append("\n")
|
| 101 |
+
|
| 102 |
+
file.writelines(lines)
|
| 103 |
+
file.close()
|
| 104 |
+
|
| 105 |
+
def numpy_to_open3d_mesh(V, F):
|
| 106 |
+
# Create an empty TriangleMesh object
|
| 107 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 108 |
+
# Set vertices
|
| 109 |
+
mesh.vertices = o3d.utility.Vector3dVector(V)
|
| 110 |
+
# Set triangles
|
| 111 |
+
mesh.triangles = o3d.utility.Vector3iVector(F)
|
| 112 |
+
return mesh
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_mesh(filepath, scale=True, return_vnormals=False):
|
| 117 |
+
V, F = pp3d.read_mesh(filepath)
|
| 118 |
+
mesh = numpy_to_open3d_mesh(V, F)
|
| 119 |
+
|
| 120 |
+
tmat = np.identity(4, dtype=np.float32)
|
| 121 |
+
center = mesh.get_center()
|
| 122 |
+
tmat[:3, 3] = -center
|
| 123 |
+
area = mesh.get_surface_area()
|
| 124 |
+
if scale:
|
| 125 |
+
smat = np.identity(4, dtype=np.float32)
|
| 126 |
+
|
| 127 |
+
smat[:3, :3] = np.identity(3, dtype=np.float32) / np.sqrt(area)
|
| 128 |
+
tmat = smat @ tmat
|
| 129 |
+
mesh.transform(tmat)
|
| 130 |
+
|
| 131 |
+
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
| 132 |
+
faces = np.asarray(mesh.triangles, dtype=np.int32)
|
| 133 |
+
if return_vnormals:
|
| 134 |
+
mesh.compute_vertex_normals()
|
| 135 |
+
vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
|
| 136 |
+
if scale:
|
| 137 |
+
return vertices, faces, vnormals, area, center
|
| 138 |
+
return vertices, faces, vnormals, area, center
|
| 139 |
+
else:
|
| 140 |
+
return vertices, faces
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def mesh_geod_matrix(vertices, faces, do_tqdm=False, verbose=False):
|
| 144 |
+
if verbose:
|
| 145 |
+
print("Setting Geodesic matrix bw vertices")
|
| 146 |
+
n_vertices = vertices.shape[0]
|
| 147 |
+
distmat = np.zeros((n_vertices, n_vertices))
|
| 148 |
+
solver = pp3d.MeshHeatMethodDistanceSolver(vertices, faces)
|
| 149 |
+
if do_tqdm:
|
| 150 |
+
iterable = tqdm(range(n_vertices))
|
| 151 |
+
else:
|
| 152 |
+
iterable = range(n_vertices)
|
| 153 |
+
for vertind in iterable:
|
| 154 |
+
distmat[vertind] = np.maximum(solver.compute_distance(vertind), 0)
|
| 155 |
+
geod_mat = distmat
|
| 156 |
+
return geod_mat
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def prepare_geod_mats(shapes_folder, out, basename=None):
|
| 160 |
+
if basename is None:
|
| 161 |
+
basename = os.path.basename(os.path.dirname(shapes_folder)) #+ "_" + os.path.basename(shapes_folder)
|
| 162 |
+
case = basename
|
| 163 |
+
case_folder_out = os.path.join(out, case)
|
| 164 |
+
os.makedirs(case_folder_out, exist_ok=True)
|
| 165 |
+
all_shapes = [f for f in os.listdir(shapes_folder) if (".ply" in f) or (".off" in f) or ('.obj' in f)]
|
| 166 |
+
for shape in tqdm(all_shapes, "Processing " + os.path.basename(shapes_folder)):
|
| 167 |
+
vertices, faces = pp3d.read_mesh(os.path.join(shapes_folder, shape))
|
| 168 |
+
areas = pp3d.face_areas(vertices, faces)
|
| 169 |
+
mat = mesh_geod_matrix(vertices, faces, verbose=False)
|
| 170 |
+
dict_save = {
|
| 171 |
+
'geod_dist': mat,
|
| 172 |
+
'areas_f': areas
|
| 173 |
+
}
|
| 174 |
+
sio.savemat(os.path.join(case_folder_out, shape[:-4]+'.mat'), dict_save)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
parser = argparse.ArgumentParser(description="What to do.")
|
| 179 |
+
parser.add_argument('--make_geods', required=True, type=int, default=0, help='launch computation of geod matrices')
|
| 180 |
+
parser.add_argument('--data', required=False, type=str, default=None)
|
| 181 |
+
parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
|
| 182 |
+
parser.add_argument('--basename', required=False, type=str, default=None)
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
if args.make_geods:
|
| 185 |
+
# from config import get_geod_path, get_dataset_path, get_template_path
|
| 186 |
+
# output = get_geod_path()
|
| 187 |
+
output = os.path.join(args.datadir, "geomats")
|
| 188 |
+
if args.data == "humans":
|
| 189 |
+
all_folders = [get_data_dirs(args.datadir, "faust", 'test')[0], get_data_dirs(args.datadir, "scape", 'test')[0], get_data_dirs(args.datadir, "shrec19", 'test')[0]]
|
| 190 |
+
for folder in all_folders:
|
| 191 |
+
prepare_geod_mats(folder, output)
|
| 192 |
+
elif args.data == "dt4d":
|
| 193 |
+
data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
|
| 194 |
+
all_folders = sorted([f for f in os.listdir(data_dir) if "cross" not in f])
|
| 195 |
+
for folder in all_folders:
|
| 196 |
+
prepare_geod_mats(os.path.join(data_dir, folder), os.path.join(output, "DT4D"), basename=folder)
|
| 197 |
+
elif args.data is not None:
|
| 198 |
+
data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
|
| 199 |
+
prepare_geod_mats(data_dir, output, args.basename)
|
| 200 |
+
# parser = argparse.ArgumentParser(description="Find all mesh files in a folder and list their paths.")
|
| 201 |
+
# parser.add_argument("--folder", type=Path, help="Path to the folder to search (will search recursively).")
|
| 202 |
+
# args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
# search_folder = args.folder
|
| 205 |
+
# if not search_folder.is_dir():
|
| 206 |
+
# print(f"Error: '{search_folder}' is not a valid directory.")
|
| 207 |
+
# else:
|
| 208 |
+
# # Find all matching files
|
| 209 |
+
# mesh_files = find_mesh_files(search_folder)
|
| 210 |
+
|
| 211 |
+
# # Sort the results for consistency
|
| 212 |
+
# mesh_files.sort()
|
| 213 |
+
# for p in mesh_files:
|
| 214 |
+
# print(p.resolve())
|
utils/meshplot.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import potpourri3d as pp3d
|
| 3 |
+
import torch
|
| 4 |
+
import meshplot as mp
|
| 5 |
+
from ipygany import PolyMesh, Scene
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
colors = np.array([[255, 119, 0], #orange
|
| 10 |
+
[163, 51, 219], #violet
|
| 11 |
+
[0, 246, 205], #bleu clair
|
| 12 |
+
[0, 131, 246], #bleu floncé
|
| 13 |
+
[246, 234, 0], #jaune
|
| 14 |
+
[143, 188, 143], #rouge
|
| 15 |
+
[255, 0, 0]])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def double_plot_surf(surf_1, surf_2,cmap1=None,cmap2=None):
|
| 19 |
+
d = mp.subplot(surf_1.vertices, surf_1.faces, c=cmap1, s=[2, 2, 0])
|
| 20 |
+
mp.subplot(surf_2.vertices, surf_2.faces, c=cmap2, s=[2, 2, 1], data=d)
|
| 21 |
+
|
| 22 |
+
def visu_pts(surf, colors=colors, idx_points=None, n_kpts=5):
|
| 23 |
+
areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
|
| 24 |
+
area = np.sqrt(areas.sum()/2)
|
| 25 |
+
if idx_points is None:
|
| 26 |
+
center = (surf.centers*(areas)).sum(axis=0)/areas.sum()
|
| 27 |
+
surf.updateVertices((surf.vertices - center)/area)
|
| 28 |
+
surf.cotanLaplacian()
|
| 29 |
+
idx_points = surf.get_keypoints(n_points=n_kpts)
|
| 30 |
+
solver = pp3d.MeshHeatMethodDistanceSolver(surf.vertices, surf.faces)
|
| 31 |
+
norm_center = solver.compute_distance(idx_points[-1])
|
| 32 |
+
color_array = np.zeros(surf.vertices.shape)
|
| 33 |
+
for i in range(len(idx_points)):
|
| 34 |
+
coeff = 0.1 + 1.5*norm_center[idx_points[i]]
|
| 35 |
+
i_v = idx_points[i]
|
| 36 |
+
dist = solver.compute_distance(i_v)*2
|
| 37 |
+
#color_array += np.clip(1-dist, 0, np.inf)[:, None]*colors[i][None, :]
|
| 38 |
+
color_array += np.exp(-dist**2/coeff)[:, None]*colors[i][None, :]
|
| 39 |
+
color_array = np.clip(color_array, 0, 255.)
|
| 40 |
+
return color_array
|
| 41 |
+
|
| 42 |
+
def toNP(tens):
|
| 43 |
+
if isinstance(tens, torch.Tensor):
|
| 44 |
+
return tens.detach().squeeze().cpu().numpy()
|
| 45 |
+
return tens
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def overlay_surf(shape_vertices, shape_faces, target_vertices, target_faces, colors=["tomato", "darksalmon"]):
|
| 49 |
+
shape_vertices = toNP(shape_vertices)
|
| 50 |
+
target_vertices = toNP(target_vertices)
|
| 51 |
+
mesh_1 = PolyMesh(
|
| 52 |
+
vertices=shape_vertices,
|
| 53 |
+
triangle_indices=shape_faces
|
| 54 |
+
)
|
| 55 |
+
mesh_1.default_color = colors[0]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
mesh_2 = PolyMesh(
|
| 59 |
+
vertices=target_vertices,
|
| 60 |
+
triangle_indices=target_faces
|
| 61 |
+
)
|
| 62 |
+
mesh_2.default_color = colors[1]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
scene = Scene([mesh_1, mesh_2])
|
| 66 |
+
#scene = Scene([mesh_5])
|
| 67 |
+
return scene, [mesh_1, mesh_2]
|
utils/misc.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
import omegaconf
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from scipy.spatial import cKDTree
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KNNSearch(object):
|
| 13 |
+
DTYPE = np.float32
|
| 14 |
+
NJOBS = 4
|
| 15 |
+
|
| 16 |
+
def __init__(self, data):
|
| 17 |
+
self.data = np.asarray(data, dtype=self.DTYPE)
|
| 18 |
+
self.kdtree = cKDTree(self.data)
|
| 19 |
+
|
| 20 |
+
def query(self, kpts, k, return_dists=False):
|
| 21 |
+
kpts = np.asarray(kpts, dtype=self.DTYPE)
|
| 22 |
+
nndists, nnindices = self.kdtree.query(kpts, k=k, workers=self.NJOBS)
|
| 23 |
+
if return_dists:
|
| 24 |
+
return nnindices, nndists
|
| 25 |
+
else:
|
| 26 |
+
return nnindices
|
| 27 |
+
|
| 28 |
+
def query_ball(self, kpt, radius):
|
| 29 |
+
kpt = np.asarray(kpt, dtype=self.DTYPE)
|
| 30 |
+
assert kpt.ndim == 1
|
| 31 |
+
nnindices = self.kdtree.query_ball_point(kpt, radius, n_jobs=self.NJOBS)
|
| 32 |
+
return nnindices
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def validate_str(x):
|
| 36 |
+
return x is not None and x != ''
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def hashing(arr, M):
|
| 40 |
+
assert isinstance(arr, np.ndarray) and arr.ndim == 2
|
| 41 |
+
N, D = arr.shape
|
| 42 |
+
|
| 43 |
+
hash_vec = np.zeros(N, dtype=np.int64)
|
| 44 |
+
for d in range(D):
|
| 45 |
+
hash_vec += arr[:, d] * M**d
|
| 46 |
+
return hash_vec
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def omegaconf_to_dotdict(hparams):
|
| 50 |
+
|
| 51 |
+
def _to_dot_dict(cfg):
|
| 52 |
+
res = {}
|
| 53 |
+
for k, v in cfg.items():
|
| 54 |
+
if v is None:
|
| 55 |
+
res[k] = v
|
| 56 |
+
elif isinstance(v, omegaconf.DictConfig):
|
| 57 |
+
res.update({k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()})
|
| 58 |
+
elif isinstance(v, (str, int, float, bool)):
|
| 59 |
+
res[k] = v
|
| 60 |
+
elif isinstance(v, omegaconf.ListConfig):
|
| 61 |
+
res[k] = omegaconf.OmegaConf.to_container(v, resolve=True)
|
| 62 |
+
else:
|
| 63 |
+
raise RuntimeError('The type of {} is not supported.'.format(type(v)))
|
| 64 |
+
return res
|
| 65 |
+
|
| 66 |
+
return _to_dot_dict(hparams)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def incrange(start, end, step):
|
| 70 |
+
assert step > 0
|
| 71 |
+
res = [start]
|
| 72 |
+
if start <= end:
|
| 73 |
+
while res[-1] + step <= end:
|
| 74 |
+
res.append(res[-1] + step)
|
| 75 |
+
else:
|
| 76 |
+
while res[-1] - step >= end:
|
| 77 |
+
res.append(res[-1] - step)
|
| 78 |
+
return res
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def seeding(seed=0):
|
| 82 |
+
torch.manual_seed(seed)
|
| 83 |
+
torch.cuda.manual_seed_all(seed)
|
| 84 |
+
np.random.seed(seed)
|
| 85 |
+
random.seed(seed)
|
| 86 |
+
|
| 87 |
+
torch.backends.cudnn.enabled = True
|
| 88 |
+
torch.backends.cudnn.benchmark = True
|
| 89 |
+
torch.backends.cudnn.deterministic = True
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def run_trainer(trainer_cls):
|
| 93 |
+
cfg_cli = OmegaConf.from_cli()
|
| 94 |
+
|
| 95 |
+
assert cfg_cli.run_mode is not None
|
| 96 |
+
if cfg_cli.run_mode == 'train':
|
| 97 |
+
assert cfg_cli.run_cfg is not None
|
| 98 |
+
cfg = OmegaConf.merge(
|
| 99 |
+
OmegaConf.load(cfg_cli.run_cfg),
|
| 100 |
+
cfg_cli,
|
| 101 |
+
)
|
| 102 |
+
OmegaConf.resolve(cfg)
|
| 103 |
+
cfg = omegaconf_to_dotdict(cfg)
|
| 104 |
+
seeding(cfg['seed'])
|
| 105 |
+
trainer = trainer_cls(cfg)
|
| 106 |
+
trainer.train()
|
| 107 |
+
trainer.test()
|
| 108 |
+
elif cfg_cli.run_mode == 'test':
|
| 109 |
+
assert cfg_cli.run_ckpt is not None
|
| 110 |
+
log_dir = str(Path(cfg_cli.run_ckpt).parent)
|
| 111 |
+
cfg = OmegaConf.merge(
|
| 112 |
+
OmegaConf.load(osp.join(log_dir, 'config.yml')),
|
| 113 |
+
cfg_cli,
|
| 114 |
+
)
|
| 115 |
+
OmegaConf.resolve(cfg)
|
| 116 |
+
cfg = omegaconf_to_dotdict(cfg)
|
| 117 |
+
cfg['test_ckpt'] = cfg_cli.run_ckpt
|
| 118 |
+
seeding(cfg['seed'])
|
| 119 |
+
trainer = trainer_cls(cfg)
|
| 120 |
+
trainer.test()
|
| 121 |
+
else:
|
| 122 |
+
raise RuntimeError(f'Mode {cfg_cli.run_mode} is not supported.')
|
utils/pickle_stuff.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import io
|
| 3 |
+
import importlib
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Mapping of old module names to new module names
|
| 10 |
+
MODULE_RENAME_MAP = {
|
| 11 |
+
'module': 'diffu_models',
|
| 12 |
+
'module.model': 'diffu_models.precond',
|
| 13 |
+
'module.dit_models': 'diffu_models.dit_models',
|
| 14 |
+
'module.model': 'diffu_models.precond',
|
| 15 |
+
# add more as needed
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
class RenameUnpickler(pickle.Unpickler):
|
| 19 |
+
def find_class(self, module, name):
|
| 20 |
+
if module in MODULE_RENAME_MAP:
|
| 21 |
+
module = MODULE_RENAME_MAP[module]
|
| 22 |
+
try:
|
| 23 |
+
return super().find_class(module, name)
|
| 24 |
+
except ModuleNotFoundError as e:
|
| 25 |
+
raise ModuleNotFoundError(f"Could not find module '{module}'. You may need to update MODULE_RENAME_MAP.") from e
|
| 26 |
+
|
| 27 |
+
# Usage
|
| 28 |
+
def load_renamed_pickle(file_path):
|
| 29 |
+
with open(file_path, 'rb') as f:
|
| 30 |
+
return RenameUnpickler(f).load()
|
| 31 |
+
|
| 32 |
+
def safe_load_with_fallback(file_path):
|
| 33 |
+
try:
|
| 34 |
+
with open(file_path, 'rb') as f:
|
| 35 |
+
return pickle.load(f)
|
| 36 |
+
except ModuleNotFoundError:
|
| 37 |
+
with open(file_path, 'rb') as f:
|
| 38 |
+
return RenameUnpickler(f).load()
|
utils/surfaces.py
ADDED
|
@@ -0,0 +1,1377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy as sp
|
| 3 |
+
import scipy.linalg as spLA
|
| 4 |
+
import scipy.sparse.linalg as sla
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import copy
|
| 8 |
+
import potpourri3d as pp3d
|
| 9 |
+
from sklearn.cluster import KMeans#, DBSCAN, SpectralClustering
|
| 10 |
+
from scipy.spatial import cKDTree
|
| 11 |
+
import trimesh as tm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from vtk import *
|
| 16 |
+
import vtk.util.numpy_support as v2n
|
| 17 |
+
|
| 18 |
+
gotVTK = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
print('could not import VTK functions')
|
| 21 |
+
gotVTK = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# import kernelFunctions as kfun
|
| 25 |
+
|
| 26 |
+
class vtkFields:
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.scalars = []
|
| 29 |
+
self.vectors = []
|
| 30 |
+
self.normals = []
|
| 31 |
+
self.tensors = []
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# General surface class, fibers possible. The fiber is a vector field of norm 1, defined on each vertex.
|
| 35 |
+
# It must be an array of the same size (number of vertices)x3.
|
| 36 |
+
|
| 37 |
+
class Surface:
|
| 38 |
+
|
| 39 |
+
# Fibers: a list of vectors, with i-th element corresponding to the value of the vector field at vertex i
|
| 40 |
+
# Contains as object:
|
| 41 |
+
# vertices : all the vertices
|
| 42 |
+
# centers: the centers of each face
|
| 43 |
+
# faces: faces along with the id of the faces
|
| 44 |
+
# surfel: surface element of each face (area*normal)
|
| 45 |
+
# List of methods:
|
| 46 |
+
# read : from filename type, call readFILENAME and set all surface attributes
|
| 47 |
+
# updateVertices: update the whole surface after a modification of the vertices
|
| 48 |
+
# computeVertexArea and Normals
|
| 49 |
+
# getEdges
|
| 50 |
+
# LocalSignedDistance distance function in the neighborhood of a shape
|
| 51 |
+
# toPolyData: convert the surface to a polydata vtk object
|
| 52 |
+
# fromPolyDate: guess what
|
| 53 |
+
# Simplify : simplify the meshes
|
| 54 |
+
# flipfaces: invert the indices of faces from [a, b, c] to [a, c, b]
|
| 55 |
+
# smooth: get a smoother surface
|
| 56 |
+
# Isosurface: compute isosurface
|
| 57 |
+
# edgeRecove: ensure that orientation is correct
|
| 58 |
+
# remove isolated: if isolated vertice, remove it
|
| 59 |
+
# laplacianMatrix: compute Laplacian Matrix of the surface graph
|
| 60 |
+
# graphLaplacianMatrix: ???
|
| 61 |
+
# laplacianSegmentation: segment the surface using laplacian properties
|
| 62 |
+
# surfVolume: compute volume inscribed in the surface (+ inscribed infinitesimal volume for each face)
|
| 63 |
+
# surfCenter: compute surface Center
|
| 64 |
+
# surfMoments: compute surface second order moments
|
| 65 |
+
# surfU: compute the informations for surface rigid alignement
|
| 66 |
+
# surfEllipsoid: compute ellipsoid representing the surface
|
| 67 |
+
# savebyu of vitk or vtk2: save the surface in a file
|
| 68 |
+
# concatenate: concatenate to another surface
|
| 69 |
+
|
| 70 |
+
def __init__(self, surf=None, filename=None, FV=None):
|
| 71 |
+
if surf == None:
|
| 72 |
+
if FV == None:
|
| 73 |
+
if filename == None:
|
| 74 |
+
self.vertices = np.empty(0)
|
| 75 |
+
self.centers = np.empty(0)
|
| 76 |
+
self.faces = np.empty(0)
|
| 77 |
+
self.surfel = np.empty(0)
|
| 78 |
+
else:
|
| 79 |
+
if type(filename) is list:
|
| 80 |
+
fvl = []
|
| 81 |
+
for name in filename:
|
| 82 |
+
fvl.append(Surface(filename=name))
|
| 83 |
+
self.concatenate(fvl)
|
| 84 |
+
else:
|
| 85 |
+
self.read(filename)
|
| 86 |
+
else:
|
| 87 |
+
self.vertices = np.copy(FV[1])
|
| 88 |
+
self.faces = np.int_(FV[0])
|
| 89 |
+
self.computeCentersAreas()
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
self.vertices = np.copy(surf.vertices)
|
| 93 |
+
self.faces = np.copy(surf.faces)
|
| 94 |
+
self.surfel = np.copy(surf.surfel)
|
| 95 |
+
self.centers = np.copy(surf.centers)
|
| 96 |
+
self.computeCentersAreas()
|
| 97 |
+
self.volume, self.vols = self.surfVolume()
|
| 98 |
+
self.center = self.surfCenter()
|
| 99 |
+
self.cotanLaplacian()
|
| 100 |
+
|
| 101 |
+
def read(self, filename):
|
| 102 |
+
(mainPart, ext) = os.path.splitext(filename)
|
| 103 |
+
if ext == '.byu':
|
| 104 |
+
self.readbyu(filename)
|
| 105 |
+
elif ext == '.off':
|
| 106 |
+
self.readOFF(filename)
|
| 107 |
+
elif ext == '.vtk':
|
| 108 |
+
self.readVTK(filename)
|
| 109 |
+
elif ext == '.obj':
|
| 110 |
+
self.readOBJ(filename)
|
| 111 |
+
elif ext == '.ply':
|
| 112 |
+
self.readPLY(filename)
|
| 113 |
+
elif ext == '.tri' or ext == ".ntri":
|
| 114 |
+
self.readTRI(filename)
|
| 115 |
+
else:
|
| 116 |
+
print('Unknown Surface Extension:', ext)
|
| 117 |
+
self.vertices = np.empty(0)
|
| 118 |
+
self.centers = np.empty(0)
|
| 119 |
+
self.faces = np.empty(0)
|
| 120 |
+
self.surfel = np.empty(0)
|
| 121 |
+
|
| 122 |
+
# face centers and area weighted normal
|
| 123 |
+
def computeCentersAreas(self):
|
| 124 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 125 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 126 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 127 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 128 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 129 |
+
|
| 130 |
+
# modify vertices without toplogical change
|
| 131 |
+
def updateVertices(self, x0):
|
| 132 |
+
self.vertices = np.copy(x0)
|
| 133 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 134 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 135 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 136 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 137 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 138 |
+
self.volume, self.vols = self.surfVolume()
|
| 139 |
+
|
| 140 |
+
def computeVertexArea(self):
|
| 141 |
+
# compute areas of faces and vertices
|
| 142 |
+
V = self.vertices
|
| 143 |
+
F = self.faces
|
| 144 |
+
nv = V.shape[0]
|
| 145 |
+
nf = F.shape[0]
|
| 146 |
+
AF = np.zeros([nf, 1])
|
| 147 |
+
AV = np.zeros([nv, 1])
|
| 148 |
+
for k in range(nf):
|
| 149 |
+
# determining if face is obtuse
|
| 150 |
+
x12 = V[F[k, 1], :] - V[F[k, 0], :]
|
| 151 |
+
x13 = V[F[k, 2], :] - V[F[k, 0], :]
|
| 152 |
+
n12 = np.sqrt((x12 ** 2).sum())
|
| 153 |
+
n13 = np.sqrt((x13 ** 2).sum())
|
| 154 |
+
c1 = (x12 * x13).sum() / (n12 * n13)
|
| 155 |
+
x23 = V[F[k, 2], :] - V[F[k, 1], :]
|
| 156 |
+
n23 = np.sqrt((x23 ** 2).sum())
|
| 157 |
+
# n23 = norm(x23) ;
|
| 158 |
+
c2 = -(x12 * x23).sum() / (n12 * n23)
|
| 159 |
+
c3 = (x13 * x23).sum() / (n13 * n23)
|
| 160 |
+
AF[k] = np.sqrt((np.cross(x12, x13) ** 2).sum()) / 2
|
| 161 |
+
if (c1 < 0):
|
| 162 |
+
# face obtuse at vertex 1
|
| 163 |
+
AV[F[k, 0]] += AF[k] / 2
|
| 164 |
+
AV[F[k, 1]] += AF[k] / 4
|
| 165 |
+
AV[F[k, 2]] += AF[k] / 4
|
| 166 |
+
elif (c2 < 0):
|
| 167 |
+
# face obuse at vertex 2
|
| 168 |
+
AV[F[k, 0]] += AF[k] / 4
|
| 169 |
+
AV[F[k, 1]] += AF[k] / 2
|
| 170 |
+
AV[F[k, 2]] += AF[k] / 4
|
| 171 |
+
elif (c3 < 0):
|
| 172 |
+
# face obtuse at vertex 3
|
| 173 |
+
AV[F[k, 0]] += AF[k] / 4
|
| 174 |
+
AV[F[k, 1]] += AF[k] / 4
|
| 175 |
+
AV[F[k, 2]] += AF[k] / 2
|
| 176 |
+
else:
|
| 177 |
+
# non obtuse face
|
| 178 |
+
cot1 = c1 / np.sqrt(1 - c1 ** 2)
|
| 179 |
+
cot2 = c2 / np.sqrt(1 - c2 ** 2)
|
| 180 |
+
cot3 = c3 / np.sqrt(1 - c3 ** 2)
|
| 181 |
+
AV[F[k, 0]] += ((x12 ** 2).sum() * cot3 + (x13 ** 2).sum() * cot2) / 8
|
| 182 |
+
AV[F[k, 1]] += ((x12 ** 2).sum() * cot3 + (x23 ** 2).sum() * cot1) / 8
|
| 183 |
+
AV[F[k, 2]] += ((x13 ** 2).sum() * cot2 + (x23 ** 2).sum() * cot1) / 8
|
| 184 |
+
|
| 185 |
+
for k in range(nv):
|
| 186 |
+
if (np.fabs(AV[k]) < 1e-10):
|
| 187 |
+
print('Warning: vertex ', k, 'has no face; use removeIsolated')
|
| 188 |
+
# print('sum check area:', AF.sum(), AV.sum()
|
| 189 |
+
return AV, AF
|
| 190 |
+
|
| 191 |
+
def computeVertexNormals(self):
|
| 192 |
+
self.computeCentersAreas()
|
| 193 |
+
normals = np.zeros(self.vertices.shape)
|
| 194 |
+
F = self.faces
|
| 195 |
+
for k in range(F.shape[0]):
|
| 196 |
+
normals[F[k, 0]] += self.surfel[k]
|
| 197 |
+
normals[F[k, 1]] += self.surfel[k]
|
| 198 |
+
normals[F[k, 2]] += self.surfel[k]
|
| 199 |
+
af = np.sqrt((normals ** 2).sum(axis=1))
|
| 200 |
+
# logging.info('min area = %.4f'%(af.min()))
|
| 201 |
+
normals /= af.reshape([self.vertices.shape[0], 1])
|
| 202 |
+
|
| 203 |
+
return normals
|
| 204 |
+
|
| 205 |
+
# Computes edges from vertices/faces
|
| 206 |
+
def getEdges(self):
|
| 207 |
+
self.edges = []
|
| 208 |
+
|
| 209 |
+
for k in range(self.faces.shape[0]):
|
| 210 |
+
for kj in (0, 1, 2):
|
| 211 |
+
u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
|
| 212 |
+
|
| 213 |
+
if (u not in self.edges) & (u.reverse() not in self.edges):
|
| 214 |
+
self.edges.append(u)
|
| 215 |
+
|
| 216 |
+
self.edgeFaces = []
|
| 217 |
+
|
| 218 |
+
for u in self.edges:
|
| 219 |
+
self.edgeFaces.append([])
|
| 220 |
+
|
| 221 |
+
for k in range(self.faces.shape[0]):
|
| 222 |
+
for kj in (0, 1, 2):
|
| 223 |
+
u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
|
| 224 |
+
|
| 225 |
+
if u in self.edges:
|
| 226 |
+
kk = self.edges.index(u)
|
| 227 |
+
else:
|
| 228 |
+
u.reverse()
|
| 229 |
+
kk = self.edges.index(u)
|
| 230 |
+
|
| 231 |
+
self.edgeFaces[kk].append(k)
|
| 232 |
+
|
| 233 |
+
self.edges = np.int_(np.array(self.edges))
|
| 234 |
+
self.bdry = np.int_(np.zeros(self.edges.shape[0]))
|
| 235 |
+
|
| 236 |
+
for k in range(self.edges.shape[0]):
|
| 237 |
+
if len(self.edgeFaces[k]) < 2:
|
| 238 |
+
self.bdry[k] = 1
|
| 239 |
+
|
| 240 |
+
# computes the signed distance function in a small neighborhood of a shape
|
| 241 |
+
def LocalSignedDistance(self, data, value):
|
| 242 |
+
d2 = 2 * np.array(data >= value) - 1
|
| 243 |
+
c2 = np.cumsum(d2, axis=0)
|
| 244 |
+
for j in range(2):
|
| 245 |
+
c2 = np.cumsum(c2, axis=j + 1)
|
| 246 |
+
(n0, n1, n2) = c2.shape
|
| 247 |
+
|
| 248 |
+
rad = 3
|
| 249 |
+
diam = 2 * rad + 1
|
| 250 |
+
(x, y, z) = np.mgrid[-rad:rad + 1, -rad:rad + 1, -rad:rad + 1]
|
| 251 |
+
cube = (x ** 2 + y ** 2 + z ** 2)
|
| 252 |
+
maxval = (diam) ** 3
|
| 253 |
+
s = 3.0 * rad ** 2
|
| 254 |
+
res = d2 * s
|
| 255 |
+
u = maxval * np.ones(c2.shape)
|
| 256 |
+
u[rad + 1:n0 - rad, rad + 1:n1 - rad, rad + 1:n2 - rad] = (c2[diam:n0, diam:n1, diam:n2]
|
| 257 |
+
- c2[0:n0 - diam, diam:n1, diam:n2] - c2[diam:n0,
|
| 258 |
+
0:n1 - diam,
|
| 259 |
+
diam:n2] - c2[
|
| 260 |
+
diam:n0,
|
| 261 |
+
diam:n1,
|
| 262 |
+
0:n2 - diam]
|
| 263 |
+
+ c2[0:n0 - diam, 0:n1 - diam, diam:n2] + c2[diam:n0,
|
| 264 |
+
0:n1 - diam,
|
| 265 |
+
0:n2 - diam] + c2[
|
| 266 |
+
0:n0 - diam,
|
| 267 |
+
diam:n1,
|
| 268 |
+
0:n2 - diam]
|
| 269 |
+
- c2[0:n0 - diam, 0:n1 - diam, 0:n2 - diam])
|
| 270 |
+
|
| 271 |
+
I = np.nonzero(np.fabs(u) < maxval)
|
| 272 |
+
# print(len(I[0]))
|
| 273 |
+
|
| 274 |
+
for k in range(len(I[0])):
|
| 275 |
+
p = np.array((I[0][k], I[1][k], I[2][k]))
|
| 276 |
+
bmin = p - rad
|
| 277 |
+
bmax = p + rad + 1
|
| 278 |
+
# print(p, bmin, bmax)
|
| 279 |
+
if (d2[p[0], p[1], p[2]] > 0):
|
| 280 |
+
# print(u[p[0],p[1], p[2]])
|
| 281 |
+
# print(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]].sum())
|
| 282 |
+
res[p[0], p[1], p[2]] = min(
|
| 283 |
+
cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] < 0)]) - .25
|
| 284 |
+
else:
|
| 285 |
+
res[p[0], p[1], p[2]] = - min(
|
| 286 |
+
cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] > 0)]) - .25
|
| 287 |
+
|
| 288 |
+
return res
|
| 289 |
+
|
| 290 |
+
def toPolyData(self):
|
| 291 |
+
if gotVTK:
|
| 292 |
+
points = vtkPoints()
|
| 293 |
+
for k in range(self.vertices.shape[0]):
|
| 294 |
+
points.InsertNextPoint(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
|
| 295 |
+
polys = vtkCellArray()
|
| 296 |
+
for k in range(self.faces.shape[0]):
|
| 297 |
+
polys.InsertNextCell(3)
|
| 298 |
+
for kk in range(3):
|
| 299 |
+
polys.InsertCellPoint(self.faces[k, kk])
|
| 300 |
+
polydata = vtkPolyData()
|
| 301 |
+
polydata.SetPoints(points)
|
| 302 |
+
polydata.SetPolys(polys)
|
| 303 |
+
return polydata
|
| 304 |
+
else:
|
| 305 |
+
raise Exception('Cannot run toPolyData without VTK')
|
| 306 |
+
|
| 307 |
+
def fromPolyData(self, g, scales=[1., 1., 1.]):
|
| 308 |
+
npoints = int(g.GetNumberOfPoints())
|
| 309 |
+
nfaces = int(g.GetNumberOfPolys())
|
| 310 |
+
logging.info('Dimensions: %d %d %d' % (npoints, nfaces, g.GetNumberOfCells()))
|
| 311 |
+
V = np.zeros([npoints, 3])
|
| 312 |
+
for kk in range(npoints):
|
| 313 |
+
V[kk, :] = np.array(g.GetPoint(kk))
|
| 314 |
+
# print(kk, V[kk])
|
| 315 |
+
# print(kk, np.array(g.GetPoint(kk)))
|
| 316 |
+
F = np.zeros([nfaces, 3])
|
| 317 |
+
gf = 0
|
| 318 |
+
for kk in range(g.GetNumberOfCells()):
|
| 319 |
+
c = g.GetCell(kk)
|
| 320 |
+
if (c.GetNumberOfPoints() == 3):
|
| 321 |
+
for ll in range(3):
|
| 322 |
+
F[gf, ll] = c.GetPointId(ll)
|
| 323 |
+
# print(kk, gf, F[gf])
|
| 324 |
+
gf += 1
|
| 325 |
+
|
| 326 |
+
# self.vertices = np.multiply(data.shape-V-1, scales)
|
| 327 |
+
self.vertices = np.multiply(V, scales)
|
| 328 |
+
self.faces = np.int_(F[0:gf, :])
|
| 329 |
+
self.computeCentersAreas()
|
| 330 |
+
|
| 331 |
+
def Simplify(self, target=1000.0):
|
| 332 |
+
if gotVTK:
|
| 333 |
+
polydata = self.toPolyData()
|
| 334 |
+
dc = vtkQuadricDecimation()
|
| 335 |
+
red = 1 - min(np.float(target) / polydata.GetNumberOfPoints(), 1)
|
| 336 |
+
dc.SetTargetReduction(red)
|
| 337 |
+
dc.SetInput(polydata)
|
| 338 |
+
dc.Update()
|
| 339 |
+
g = dc.GetOutput()
|
| 340 |
+
self.fromPolyData(g)
|
| 341 |
+
z = self.surfVolume()
|
| 342 |
+
if (z > 0):
|
| 343 |
+
self.flipFaces()
|
| 344 |
+
print('flipping volume', z, self.surfVolume())
|
| 345 |
+
else:
|
| 346 |
+
raise Exception('Cannot run Simplify without VTK')
|
| 347 |
+
|
| 348 |
+
def flipFaces(self):
|
| 349 |
+
self.faces = self.faces[:, [0, 2, 1]]
|
| 350 |
+
self.computeCentersAreas()
|
| 351 |
+
|
| 352 |
+
def smooth(self, n=30, smooth=0.1):
|
| 353 |
+
if gotVTK:
|
| 354 |
+
g = self.toPolyData()
|
| 355 |
+
print(g)
|
| 356 |
+
smoother = vtkWindowedSincPolyDataFilter()
|
| 357 |
+
smoother.SetInput(g)
|
| 358 |
+
smoother.SetNumberOfIterations(n)
|
| 359 |
+
smoother.SetPassBand(smooth)
|
| 360 |
+
smoother.NonManifoldSmoothingOn()
|
| 361 |
+
smoother.NormalizeCoordinatesOn()
|
| 362 |
+
smoother.GenerateErrorScalarsOn()
|
| 363 |
+
# smoother.GenerateErrorVectorsOn()
|
| 364 |
+
smoother.Update()
|
| 365 |
+
g = smoother.GetOutput()
|
| 366 |
+
self.fromPolyData(g)
|
| 367 |
+
else:
|
| 368 |
+
raise Exception('Cannot run smooth without VTK')
|
| 369 |
+
|
| 370 |
+
# Computes isosurfaces using vtk
|
| 371 |
+
def Isosurface(self, data, value=0.5, target=1000.0, scales=[1., 1., 1.], smooth=0.1, fill_holes=1.):
|
| 372 |
+
if gotVTK:
|
| 373 |
+
# data = self.LocalSignedDistance(data0, value)
|
| 374 |
+
if isinstance(data, vtkImageData):
|
| 375 |
+
img = data
|
| 376 |
+
else:
|
| 377 |
+
img = vtkImageData()
|
| 378 |
+
img.SetDimensions(data.shape)
|
| 379 |
+
img.SetOrigin(0, 0, 0)
|
| 380 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 381 |
+
img.AllocateScalars(VTK_FLOAT, 1)
|
| 382 |
+
else:
|
| 383 |
+
img.SetNumberOfScalarComponents(1)
|
| 384 |
+
v = vtkDoubleArray()
|
| 385 |
+
v.SetNumberOfValues(data.size)
|
| 386 |
+
v.SetNumberOfComponents(1)
|
| 387 |
+
for ii, tmp in enumerate(np.ravel(data, order='F')):
|
| 388 |
+
v.SetValue(ii, tmp)
|
| 389 |
+
img.GetPointData().SetScalars(v)
|
| 390 |
+
|
| 391 |
+
cf = vtkContourFilter()
|
| 392 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 393 |
+
cf.SetInputData(img)
|
| 394 |
+
else:
|
| 395 |
+
cf.SetInput(img)
|
| 396 |
+
cf.SetValue(0, value)
|
| 397 |
+
cf.SetNumberOfContours(1)
|
| 398 |
+
cf.Update()
|
| 399 |
+
# print(cf
|
| 400 |
+
connectivity = vtkPolyDataConnectivityFilter()
|
| 401 |
+
connectivity.ScalarConnectivityOff()
|
| 402 |
+
connectivity.SetExtractionModeToLargestRegion()
|
| 403 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 404 |
+
connectivity.SetInputData(cf.GetOutput())
|
| 405 |
+
else:
|
| 406 |
+
connectivity.SetInput(cf.GetOutput())
|
| 407 |
+
connectivity.Update()
|
| 408 |
+
g = connectivity.GetOutput()
|
| 409 |
+
|
| 410 |
+
if smooth > 0:
|
| 411 |
+
smoother = vtkWindowedSincPolyDataFilter()
|
| 412 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 413 |
+
smoother.SetInputData(g)
|
| 414 |
+
else:
|
| 415 |
+
smoother.SetInput(g)
|
| 416 |
+
# else:
|
| 417 |
+
# smoother.SetInputConnection(contour.GetOutputPort())
|
| 418 |
+
smoother.SetNumberOfIterations(30)
|
| 419 |
+
# this has little effect on the error!
|
| 420 |
+
# smoother.BoundarySmoothingOff()
|
| 421 |
+
# smoother.FeatureEdgeSmoothingOff()
|
| 422 |
+
# smoother.SetFeatureAngle(120.0)
|
| 423 |
+
smoother.SetPassBand(smooth) # this increases the error a lot!
|
| 424 |
+
smoother.NonManifoldSmoothingOn()
|
| 425 |
+
# smoother.NormalizeCoordinatesOn()
|
| 426 |
+
# smoother.GenerateErrorScalarsOn()
|
| 427 |
+
# smoother.GenerateErrorVectorsOn()
|
| 428 |
+
smoother.Update()
|
| 429 |
+
g = smoother.GetOutput()
|
| 430 |
+
|
| 431 |
+
# dc = vtkDecimatePro()
|
| 432 |
+
red = 1 - min(np.float(target) / g.GetNumberOfPoints(), 1)
|
| 433 |
+
# print('Reduction: ', red)
|
| 434 |
+
dc = vtkQuadricDecimation()
|
| 435 |
+
dc.SetTargetReduction(red)
|
| 436 |
+
# dc.AttributeErrorMetricOn()
|
| 437 |
+
# dc.SetDegree(10)
|
| 438 |
+
# dc.SetSplitting(0)
|
| 439 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 440 |
+
dc.SetInputData(g)
|
| 441 |
+
else:
|
| 442 |
+
dc.SetInput(g)
|
| 443 |
+
# dc.SetInput(g)
|
| 444 |
+
# print(dc)
|
| 445 |
+
dc.Update()
|
| 446 |
+
g = dc.GetOutput()
|
| 447 |
+
# print('points:', g.GetNumberOfPoints())
|
| 448 |
+
cp = vtkCleanPolyData()
|
| 449 |
+
if vtkVersion.GetVTKMajorVersion() >= 6:
|
| 450 |
+
cp.SetInputData(dc.GetOutput())
|
| 451 |
+
else:
|
| 452 |
+
cp.SetInput(dc.GetOutput())
|
| 453 |
+
# cp.SetInput(dc.GetOutput())
|
| 454 |
+
# cp.SetPointMerging(1)
|
| 455 |
+
cp.ConvertPolysToLinesOn()
|
| 456 |
+
cp.SetAbsoluteTolerance(1e-5)
|
| 457 |
+
cp.Update()
|
| 458 |
+
g = cp.GetOutput()
|
| 459 |
+
self.fromPolyData(g, scales)
|
| 460 |
+
z = self.surfVolume()
|
| 461 |
+
if (z > 0):
|
| 462 |
+
self.flipFaces()
|
| 463 |
+
# print('flipping volume', z, self.surfVolume())
|
| 464 |
+
logging.info('flipping volume %.2f %.2f' % (z, self.surfVolume()))
|
| 465 |
+
|
| 466 |
+
# print(g)
|
| 467 |
+
# npoints = int(g.GetNumberOfPoints())
|
| 468 |
+
# nfaces = int(g.GetNumberOfPolys())
|
| 469 |
+
# print('Dimensions:', npoints, nfaces, g.GetNumberOfCells())
|
| 470 |
+
# V = np.zeros([npoints, 3])
|
| 471 |
+
# for kk in range(npoints):
|
| 472 |
+
# V[kk, :] = np.array(g.GetPoint(kk))
|
| 473 |
+
# #print(kk, V[kk])
|
| 474 |
+
# #print(kk, np.array(g.GetPoint(kk)))
|
| 475 |
+
# F = np.zeros([nfaces, 3])
|
| 476 |
+
# gf = 0
|
| 477 |
+
# for kk in range(g.GetNumberOfCells()):
|
| 478 |
+
# c = g.GetCell(kk)
|
| 479 |
+
# if(c.GetNumberOfPoints() == 3):
|
| 480 |
+
# for ll in range(3):
|
| 481 |
+
# F[gf,ll] = c.GetPointId(ll)
|
| 482 |
+
# #print(kk, gf, F[gf])
|
| 483 |
+
# gf += 1
|
| 484 |
+
|
| 485 |
+
# #self.vertices = np.multiply(data.shape-V-1, scales)
|
| 486 |
+
# self.vertices = np.multiply(V, scales)
|
| 487 |
+
# self.faces = np.int_(F[0:gf, :])
|
| 488 |
+
# self.computeCentersAreas()
|
| 489 |
+
else:
|
| 490 |
+
raise Exception('Cannot run Isosurface without VTK')
|
| 491 |
+
|
| 492 |
+
# Ensures that orientation is correct
|
| 493 |
+
def edgeRecover(self):
|
| 494 |
+
v = self.vertices
|
| 495 |
+
f = self.faces
|
| 496 |
+
nv = v.shape[0]
|
| 497 |
+
nf = f.shape[0]
|
| 498 |
+
# faces containing each oriented edge
|
| 499 |
+
edg0 = np.int_(np.zeros((nv, nv)))
|
| 500 |
+
# number of edges between each vertex
|
| 501 |
+
edg = np.int_(np.zeros((nv, nv)))
|
| 502 |
+
# contiguous faces
|
| 503 |
+
edgF = np.int_(np.zeros((nf, nf)))
|
| 504 |
+
for (kf, c) in enumerate(f):
|
| 505 |
+
if (edg0[c[0], c[1]] > 0):
|
| 506 |
+
edg0[c[1], c[0]] = kf + 1
|
| 507 |
+
else:
|
| 508 |
+
edg0[c[0], c[1]] = kf + 1
|
| 509 |
+
|
| 510 |
+
if (edg0[c[1], c[2]] > 0):
|
| 511 |
+
edg0[c[2], c[1]] = kf + 1
|
| 512 |
+
else:
|
| 513 |
+
edg0[c[1], c[2]] = kf + 1
|
| 514 |
+
|
| 515 |
+
if (edg0[c[2], c[0]] > 0):
|
| 516 |
+
edg0[c[0], c[2]] = kf + 1
|
| 517 |
+
else:
|
| 518 |
+
edg0[c[2], c[0]] = kf + 1
|
| 519 |
+
|
| 520 |
+
edg[c[0], c[1]] += 1
|
| 521 |
+
edg[c[1], c[2]] += 1
|
| 522 |
+
edg[c[2], c[0]] += 1
|
| 523 |
+
|
| 524 |
+
for kv in range(nv):
|
| 525 |
+
I2 = np.nonzero(edg0[kv, :])
|
| 526 |
+
for kkv in I2[0].tolist():
|
| 527 |
+
edgF[edg0[kkv, kv] - 1, edg0[kv, kkv] - 1] = kv + 1
|
| 528 |
+
|
| 529 |
+
isOriented = np.int_(np.zeros(f.shape[0]))
|
| 530 |
+
isActive = np.int_(np.zeros(f.shape[0]))
|
| 531 |
+
I = np.nonzero(np.squeeze(edgF[0, :]))
|
| 532 |
+
# list of faces to be oriented
|
| 533 |
+
activeList = [0] + I[0].tolist()
|
| 534 |
+
lastOriented = 0
|
| 535 |
+
isOriented[0] = True
|
| 536 |
+
for k in activeList:
|
| 537 |
+
isActive[k] = True
|
| 538 |
+
|
| 539 |
+
while lastOriented < len(activeList) - 1:
|
| 540 |
+
i = activeList[lastOriented]
|
| 541 |
+
j = activeList[lastOriented + 1]
|
| 542 |
+
I = np.nonzero(edgF[j, :])
|
| 543 |
+
foundOne = False
|
| 544 |
+
for kk in I[0].tolist():
|
| 545 |
+
if (foundOne == False) & (isOriented[kk]):
|
| 546 |
+
foundOne = True
|
| 547 |
+
u1 = edgF[j, kk] - 1
|
| 548 |
+
u2 = edgF[kk, j] - 1
|
| 549 |
+
if not ((edg[u1, u2] == 1) & (edg[u2, u1] == 1)):
|
| 550 |
+
# reorient face j
|
| 551 |
+
edg[f[j, 0], f[j, 1]] -= 1
|
| 552 |
+
edg[f[j, 1], f[j, 2]] -= 1
|
| 553 |
+
edg[f[j, 2], f[j, 0]] -= 1
|
| 554 |
+
a = f[j, 1]
|
| 555 |
+
f[j, 1] = f[j, 2]
|
| 556 |
+
f[j, 2] = a
|
| 557 |
+
edg[f[j, 0], f[j, 1]] += 1
|
| 558 |
+
edg[f[j, 1], f[j, 2]] += 1
|
| 559 |
+
edg[f[j, 2], f[j, 0]] += 1
|
| 560 |
+
elif (not isActive[kk]):
|
| 561 |
+
activeList.append(kk)
|
| 562 |
+
isActive[kk] = True
|
| 563 |
+
if foundOne:
|
| 564 |
+
lastOriented = lastOriented + 1
|
| 565 |
+
isOriented[j] = True
|
| 566 |
+
# print('oriented face', j, lastOriented, 'out of', nf, '; total active', len(activeList))
|
| 567 |
+
else:
|
| 568 |
+
print('Unable to orient face', j)
|
| 569 |
+
return
|
| 570 |
+
self.vertices = v;
|
| 571 |
+
self.faces = f;
|
| 572 |
+
|
| 573 |
+
z, _ = self.surfVolume()
|
| 574 |
+
if (z > 0):
|
| 575 |
+
self.flipFaces()
|
| 576 |
+
|
| 577 |
+
def removeIsolated(self):
|
| 578 |
+
N = self.vertices.shape[0]
|
| 579 |
+
inFace = np.int_(np.zeros(N))
|
| 580 |
+
for k in range(3):
|
| 581 |
+
inFace[self.faces[:, k]] = 1
|
| 582 |
+
J = np.nonzero(inFace)
|
| 583 |
+
self.vertices = self.vertices[J[0], :]
|
| 584 |
+
logging.info('Found %d isolated vertices' % (J[0].shape[0]))
|
| 585 |
+
Q = -np.ones(N)
|
| 586 |
+
for k, j in enumerate(J[0]):
|
| 587 |
+
Q[j] = k
|
| 588 |
+
self.faces = np.int_(Q[self.faces])
|
| 589 |
+
|
| 590 |
+
def laplacianMatrix(self):
|
| 591 |
+
F = self.faces
|
| 592 |
+
V = self.vertices;
|
| 593 |
+
nf = F.shape[0]
|
| 594 |
+
nv = V.shape[0]
|
| 595 |
+
|
| 596 |
+
AV, AF = self.computeVertexArea()
|
| 597 |
+
|
| 598 |
+
# compute edges and detect boundary
|
| 599 |
+
# edm = sp.lil_matrix((nv,nv))
|
| 600 |
+
edm = -np.ones([nv, nv]).astype(np.int32)
|
| 601 |
+
E = np.zeros([3 * nf, 2]).astype(np.int32)
|
| 602 |
+
j = 0
|
| 603 |
+
for k in range(nf):
|
| 604 |
+
if (edm[F[k, 0], F[k, 1]] == -1):
|
| 605 |
+
edm[F[k, 0], F[k, 1]] = j
|
| 606 |
+
edm[F[k, 1], F[k, 0]] = j
|
| 607 |
+
E[j, :] = [F[k, 0], F[k, 1]]
|
| 608 |
+
j = j + 1
|
| 609 |
+
if (edm[F[k, 1], F[k, 2]] == -1):
|
| 610 |
+
edm[F[k, 1], F[k, 2]] = j
|
| 611 |
+
edm[F[k, 2], F[k, 1]] = j
|
| 612 |
+
E[j, :] = [F[k, 1], F[k, 2]]
|
| 613 |
+
j = j + 1
|
| 614 |
+
if (edm[F[k, 0], F[k, 2]] == -1):
|
| 615 |
+
edm[F[k, 2], F[k, 0]] = j
|
| 616 |
+
edm[F[k, 0], F[k, 2]] = j
|
| 617 |
+
E[j, :] = [F[k, 2], F[k, 0]]
|
| 618 |
+
j = j + 1
|
| 619 |
+
E = E[0:j, :]
|
| 620 |
+
|
| 621 |
+
edgeFace = np.zeros([j, nf])
|
| 622 |
+
ne = j
|
| 623 |
+
# print(E)
|
| 624 |
+
for k in range(nf):
|
| 625 |
+
edgeFace[edm[F[k, 0], F[k, 1]], k] = 1
|
| 626 |
+
edgeFace[edm[F[k, 1], F[k, 2]], k] = 1
|
| 627 |
+
edgeFace[edm[F[k, 2], F[k, 0]], k] = 1
|
| 628 |
+
|
| 629 |
+
bEdge = np.zeros([ne, 1])
|
| 630 |
+
bVert = np.zeros([nv, 1])
|
| 631 |
+
edgeAngles = np.zeros([ne, 2])
|
| 632 |
+
for k in range(ne):
|
| 633 |
+
I = np.flatnonzero(edgeFace[k, :])
|
| 634 |
+
# print('I=', I, F[I, :], E.shape)
|
| 635 |
+
# print('E[k, :]=', k, E[k, :])
|
| 636 |
+
# print(k, edgeFace[k, :])
|
| 637 |
+
for u in range(len(I)):
|
| 638 |
+
f = I[u]
|
| 639 |
+
i1l = np.flatnonzero(F[f, :] == E[k, 0])
|
| 640 |
+
i2l = np.flatnonzero(F[f, :] == E[k, 1])
|
| 641 |
+
# print(f, F[f, :])
|
| 642 |
+
# print(i1l, i2l)
|
| 643 |
+
i1 = i1l[0]
|
| 644 |
+
i2 = i2l[0]
|
| 645 |
+
s = i1 + i2
|
| 646 |
+
if s == 1:
|
| 647 |
+
i3 = 2
|
| 648 |
+
elif s == 2:
|
| 649 |
+
i3 = 1
|
| 650 |
+
elif s == 3:
|
| 651 |
+
i3 = 0
|
| 652 |
+
x1 = V[F[f, i1], :] - V[F[f, i3], :]
|
| 653 |
+
x2 = V[F[f, i2], :] - V[F[f, i3], :]
|
| 654 |
+
a = (np.cross(x1, x2) * np.cross(V[F[f, 1], :] - V[F[f, 0], :], V[F[f, 2], :] - V[F[f, 0], :])).sum()
|
| 655 |
+
b = (x1 * x2).sum()
|
| 656 |
+
if (a > 0):
|
| 657 |
+
edgeAngles[k, u] = b / np.sqrt(a)
|
| 658 |
+
else:
|
| 659 |
+
edgeAngles[k, u] = b / np.sqrt(-a)
|
| 660 |
+
if (len(I) == 1):
|
| 661 |
+
# boundary edge
|
| 662 |
+
bEdge[k] = 1
|
| 663 |
+
bVert[E[k, 0]] = 1
|
| 664 |
+
bVert[E[k, 1]] = 1
|
| 665 |
+
edgeAngles[k, 1] = 0
|
| 666 |
+
|
| 667 |
+
# Compute Laplacian matrix
|
| 668 |
+
L = np.zeros([nv, nv])
|
| 669 |
+
|
| 670 |
+
for k in range(ne):
|
| 671 |
+
L[E[k, 0], E[k, 1]] = (edgeAngles[k, 0] + edgeAngles[k, 1]) / 2
|
| 672 |
+
L[E[k, 1], E[k, 0]] = L[E[k, 0], E[k, 1]]
|
| 673 |
+
|
| 674 |
+
for k in range(nv):
|
| 675 |
+
L[k, k] = - L[k, :].sum()
|
| 676 |
+
|
| 677 |
+
A = np.zeros([nv, nv])
|
| 678 |
+
for k in range(nv):
|
| 679 |
+
A[k, k] = AV[k]
|
| 680 |
+
|
| 681 |
+
return L, A
|
| 682 |
+
|
| 683 |
+
def graphLaplacianMatrix(self):
|
| 684 |
+
F = self.faces
|
| 685 |
+
V = self.vertices
|
| 686 |
+
nf = F.shape[0]
|
| 687 |
+
nv = V.shape[0]
|
| 688 |
+
|
| 689 |
+
# compute edges and detect boundary
|
| 690 |
+
# edm = sp.lil_matrix((nv,nv))
|
| 691 |
+
edm = -np.ones([nv, nv])
|
| 692 |
+
E = np.zeros([3 * nf, 2])
|
| 693 |
+
j = 0
|
| 694 |
+
for k in range(nf):
|
| 695 |
+
if (edm[F[k, 0], F[k, 1]] == -1):
|
| 696 |
+
edm[F[k, 0], F[k, 1]] = j
|
| 697 |
+
edm[F[k, 1], F[k, 0]] = j
|
| 698 |
+
E[j, :] = [F[k, 0], F[k, 1]]
|
| 699 |
+
j = j + 1
|
| 700 |
+
if (edm[F[k, 1], F[k, 2]] == -1):
|
| 701 |
+
edm[F[k, 1], F[k, 2]] = j
|
| 702 |
+
edm[F[k, 2], F[k, 1]] = j
|
| 703 |
+
E[j, :] = [F[k, 1], F[k, 2]]
|
| 704 |
+
j = j + 1
|
| 705 |
+
if (edm[F[k, 0], F[k, 2]] == -1):
|
| 706 |
+
edm[F[k, 2], F[k, 0]] = j
|
| 707 |
+
edm[F[k, 0], F[k, 2]] = j
|
| 708 |
+
E[j, :] = [F[k, 2], F[k, 0]]
|
| 709 |
+
j = j + 1
|
| 710 |
+
E = E[0:j, :]
|
| 711 |
+
|
| 712 |
+
edgeFace = np.zeros([j, nf])
|
| 713 |
+
ne = j
|
| 714 |
+
# print(E)
|
| 715 |
+
|
| 716 |
+
# Compute Laplacian matrix
|
| 717 |
+
L = np.zeros([nv, nv])
|
| 718 |
+
|
| 719 |
+
for k in range(ne):
|
| 720 |
+
L[E[k, 0], E[k, 1]] = 1
|
| 721 |
+
L[E[k, 1], E[k, 0]] = 1
|
| 722 |
+
|
| 723 |
+
for k in range(nv):
|
| 724 |
+
L[k, k] = - L[k, :].sum()
|
| 725 |
+
|
| 726 |
+
return L
|
| 727 |
+
|
| 728 |
+
def cotanLaplacian(self, eps=1e-8):
|
| 729 |
+
L = pp3d.cotan_laplacian(self.vertices, self.faces, denom_eps=1e-10)
|
| 730 |
+
massvec_np = pp3d.vertex_areas(self.vertices, self.faces)
|
| 731 |
+
massvec_np += eps * np.mean(massvec_np)
|
| 732 |
+
|
| 733 |
+
if (np.isnan(L.data).any()):
|
| 734 |
+
raise RuntimeError("NaN Laplace matrix")
|
| 735 |
+
if (np.isnan(massvec_np).any()):
|
| 736 |
+
raise RuntimeError("NaN mass matrix")
|
| 737 |
+
self.L = L
|
| 738 |
+
self.massvec_np = massvec_np
|
| 739 |
+
return L, massvec_np
|
| 740 |
+
|
| 741 |
+
def lapEigen(self, n_eig=10, eps=1e-8):
|
| 742 |
+
L_eigsh = (self.L + sp.sparse.identity(self.L.shape[0]) * eps).tocsc()
|
| 743 |
+
massvec_eigsh = self.massvec_np
|
| 744 |
+
Mmat = sp.sparse.diags(massvec_eigsh)
|
| 745 |
+
eigs_sigma = eps#-0.01
|
| 746 |
+
evals, evecs = sla.eigsh(L_eigsh, k=n_eig, M=Mmat, sigma=eigs_sigma)
|
| 747 |
+
return evals, evecs
|
| 748 |
+
|
| 749 |
+
def laplacianSegmentation(self, k, verbose=False):
|
| 750 |
+
# (L, AA) = self.laplacianMatrix()
|
| 751 |
+
# # print((L.shape[0]-k-1, L.shape[0]-2))
|
| 752 |
+
# (D, y) = spLA.eigh(L, AA, eigvals=(L.shape[0] - k, L.shape[0] - 1))
|
| 753 |
+
evals, evecs = self.lapEigen(k*2)
|
| 754 |
+
y = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) #**2
|
| 755 |
+
# N = y.shape[0]
|
| 756 |
+
# d = y.shape[1]
|
| 757 |
+
# I = np.argsort(y.sum(axis=1))
|
| 758 |
+
# I0 = np.floor((N - 1) * sp.linspace(0, 1, num=k)).astype(int)
|
| 759 |
+
# # print(y.shape, L.shape, N, k, d)
|
| 760 |
+
# C = y[I0, :].copy()
|
| 761 |
+
#
|
| 762 |
+
# eps = 1e-20
|
| 763 |
+
# Cold = C.copy()
|
| 764 |
+
# u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
|
| 765 |
+
# T = u.min(axis=0).sum() / (N)
|
| 766 |
+
# # print(T)
|
| 767 |
+
# j = 0
|
| 768 |
+
# while j < 5000:
|
| 769 |
+
# u0 = u - u.min(axis=0).reshape([1, N])
|
| 770 |
+
# w = np.exp(-u0 / T);
|
| 771 |
+
# w = w / (eps + w.sum(axis=0).reshape([1, N]))
|
| 772 |
+
# # print(w.min(), w.max())
|
| 773 |
+
# cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
|
| 774 |
+
# C = np.dot(w, y) / (eps + w.sum(axis=1).reshape([k, 1]))
|
| 775 |
+
# # print(j, 'cost0 ', cost)
|
| 776 |
+
#
|
| 777 |
+
# u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
|
| 778 |
+
# cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
|
| 779 |
+
# err = np.sqrt(((C - Cold) ** 2).sum(axis=1)).sum()
|
| 780 |
+
# # print(j, 'cost ', cost, err, T)
|
| 781 |
+
# if (j > 100) & (err < 1e-4):
|
| 782 |
+
# break
|
| 783 |
+
# j = j + 1
|
| 784 |
+
# Cold = C.copy()
|
| 785 |
+
# T = T * 0.99
|
| 786 |
+
#
|
| 787 |
+
# # print(k, d, C.shape)
|
| 788 |
+
# dst = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
|
| 789 |
+
# md = dst.min(axis=0)
|
| 790 |
+
# idx = np.zeros(N).astype(int)
|
| 791 |
+
# for j in range(N):
|
| 792 |
+
# I = np.flatnonzero(dst[:, j] < md[j] + 1e-10)
|
| 793 |
+
# idx[j] = I[0]
|
| 794 |
+
# I = -np.ones(k).astype(int)
|
| 795 |
+
# kk = 0
|
| 796 |
+
# for j in range(k):
|
| 797 |
+
# if True in (idx == j):
|
| 798 |
+
# I[j] = kk
|
| 799 |
+
# kk += 1
|
| 800 |
+
# idx = I[idx]
|
| 801 |
+
# if idx.max() < (k - 1):
|
| 802 |
+
# print('Warning: kmeans convergence with %d clusters instead of %d' % (idx.max(), k))
|
| 803 |
+
# # ml = w.sum(axis=1)/N
|
| 804 |
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(y)
|
| 805 |
+
# kmeans = DBSCAN().fit(y)
|
| 806 |
+
idx = kmeans.labels_
|
| 807 |
+
nc = idx.max() + 1
|
| 808 |
+
C = np.zeros([nc, self.vertices.shape[1]])
|
| 809 |
+
a, foo = self.computeVertexArea()
|
| 810 |
+
for k in range(nc):
|
| 811 |
+
I = np.flatnonzero(idx == k)
|
| 812 |
+
nI = len(I)
|
| 813 |
+
# print(a.shape, nI)
|
| 814 |
+
aI = a[I]
|
| 815 |
+
ak = aI.sum()
|
| 816 |
+
C[k, :] = (self.vertices[I, :] * aI).sum(axis=0) / ak;
|
| 817 |
+
mean_eigen = (y*a).sum(axis=0)/a
|
| 818 |
+
tree = cKDTree(y)
|
| 819 |
+
_, indices = tree.query(mean_eigen, k=1)
|
| 820 |
+
index_center = indices[0]
|
| 821 |
+
_, indices = tree.query(C, k=1)
|
| 822 |
+
index_C = indices[0]
|
| 823 |
+
|
| 824 |
+
mesh = tm.Trimesh(vertices=self.vertices, faces=self.faces)
|
| 825 |
+
nv = self.computeVertexNormals()
|
| 826 |
+
rori = self.vertices[index] + 0.001 * (-nv[index, :])
|
| 827 |
+
rdir = -nv[index, :]
|
| 828 |
+
|
| 829 |
+
locations, index_ray, index_tri = mesh.ray.intersects_location(ray_origins=rori[None, :], ray_directions=rdir[None, :])
|
| 830 |
+
if verbose:
|
| 831 |
+
print(locations, index_ray, index_tri)
|
| 832 |
+
print(index_center, index_C)
|
| 833 |
+
return idx, C, index_center, locations[0]
|
| 834 |
+
|
| 835 |
+
def get_keypoints(self, n_eig=30, n_points=5, sym=False):
|
| 836 |
+
evals, evecs = self.lapEigen(n_eig+1)
|
| 837 |
+
pts_evecs = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) # approximate geod distance
|
| 838 |
+
tree = cKDTree(pts_evecs)
|
| 839 |
+
_, center_index = tree.query(np.zeros(pts_evecs.shape[-1]), k=1)
|
| 840 |
+
indices = fps_gpt_geod(self.vertices, self.faces, 10, center_index)[1:]
|
| 841 |
+
areas = np.linalg.norm(self.surfel, axis=-1, keepdims=True)
|
| 842 |
+
area = np.sqrt(areas.sum()/2)
|
| 843 |
+
print(area)
|
| 844 |
+
## Looking for center + provide distances to the center
|
| 845 |
+
|
| 846 |
+
print(center_index)
|
| 847 |
+
solver = pp3d.MeshHeatMethodDistanceSolver(self.vertices, self.faces)
|
| 848 |
+
norm_center = solver.compute_distance(center_index)
|
| 849 |
+
|
| 850 |
+
## Dist matrix just between selected points
|
| 851 |
+
dist_mat_indices = np.zeros((len(indices), len(indices)))
|
| 852 |
+
all_ii = np.arange(len(indices)) # just to easy code
|
| 853 |
+
for ii, index in enumerate(indices):
|
| 854 |
+
dist_mat_indices[ii, ii!=all_ii] = solver.compute_distance(index)[indices[indices != index]]
|
| 855 |
+
|
| 856 |
+
## Select points which : farthest from the center, delete their neighbors with dist < 0.5*area
|
| 857 |
+
keep = []
|
| 858 |
+
max_norm = norm_center.max()
|
| 859 |
+
for ii, index in enumerate(indices[:n_points]):
|
| 860 |
+
# distances_compar = (dist_mat_indices[ii, ii!=all_ii] + norm_center[ii]) / norm_center[indices[ii!=all_ii]]
|
| 861 |
+
# min_dist = np.amin(np.abs(distances_compar - 1))
|
| 862 |
+
# print(min_dist/norm_center[ii], min_dist)
|
| 863 |
+
# if min_dist > 0.3:
|
| 864 |
+
# #print(ii, min_dist/norm_center[ii], norm_center[ii], dist_mat_indices[ii, ii!=all_ii], norm_center[indices[indices != index]], indices[7])
|
| 865 |
+
keep.append(index)
|
| 866 |
+
## Add the index of the center
|
| 867 |
+
## Add the index of the center
|
| 868 |
+
if sym:
|
| 869 |
+
mesh = tm.base.Trimesh(self.vertices, self.faces, process=False) # Load your mesh
|
| 870 |
+
# Given data
|
| 871 |
+
index = 123 # Your starting point index
|
| 872 |
+
direction = -mesh.vertex_normals[center_index]
|
| 873 |
+
|
| 874 |
+
# Get the starting vertex position
|
| 875 |
+
start_point = mesh.vertices[center_index]
|
| 876 |
+
|
| 877 |
+
# Perform ray intersection
|
| 878 |
+
locations, index_ray, index_tri = mesh.ray.intersects_location(
|
| 879 |
+
ray_origins=start_point[None, :], # Starting point
|
| 880 |
+
ray_directions=direction[None, :] # Direction of travel
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# If intersections exist
|
| 884 |
+
if len(locations) > 0:
|
| 885 |
+
# Sort intersections by distance from start_point
|
| 886 |
+
distances = np.linalg.norm(locations - start_point, axis=1)
|
| 887 |
+
sorted_indices = np.argsort(distances)
|
| 888 |
+
|
| 889 |
+
# Get the first valid intersection point beyond the start point
|
| 890 |
+
#next_intersection = locations[sorted_indices]
|
| 891 |
+
#print(f"Next intersection at: {next_intersection}")
|
| 892 |
+
intersected_triangle = index_tri[sorted_indices[1]]
|
| 893 |
+
print(f"Intersected Triangle Index: {mesh.faces[intersected_triangle]}")
|
| 894 |
+
center_index = mesh.faces[intersected_triangle][0]
|
| 895 |
+
else:
|
| 896 |
+
print("No intersection found in the given direction.")
|
| 897 |
+
keep.append(center_index)
|
| 898 |
+
return keep
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
# Computes surface volume
|
| 902 |
+
def surfVolume(self):
|
| 903 |
+
f = self.faces
|
| 904 |
+
v = self.vertices
|
| 905 |
+
t = v[f, :]
|
| 906 |
+
vols = np.linalg.det(t) / 6
|
| 907 |
+
return vols.sum(), vols
|
| 908 |
+
|
| 909 |
+
def surfCenter(self):
|
| 910 |
+
f = self.faces
|
| 911 |
+
v = self.vertices
|
| 912 |
+
center_infs = (v[f, :].sum(axis=1) / 4) * self.vols[:, np.newaxis]
|
| 913 |
+
center = center_infs.sum(axis=0)
|
| 914 |
+
return center
|
| 915 |
+
|
| 916 |
+
def surfMoments(self):
|
| 917 |
+
f = self.faces
|
| 918 |
+
v = self.vertices
|
| 919 |
+
vec_0 = v[f[:, 0], :] + v[f[:, 1], :]
|
| 920 |
+
s_0 = vec_0[:, :, np.newaxis] * vec_0[:, np.newaxis, :]
|
| 921 |
+
vec_1 = v[f[:, 0], :] + v[f[:, 2], :]
|
| 922 |
+
s_1 = vec_1[:, :, np.newaxis] * vec_1[:, np.newaxis, :]
|
| 923 |
+
vec_2 = v[f[:, 1], :] + v[f[:, 2], :]
|
| 924 |
+
s_2 = vec_2[:, :, np.newaxis] * vec_2[:, np.newaxis, :]
|
| 925 |
+
moments_inf = self.vols[:, np.newaxis, np.newaxis] * (1. / 20) * (s_0 + s_1 + s_2)
|
| 926 |
+
return moments_inf.sum(axis=0)
|
| 927 |
+
|
| 928 |
+
def surfF(self):
|
| 929 |
+
f = self.faces
|
| 930 |
+
v = self.vertices
|
| 931 |
+
cent = (v[f, :].sum(axis=1)) / 4.
|
| 932 |
+
F = self.vols[:, np.newaxis] * np.sign(cent) * (cent ** 2)
|
| 933 |
+
return F.sum(axis=0)
|
| 934 |
+
|
| 935 |
+
def surfU(self):
|
| 936 |
+
vol = self.volume
|
| 937 |
+
vertices = self.vertices / pow(vol, 1. / 3)
|
| 938 |
+
vertices -= self.center
|
| 939 |
+
self.updateVertices(vertices)
|
| 940 |
+
moments = self.surfMoments()
|
| 941 |
+
u, s, vh = np.linalg.svd(moments)
|
| 942 |
+
F = self.surfF()
|
| 943 |
+
return np.diag(np.sign(F)) @ u, s
|
| 944 |
+
|
| 945 |
+
def surfEllipsoid(self, u, s, moments):
|
| 946 |
+
coeff = pow(4 * np.pi / 15, 1. / 5) * pow(np.linalg.det(moments), -1. / 10)
|
| 947 |
+
A = coeff * ((u * np.sqrt(s)) @ u.T)
|
| 948 |
+
return u, A
|
| 949 |
+
|
| 950 |
+
# Reads from .off file
|
| 951 |
+
def readOFF(self, offfile):
|
| 952 |
+
with open(offfile, 'r') as f:
|
| 953 |
+
all_lines = f.readlines()
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
n_vertices = int(all_lines[1].split()[0])
|
| 957 |
+
n_faces = int(all_lines[1].split()[1])
|
| 958 |
+
|
| 959 |
+
vertices_list = []
|
| 960 |
+
for i in range(n_vertices):
|
| 961 |
+
vertices_list.append([float(x) for x in all_lines[2+i].split()[:3]])
|
| 962 |
+
|
| 963 |
+
faces_list = []
|
| 964 |
+
for i in range(n_faces):
|
| 965 |
+
# Be careful to convert to int. Otherwise, you can use np.array(faces_list).astype(np.int32)
|
| 966 |
+
faces_list.append([int(x) for x in all_lines[2+i+n_vertices].split()[1:4]])
|
| 967 |
+
self.faces = np.array(faces_list)
|
| 968 |
+
self.vertices = np.array(vertices_list)
|
| 969 |
+
self.computeCentersAreas() # Reads from .byu file
|
| 970 |
+
def readbyu(self, byufile):
|
| 971 |
+
with open(byufile, 'r') as fbyu:
|
| 972 |
+
ln0 = fbyu.readline()
|
| 973 |
+
ln = ln0.split()
|
| 974 |
+
# read header
|
| 975 |
+
ncomponents = int(ln[0]) # number of components
|
| 976 |
+
npoints = int(ln[1]) # number of vertices
|
| 977 |
+
nfaces = int(ln[2]) # number of faces
|
| 978 |
+
# fscanf(fbyu,'%d',1); % number of edges
|
| 979 |
+
# %ntest = fscanf(fbyu,'%d',1); % number of edges
|
| 980 |
+
for k in range(ncomponents):
|
| 981 |
+
fbyu.readline() # components (ignored)
|
| 982 |
+
# read data
|
| 983 |
+
self.vertices = np.empty([npoints, 3])
|
| 984 |
+
k = -1
|
| 985 |
+
while k < npoints - 1:
|
| 986 |
+
ln = fbyu.readline().split()
|
| 987 |
+
k = k + 1;
|
| 988 |
+
self.vertices[k, 0] = float(ln[0])
|
| 989 |
+
self.vertices[k, 1] = float(ln[1])
|
| 990 |
+
self.vertices[k, 2] = float(ln[2])
|
| 991 |
+
if len(ln) > 3:
|
| 992 |
+
k = k + 1;
|
| 993 |
+
self.vertices[k, 0] = float(ln[3])
|
| 994 |
+
self.vertices[k, 1] = float(ln[4])
|
| 995 |
+
self.vertices[k, 2] = float(ln[5])
|
| 996 |
+
|
| 997 |
+
self.faces = np.empty([nfaces, 3])
|
| 998 |
+
ln = fbyu.readline().split()
|
| 999 |
+
kf = 0
|
| 1000 |
+
j = 0
|
| 1001 |
+
while ln:
|
| 1002 |
+
if kf >= nfaces:
|
| 1003 |
+
break
|
| 1004 |
+
# print(nfaces, kf, ln)
|
| 1005 |
+
for s in ln:
|
| 1006 |
+
self.faces[kf, j] = int(sp.fabs(int(s)))
|
| 1007 |
+
j = j + 1
|
| 1008 |
+
if j == 3:
|
| 1009 |
+
kf = kf + 1
|
| 1010 |
+
j = 0
|
| 1011 |
+
ln = fbyu.readline().split()
|
| 1012 |
+
self.faces = np.int_(self.faces) - 1
|
| 1013 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 1014 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 1015 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 1016 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 1017 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 1018 |
+
|
| 1019 |
+
# Saves in .byu format
|
| 1020 |
+
def savebyu(self, byufile):
|
| 1021 |
+
# FV = readbyu(byufile)
|
| 1022 |
+
# reads from a .byu file into matlab's face vertex structure FV
|
| 1023 |
+
|
| 1024 |
+
with open(byufile, 'w') as fbyu:
|
| 1025 |
+
# copy header
|
| 1026 |
+
ncomponents = 1 # number of components
|
| 1027 |
+
npoints = self.vertices.shape[0] # number of vertices
|
| 1028 |
+
nfaces = self.faces.shape[0] # number of faces
|
| 1029 |
+
nedges = 3 * nfaces # number of edges
|
| 1030 |
+
|
| 1031 |
+
str = '{0: d} {1: d} {2: d} {3: d} 0\n'.format(ncomponents, npoints, nfaces, nedges)
|
| 1032 |
+
fbyu.write(str)
|
| 1033 |
+
str = '1 {0: d}\n'.format(nfaces)
|
| 1034 |
+
fbyu.write(str)
|
| 1035 |
+
|
| 1036 |
+
k = -1
|
| 1037 |
+
while k < (npoints - 1):
|
| 1038 |
+
k = k + 1
|
| 1039 |
+
str = '{0: f} {1: f} {2: f} '.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
|
| 1040 |
+
fbyu.write(str)
|
| 1041 |
+
if k < (npoints - 1):
|
| 1042 |
+
k = k + 1
|
| 1043 |
+
str = '{0: f} {1: f} {2: f}\n'.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
|
| 1044 |
+
fbyu.write(str)
|
| 1045 |
+
else:
|
| 1046 |
+
fbyu.write('\n')
|
| 1047 |
+
|
| 1048 |
+
j = 0
|
| 1049 |
+
for k in range(nfaces):
|
| 1050 |
+
for kk in (0, 1):
|
| 1051 |
+
fbyu.write('{0: d} '.format(self.faces[k, kk] + 1))
|
| 1052 |
+
j = j + 1
|
| 1053 |
+
if j == 16:
|
| 1054 |
+
fbyu.write('\n')
|
| 1055 |
+
j = 0
|
| 1056 |
+
|
| 1057 |
+
fbyu.write('{0: d} '.format(-self.faces[k, 2] - 1))
|
| 1058 |
+
j = j + 1
|
| 1059 |
+
if j == 16:
|
| 1060 |
+
fbyu.write('\n')
|
| 1061 |
+
j = 0
|
| 1062 |
+
|
| 1063 |
+
def saveVTK(self, fileName, scalars=None, normals=None, tensors=None, scal_name='scalars', vectors=None,
|
| 1064 |
+
vect_name='vectors'):
|
| 1065 |
+
vf = vtkFields()
|
| 1066 |
+
# print(scalars)
|
| 1067 |
+
if not (scalars == None):
|
| 1068 |
+
vf.scalars.append(scal_name)
|
| 1069 |
+
vf.scalars.append(scalars)
|
| 1070 |
+
if not (vectors == None):
|
| 1071 |
+
vf.vectors.append(vect_name)
|
| 1072 |
+
vf.vectors.append(vectors)
|
| 1073 |
+
if not (normals == None):
|
| 1074 |
+
vf.normals.append('normals')
|
| 1075 |
+
vf.normals.append(normals)
|
| 1076 |
+
if not (tensors == None):
|
| 1077 |
+
vf.tensors.append('tensors')
|
| 1078 |
+
vf.tensors.append(tensors)
|
| 1079 |
+
self.saveVTK2(fileName, vf)
|
| 1080 |
+
|
| 1081 |
+
# Saves in .vtk format
|
| 1082 |
+
def saveVTK2(self, fileName, vtkFields=None):
|
| 1083 |
+
F = self.faces;
|
| 1084 |
+
V = self.vertices;
|
| 1085 |
+
|
| 1086 |
+
with open(fileName, 'w') as fvtkout:
|
| 1087 |
+
fvtkout.write('# vtk DataFile Version 3.0\nSurface Data\nASCII\nDATASET POLYDATA\n')
|
| 1088 |
+
fvtkout.write('\nPOINTS {0: d} float'.format(V.shape[0]))
|
| 1089 |
+
for ll in range(V.shape[0]):
|
| 1090 |
+
fvtkout.write('\n{0: f} {1: f} {2: f}'.format(V[ll, 0], V[ll, 1], V[ll, 2]))
|
| 1091 |
+
fvtkout.write('\nPOLYGONS {0:d} {1:d}'.format(F.shape[0], 4 * F.shape[0]))
|
| 1092 |
+
for ll in range(F.shape[0]):
|
| 1093 |
+
fvtkout.write('\n3 {0: d} {1: d} {2: d}'.format(F[ll, 0], F[ll, 1], F[ll, 2]))
|
| 1094 |
+
if not (vtkFields == None):
|
| 1095 |
+
wrote_pd_hdr = False
|
| 1096 |
+
if len(vtkFields.scalars) > 0:
|
| 1097 |
+
if not wrote_pd_hdr:
|
| 1098 |
+
fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
|
| 1099 |
+
wrote_pd_hdr = True
|
| 1100 |
+
nf = len(vtkFields.scalars) / 2
|
| 1101 |
+
for k in range(nf):
|
| 1102 |
+
fvtkout.write('\nSCALARS ' + vtkFields.scalars[2 * k] + ' float 1\nLOOKUP_TABLE default')
|
| 1103 |
+
for ll in range(V.shape[0]):
|
| 1104 |
+
# print(scalars[ll])
|
| 1105 |
+
fvtkout.write('\n {0: .5f}'.format(vtkFields.scalars[2 * k + 1][ll]))
|
| 1106 |
+
if len(vtkFields.vectors) > 0:
|
| 1107 |
+
if not wrote_pd_hdr:
|
| 1108 |
+
fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
|
| 1109 |
+
wrote_pd_hdr = True
|
| 1110 |
+
nf = len(vtkFields.vectors) / 2
|
| 1111 |
+
for k in range(nf):
|
| 1112 |
+
fvtkout.write('\nVECTORS ' + vtkFields.vectors[2 * k] + ' float')
|
| 1113 |
+
vectors = vtkFields.vectors[2 * k + 1]
|
| 1114 |
+
for ll in range(V.shape[0]):
|
| 1115 |
+
fvtkout.write(
|
| 1116 |
+
'\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
|
| 1117 |
+
if len(vtkFields.normals) > 0:
|
| 1118 |
+
if not wrote_pd_hdr:
|
| 1119 |
+
fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
|
| 1120 |
+
wrote_pd_hdr = True
|
| 1121 |
+
nf = len(vtkFields.normals) / 2
|
| 1122 |
+
for k in range(nf):
|
| 1123 |
+
fvtkout.write('\nNORMALS ' + vtkFields.normals[2 * k] + ' float')
|
| 1124 |
+
vectors = vtkFields.normals[2 * k + 1]
|
| 1125 |
+
for ll in range(V.shape[0]):
|
| 1126 |
+
fvtkout.write(
|
| 1127 |
+
'\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
|
| 1128 |
+
if len(vtkFields.tensors) > 0:
|
| 1129 |
+
if not wrote_pd_hdr:
|
| 1130 |
+
fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
|
| 1131 |
+
wrote_pd_hdr = True
|
| 1132 |
+
nf = len(vtkFields.tensors) / 2
|
| 1133 |
+
for k in range(nf):
|
| 1134 |
+
fvtkout.write('\nTENSORS ' + vtkFields.tensors[2 * k] + ' float')
|
| 1135 |
+
tensors = vtkFields.tensors[2 * k + 1]
|
| 1136 |
+
for ll in range(V.shape[0]):
|
| 1137 |
+
for kk in range(2):
|
| 1138 |
+
fvtkout.write(
|
| 1139 |
+
'\n {0: .5f} {1: .5f} {2: .5f}'.format(tensors[ll, kk, 0], tensors[ll, kk, 1],
|
| 1140 |
+
tensors[ll, kk, 2]))
|
| 1141 |
+
fvtkout.write('\n')
|
| 1142 |
+
|
| 1143 |
+
# Reads .vtk file
|
| 1144 |
+
def readVTK(self, fileName):
|
| 1145 |
+
if gotVTK:
|
| 1146 |
+
u = vtkPolyDataReader()
|
| 1147 |
+
u.SetFileName(fileName)
|
| 1148 |
+
u.Update()
|
| 1149 |
+
v = u.GetOutput()
|
| 1150 |
+
# print(v)
|
| 1151 |
+
npoints = int(v.GetNumberOfPoints())
|
| 1152 |
+
nfaces = int(v.GetNumberOfPolys())
|
| 1153 |
+
V = np.zeros([npoints, 3])
|
| 1154 |
+
for kk in range(npoints):
|
| 1155 |
+
V[kk, :] = np.array(v.GetPoint(kk))
|
| 1156 |
+
|
| 1157 |
+
F = np.zeros([nfaces, 3])
|
| 1158 |
+
for kk in range(nfaces):
|
| 1159 |
+
c = v.GetCell(kk)
|
| 1160 |
+
for ll in range(3):
|
| 1161 |
+
F[kk, ll] = c.GetPointId(ll)
|
| 1162 |
+
|
| 1163 |
+
self.vertices = V
|
| 1164 |
+
self.faces = np.int_(F)
|
| 1165 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 1166 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 1167 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 1168 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 1169 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 1170 |
+
else:
|
| 1171 |
+
raise Exception('Cannot run readVTK without VTK')
|
| 1172 |
+
|
| 1173 |
+
# Reads .obj file
|
| 1174 |
+
def readOBJ(self, fileName):
|
| 1175 |
+
if gotVTK:
|
| 1176 |
+
u = vtkOBJReader()
|
| 1177 |
+
u.SetFileName(fileName)
|
| 1178 |
+
u.Update()
|
| 1179 |
+
v = u.GetOutput()
|
| 1180 |
+
# print(v)
|
| 1181 |
+
npoints = int(v.GetNumberOfPoints())
|
| 1182 |
+
nfaces = int(v.GetNumberOfPolys())
|
| 1183 |
+
V = np.zeros([npoints, 3])
|
| 1184 |
+
for kk in range(npoints):
|
| 1185 |
+
V[kk, :] = np.array(v.GetPoint(kk))
|
| 1186 |
+
|
| 1187 |
+
F = np.zeros([nfaces, 3])
|
| 1188 |
+
for kk in range(nfaces):
|
| 1189 |
+
c = v.GetCell(kk)
|
| 1190 |
+
for ll in range(3):
|
| 1191 |
+
F[kk, ll] = c.GetPointId(ll)
|
| 1192 |
+
|
| 1193 |
+
self.vertices = V
|
| 1194 |
+
self.faces = np.int_(F)
|
| 1195 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 1196 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 1197 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 1198 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 1199 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 1200 |
+
else:
|
| 1201 |
+
raise Exception('Cannot run readOBJ without VTK')
|
| 1202 |
+
|
| 1203 |
+
def readPLY(self, fileName):
|
| 1204 |
+
if gotVTK:
|
| 1205 |
+
u = vtkPLYReader()
|
| 1206 |
+
u.SetFileName(fileName)
|
| 1207 |
+
u.Update()
|
| 1208 |
+
v = u.GetOutput()
|
| 1209 |
+
# print(v)
|
| 1210 |
+
npoints = int(v.GetNumberOfPoints())
|
| 1211 |
+
nfaces = int(v.GetNumberOfPolys())
|
| 1212 |
+
V = np.zeros([npoints, 3])
|
| 1213 |
+
for kk in range(npoints):
|
| 1214 |
+
V[kk, :] = np.array(v.GetPoint(kk))
|
| 1215 |
+
|
| 1216 |
+
F = np.zeros([nfaces, 3])
|
| 1217 |
+
for kk in range(nfaces):
|
| 1218 |
+
c = v.GetCell(kk)
|
| 1219 |
+
for ll in range(3):
|
| 1220 |
+
F[kk, ll] = c.GetPointId(ll)
|
| 1221 |
+
|
| 1222 |
+
self.vertices = V
|
| 1223 |
+
self.faces = np.int_(F)
|
| 1224 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 1225 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 1226 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 1227 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 1228 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 1229 |
+
else:
|
| 1230 |
+
raise Exception('Cannot run readPLY without VTK')
|
| 1231 |
+
|
| 1232 |
+
def readTRI(self, fileName):
|
| 1233 |
+
vertices = []
|
| 1234 |
+
faces = []
|
| 1235 |
+
with open(fileName, "r") as f:
|
| 1236 |
+
pithc = f.readlines()
|
| 1237 |
+
line_1 = pithc[0]
|
| 1238 |
+
n_pts, n_tri = line_1.split(' ')
|
| 1239 |
+
for i, line in enumerate(pithc[1:]):
|
| 1240 |
+
pouetage = line.split(' ')
|
| 1241 |
+
if i <= int(n_pts):
|
| 1242 |
+
vertices.append([float(poupout) for poupout in pouetage[:3]])
|
| 1243 |
+
# vertices.append([float(poupout) for poupout in pouetage[4:7]])
|
| 1244 |
+
else:
|
| 1245 |
+
faces.append([int(poupout.split("\n")[0]) for poupout in pouetage[1:]])
|
| 1246 |
+
self.vertices = np.array(vertices)
|
| 1247 |
+
self.faces = np.array(faces)
|
| 1248 |
+
xDef1 = self.vertices[self.faces[:, 0], :]
|
| 1249 |
+
xDef2 = self.vertices[self.faces[:, 1], :]
|
| 1250 |
+
xDef3 = self.vertices[self.faces[:, 2], :]
|
| 1251 |
+
self.centers = (xDef1 + xDef2 + xDef3) / 3
|
| 1252 |
+
self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
|
| 1253 |
+
|
| 1254 |
+
def concatenate(self, fvl):
|
| 1255 |
+
nv = 0
|
| 1256 |
+
nf = 0
|
| 1257 |
+
for fv in fvl:
|
| 1258 |
+
nv += fv.vertices.shape[0]
|
| 1259 |
+
nf += fv.faces.shape[0]
|
| 1260 |
+
self.vertices = np.zeros([nv, 3])
|
| 1261 |
+
self.faces = np.zeros([nf, 3], dtype='int')
|
| 1262 |
+
|
| 1263 |
+
nv0 = 0
|
| 1264 |
+
nf0 = 0
|
| 1265 |
+
for fv in fvl:
|
| 1266 |
+
nv = nv0 + fv.vertices.shape[0]
|
| 1267 |
+
nf = nf0 + fv.faces.shape[0]
|
| 1268 |
+
self.vertices[nv0:nv, :] = fv.vertices
|
| 1269 |
+
self.faces[nf0:nf, :] = fv.faces + nv0
|
| 1270 |
+
nv0 = nv
|
| 1271 |
+
nf0 = nf
|
| 1272 |
+
self.computeCentersAreas()
|
| 1273 |
+
|
| 1274 |
+
def get_surf_area(surf):
|
| 1275 |
+
areas = np.linalg.norm(surf.surfel, axis=-1)
|
| 1276 |
+
return areas.sum()/2
|
| 1277 |
+
|
| 1278 |
+
def centroid(surf):
|
| 1279 |
+
areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
|
| 1280 |
+
center = (surf.centers * areas).sum(axis=0) / areas.sum()
|
| 1281 |
+
return center, np.sqrt(areas.sum()/2)
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def do_bbox_vertices(vertices):
|
| 1285 |
+
new_verts = np.zeros(vertices.shape)
|
| 1286 |
+
new_verts[:, 0] = (vertices[:, 0] - np.amin(vertices[:, 0]))/(np.amax(vertices[:, 0]) - np.amin(vertices[:, 0]))
|
| 1287 |
+
new_verts[:, 1] = (vertices[:, 1] - np.amin(vertices[:, 1]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 1]))
|
| 1288 |
+
new_verts[:, 2] = (vertices[:, 2] - np.amin(vertices[:, 2]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 2]))*0.5
|
| 1289 |
+
return new_verts
|
| 1290 |
+
|
| 1291 |
+
def opt_rot_surf(surf_1, surf_2, areas_c):
|
| 1292 |
+
center_1, _ = centroid(surf_1)
|
| 1293 |
+
pts_c1 = surf_1.centers - center_1
|
| 1294 |
+
center_2, _ = centroid(surf_2)
|
| 1295 |
+
pts_c2 = surf_2.centers - center_2
|
| 1296 |
+
to_sum = pts_c1[:, :, np.newaxis] * pts_c2[:, np.newaxis, :]
|
| 1297 |
+
A = (to_sum * areas_c[:, :, np.newaxis]).sum(axis=0)
|
| 1298 |
+
#A = to_sum.sum(axis=0)
|
| 1299 |
+
u, _, v = np.linalg.svd(A)
|
| 1300 |
+
a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, np.sign(np.linalg.det(A))]])
|
| 1301 |
+
O = u @ a @ v
|
| 1302 |
+
return O.T
|
| 1303 |
+
|
| 1304 |
+
def srnf(surf):
|
| 1305 |
+
areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True) + 1e-8
|
| 1306 |
+
return surf.surfel/np.sqrt(areas)
|
| 1307 |
+
|
| 1308 |
+
def fps_gpt(points, num_samples):
|
| 1309 |
+
"""
|
| 1310 |
+
Selects a subset of points using the Farthest Point Sampling algorithm.
|
| 1311 |
+
|
| 1312 |
+
:param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
|
| 1313 |
+
:param num_samples: number of points to select
|
| 1314 |
+
:return: numpy array of indices of the selected points
|
| 1315 |
+
"""
|
| 1316 |
+
if num_samples > len(points):
|
| 1317 |
+
raise ValueError("num_samples must be less than or equal to the number of points")
|
| 1318 |
+
|
| 1319 |
+
# Initialize an array to hold the indices of the selected points
|
| 1320 |
+
selected_indices = np.zeros(num_samples, dtype=int)
|
| 1321 |
+
|
| 1322 |
+
center = np.mean(points, axis=0)
|
| 1323 |
+
|
| 1324 |
+
# Array to store the shortest distance of each point to a selected poin
|
| 1325 |
+
shortest_distances = np.full(len(points), np.inf)
|
| 1326 |
+
last_selected_point = center
|
| 1327 |
+
for i in range(1, num_samples):
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
# Update the shortest distance of each point to the selected points
|
| 1331 |
+
for j, point in enumerate(points):
|
| 1332 |
+
distance = np.linalg.norm(point - last_selected_point)
|
| 1333 |
+
shortest_distances[j] = min(shortest_distances[j], distance)
|
| 1334 |
+
|
| 1335 |
+
# Select the point that is farthest from all selected points
|
| 1336 |
+
selected_indices[i] = np.argmax(shortest_distances)
|
| 1337 |
+
last_selected_point = points[selected_indices[i]]
|
| 1338 |
+
return selected_indices
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
def fps_gpt_geod(points, faces, num_samples, index_start):
|
| 1342 |
+
"""
|
| 1343 |
+
Selects a subset of points using the Farthest Point Sampling algorithm.
|
| 1344 |
+
|
| 1345 |
+
:param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
|
| 1346 |
+
:param num_samples: number of points to select
|
| 1347 |
+
:return: numpy array of indices of the selected points
|
| 1348 |
+
"""
|
| 1349 |
+
if num_samples > len(points):
|
| 1350 |
+
raise ValueError("num_samples must be less than or equal to the number of points")
|
| 1351 |
+
|
| 1352 |
+
# Initialize an array to hold the indices of the selected points
|
| 1353 |
+
selected_indices = np.zeros(num_samples, dtype=int)
|
| 1354 |
+
|
| 1355 |
+
# Array to store the shortest distance of each point to a selected poin
|
| 1356 |
+
shortest_distances = np.full(len(points), np.inf)
|
| 1357 |
+
selected_indices[0] = index_start
|
| 1358 |
+
distances = np.zeros((num_samples, len(points)))
|
| 1359 |
+
solver = pp3d.MeshHeatMethodDistanceSolver(points, faces)
|
| 1360 |
+
for i in range(1, num_samples):
|
| 1361 |
+
distances[i] = (solver.compute_distance(selected_indices[i-1]))
|
| 1362 |
+
|
| 1363 |
+
# Update the shortest distance of each point to the selected points
|
| 1364 |
+
for j, point in enumerate(points):
|
| 1365 |
+
distance = distances[i][j]
|
| 1366 |
+
shortest_distances[j] = min(shortest_distances[j], distance)
|
| 1367 |
+
|
| 1368 |
+
# Select the point that is farthest from all selected points
|
| 1369 |
+
selected_indices[i] = np.argmax(shortest_distances)
|
| 1370 |
+
return selected_indices
|
| 1371 |
+
|
| 1372 |
+
# class MultiSurface:
|
| 1373 |
+
# def __init__(self, pattern):
|
| 1374 |
+
# self.surf = []
|
| 1375 |
+
# files = glob.glob(pattern)
|
| 1376 |
+
# for f in files:
|
| 1377 |
+
# self.surf.append(Surface(filename=f))
|
utils/torch_fmap.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from pykeops.torch import LazyTensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def euclidean_dist(x, y):
|
| 7 |
+
"""
|
| 8 |
+
Args:
|
| 9 |
+
x: pytorch Variable, with shape [m, d]
|
| 10 |
+
y: pytorch Variable, with shape [n, d]
|
| 11 |
+
Returns:
|
| 12 |
+
dist: pytorch Variable, with shape [m, n]
|
| 13 |
+
"""
|
| 14 |
+
#bs, m, n = x.size(0), x.size(1), y.size(1)
|
| 15 |
+
xx = torch.pow(x.squeeze(), 2).sum(1, keepdim=True)
|
| 16 |
+
yy = torch.pow(y.squeeze(), 2).sum(1, keepdim=True).t()
|
| 17 |
+
dist = xx + yy - 2 * torch.inner(x.squeeze(), y.squeeze())
|
| 18 |
+
dist = dist.clamp(min=1e-12).sqrt()
|
| 19 |
+
return dist
|
| 20 |
+
|
| 21 |
+
def knnsearch(x, y, alpha=1./0.07, prod=False):
|
| 22 |
+
if prod:
|
| 23 |
+
prods = torch.inner(x.squeeze(), y.squeeze())#/( torch.norm(x.squeeze(), dim=-1)[:, None]*torch.norm(y.squeeze(), dim=-1)[None, :])
|
| 24 |
+
output = F.softmax(alpha*prods, dim=1)
|
| 25 |
+
else:
|
| 26 |
+
distance = euclidean_dist(x, y[None,:])
|
| 27 |
+
output = F.softmax(-alpha*distance, dim=1)
|
| 28 |
+
return output.squeeze()
|
| 29 |
+
|
| 30 |
+
def extract_p2p_torch(reps_shape, reps_template):
|
| 31 |
+
n_ev = reps_shape.shape[-1]
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
# print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
|
| 34 |
+
# print(evecs1_dzo.shape)
|
| 35 |
+
reps_shape_torch = torch.from_numpy(reps_shape).float().cuda()
|
| 36 |
+
G_i = LazyTensor(reps_shape_torch[:, None, :].contiguous()) # (M**2, 1, 2)
|
| 37 |
+
reps_template_torch = torch.from_numpy(reps_template).float().cuda()
|
| 38 |
+
X_j = LazyTensor(reps_template_torch[None, :, :n_ev].contiguous()) # (1, N, 2)
|
| 39 |
+
D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
|
| 40 |
+
indKNN = D_ij.argKmin(1, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
|
| 41 |
+
# pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
|
| 42 |
+
# print(indKNN[:10], pmap10_ref[:10])
|
| 43 |
+
indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
|
| 44 |
+
return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
|
| 45 |
+
|
| 46 |
+
def extract_p2p_torch_fmap(fmap_shape_template, evecs_shape, evecs_template):
|
| 47 |
+
n_ev = fmap_shape_template.shape[-1]
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
# print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
|
| 50 |
+
# print(evecs1_dzo.shape)
|
| 51 |
+
G_i = LazyTensor((evecs_shape[:, :n_ev] @ fmap_shape_template.squeeze().T)[:, None, :].contiguous()) # (M**2, 1, 2)
|
| 52 |
+
X_j = LazyTensor(evecs_template[None, :, :n_ev].contiguous()) # (1, N, 2)
|
| 53 |
+
D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
|
| 54 |
+
indKNN = D_ij.argKmin(1, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
|
| 55 |
+
# pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
|
| 56 |
+
# print(indKNN[:10], pmap10_ref[:10])
|
| 57 |
+
indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
|
| 58 |
+
return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
|
| 59 |
+
|
| 60 |
+
def wlstsq(A, B, w):
|
| 61 |
+
if w is None:
|
| 62 |
+
return torch.linalg.lstsq(A, B).solution
|
| 63 |
+
else:
|
| 64 |
+
assert w.dim() + 1 == A.dim() and w.shape[-1] == A.shape[-2]
|
| 65 |
+
W = torch.diag_embed(w)
|
| 66 |
+
return torch.linalg.lstsq(W @ A, W @ B).solution
|
| 67 |
+
|
| 68 |
+
def torch_zoomout(evecs0, evecs1, evecs_1_trans, fmap01, target_size, step=1):
|
| 69 |
+
assert fmap01.shape[-2] == fmap01.shape[-1], f"square fmap needed, got {fmap01.shape[-2]} and {fmap01.shape[-1]}"
|
| 70 |
+
fs = fmap01.shape[0]
|
| 71 |
+
for i in range(fs, target_size+1, step):
|
| 72 |
+
indKNN, _ = extract_p2p_torch_fmap(fmap01, evecs0, evecs1)
|
| 73 |
+
#fmap01 = wlstsq(evecs1[..., :i], evecs0[indKNN, :i], None)
|
| 74 |
+
fmap01 = evecs_1_trans[:i, :] @ evecs0[indKNN, :i]
|
| 75 |
+
if fmap01.shape[0] < target_size:
|
| 76 |
+
fmap01 = evecs_1_trans[:target_size, :] @ evecs0[indKNN, :target_size]
|
| 77 |
+
return fmap01
|
utils/utils_func.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import scipy.sparse as sp
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import re
|
| 9 |
+
import os
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
def ensure_pretrained_file(hf_url: str, save_dir: str = "pretrained", filename: str = "pretrained.pkl", token: str = None):
|
| 13 |
+
"""
|
| 14 |
+
Ensure that a pretrained file exists locally.
|
| 15 |
+
If the folder is empty, download from Hugging Face.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
hf_url (str): Hugging Face file URL (resolve/main/...).
|
| 19 |
+
save_dir (str): Directory to store pretrained file.
|
| 20 |
+
filename (str): Name of file to save.
|
| 21 |
+
token (str): Optional Hugging Face token (for gated/private repos).
|
| 22 |
+
"""
|
| 23 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 24 |
+
save_path = os.path.join(save_dir, filename)
|
| 25 |
+
|
| 26 |
+
if os.path.exists(save_path):
|
| 27 |
+
print(f"✅ Found pretrained file: {save_path}")
|
| 28 |
+
return save_path
|
| 29 |
+
|
| 30 |
+
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
| 31 |
+
|
| 32 |
+
print(f"⬇️ Downloading pretrained file from {hf_url} to {save_path} ...")
|
| 33 |
+
response = requests.get(hf_url, headers=headers, stream=True)
|
| 34 |
+
response.raise_for_status()
|
| 35 |
+
|
| 36 |
+
with open(save_path, "wb") as f:
|
| 37 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 38 |
+
f.write(chunk)
|
| 39 |
+
|
| 40 |
+
print("✅ Download complete.")
|
| 41 |
+
return save_path
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def may_create_folder(folder_path):
|
| 45 |
+
if not osp.exists(folder_path):
|
| 46 |
+
oldmask = os.umask(000)
|
| 47 |
+
os.makedirs(folder_path, mode=0o777)
|
| 48 |
+
os.umask(oldmask)
|
| 49 |
+
return True
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def make_clean_folder(folder_path):
|
| 54 |
+
success = may_create_folder(folder_path)
|
| 55 |
+
if not success:
|
| 56 |
+
shutil.rmtree(folder_path)
|
| 57 |
+
may_create_folder(folder_path)
|
| 58 |
+
|
| 59 |
+
def str_delta(delta):
|
| 60 |
+
s = delta.total_seconds()
|
| 61 |
+
hours, remainder = divmod(s, 3600)
|
| 62 |
+
minutes, seconds = divmod(remainder, 60)
|
| 63 |
+
return '{:02}:{:02}:{:02}'.format(int(hours), int(minutes), int(seconds))
|
| 64 |
+
|
| 65 |
+
def desc_from_config(config, world_size):
|
| 66 |
+
cond_str = 'uncond' #'cond' if c.dataset_kwargs.use_labels else 'uncond'
|
| 67 |
+
dtype_str = 'fp16' if config["perfs"]["fp16"] else 'fp32'
|
| 68 |
+
name_data = config["data"]["name"]
|
| 69 |
+
if "abs" in config["data"]:
|
| 70 |
+
name_data += "abs" if config["data"]["abs"] else ""
|
| 71 |
+
desc = (f'{name_data:s}-{cond_str:s}-{config["architecture"]["model"]:s}-'
|
| 72 |
+
f'gpus{world_size:d}-batch{config["hyper_params"]["batch_size"]:d}-{dtype_str:s}')
|
| 73 |
+
if config["misc"]["precond"]:
|
| 74 |
+
desc += f'-{config["misc"]["precond"]:s}'
|
| 75 |
+
|
| 76 |
+
if "desc" in config["misc"]:
|
| 77 |
+
if config["misc"]["desc"] is not None:
|
| 78 |
+
desc += f'-{config["misc"]["desc"]}'
|
| 79 |
+
return desc
|
| 80 |
+
|
| 81 |
+
def get_dataset_path(config):
|
| 82 |
+
name_exp = config["data"]["name"]
|
| 83 |
+
if "add_name" in config:
|
| 84 |
+
if config["add_name"]["do"]:
|
| 85 |
+
name_exp += "_" + config["add_name"]["name"]
|
| 86 |
+
dataset_path = os.path.join(config["data"]["root_dir"], name_exp)
|
| 87 |
+
return name_exp, dataset_path
|
| 88 |
+
|
| 89 |
+
def sparse_np_to_torch(A):
|
| 90 |
+
Acoo = A.tocoo()
|
| 91 |
+
values = Acoo.data
|
| 92 |
+
indices = np.vstack((Acoo.row, Acoo.col))
|
| 93 |
+
shape = Acoo.shape
|
| 94 |
+
return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
|
| 95 |
+
|
| 96 |
+
def convert_dict(np_dict, device="cpu"):
|
| 97 |
+
torch_dict = {}
|
| 98 |
+
for k, value in np_dict.items():
|
| 99 |
+
if sp.issparse(value):
|
| 100 |
+
torch_dict[k] = sparse_np_to_torch(value).to(device)
|
| 101 |
+
if torch_dict[k].dtype == torch.int32:
|
| 102 |
+
torch_dict[k] = torch_dict[k].long().to(device)
|
| 103 |
+
elif isinstance(value, np.ndarray):
|
| 104 |
+
torch_dict[k] = torch.from_numpy(value).to(device)
|
| 105 |
+
if torch_dict[k].dtype == torch.int32:
|
| 106 |
+
torch_dict[k] = torch_dict[k].squeeze().long().to(device)
|
| 107 |
+
else:
|
| 108 |
+
torch_dict[k] = value
|
| 109 |
+
return torch_dict
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def convert_dict_torch(in_dict, device="cpu"):
|
| 113 |
+
torch_dict = {}
|
| 114 |
+
for k, value in in_dict.items():
|
| 115 |
+
if isinstance(value, torch.Tensor):
|
| 116 |
+
torch_dict[k] = in_dict[k].to(device)
|
| 117 |
+
return torch_dict
|
| 118 |
+
|
| 119 |
+
def batchify_dict(torch_dict):
|
| 120 |
+
for k, value in torch_dict.items():
|
| 121 |
+
if isinstance(value, torch.Tensor):
|
| 122 |
+
if torch_dict[k].dtype != torch.int64:
|
| 123 |
+
torch_dict[k] = torch_dict[k].unsqueeze(0)
|
utils/utils_legacy.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import hashlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import scipy
|
| 9 |
+
|
| 10 |
+
def read_lines(file_path):
|
| 11 |
+
with open(file_path, 'r') as fin:
|
| 12 |
+
lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
|
| 13 |
+
return lines
|
| 14 |
+
|
| 15 |
+
# == Pytorch things
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def toNP(x):
|
| 19 |
+
"""
|
| 20 |
+
Really, definitely convert a torch tensor to a numpy array
|
| 21 |
+
"""
|
| 22 |
+
return x.detach().to(torch.device('cpu')).numpy()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def label_smoothing_log_loss(pred, labels, smoothing=0.0):
|
| 26 |
+
n_class = pred.shape[-1]
|
| 27 |
+
one_hot = torch.zeros_like(pred)
|
| 28 |
+
one_hot[labels] = 1.
|
| 29 |
+
one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
|
| 30 |
+
loss = -(one_hot * pred).sum(dim=-1).mean()
|
| 31 |
+
return loss
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Randomly rotate points.
|
| 35 |
+
# Torch in, torch out
|
| 36 |
+
# Note fornow, builds rotation matrix on CPU.
|
| 37 |
+
def random_rotate_points(pts, randgen=None):
|
| 38 |
+
R = random_rotation_matrix(randgen)
|
| 39 |
+
R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
|
| 40 |
+
return torch.matmul(pts, R)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def random_rotate_points_y(pts):
|
| 44 |
+
angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
|
| 45 |
+
rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
|
| 46 |
+
rot_mats[0, 0] = torch.cos(angles)
|
| 47 |
+
rot_mats[0, 2] = torch.sin(angles)
|
| 48 |
+
rot_mats[2, 0] = -torch.sin(angles)
|
| 49 |
+
rot_mats[2, 2] = torch.cos(angles)
|
| 50 |
+
rot_mats[1, 1] = 1.
|
| 51 |
+
|
| 52 |
+
pts = torch.matmul(pts, rot_mats)
|
| 53 |
+
return pts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Numpy things
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Numpy sparse matrix to pytorch
|
| 60 |
+
def sparse_np_to_torch(A):
|
| 61 |
+
Acoo = A.tocoo()
|
| 62 |
+
values = Acoo.data
|
| 63 |
+
indices = np.vstack((Acoo.row, Acoo.col))
|
| 64 |
+
shape = Acoo.shape
|
| 65 |
+
return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Pytorch sparse to numpy csc matrix
|
| 69 |
+
def sparse_torch_to_np(A):
|
| 70 |
+
if len(A.shape) != 2:
|
| 71 |
+
raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
|
| 72 |
+
|
| 73 |
+
indices = toNP(A.indices())
|
| 74 |
+
values = toNP(A.values())
|
| 75 |
+
|
| 76 |
+
mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
|
| 77 |
+
|
| 78 |
+
return mat
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Hash a list of numpy arrays
|
| 82 |
+
def hash_arrays(arrs):
|
| 83 |
+
running_hash = hashlib.sha1()
|
| 84 |
+
for arr in arrs:
|
| 85 |
+
binarr = arr.view(np.uint8)
|
| 86 |
+
running_hash.update(binarr)
|
| 87 |
+
return running_hash.hexdigest()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def random_rotation_matrix(randgen=None):
|
| 91 |
+
"""
|
| 92 |
+
Creates a random rotation matrix.
|
| 93 |
+
randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
|
| 94 |
+
"""
|
| 95 |
+
# adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
|
| 96 |
+
|
| 97 |
+
if randgen is None:
|
| 98 |
+
randgen = np.random.RandomState()
|
| 99 |
+
|
| 100 |
+
theta, phi, z = tuple(randgen.rand(3).tolist())
|
| 101 |
+
|
| 102 |
+
theta = theta * 2.0 * np.pi # Rotation about the pole (Z).
|
| 103 |
+
phi = phi * 2.0 * np.pi # For direction of pole deflection.
|
| 104 |
+
z = z * 2.0 # For magnitude of pole deflection.
|
| 105 |
+
|
| 106 |
+
# Compute a vector V used for distributing points over the sphere
|
| 107 |
+
# via the reflection I - V Transpose(V). This formulation of V
|
| 108 |
+
# will guarantee that if x[1] and x[2] are uniformly distributed,
|
| 109 |
+
# the reflected points will be uniform on the sphere. Note that V
|
| 110 |
+
# has length sqrt(2) to eliminate the 2 in the Householder matrix.
|
| 111 |
+
|
| 112 |
+
r = np.sqrt(z)
|
| 113 |
+
Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
|
| 114 |
+
|
| 115 |
+
st = np.sin(theta)
|
| 116 |
+
ct = np.cos(theta)
|
| 117 |
+
|
| 118 |
+
R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
|
| 119 |
+
# Construct the rotation matrix ( V Transpose(V) - I ) R.
|
| 120 |
+
|
| 121 |
+
M = (np.outer(V, V) - np.eye(3)).dot(R)
|
| 122 |
+
return M
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Python string/file utilities
|
| 126 |
+
def ensure_dir_exists(d):
|
| 127 |
+
if not os.path.exists(d):
|
| 128 |
+
os.makedirs(d)
|
| 129 |
+
|
| 130 |
+
|
zero_shot.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import importlib
|
| 12 |
+
import json
|
| 13 |
+
import argparse
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from snk.loss import PrismRegularizationLoss
|
| 16 |
+
from snk.prism_decoder import PrismDecoder
|
| 17 |
+
from shape_models.fmap import DFMNet
|
| 18 |
+
from shape_models.encoder import Encoder
|
| 19 |
+
from diffu_models.losses import VELoss, VPLoss, EDMLoss
|
| 20 |
+
from diffu_models.sds import guidance_grad
|
| 21 |
+
from utils.torch_fmap import torch_zoomout, knnsearch, extract_p2p_torch_fmap
|
| 22 |
+
from utils.utils_func import convert_dict, str_delta, ensure_pretrained_file
|
| 23 |
+
from utils.eval import accuracy
|
| 24 |
+
from utils.mesh import save_ply, load_mesh
|
| 25 |
+
from shape_data import get_data_dirs
|
| 26 |
+
from utils.pickle_stuff import safe_load_with_fallback
|
| 27 |
+
from utils.geometry import compute_operators, load_operators
|
| 28 |
+
from utils.surfaces import Surface
|
| 29 |
+
import sys
|
| 30 |
+
try:
|
| 31 |
+
import google.colab
|
| 32 |
+
print("Running Colab")
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
except ImportError:
|
| 35 |
+
print("Running local")
|
| 36 |
+
from tqdm.auto import tqdm
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def seed_everything(seed=42):
|
| 40 |
+
random.seed(seed)
|
| 41 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
torch.backends.cudnn.deterministic = True
|
| 45 |
+
torch.backends.cudnn.benchmark = False
|
| 46 |
+
|
| 47 |
+
seed_everything()
|
| 48 |
+
|
| 49 |
+
class Tee:
|
| 50 |
+
def __init__(self, *outputs):
|
| 51 |
+
self.outputs = outputs
|
| 52 |
+
|
| 53 |
+
def write(self, message):
|
| 54 |
+
for output in self.outputs:
|
| 55 |
+
output.write(message)
|
| 56 |
+
output.flush() # ensure it's written immediately
|
| 57 |
+
|
| 58 |
+
def flush(self):
|
| 59 |
+
for output in self.outputs:
|
| 60 |
+
output.flush()
|
| 61 |
+
|
| 62 |
+
class DiffModel:
|
| 63 |
+
|
| 64 |
+
def __init__(self, cfg, device="cuda:0"):
|
| 65 |
+
if cfg["train_dir"] == "pretrained":
|
| 66 |
+
url = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/network-snapshot-041216.pkl"
|
| 67 |
+
network_pkl = ensure_pretrained_file(url, "pretrained")
|
| 68 |
+
url_json = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/training_options.json"
|
| 69 |
+
json_filename = ensure_pretrained_file(url_json, "pretrained", filename="training_options.json")
|
| 70 |
+
train_cfg = json.load(open(json_filename))
|
| 71 |
+
else:
|
| 72 |
+
num_exp = cfg["diff_num_exp"]
|
| 73 |
+
files = os.listdir(cfg["train_dir"])
|
| 74 |
+
for file in files:
|
| 75 |
+
if file[:5] == f"{num_exp:05d}":
|
| 76 |
+
netdir = os.path.join(cfg["train_dir"], file)
|
| 77 |
+
train_cfg = json.load(open(os.path.join(netdir, "training_options.json")))
|
| 78 |
+
pkls = [f for f in os.listdir(netdir) if ".pkl" in f]
|
| 79 |
+
nice_pkls = sorted(pkls, key=lambda x: int(x.split(".")[0].split("-")[-1]))
|
| 80 |
+
chosen_pkl = nice_pkls[-1]
|
| 81 |
+
network_pkl = os.path.join(netdir, chosen_pkl)
|
| 82 |
+
print(f'Loading network from "{network_pkl}"...')
|
| 83 |
+
self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
|
| 84 |
+
|
| 85 |
+
print('Done!')
|
| 86 |
+
loss_name = train_cfg['hyper_params']['loss_name']
|
| 87 |
+
self.loss_sde = None
|
| 88 |
+
if loss_name == "EDMLoss":
|
| 89 |
+
self.loss_sde = EDMLoss()
|
| 90 |
+
elif loss_name == "VPLoss":
|
| 91 |
+
self.loss_sde = VPLoss()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Matcher(object):
|
| 95 |
+
|
| 96 |
+
def __init__(self, cfg):
|
| 97 |
+
self.cfg = cfg
|
| 98 |
+
self.device = torch.device(f'cuda:{cfg["gpu"]}' if torch.cuda.is_available() else 'cpu')
|
| 99 |
+
self.diffusion_model = None
|
| 100 |
+
if self.cfg.get("sds", False):
|
| 101 |
+
self.diffusion_model = DiffModel(cfg["sds_conf"])
|
| 102 |
+
self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
|
| 103 |
+
self.n_loop = 0
|
| 104 |
+
if self.cfg.get("optimize", False):
|
| 105 |
+
self.n_loop = self.cfg.opt.get("n_loop", 0)
|
| 106 |
+
self.snk = self.cfg.get("snk", False)
|
| 107 |
+
self.fmap_cfg = self.cfg.deepfeat_conf.fmap
|
| 108 |
+
self.dataloaders = dict()
|
| 109 |
+
|
| 110 |
+
def reconf(self, cfg):
|
| 111 |
+
self.cfg = cfg
|
| 112 |
+
self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
|
| 113 |
+
self.n_loop = 0
|
| 114 |
+
if self.cfg.get("optimize", False):
|
| 115 |
+
self.n_loop = self.cfg.opt.get("n_loop", 0)
|
| 116 |
+
self.fmap_cfg = self.cfg.deepfeat_conf.fmap
|
| 117 |
+
self.dataloaders = dict()
|
| 118 |
+
|
| 119 |
+
def _init(self):
|
| 120 |
+
cfg = self.cfg
|
| 121 |
+
self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
|
| 122 |
+
if self.snk:
|
| 123 |
+
self.encoder = Encoder().to(self.device)
|
| 124 |
+
self.decoder = PrismDecoder(dim_in=515).to(self.device)
|
| 125 |
+
self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
|
| 126 |
+
self.soft_p2p = True
|
| 127 |
+
params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(self.decoder.parameters())
|
| 128 |
+
else:
|
| 129 |
+
params_to_opt = self.fmap_model.parameters()
|
| 130 |
+
self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
|
| 131 |
+
self.eye = torch.eye(self.n_fmap).float().to(self.device)
|
| 132 |
+
self.eye.requires_grad = False
|
| 133 |
+
|
| 134 |
+
def fmap(self, shape_dict, target_dict):
|
| 135 |
+
if self.fmap_cfg.get("use_diff", False):
|
| 136 |
+
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model, scale=self.fmap_cfg.diffusion.time)
|
| 137 |
+
C12_pred, C12_obj, mask_12 = C12_pred
|
| 138 |
+
C21_pred, C21_obj, mask_21 = C21_pred
|
| 139 |
+
else:
|
| 140 |
+
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict})
|
| 141 |
+
C12_obj, C21_obj = C12_pred, C21_pred
|
| 142 |
+
mask_12, mask_21 = None, None
|
| 143 |
+
return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def zo_shot(self, shape_dict, target_dict):
|
| 147 |
+
self._init()
|
| 148 |
+
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 149 |
+
_, C12_mask_init, _, _, _, _, _ , _, _, _ = self.fmap(shape_dict, target_dict)
|
| 150 |
+
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 151 |
+
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
|
| 152 |
+
indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 153 |
+
return new_FM, indKNN_new
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def optimize(self, shape_dict, target_dict, target_normals):
|
| 157 |
+
self._init()
|
| 158 |
+
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 159 |
+
C12_pred_init, _, _, _ , _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
|
| 160 |
+
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 161 |
+
evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
|
| 162 |
+
n_verts_target = target_dict["vertices"].shape[-2]
|
| 163 |
+
|
| 164 |
+
loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [], "proper": []}
|
| 165 |
+
snk_rec = None
|
| 166 |
+
for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
|
| 167 |
+
C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
|
| 168 |
+
if self.cfg.opt.soft_p2p:
|
| 169 |
+
### A la SNK
|
| 170 |
+
## P2P 2 -> 1
|
| 171 |
+
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap], prod=True)
|
| 172 |
+
C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
|
| 173 |
+
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
|
| 174 |
+
|
| 175 |
+
## P2P 1 -> 2
|
| 176 |
+
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap], prod=True)
|
| 177 |
+
C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
|
| 178 |
+
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
|
| 179 |
+
|
| 180 |
+
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
|
| 181 |
+
else:
|
| 182 |
+
C12_new, C21_new = C12_pred, C21_pred
|
| 183 |
+
|
| 184 |
+
l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye)**2).mean() + ((C21_new.squeeze() @ C21_new.squeeze().T - self.eye)**2).mean()
|
| 185 |
+
l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye)**2).mean() + ((C21_new.squeeze() @ C12_new.squeeze() - self.eye)**2).mean()
|
| 186 |
+
l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(target_dict["evals"][:self.n_fmap]) @ C12_new)**2).mean()
|
| 187 |
+
l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(shape_dict["evals"][:self.n_fmap]) @ C21_new)**2).mean()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
|
| 191 |
+
if self.snk:
|
| 192 |
+
# Latent vector
|
| 193 |
+
latents = self.encoder(shape_dict)
|
| 194 |
+
latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
|
| 195 |
+
|
| 196 |
+
# Prism decoder
|
| 197 |
+
feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
|
| 198 |
+
snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
|
| 199 |
+
l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
|
| 200 |
+
l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec)**2).sum(dim=-1).mean()
|
| 201 |
+
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
|
| 202 |
+
l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
|
| 203 |
+
if self.fmap_cfg.get("use_diff", False):
|
| 204 |
+
if self.fmap_cfg.diffusion.get("abs", False):
|
| 205 |
+
C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
|
| 206 |
+
else:
|
| 207 |
+
C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
|
| 208 |
+
grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 209 |
+
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
|
| 212 |
+
targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
|
| 213 |
+
|
| 214 |
+
l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
|
| 215 |
+
|
| 216 |
+
grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 217 |
+
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 218 |
+
#denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 221 |
+
targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(), self.cfg.sds_conf.zoomout)#, step=10)
|
| 222 |
+
l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
|
| 223 |
+
l_proper = l_proper_12 + l_proper_21
|
| 224 |
+
|
| 225 |
+
l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
|
| 226 |
+
l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
|
| 227 |
+
loss = torch.as_tensor(0.).float().to(self.device)
|
| 228 |
+
if self.cfg.loss.get("ortho", 0) > 0:
|
| 229 |
+
loss += self.cfg.loss.get("ortho", 0) * l_ortho
|
| 230 |
+
if self.cfg.loss.get("bij", 0) > 0:
|
| 231 |
+
loss += self.cfg.loss.get("bij", 0) * l_bij
|
| 232 |
+
if self.cfg.loss.get("lap", 0) > 0:
|
| 233 |
+
loss += self.cfg.loss.get("lap", 0) * l_lap
|
| 234 |
+
if self.cfg.loss.get("cycle", 0) > 0:
|
| 235 |
+
loss += self.cfg.loss.get("cycle", 0) * l_cycle
|
| 236 |
+
if self.cfg.loss.get("mse_rec", 0) > 0:
|
| 237 |
+
loss += self.cfg.loss.get("mse_rec", 0) * l_mse
|
| 238 |
+
if self.cfg.loss.get("prism_rec", 0) > 0:
|
| 239 |
+
loss += self.cfg.loss.get("prism_rec", 0) * l_prism
|
| 240 |
+
if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 241 |
+
loss += self.cfg.loss.get("sds", 0) * l_sds
|
| 242 |
+
if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 243 |
+
loss += self.cfg.loss.get("proper", 0) * l_proper
|
| 244 |
+
|
| 245 |
+
loss.backward()
|
| 246 |
+
self.optim.step()
|
| 247 |
+
self.optim.zero_grad()
|
| 248 |
+
|
| 249 |
+
loss_save["cycle"].append(l_cycle.item())
|
| 250 |
+
loss_save["ortho"].append(l_ortho.item())
|
| 251 |
+
loss_save["bij"].append(l_bij.item())
|
| 252 |
+
loss_save["sds"].append(l_sds.item())
|
| 253 |
+
loss_save["proper"].append(l_proper.item())
|
| 254 |
+
loss_save["mse"].append(l_mse.item())
|
| 255 |
+
loss_save["prism"].append(l_prism.item())
|
| 256 |
+
indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
|
| 257 |
+
indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
|
| 258 |
+
return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
|
| 263 |
+
shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
|
| 264 |
+
shape_dict_device = convert_dict(shape_dict, self.device)
|
| 265 |
+
target_dict_device = convert_dict(target_dict, self.device)
|
| 266 |
+
print(shape_dict_device["vertices"].device)
|
| 267 |
+
os.makedirs(output_pair, exist_ok=True)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if self.cfg["optimize"]:
|
| 271 |
+
C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device, target_normals.to(self.device))
|
| 272 |
+
np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
|
| 273 |
+
np.save(os.path.join(output_pair, "losses.npy"), loss_save)
|
| 274 |
+
else:
|
| 275 |
+
C12_new, p2p = self.zo_shot(shape_dict_device, target_dict_device)
|
| 276 |
+
snk_rec, loss_save = None, None
|
| 277 |
+
np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
|
| 278 |
+
np.save(os.path.join(output_pair, "p2p.npy"), p2p)
|
| 279 |
+
if snk_rec is not None:
|
| 280 |
+
save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(), target_dict["faces"])
|
| 281 |
+
|
| 282 |
+
if refine:
|
| 283 |
+
evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
|
| 284 |
+
evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
|
| 285 |
+
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128)#, step=10)
|
| 286 |
+
p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 287 |
+
np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
|
| 288 |
+
if eval:
|
| 289 |
+
file_i, vts_1, vts_2 = mapinfo
|
| 290 |
+
mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
|
| 291 |
+
A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
|
| 292 |
+
_, dist = accuracy(p2p[vts_2], vts_1, A_geod,
|
| 293 |
+
sqrt_area=sqrt_area,
|
| 294 |
+
return_all=True)
|
| 295 |
+
if refine:
|
| 296 |
+
_, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
|
| 297 |
+
sqrt_area=sqrt_area,
|
| 298 |
+
return_all=True)
|
| 299 |
+
np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
|
| 300 |
+
return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
|
| 301 |
+
return p2p, loss_save, dist.mean()
|
| 302 |
+
return p2p, loss_save
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
|
| 308 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 309 |
+
# dloader = DataLoader(dataset, collate_fn=collate_default, batch_size=1)
|
| 310 |
+
num_pairs = len(dataset)
|
| 311 |
+
id_pair = 0
|
| 312 |
+
all_accs = []
|
| 313 |
+
all_accs_zo = []
|
| 314 |
+
t1 = datetime.now()
|
| 315 |
+
save_txt = os.path.join(save_dir, "log.txt")
|
| 316 |
+
# Open a file for writing
|
| 317 |
+
log_file = open(save_txt, 'w')
|
| 318 |
+
# Replace sys.stdout with Tee that writes to both console and file
|
| 319 |
+
sys.stdout = Tee(sys.__stdout__, log_file)
|
| 320 |
+
|
| 321 |
+
for batch in dset:
|
| 322 |
+
shape_dict, _, target_dict, _, _, _ = batch
|
| 323 |
+
print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
|
| 324 |
+
name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
|
| 325 |
+
if self.cfg.get("refine", False):
|
| 326 |
+
_, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=True)
|
| 327 |
+
else:
|
| 328 |
+
_, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=False)
|
| 329 |
+
delta = datetime.now() - t1
|
| 330 |
+
fm_delta = str_delta(delta)
|
| 331 |
+
remains = ((delta/(id_pair+1))*num_pairs) - delta
|
| 332 |
+
fm_remains = str_delta(remains)
|
| 333 |
+
all_accs.append(dist)
|
| 334 |
+
accs_mean = np.mean(all_accs)
|
| 335 |
+
if self.cfg.get("refine", False):
|
| 336 |
+
all_accs_zo.append(dist_zo)
|
| 337 |
+
accs_zo = np.mean(all_accs_zo)
|
| 338 |
+
print(f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
|
| 339 |
+
else:
|
| 340 |
+
print(f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
|
| 341 |
+
id_pair += 1
|
| 342 |
+
if self.cfg.get("refine", False):
|
| 343 |
+
print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
|
| 344 |
+
else:
|
| 345 |
+
print(f"mean error : {np.mean(all_accs)}")
|
| 346 |
+
sys.stdout = sys.__stdout__
|
| 347 |
+
|
| 348 |
+
def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
|
| 349 |
+
name = os.path.basename(os.path.splitext(file)[0])
|
| 350 |
+
cache_file = "single_" + name + ".npz"
|
| 351 |
+
verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
|
| 352 |
+
cache_path = os.path.join(self.cfg.cache, cache_file)
|
| 353 |
+
print("Cache is: ", cache_path)
|
| 354 |
+
if not os.path.exists(cache_path) or make_cache:
|
| 355 |
+
print("Computing operators ...")
|
| 356 |
+
compute_operators(verts_shape, faces, vnormals, num_evecs, cache_path, force_save=make_cache)
|
| 357 |
+
data_dict = load_operators(cache_path)
|
| 358 |
+
data_dict['name'] = name
|
| 359 |
+
data_dict_torch = convert_dict(data_dict, self.device)
|
| 360 |
+
#batchify_dict(data_dict_torch)
|
| 361 |
+
return data_dict_torch, area_shape
|
| 362 |
+
|
| 363 |
+
def match_files(self, file_shape, file_target):
|
| 364 |
+
batch_shape, _ = self.load_data(file_shape)
|
| 365 |
+
batch_target, _ = self.load_data(file_target)
|
| 366 |
+
target_surf = Surface(filename=file_target)
|
| 367 |
+
target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
|
| 368 |
+
batch = batch_shape, None, batch_target, target_normals, None, None
|
| 369 |
+
output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
|
| 370 |
+
p2p, _ = self.match(batch, output_folder, None)
|
| 371 |
+
return batch_shape, batch_target, p2p
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == '__main__':
|
| 377 |
+
parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
|
| 378 |
+
parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
|
| 379 |
+
parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
|
| 380 |
+
parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
|
| 381 |
+
parser.add_argument('--output', type=str, default="results", help="where to store experience results")
|
| 382 |
+
args = parser.parse_args()
|
| 383 |
+
|
| 384 |
+
arg_cfg = OmegaConf.from_dotlist(
|
| 385 |
+
[f"{k}={v}" for k, v in vars(args).items() if v is not None]
|
| 386 |
+
)
|
| 387 |
+
yaml_cfg = OmegaConf.load(args.config)
|
| 388 |
+
cfg = OmegaConf.merge(yaml_cfg, arg_cfg)
|
| 389 |
+
dataset_name = args.dataset.lower()
|
| 390 |
+
if cfg.get("oriented", False):
|
| 391 |
+
dataset_name += "_ori"
|
| 392 |
+
shape_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapeDataset')
|
| 393 |
+
pair_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapePairDataset')
|
| 394 |
+
data_dir, name_data_geo, corr_dir = get_data_dirs(args.datadir, dataset_name, 'test')
|
| 395 |
+
name_data_geo = "_".join(name_data_geo.split("_")[:2])
|
| 396 |
+
dset_shape = shape_cls(data_dir, "cache/fmaps", "test", oriented=cfg.get("oriented", False))
|
| 397 |
+
print("Preprocessing shapes done.")
|
| 398 |
+
dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
|
| 399 |
+
exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
|
| 400 |
+
output_logs = os.path.join(args.output, name_data_geo, exp_time)
|
| 401 |
+
matcher = Matcher(cfg)
|
| 402 |
+
matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)
|