daidedou commited on
Commit
458efe2
·
1 Parent(s): cf4ac70
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Gradio app for two-mesh initialization and run phases.
3
+ - Upload two meshes (.ply, .obj, .off)
4
+ - (Optional) upload a YAML config to override defaults
5
+ - Adjust a few numeric settings (sane ranges). Defaults pulled from the provided YAML when present.
6
+ - Click **Init** to generate "initialization maps" (here: position/normal-based vertex colors) for both meshes.
7
+ - Click **Run** to simulate an iterative evolution with a progress bar, then output another pair of colored meshes.
8
+
9
+ Replace the bodies of `make_initialization_maps` and `run_evolution` with your real pipeline as needed.
10
+
11
+ Tested with: gradio >= 4.0, trimesh, pyyaml, numpy.
12
+ """
13
+ from __future__ import annotations
14
+ import os
15
+ import io
16
+ import time
17
+ import json
18
+ import tempfile
19
+ from typing import Dict, Tuple, Optional
20
+ from omegaconf import OmegaConf
21
+ import gradio as gr
22
+ import numpy as np
23
+ import trimesh
24
+ import zero_shot
25
+ import yaml
26
+ from utils.surfaces import Surface
27
+ import notebook_helpers as helper
28
+ from utils.meshplot import visu_pts
29
+ from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
30
+ import torch
31
+ import argparse
32
+ # -----------------------------
33
+ # Utils
34
+ # -----------------------------
35
+ SUPPORTED_EXTS = {".ply", ".obj", ".off", ".stl", ".glb", ".gltf"}
36
+
37
+ def _safe_ext(path: str) -> str:
38
+ for ext in SUPPORTED_EXTS:
39
+ if path.lower().endswith(ext):
40
+ return ext
41
+ return os.path.splitext(path)[1].lower()
42
+
43
+
44
+ def normalize_vertices(vertices: np.ndarray) -> np.ndarray:
45
+ v = vertices.astype(np.float64)
46
+ v = v - v.mean(axis=0, keepdims=True)
47
+ scale = np.linalg.norm(v, axis=1).max()
48
+ if scale == 0:
49
+ scale = 1.0
50
+ v = v / scale
51
+ return v.astype(np.float32)
52
+
53
+
54
+
55
+ def ensure_vertex_colors(mesh: trimesh.Trimesh, colors: np.ndarray) -> trimesh.Trimesh:
56
+ out = mesh.copy()
57
+ if colors.shape[1] == 3:
58
+ rgba = np.concatenate([colors, 255*np.ones((colors.shape[0],1), dtype=np.uint8)], axis=1)
59
+ else:
60
+ rgba = colors
61
+ out.visual.vertex_colors = rgba
62
+ return out
63
+
64
+
65
+ def export_for_view(surf: Surface, colors: np.ndarray, basename: str, outdir: str) -> Tuple[str, str]:
66
+ """Export to PLY (with vertex colors) and GLB for Model3D preview."""
67
+ glb_path = os.path.join(outdir, f"{basename}.glb")
68
+ mesh = trimesh.Trimesh(surf.vertices, surf.faces, process=False)
69
+ colored_mesh = ensure_vertex_colors(mesh, colors)
70
+ colored_mesh.export(glb_path)
71
+ return glb_path
72
+
73
+
74
+ # -----------------------------
75
+ # Algorithm placeholders (replace with your real pipeline)
76
+ # -----------------------------
77
+ DEFAULT_SETTINGS = {
78
+ "deepfeat_conf.fmap.lambda_": 1,
79
+ "sds_conf.zoomout": 40.0,
80
+ "diffusion.time": 1.0,
81
+ "opt.n_loop": 300,
82
+ "loss.sds": 1.0,
83
+ "loss.proper": 1.0,
84
+ }
85
+
86
+ FLOAT_SLIDERS = {
87
+ # name: (min, max, step)
88
+ "deepfeat_conf.fmap.lambda_": (1e-3, 10.0, 1e-3),
89
+ "sds_conf.zoomout": (1e-3, 10.0, 1e-3),
90
+ "diffusion.time": (1e-3, 10.0, 1e-3),
91
+ "loss.sds": (1e-3, 10.0, 1e-3),
92
+ "loss.proper": (1e-3, 10.0, 1e-3),
93
+ }
94
+
95
+ INT_SLIDERS = {
96
+ "opt.n_loop": (1, 5000, 1),
97
+ }
98
+
99
+
100
+ def flatten_yaml_floats(d: Dict, prefix: str = "") -> Dict[str, float]:
101
+ flat = {}
102
+ for k, v in d.items():
103
+ key = f"{prefix}.{k}" if prefix else str(k)
104
+ if isinstance(v, dict):
105
+ flat.update(flatten_yaml_floats(v, key))
106
+ elif isinstance(v, (int, float)):
107
+ flat[key] = float(v)
108
+ return flat
109
+
110
+
111
+ def read_yaml_defaults(yaml_path: Optional[str]) -> Dict[str, float]:
112
+ if yaml_path and os.path.exists(yaml_path):
113
+ with open(yaml_path, "r") as f:
114
+ data = yaml.safe_load(f)
115
+ flat = flatten_yaml_floats(data)
116
+ # Only keep known keys we expose as controls
117
+ defaults = DEFAULT_SETTINGS.copy()
118
+ for k in list(DEFAULT_SETTINGS.keys()):
119
+ if k in flat:
120
+ defaults[k] = float(flat[k])
121
+ return defaults
122
+ return DEFAULT_SETTINGS.copy()
123
+
124
+
125
+
126
+
127
+ class Datadicts:
128
+ def __init__(self, shape_path, target_path):
129
+ self.shape_path = shape_path
130
+ basename_1 = os.path.basename(shape_path)
131
+ self.shape_dict, _ = helper.load_data(shape_path, "tmp/" + os.path.splitext(basename_1)[0]+".npz", "source", make_cache=True)
132
+ self.shape_surf = Surface(filename=shape_path)
133
+ self.target_path = target_path
134
+ basename_2 = os.path.basename(target_path)
135
+ self.target_dict, _ = helper.load_data(target_path, "tmp/" + os.path.splitext(basename_2)[0]+".npz", "target", make_cache=True)
136
+ self.target_surf = Surface(filename=target_path)
137
+ self.cmap1 = visu_pts(self.shape_surf)
138
+
139
+ # -----------------------------
140
+ # Gradio UI
141
+ # -----------------------------
142
+ TMP_ROOT = tempfile.mkdtemp(prefix="meshapp_")
143
+
144
+ def save_array_txt(arr):
145
+ # Create a temporary file with .txt suffix
146
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
147
+ np.savetxt(f, arr.astype(int), fmt="%d") # save as text
148
+ return f.name
149
+
150
+ def build_outputs(surf_a: Surface, surf_b: Surface, cmap_a: np.ndarray, p2p: np.ndarray, tag: str) -> Tuple[str, str, str, str]:
151
+ outdir = os.path.join(TMP_ROOT, tag)
152
+ os.makedirs(outdir, exist_ok=True)
153
+ glb_a = export_for_view(surf_a, cmap_a, f"A_{tag}", outdir)
154
+ cmap_b = cmap_a[p2p]
155
+ glb_b = export_for_view(surf_b, cmap_b, f"B_{tag}", outdir)
156
+ out_file = save_array_txt(p2p)
157
+ return glb_a, glb_b, out_file
158
+
159
+
160
+ def init_clicked(mesh1_path, mesh2_path,
161
+ lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val):
162
+ cfg.deepfeat_conf.fmap.lambda_ = lambda_val
163
+ cfg.sds_conf.zoomout = zoomout_val
164
+ cfg.deepfeat_conf.fmap.diffusion.time = time_val
165
+ cfg.opt.n_loop = nloop_val
166
+ cfg.loss.sds = sds_val
167
+ cfg.loss.proper = proper_val
168
+ matcher.reconf(cfg)
169
+ if not mesh1_path or not mesh2_path:
170
+ raise gr.Error("Please upload both meshes.")
171
+ matcher._init()
172
+ global datadicts
173
+ datadicts = Datadicts(mesh1_path, mesh2_path)
174
+
175
+ C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = matcher.fmap_model({"shape1": datadicts.shape_dict, "shape2": datadicts.target_dict}, diff_model=matcher.diffusion_model, scale=matcher.fmap_cfg.diffusion.time)
176
+ C12_pred, C12_obj, mask_12 = C12_pred_init
177
+ p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
178
+ return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
179
+
180
+
181
+ def run_clicked(mesh1_path, mesh2_path, yaml_path, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val, progress=gr.Progress(track_tqdm=True)):
182
+ if not mesh1_path or not mesh2_path:
183
+ raise gr.Error("Please upload both meshes.")
184
+
185
+ cfg.deepfeat_conf.fmap.lambda_ = lambda_val
186
+ cfg.sds_conf.zoomout = zoomout_val
187
+ cfg.deepfeat_conf.fmap.diffusion.time = time_val
188
+ cfg.opt.n_loop = nloop_val
189
+ cfg.loss.sds = sds_val
190
+ cfg.loss.proper = proper_val
191
+ matcher.reconf(cfg)
192
+ if not mesh1_path or not mesh2_path:
193
+ raise gr.Error("Please upload both meshes.")
194
+ matcher._init()
195
+ global datadicts
196
+ if datadicts is None:
197
+ datadicts = Datadicts(mesh1_path, mesh2_path)
198
+ elif datadicts is not None:
199
+ if not (datadicts.shape_path == mesh1_path and datadicts.target_path == mesh2_path):
200
+ datadicts = Datadicts(mesh1_path, mesh2_path)
201
+
202
+ target_normals = torch.from_numpy(datadicts.target_surf.surfel/np.linalg.norm(datadicts.target_surf.surfel, axis=-1, keepdims=True)).float().to("cuda")
203
+
204
+ C12_new, p2p, p2p_init, _, loss_save = matcher.optimize(datadicts.shape_dict, datadicts.target_dict, target_normals)
205
+ evecs1, evecs2 = datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"]
206
+ evecs_2trans = evecs2.t() @ torch.diag(datadicts.target_dict["mass"])
207
+ C12_end_zo = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze()[:15, :15], 150)# matcher.cfg.sds_conf.zoomout)
208
+ p2p_zo, _ = extract_p2p_torch_fmap(C12_end_zo, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
209
+ return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_zo, tag="run")
210
+
211
+
212
+ with gr.Blocks(title="DiffuMatch demo") as demo:
213
+ text_in = "Upload two meshes and try our ICCV zero-shot method DiffuMatch! \n"
214
+ text_in += "*Init* will give you a rough correspondence, and you can click on *Run* to see if our method is able to match the two shapes! \n"
215
+ text_in += "*Recommended*: The method requires that the meshes are aligned (rotation-wise) to work well. Also might not work with scans (but try it out!)."
216
+ gr.Markdown(text_in)
217
+ with gr.Row():
218
+ mesh1 = gr.File(label="Mesh A (.ply/.obj/.off)")
219
+ mesh2 = gr.File(label="Mesh B (.ply/.obj/.off)")
220
+ yaml_file = gr.File(label="Optional YAML config", file_types=[".yaml", ".yml"], visible=True)
221
+ # except Exception:
222
+ with gr.Accordion("Settings", open=True):
223
+ with gr.Row():
224
+ lambda_val = gr.Slider(minimum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][0], maximum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][1], step=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][2], value=1, label="deepfeat_conf.fmap.lambda_")
225
+ zoomout_val = gr.Slider(minimum=FLOAT_SLIDERS["sds_conf.zoomout"][0], maximum=FLOAT_SLIDERS["sds_conf.zoomout"][1], step=FLOAT_SLIDERS["sds_conf.zoomout"][2], value=40, label="sds_conf.zoomout")
226
+ time_val = gr.Slider(minimum=FLOAT_SLIDERS["diffusion.time"][0], maximum=FLOAT_SLIDERS["diffusion.time"][1], step=FLOAT_SLIDERS["diffusion.time"][2], value=1, label="diffusion.time")
227
+ with gr.Row():
228
+ nloop_val = gr.Slider(minimum=INT_SLIDERS["opt.n_loop"][0], maximum=INT_SLIDERS["opt.n_loop"][1], step=INT_SLIDERS["opt.n_loop"][2], value=300, label="opt.n_loop")
229
+ sds_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.sds"][0], maximum=FLOAT_SLIDERS["loss.sds"][1], step=FLOAT_SLIDERS["loss.sds"][2], value=1, label="loss.sds")
230
+ proper_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.proper"][0], maximum=FLOAT_SLIDERS["loss.proper"][1], step=FLOAT_SLIDERS["loss.proper"][2], value=1, label="loss.proper")
231
+
232
+ with gr.Row():
233
+ init_btn = gr.Button("Init", variant="primary")
234
+ run_btn = gr.Button("Run", variant="secondary")
235
+
236
+ gr.Markdown("### Outputs\nEach stage exports both **GLB** (preview below) and **PLY** (download links) with per‑vertex colors.")
237
+ with gr.Tab("Init"):
238
+ with gr.Row():
239
+ init_view_a = gr.Model3D(label="Shape")
240
+ init_view_b = gr.Model3D(label="Target correspondence (init)")
241
+ with gr.Row():
242
+ out_file_init = gr.File(label="Download correspondences TXT")
243
+ with gr.Tab("Run"):
244
+ with gr.Row():
245
+ run_view_a = gr.Model3D(label="Shape")
246
+ run_view_b = gr.Model3D(label="Target correspondence (run)")
247
+ with gr.Row():
248
+ out_file = gr.File(label="Download correspondences TXT")
249
+
250
+ init_btn.click(
251
+ fn=init_clicked,
252
+ inputs=[mesh1, mesh2, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
253
+ outputs=[init_view_a, init_view_b, out_file_init],
254
+ api_name="init",
255
+ )
256
+
257
+ run_btn.click(
258
+ fn=run_clicked,
259
+ inputs=[mesh1, mesh2, yaml_file, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
260
+ outputs=[run_view_a, run_view_b, out_file],
261
+ api_name="run",
262
+ )
263
+
264
+ if __name__ == "__main__":
265
+ parser = argparse.ArgumentParser(description="Launch the gradio demo")
266
+ parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
267
+ parser.add_argument('--share', action="store_true")
268
+ args = parser.parse_args()
269
+ cfg = OmegaConf.load(args.config)
270
+ matcher = zero_shot.Matcher(cfg)
271
+ datadicts = None
272
+ demo.launch(share=args.share)
notebook_helpers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.mesh import load_mesh
2
+ from utils.geometry import get_operators, load_operators
3
+ import os
4
+ from utils.utils_func import convert_dict
5
+ from utils.surfaces import Surface
6
+ import numpy as np
7
+
8
+
9
+ device = "cuda:0"
10
+
11
+ def load_data(file, cache_path, name, num_evecs=128, make_cache=False, factor=None):
12
+ if factor is None:
13
+ verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
14
+ else:
15
+ verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True, scale=False)
16
+ verts_shape = verts_shape/factor
17
+ area_shape /= factor**2
18
+ # print("Cache is: ", cache_path)
19
+ if not os.path.exists(cache_path) or make_cache:
20
+ print("Computing operators ...")
21
+ get_operators(verts_shape, faces, num_evecs, cache_path, vnormals)
22
+ data_dict = load_operators(cache_path)
23
+ data_dict['name'] = name
24
+ data_dict['normals'] = vnormals
25
+ data_dict['vertices'] = verts_shape
26
+ data_dict_torch = convert_dict(data_dict, device)
27
+ #batchify_dict(data_dict_torch)
28
+ return data_dict_torch, area_shape
29
+
30
+ def get_map_info(file_1, file_2, dict_1, dict_2, dataset):
31
+ shape_dict, target_dict = dict_1, dict_2
32
+ name_1, name_2 = shape_dict["name"], target_dict["name"]
33
+ if dataset is None:
34
+ vts_1, vts_2 = np.arange(shape_dict['vertices'].shape[0]), np.arange(target_dict['vertices'].shape[0])
35
+ map_info = (file_1, file_2, vts_1, vts_2)
36
+ file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
37
+ vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
38
+
39
+ file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
40
+ vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
41
+ if "DT4D" in dataset:
42
+ file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
43
+ vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
44
+
45
+ file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
46
+ vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
47
+ if name_1 == name_2:
48
+ map_info = (file_1, file_2, vts_1, vts_2)
49
+ elif ("crypto" in name_1) or ("crypto" in name_2):
50
+ name_cat_1, name_cat_2 = name_1.split(os.sep)[0], name_2.split(os.sep)[0]
51
+ data_path = os.path.dirname(os.path.dirname(os.path.dirname(file_1)))
52
+ map_file = os.path.join(data_path, "correspondences/cross_category_corres", f"{name_cat_1}_{name_cat_2}.vts")
53
+ if os.path.exists(map_file):
54
+ map_idx = np.loadtxt(map_file).astype(np.int32) - 1
55
+ map_info = (file_1, file_2, vts_1, vts_2[map_idx])
56
+ else:
57
+ print("NO GROUND TRUTH PAIRS")
58
+ else:
59
+ map_info = (file_1, file_2, vts_1, vts_2)
60
+ return map_info
61
+
62
+
63
+ def load_pair(cache, id_1, id_2, name_1, name_2, dataset):
64
+ if "SCAPE" in dataset:
65
+ os.makedirs(cache, exist_ok=True)
66
+ cache_file = os.path.join(cache, f"mesh{id_1:03d}_mesh_256k_0n.npz")
67
+ shape_file = f"data/{dataset}/shapes/mesh{id_1:03d}.ply"
68
+ shape_surf = Surface(filename=shape_file)
69
+ shape_dict, _ = load_data(shape_file, cache_file, str(id_1))
70
+
71
+ cache_file = os.path.join(cache, f"mesh{id_2:03d}_mesh_256k_0n.npz")
72
+ target_file = f"data/{dataset}/shapes/mesh{id_2:03d}.ply"
73
+ target_surf = Surface(filename=target_file)
74
+ target_dict, _ = load_data(target_file, cache_file, str(id_2))
75
+ map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "SCAPE")
76
+ elif "DT4D" in dataset:
77
+ cache_file = os.path.join(cache, f"{name_1}_mesh_256k_0n.npz")
78
+ shape_file = f"data/DT4D_r_ori/shapes/{name_1}.ply"
79
+ shape_surf = Surface(filename=shape_file)
80
+ shape_dict, _ = load_data(shape_file, cache_file, name_1)
81
+ cache_file = os.path.join(cache, f"{name_2}_mesh_256k_0n.npz")
82
+ # cache_file = f"../nonrigiddiff/cache/snk/{name_2}.npz"
83
+ cache_file = f"../nonrigiddiff/cache/attentive/{name_2}_mesh_256k_0n.npz"
84
+ target_file = f"data/DT4D_r_ori/shapes/{name_2}.ply"
85
+ target_surf = Surface(filename=target_file)
86
+ target_dict, _ = load_data(target_file, cache_file, name_2)
87
+ map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "DT4D")
88
+ return shape_surf, target_surf, shape_dict, target_dict, map_info
requirements.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ comm==0.2.1
2
+ configer
3
+ configparser==6.0.0
4
+ contourpy==1.1.1
5
+ cycler==0.12.1
6
+ fonttools==4.47.2
7
+ freetype-py==2.4.0
8
+ imageio==2.33.1
9
+ ipygany==0.5.0
10
+ ipywidgets==8.1.7
11
+ nbformat==5.10.4
12
+ jupyter
13
+ kiwisolver==1.4.5
14
+ lxml==5.1.0
15
+ matplotlib==3.7.4
16
+ networkx==3.1
17
+ open3d>=0.18.0
18
+ pyglet==2.0.10
19
+ pyopengl==3.1.0
20
+ pyparsing==3.1.1
21
+ pyrender==0.1.45
22
+ scipy>=1.10.1
23
+ shapely==2.0.2
24
+ smplx[all]
25
+ tqdm==4.66.1
26
+ trimesh==4.0.10
27
+ vtk==9.3.0
28
+ timm==1.0.11
29
+ potpourri3d
30
+ pythreejs==2.4.2
31
+ lapy
32
+ roma
33
+ webdataset
34
+ lmdb
35
+ robust_laplacian
36
+ termcolor
37
+ omegaconf
38
+ pykeops==2.2.3
39
+ scikit-image
40
+ pyfmaps
41
+ wandb
42
+ gradio==4.44.1
43
+ pydantic==2.10.6
44
+ traitlets==5.7.1
45
+ https://github.com/skoch9/meshplot/archive/0.4.0.tar.gz
shape_models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .geometry import *
3
+ from .layers import *
shape_models/encoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .layers import DiffusionNet
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(self, with_grad=True, key_verts="vertices"):
7
+ super(Encoder, self).__init__()
8
+ self.diff_net = DiffusionNet(
9
+ C_in=3,
10
+ C_out=512,
11
+ C_width=128,
12
+ N_block=4,
13
+ dropout=True,
14
+ with_gradient_features=with_grad,
15
+ with_gradient_rotations=with_grad,
16
+ )
17
+ self.key_verts = key_verts
18
+
19
+
20
+ def forward(self, shape_dict):
21
+ feats = self.diff_net(shape_dict[self.key_verts], shape_dict["mass"], shape_dict["L"], evals=shape_dict["evals"],
22
+ evecs=shape_dict["evecs"], gradX=shape_dict["gradX"], gradY=shape_dict["gradY"], faces=shape_dict["faces"])
23
+ x_out = torch.max(feats, dim=0).values
24
+ return x_out
shape_models/fmap.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # feature extractor
7
+ from .layers import DiffusionNet
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def get_mask(evals1, evals2, gamma=0.5, device="cpu"):
12
+ scaling_factor = max(torch.max(evals1), torch.max(evals2))
13
+ evals1, evals2 = evals1.to(device) / scaling_factor, evals2.to(device) / scaling_factor
14
+ evals_gamma1, evals_gamma2 = (evals1 ** gamma)[None, :], (evals2 ** gamma)[:, None]
15
+
16
+ M_re = evals_gamma2 / (evals_gamma2.square() + 1) - evals_gamma1 / (evals_gamma1.square() + 1)
17
+ M_im = 1 / (evals_gamma2.square() + 1) - 1 / (evals_gamma1.square() + 1)
18
+ return M_re.square() + M_im.square()
19
+
20
+ def get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y):
21
+ # compute linear operator matrix representation C1 and C2
22
+
23
+ F_hat = torch.bmm(evecs_trans_x, feat_x)
24
+ G_hat = torch.bmm(evecs_trans_y, feat_y)
25
+ A, B = F_hat, G_hat
26
+
27
+ A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
28
+ A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
29
+ B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
30
+
31
+ C12 = torch.bmm(B_A_t, torch.inverse(A_A_t))
32
+
33
+ C21 = torch.bmm(A_B_t, torch.inverse(B_B_t))
34
+ return [C12, C21]
35
+
36
+ def get_mask_noise(C, diff_model, scale=1, N_est=200, device="cuda", normalize=True, absolute=False):
37
+ with torch.no_grad():
38
+ sig = torch.ones([1, 1, 1, 1], device=device) * scale
39
+
40
+ noise = torch.randn((N_est, 1, 30, 30), device=device)
41
+ if absolute:
42
+ #noisy_new = torch.abs(C[None, :, :] + noise *scale)
43
+ noisy_new = torch.abs(torch.abs(C[None, :, :]) + noise * scale)
44
+ else:
45
+ noisy_new = C[None, :, :] + noise *scale
46
+ denoised = diff_model.net(noisy_new, sig)
47
+ mask_squared = torch.mean(torch.abs(noisy_new - denoised)/(2*scale), dim=0)/torch.mean(torch.abs(noisy_new), dim=0)
48
+ mask_median = torch.median(mask_squared).item()
49
+ if normalize:
50
+ mask_median = torch.median(mask_squared).item()
51
+ M_denoised = torch.sqrt(torch.clamp(mask_squared-mask_median/2, 0, mask_median*2))
52
+ else:
53
+ M_denoised = torch.sqrt(mask_squared)
54
+ return M_denoised-M_denoised.min()#, mask_median*2)
55
+
56
+
57
+ class RegularizedFMNet(nn.Module):
58
+ """Compute the functional map matrix representation."""
59
+
60
+ def __init__(self, lambda_=1e-3, resolvant_gamma=0.5, use_resolvent=False):
61
+ super().__init__()
62
+ self.lambda_ = lambda_
63
+ self.use_resolvent = use_resolvent
64
+ self.resolvant_gamma = resolvant_gamma
65
+
66
+ def forward(self, feat_x, feat_y, evals_x, evals_y, evecs_trans_x, evecs_trans_y, diff_conf=None):
67
+ # compute linear operator matrix representation C1 and C2
68
+ evecs_trans_x, evecs_trans_y = evecs_trans_x.unsqueeze(0), evecs_trans_y.unsqueeze(0)
69
+ evals_x, evals_y = evals_x.unsqueeze(0), evals_y.unsqueeze(0)
70
+
71
+ F_hat = torch.bmm(evecs_trans_x, feat_x)
72
+ G_hat = torch.bmm(evecs_trans_y, feat_y)
73
+ A, B = F_hat, G_hat
74
+
75
+
76
+
77
+ if diff_conf is not None:
78
+ diff_model, scale, normalize, absolute, N_est = diff_conf
79
+ C12_raw, C21_raw = get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y)
80
+ D12 = get_mask_noise(C12_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
81
+ D21 = get_mask_noise(C21_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
82
+ elif self.use_resolvent:
83
+ D12 = get_mask(evals_x.flatten(), evals_y.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
84
+ D21 = get_mask(evals_y.flatten(), evals_x.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
85
+ else:
86
+ D12 = (torch.unsqueeze(evals_y, 2) - torch.unsqueeze(evals_x, 1))**2
87
+ D21 = (torch.unsqueeze(evals_x, 2) - torch.unsqueeze(evals_y, 1))**2
88
+
89
+ A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
90
+ A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
91
+ B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
92
+
93
+ C12_i = []
94
+ for i in range(evals_x.size(1)):
95
+ D12_i = torch.cat([torch.diag(D12[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0)
96
+ C12 = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D12_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2))
97
+ C12_i.append(C12.transpose(1, 2))
98
+ C12 = torch.cat(C12_i, dim=1)
99
+
100
+ C21_i = []
101
+ for i in range(evals_y.size(1)):
102
+ D21_i = torch.cat([torch.diag(D21[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_y.size(0))], dim=0)
103
+ C21 = torch.bmm(torch.inverse(B_B_t + self.lambda_ * D21_i), A_B_t[:, i, :].unsqueeze(1).transpose(1, 2))
104
+ C21_i.append(C21.transpose(1, 2))
105
+ C21 = torch.cat(C21_i, dim=1)
106
+ if diff_conf is not None:
107
+ return [C12_raw, C12, D12], [C21_raw, C21, D21]
108
+ else:
109
+ return [C12, C21]
110
+
111
+
112
+ class DFMNet(nn.Module):
113
+ """
114
+ Compilation of the global model :
115
+ - diffusion net as feature extractor
116
+ - fmap + q-fmap
117
+ - unsupervised loss
118
+ """
119
+
120
+ def __init__(self, cfg):
121
+ super().__init__()
122
+
123
+ # feature extractor #
124
+ with_grad=True
125
+ self.feat = cfg["feat"]
126
+
127
+ self.feature_extractor = DiffusionNet(
128
+ C_in=cfg["C_in"],
129
+ C_out=cfg["n_feat"],
130
+ C_width=128,
131
+ N_block=4,
132
+ dropout=True,
133
+ with_gradient_features=with_grad,
134
+ with_gradient_rotations=with_grad,
135
+ )
136
+
137
+ # regularized fmap
138
+ self.fmreg_net = RegularizedFMNet(lambda_=cfg["lambda_"],
139
+ resolvant_gamma=cfg.get("resolvent_gamma", 0.5), use_resolvent=cfg.get("use_resolvent", False))
140
+ # parameters
141
+ self.n_fmap = cfg["n_fmap"]
142
+ if cfg.get("diffusion", None) is not None:
143
+ self.normalize = cfg["diffusion"]["normalize"]
144
+ self.abs = cfg["diffusion"]["abs"]
145
+ self.N_est = cfg.diffusion.get("batch_mask", 100)
146
+
147
+ def forward(self, batch, diff_model=None, scale=1):
148
+ if self.feat == "xyz":
149
+ feat_1, feat_2 = batch["shape1"]["vertices"], batch["shape2"]["vertices"]
150
+ elif self.feat == "wks":
151
+ feat_1, feat_2 = batch["shape1"]["wks"], batch["shape2"]["wks"]
152
+ elif self.feat == "hks":
153
+ feat_1, feat_2 = batch["shape1"]["hks"], batch["shape2"]["hks"]
154
+ else:
155
+ raise Exception("Unknow Feature")
156
+
157
+
158
+ verts1, faces1, mass1, L1, evals1, evecs1, gradX1, gradY1 = (feat_1, batch["shape1"]["faces"],
159
+ batch["shape1"]["mass"], batch["shape1"]["L"],
160
+ batch["shape1"]["evals"], batch["shape1"]["evecs"],
161
+ batch["shape1"]["gradX"], batch["shape1"]["gradY"])
162
+ verts2, faces2, mass2, L2, evals2, evecs2, gradX2, gradY2 = (feat_2, batch["shape2"]["faces"],
163
+ batch["shape2"]["mass"], batch["shape2"]["L"],
164
+ batch["shape2"]["evals"], batch["shape2"]["evecs"],
165
+ batch["shape2"]["gradX"], batch["shape2"]["gradY"])
166
+
167
+ # set features to vertices
168
+ features1, features2 = verts1, verts2
169
+ # print(features1.shape, features2.shape)
170
+
171
+ feat1 = self.feature_extractor(features1, mass1, L=L1, evals=evals1, evecs=evecs1,
172
+ gradX=gradX1, gradY=gradY1, faces=faces1).unsqueeze(0)
173
+ feat2 = self.feature_extractor(features2, mass2, L=L2, evals=evals2, evecs=evecs2,
174
+ gradX=gradX2, gradY=gradY2, faces=faces2).unsqueeze(0)
175
+
176
+ evecs_trans1, evecs_trans2 = evecs1.t()[:self.n_fmap] @ torch.diag(mass1.squeeze()), evecs2.squeeze().t()[:self.n_fmap].squeeze() @ torch.diag(mass2.squeeze())
177
+ evals1, evals2 = evals1[:self.n_fmap], evals2[:self.n_fmap]
178
+
179
+ #
180
+ if diff_model is not None:
181
+ C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2,
182
+ diff_conf=[diff_model, scale, self.normalize, self.abs, self.N_est])
183
+ else:
184
+ C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2)
185
+
186
+ return C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2
shape_models/geometry.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import scipy.sparse.linalg as sla
3
+ # ^^^ we NEED to import scipy before torch, or it crashes :(
4
+ # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
5
+
6
+ import os.path
7
+ import sys
8
+ import random
9
+ from multiprocessing import Pool
10
+
11
+ import numpy as np
12
+ import scipy.spatial
13
+ import torch
14
+ import sklearn.neighbors
15
+
16
+ import robust_laplacian
17
+ import potpourri3d as pp3d
18
+
19
+
20
+ def norm(x, highdim=False):
21
+ """
22
+ Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
23
+ """
24
+ return torch.norm(x, dim=len(x.shape) - 1)
25
+
26
+
27
+ def norm2(x, highdim=False):
28
+ """
29
+ Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
30
+ """
31
+ return dot(x, x)
32
+
33
+
34
+ def normalize(x, divide_eps=1e-6, highdim=False):
35
+ """
36
+ Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
37
+ """
38
+ if (len(x.shape) == 1):
39
+ raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
40
+ if (not highdim and x.shape[-1] > 4):
41
+ raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
42
+ return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
43
+
44
+
45
+ def face_coords(verts, faces):
46
+ coords = verts[faces]
47
+ return coords
48
+
49
+
50
+ def cross(vec_A, vec_B):
51
+ return torch.cross(vec_A, vec_B, dim=-1)
52
+
53
+
54
+ def dot(vec_A, vec_B):
55
+ return torch.sum(vec_A * vec_B, dim=-1)
56
+
57
+
58
+ # Given (..., 3) vectors and normals, projects out any components of vecs
59
+ # which lies in the direction of normals. Normals are assumed to be unit.
60
+
61
+
62
+ def project_to_tangent(vecs, unit_normals):
63
+ dots = dot(vecs, unit_normals)
64
+ return vecs - unit_normals * dots.unsqueeze(-1)
65
+
66
+
67
+ def face_area(verts, faces):
68
+ coords = face_coords(verts, faces)
69
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
70
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
71
+
72
+ raw_normal = cross(vec_A, vec_B)
73
+ return 0.5 * norm(raw_normal)
74
+
75
+
76
+ def face_normals(verts, faces, normalized=True):
77
+ coords = face_coords(verts, faces)
78
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
79
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
80
+
81
+ raw_normal = cross(vec_A, vec_B)
82
+
83
+ if normalized:
84
+ return normalize(raw_normal)
85
+
86
+ return raw_normal
87
+
88
+
89
+ def neighborhood_normal(points):
90
+ # points: (N, K, 3) array of neighborhood psoitions
91
+ # points should be centered at origin
92
+ # out: (N,3) array of normals
93
+ # numpy in, numpy out
94
+ (u, s, vh) = np.linalg.svd(points, full_matrices=False)
95
+ normal = vh[:, 2, :]
96
+ return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
97
+
98
+
99
+ def mesh_vertex_normals(verts, faces):
100
+ # numpy in / out
101
+ face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces))) # ugly torch <---> numpy
102
+
103
+ vertex_normals = np.zeros(verts.shape)
104
+ for i in range(3):
105
+ np.add.at(vertex_normals, faces[:, i], face_n)
106
+
107
+ vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=-1, keepdims=True)
108
+
109
+ return vertex_normals
110
+
111
+
112
+ def vertex_normals(verts, faces, n_neighbors_cloud=30):
113
+ verts_np = toNP(verts)
114
+
115
+ if faces.numel() == 0: # point cloud
116
+
117
+ _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
118
+ neigh_points = verts_np[neigh_inds, :]
119
+ neigh_points = neigh_points - verts_np[:, np.newaxis, :]
120
+ normals = neighborhood_normal(neigh_points)
121
+
122
+ else: # mesh
123
+
124
+ normals = mesh_vertex_normals(verts_np, toNP(faces))
125
+
126
+ # if any are NaN, wiggle slightly and recompute
127
+ bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
128
+ if bad_normals_mask.any():
129
+ bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
130
+ scale = np.linalg.norm(bbox) * 1e-4
131
+ wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
132
+ wiggle_verts = verts_np + bad_normals_mask * wiggle
133
+ normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
134
+
135
+ # if still NaN assign random normals (probably means unreferenced verts in mesh)
136
+ bad_normals_mask = np.isnan(normals).any(axis=1)
137
+ if bad_normals_mask.any():
138
+ normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
139
+ normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
140
+
141
+ normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
142
+
143
+ if torch.any(torch.isnan(normals)):
144
+ raise ValueError("NaN normals :(")
145
+
146
+ return normals
147
+
148
+
149
+ def build_tangent_frames(verts, faces, normals=None):
150
+
151
+ V = verts.shape[0]
152
+ dtype = verts.dtype
153
+ device = verts.device
154
+
155
+ if normals == None:
156
+ vert_normals = vertex_normals(verts, faces) # (V,3)
157
+ else:
158
+ vert_normals = normals
159
+
160
+ # = find an orthogonal basis
161
+
162
+ basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
163
+ basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
164
+
165
+ basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
166
+ basisX = project_to_tangent(basisX, vert_normals)
167
+ basisX = normalize(basisX)
168
+ basisY = cross(vert_normals, basisX)
169
+ frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
170
+
171
+ if torch.any(torch.isnan(frames)):
172
+ raise ValueError("NaN coordinate frame! Must be very degenerate")
173
+
174
+ return frames
175
+
176
+
177
+ def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
178
+
179
+ verts_np = toNP(verts)
180
+ frames_np = toNP(frames)
181
+
182
+ _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
183
+ neigh_points = verts_np[neigh_inds, :]
184
+ neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
185
+
186
+ # TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
187
+
188
+ edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
189
+ edges = np.stack((edge_inds_from, neigh_inds.flatten()))
190
+ edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
191
+
192
+ return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
193
+
194
+
195
+ def edge_tangent_vectors(verts, frames, edges):
196
+ edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
197
+ basisX = frames[edges[0, :], 0, :]
198
+ basisY = frames[edges[0, :], 1, :]
199
+
200
+ compX = dot(edge_vecs, basisX)
201
+ compY = dot(edge_vecs, basisY)
202
+ edge_tangent = torch.stack((compX, compY), dim=-1)
203
+
204
+ return edge_tangent
205
+
206
+
207
+ def build_grad(verts, edges, edge_tangent_vectors):
208
+ """
209
+ Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
210
+ - edges: (2, E)
211
+ """
212
+
213
+ edges_np = toNP(edges)
214
+ edge_tangent_vectors_np = toNP(edge_tangent_vectors)
215
+
216
+ # TODO find a way to do this in pure numpy?
217
+
218
+ # Build outgoing neighbor lists
219
+ N = verts.shape[0]
220
+ vert_edge_outgoing = [[] for i in range(N)]
221
+ for iE in range(edges_np.shape[1]):
222
+ tail_ind = edges_np[0, iE]
223
+ tip_ind = edges_np[1, iE]
224
+ if tip_ind != tail_ind:
225
+ vert_edge_outgoing[tail_ind].append(iE)
226
+
227
+ # Build local inversion matrix for each vertex
228
+ row_inds = []
229
+ col_inds = []
230
+ data_vals = []
231
+ eps_reg = 1e-5
232
+ for iV in range(N):
233
+ n_neigh = len(vert_edge_outgoing[iV])
234
+
235
+ lhs_mat = np.zeros((n_neigh, 2))
236
+ rhs_mat = np.zeros((n_neigh, n_neigh + 1))
237
+ ind_lookup = [iV]
238
+ for i_neigh in range(n_neigh):
239
+ iE = vert_edge_outgoing[iV][i_neigh]
240
+ jV = edges_np[1, iE]
241
+ ind_lookup.append(jV)
242
+
243
+ edge_vec = edge_tangent_vectors[iE][:]
244
+ w_e = 1.
245
+
246
+ lhs_mat[i_neigh][:] = w_e * edge_vec
247
+ rhs_mat[i_neigh][0] = w_e * (-1)
248
+ rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
249
+
250
+ lhs_T = lhs_mat.T
251
+ lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
252
+
253
+ sol_mat = lhs_inv @ rhs_mat
254
+ sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
255
+
256
+ for i_neigh in range(n_neigh + 1):
257
+ i_glob = ind_lookup[i_neigh]
258
+
259
+ row_inds.append(iV)
260
+ col_inds.append(i_glob)
261
+ data_vals.append(sol_coefs[i_neigh])
262
+
263
+ # build the sparse matrix
264
+ row_inds = np.array(row_inds)
265
+ col_inds = np.array(col_inds)
266
+ data_vals = np.array(data_vals)
267
+ mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
268
+
269
+ return mat
270
+
271
+
272
+ def compute_operators(verts, faces, k_eig, normals=None):
273
+ """
274
+ Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
275
+
276
+ See get_operators() for a similar routine that wraps this one with a layer of caching.
277
+
278
+ Torch in / torch out.
279
+
280
+ Arguments:
281
+ - vertices: (V,3) vertex positions
282
+ - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
283
+ - k_eig: number of eigenvectors to use
284
+
285
+ Returns:
286
+ - frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
287
+ - massvec: (V) real diagonal of lumped mass matrix
288
+ - L: (VxV) real sparse matrix of (weak) Laplacian
289
+ - evals: (k) list of eigenvalues of the Laplacian
290
+ - evecs: (V,k) list of eigenvectors of the Laplacian
291
+ - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
292
+ - gradY: same as gradX but for Y-component of gradient
293
+
294
+ PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
295
+
296
+ Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
297
+ """
298
+
299
+ device = verts.device
300
+ dtype = verts.dtype
301
+ V = verts.shape[0]
302
+ is_cloud = faces.numel() == 0
303
+
304
+ eps = 1e-8
305
+
306
+ verts_np = toNP(verts).astype(np.float64)
307
+ faces_np = toNP(faces)
308
+ frames = build_tangent_frames(verts, faces, normals=normals)
309
+ frames_np = toNP(frames)
310
+
311
+ # Build the scalar Laplacian
312
+ if is_cloud:
313
+ L, M = robust_laplacian.point_cloud_laplacian(verts_np)
314
+ massvec_np = M.diagonal()
315
+ else:
316
+ # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
317
+ # massvec_np = M.diagonal()
318
+ L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
319
+ massvec_np = pp3d.vertex_areas(verts_np, faces_np)
320
+ massvec_np += eps * np.mean(massvec_np)
321
+
322
+ if (np.isnan(L.data).any()):
323
+ raise RuntimeError("NaN Laplace matrix")
324
+ if (np.isnan(massvec_np).any()):
325
+ raise RuntimeError("NaN mass matrix")
326
+
327
+ # Read off neighbors & rotations from the Laplacian
328
+ L_coo = L.tocoo()
329
+ inds_row = L_coo.row
330
+ inds_col = L_coo.col
331
+
332
+ # === Compute the eigenbasis
333
+ if k_eig > 0:
334
+
335
+ # Prepare matrices
336
+ L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
337
+ massvec_eigsh = massvec_np
338
+ Mmat = scipy.sparse.diags(massvec_eigsh)
339
+ eigs_sigma = eps
340
+
341
+ failcount = 0
342
+ while True:
343
+ try:
344
+ # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
345
+ evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
346
+
347
+ # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
348
+ evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
349
+
350
+ break
351
+ except Exception as e:
352
+ print(e)
353
+ if (failcount > 3):
354
+ raise ValueError("failed to compute eigendecomp")
355
+ failcount += 1
356
+ print("--- decomp failed; adding eps ===> count: " + str(failcount))
357
+ L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
358
+
359
+ else: #k_eig == 0
360
+ evals_np = np.zeros((0))
361
+ evecs_np = np.zeros((verts.shape[0], 0))
362
+
363
+ # == Build gradient matrices
364
+
365
+ # For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
366
+ if is_cloud:
367
+ grad_mat_np = build_grad_point_cloud(verts, frames)
368
+ else:
369
+ edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
370
+ edge_vecs = edge_tangent_vectors(verts, frames, edges)
371
+ grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
372
+
373
+ # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
374
+ gradX_np = np.real(grad_mat_np)
375
+ gradY_np = np.imag(grad_mat_np)
376
+
377
+ # === Convert back to torch
378
+ massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
379
+ L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
380
+ evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
381
+ evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
382
+ gradX = utils.sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
383
+ gradY = utils.sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
384
+
385
+ return frames, massvec, L, evals, evecs, gradX, gradY
386
+
387
+
388
+ def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
389
+ N = len(verts_list)
390
+
391
+ frames = [None] * N
392
+ massvec = [None] * N
393
+ L = [None] * N
394
+ evals = [None] * N
395
+ evecs = [None] * N
396
+ gradX = [None] * N
397
+ gradY = [None] * N
398
+
399
+ inds = [i for i in range(N)]
400
+ # process in random order
401
+ # random.shuffle(inds)
402
+
403
+ for num, i in enumerate(inds):
404
+ print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
405
+ if normals is None:
406
+ outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
407
+ else:
408
+ outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
409
+ frames[i] = outputs[0]
410
+ massvec[i] = outputs[1]
411
+ L[i] = outputs[2]
412
+ evals[i] = outputs[3]
413
+ evecs[i] = outputs[4]
414
+ gradX[i] = outputs[5]
415
+ gradY[i] = outputs[6]
416
+
417
+ return frames, massvec, L, evals, evecs, gradX, gradY
418
+
419
+
420
+ def get_operators(verts, faces, k_eig=128, op_cache_dir=None, normals=None, overwrite_cache=False):
421
+ """
422
+ See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
423
+ All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
424
+ """
425
+
426
+ device = verts.device
427
+ dtype = verts.dtype
428
+ verts_np = toNP(verts)
429
+ faces_np = toNP(faces)
430
+ is_cloud = faces.numel() == 0
431
+
432
+ if (np.isnan(verts_np).any()):
433
+ raise RuntimeError("tried to construct operators from NaN verts")
434
+
435
+ # Check the cache directory
436
+ # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
437
+ # but for good measure we check values nonetheless.
438
+ # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
439
+ # entries in this cache. The good news is that that is totally fine, and at most slightly
440
+ # slows performance with rare extra cache misses.
441
+ found = False
442
+ if op_cache_dir is not None:
443
+ utils.ensure_dir_exists(op_cache_dir)
444
+ hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
445
+ # print("Building operators for input with hash: " + hash_key_str)
446
+
447
+ # Search through buckets with matching hashes. When the loop exits, this
448
+ # is the bucket index of the file we should write to.
449
+ i_cache_search = 0
450
+ while True:
451
+
452
+ # Form the name of the file to check
453
+ search_path = os.path.join(op_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
454
+
455
+ try:
456
+ # print('loading path: ' + str(search_path))
457
+ npzfile = np.load(search_path, allow_pickle=True)
458
+ cache_verts = npzfile["verts"]
459
+ cache_faces = npzfile["faces"]
460
+ cache_k_eig = npzfile["k_eig"].item()
461
+
462
+ # If the cache doesn't match, keep looking
463
+ if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
464
+ i_cache_search += 1
465
+ print("hash collision! searching next.")
466
+ continue
467
+
468
+ # print(" cache hit!")
469
+
470
+ # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
471
+ # entry below more eigenvalues
472
+ if overwrite_cache:
473
+ print(" overwriting cache by request")
474
+ os.remove(search_path)
475
+ break
476
+
477
+ if cache_k_eig < k_eig:
478
+ print(" overwriting cache --- not enough eigenvalues")
479
+ os.remove(search_path)
480
+ break
481
+
482
+ if "L_data" not in npzfile:
483
+ print(" overwriting cache --- entries are absent")
484
+ os.remove(search_path)
485
+ break
486
+
487
+ def read_sp_mat(prefix):
488
+ data = npzfile[prefix + "_data"]
489
+ indices = npzfile[prefix + "_indices"]
490
+ indptr = npzfile[prefix + "_indptr"]
491
+ shape = npzfile[prefix + "_shape"]
492
+ mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
493
+ return mat
494
+
495
+ # This entry matches! Return it.
496
+ frames = npzfile["frames"]
497
+ mass = npzfile["mass"]
498
+ L = read_sp_mat("L")
499
+ evals = npzfile["evals"][:k_eig]
500
+ evecs = npzfile["evecs"][:, :k_eig]
501
+ gradX = read_sp_mat("gradX")
502
+ gradY = read_sp_mat("gradY")
503
+
504
+ frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
505
+ mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
506
+ L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
507
+ evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
508
+ evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
509
+ gradX = utils.sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
510
+ gradY = utils.sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
511
+
512
+ found = True
513
+
514
+ break
515
+
516
+ except FileNotFoundError:
517
+ print(" cache miss -- constructing operators")
518
+ break
519
+
520
+ except Exception as E:
521
+ print("unexpected error loading file: " + str(E))
522
+ print("-- constructing operators")
523
+ break
524
+
525
+ if not found:
526
+
527
+ # No matching entry found; recompute.
528
+ frames, mass, L, evals, evecs, gradX, gradY = compute_operators(verts, faces, k_eig, normals=normals)
529
+
530
+ dtype_np = np.float32
531
+
532
+ # Store it in the cache
533
+ if op_cache_dir is not None:
534
+
535
+ L_np = utils.sparse_torch_to_np(L).astype(dtype_np)
536
+ gradX_np = utils.sparse_torch_to_np(gradX).astype(dtype_np)
537
+ gradY_np = utils.sparse_torch_to_np(gradY).astype(dtype_np)
538
+
539
+ np.savez(
540
+ search_path,
541
+ verts=verts_np.astype(dtype_np),
542
+ frames=toNP(frames).astype(dtype_np),
543
+ faces=faces_np,
544
+ k_eig=k_eig,
545
+ mass=toNP(mass).astype(dtype_np),
546
+ L_data=L_np.data.astype(dtype_np),
547
+ L_indices=L_np.indices,
548
+ L_indptr=L_np.indptr,
549
+ L_shape=L_np.shape,
550
+ evals=toNP(evals).astype(dtype_np),
551
+ evecs=toNP(evecs).astype(dtype_np),
552
+ gradX_data=gradX_np.data.astype(dtype_np),
553
+ gradX_indices=gradX_np.indices,
554
+ gradX_indptr=gradX_np.indptr,
555
+ gradX_shape=gradX_np.shape,
556
+ gradY_data=gradY_np.data.astype(dtype_np),
557
+ gradY_indices=gradY_np.indices,
558
+ gradY_indptr=gradY_np.indptr,
559
+ gradY_shape=gradY_np.shape,
560
+ )
561
+
562
+ return frames, mass, L, evals, evecs, gradX, gradY
563
+
564
+
565
+ def to_basis(values, basis, massvec):
566
+ """
567
+ Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
568
+ Inputs:
569
+ - values: (B,V,D)
570
+ - basis: (B,V,K)
571
+ - massvec: (B,V)
572
+ Outputs:
573
+ - (B,K,D) transformed values
574
+ """
575
+ basisT = basis.transpose(-2, -1)
576
+ return torch.matmul(basisT, values * massvec.unsqueeze(-1))
577
+
578
+
579
+ def from_basis(values, basis):
580
+ """
581
+ Transform data out of an orthonormal basis
582
+ Inputs:
583
+ - values: (K,D)
584
+ - basis: (V,K)
585
+ Outputs:
586
+ - (V,D) reconstructed values
587
+ """
588
+ if values.is_complex() or basis.is_complex():
589
+ return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
590
+ else:
591
+ return torch.matmul(basis, values)
592
+
593
+
594
+ def compute_hks(evals, evecs, scales):
595
+ """
596
+ Inputs:
597
+ - evals: (K) eigenvalues
598
+ - evecs: (V,K) values
599
+ - scales: (S) times
600
+ Outputs:
601
+ - (V,S) hks values
602
+ """
603
+
604
+ # expand batch
605
+ if len(evals.shape) == 1:
606
+ expand_batch = True
607
+ evals = evals.unsqueeze(0)
608
+ evecs = evecs.unsqueeze(0)
609
+ scales = scales.unsqueeze(0)
610
+ else:
611
+ expand_batch = False
612
+
613
+ # TODO could be a matmul
614
+ power_coefs = torch.exp(-evals.unsqueeze(1) * scales.unsqueeze(-1)).unsqueeze(1) # (B,1,S,K)
615
+ terms = power_coefs * (evecs * evecs).unsqueeze(2) # (B,V,S,K)
616
+
617
+ out = torch.sum(terms, dim=-1) # (B,V,S)
618
+
619
+ if expand_batch:
620
+ return out.squeeze(0)
621
+ else:
622
+ return out
623
+
624
+
625
+ def compute_hks_autoscale(evals, evecs, count):
626
+ # these scales roughly approximate those suggested in the hks paper
627
+ scales = torch.logspace(-2, 0., steps=count, device=evals.device, dtype=evals.dtype)
628
+ return compute_hks(evals, evecs, scales)
629
+
630
+
631
+ def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
632
+ # center and unit-scale positions
633
+
634
+ if method == 'mean':
635
+ # center using the average point position
636
+ pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
637
+ elif method == 'bbox':
638
+ # center via the middle of the axis-aligned bounding box
639
+ bbox_min = torch.min(pos, dim=-2).values
640
+ bbox_max = torch.max(pos, dim=-2).values
641
+ center = (bbox_max + bbox_min) / 2.
642
+ pos -= center.unsqueeze(-2)
643
+ else:
644
+ raise ValueError("unrecognized method")
645
+
646
+ if scale_method == 'max_rad':
647
+ scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
648
+ pos = pos / scale
649
+ elif scale_method == 'area':
650
+ if faces is None:
651
+ raise ValueError("must pass faces for area normalization")
652
+ coords = pos[faces]
653
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
654
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
655
+ face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
656
+ total_area = torch.sum(face_areas)
657
+ scale = (1. / torch.sqrt(total_area))
658
+ pos = pos * scale
659
+ else:
660
+ raise ValueError("unrecognized scale method")
661
+ return pos
662
+
663
+
664
+ # Finds the k nearest neighbors of source on target.
665
+ # Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
666
+ def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
667
+
668
+ if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
669
+ raise ValueError("omit_diagonal can only be used when source and target are same shape")
670
+
671
+ if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
672
+ method = 'cpu_kd'
673
+ print("switching to cpu_kd knn")
674
+
675
+ if method == 'brute':
676
+
677
+ # Expand so both are NxMx3 tensor
678
+ points_source_expand = points_source.unsqueeze(1)
679
+ points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
680
+ points_target_expand = points_target.unsqueeze(0)
681
+ points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
682
+
683
+ diff_mat = points_source_expand - points_target_expand
684
+ dist_mat = norm(diff_mat)
685
+
686
+ if omit_diagonal:
687
+ torch.diagonal(dist_mat)[:] = float('inf')
688
+
689
+ result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
690
+ return result
691
+
692
+ elif method == 'cpu_kd':
693
+
694
+ if largest:
695
+ raise ValueError("can't do largest with cpu_kd")
696
+
697
+ points_source_np = toNP(points_source)
698
+ points_target_np = toNP(points_target)
699
+
700
+ # Build the tree
701
+ kd_tree = sklearn.neighbors.KDTree(points_target_np)
702
+
703
+ k_search = k + 1 if omit_diagonal else k
704
+ _, neighbors = kd_tree.query(points_source_np, k=k_search)
705
+
706
+ if omit_diagonal:
707
+ # Mask out self element
708
+ mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
709
+
710
+ # make sure we mask out exactly one element in each row, in rare case of many duplicate points
711
+ mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
712
+
713
+ neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
714
+
715
+ inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
716
+ dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
717
+
718
+ return dists, inds
719
+
720
+ else:
721
+ raise ValueError("unrecognized method")
722
+
723
+
724
+ def farthest_point_sampling(points, n_sample):
725
+ # Torch in, torch out. Returns a |V| mask with n_sample elements set to true.
726
+
727
+ N = points.shape[0]
728
+ if (n_sample > N):
729
+ raise ValueError("not enough points to sample")
730
+
731
+ chosen_mask = torch.zeros(N, dtype=torch.bool, device=points.device)
732
+ min_dists = torch.ones(N, dtype=points.dtype, device=points.device) * float('inf')
733
+
734
+ # pick the centermost first point
735
+ points = normalize_positions(points)
736
+ i = torch.min(norm2(points), dim=0).indices
737
+ chosen_mask[i] = True
738
+
739
+ for _ in range(n_sample - 1):
740
+
741
+ # update distance
742
+ dists = norm2(points[i, :].unsqueeze(0) - points)
743
+ min_dists = torch.minimum(dists, min_dists)
744
+
745
+ # take the farthest
746
+ i = torch.max(min_dists, dim=0).indices.item()
747
+ chosen_mask[i] = True
748
+
749
+ return chosen_mask
750
+
751
+
752
+ def geodesic_label_errors(target_verts,
753
+ target_faces,
754
+ pred_labels,
755
+ gt_labels,
756
+ normalization='diameter',
757
+ geodesic_cache_dir=None):
758
+ """
759
+ Return a vector of distances between predicted and ground-truth lables (normalized by geodesic diameter or area)
760
+
761
+ This method is SLOW when it needs to recompute geodesic distances.
762
+ """
763
+
764
+ # move all to numpy cpu
765
+ target_verts = toNP(target_verts)
766
+ target_faces = toNP(target_faces)
767
+
768
+ pred_labels = toNP(pred_labels)
769
+ gt_labels = toNP(gt_labels)
770
+
771
+ dists = get_all_pairs_geodesic_distance(target_verts, target_faces, geodesic_cache_dir)
772
+
773
+ result_dists = dists[pred_labels, gt_labels]
774
+
775
+ if normalization == 'diameter':
776
+ geodesic_diameter = np.max(dists)
777
+ normalized_result_dists = result_dists / geodesic_diameter
778
+ elif normalization == 'area':
779
+ total_area = torch.sum(face_area(torch.tensor(target_verts), torch.tensor(target_faces)))
780
+ normalized_result_dists = result_dists / torch.sqrt(total_area)
781
+ else:
782
+ raise ValueError('unrecognized normalization')
783
+
784
+ return normalized_result_dists
785
+
786
+
787
+ # This function and the helper class below are to support parallel computation of all-pairs geodesic distance
788
+ def all_pairs_geodesic_worker(verts, faces, i):
789
+ import igl
790
+
791
+ N = verts.shape[0]
792
+
793
+ # TODO: this re-does a ton of work, since it is called independently each time. Some custom C++ code could surely make it faster.
794
+ sources = np.array([i])[:, np.newaxis]
795
+ targets = np.arange(N)[:, np.newaxis]
796
+ dist_vec = igl.exact_geodesic(verts, faces, sources, targets)
797
+
798
+ return dist_vec
799
+
800
+
801
+ class AllPairsGeodesicEngine(object):
802
+
803
+ def __init__(self, verts, faces):
804
+ self.verts = verts
805
+ self.faces = faces
806
+
807
+ def __call__(self, i):
808
+ return all_pairs_geodesic_worker(self.verts, self.faces, i)
809
+
810
+
811
+ def get_all_pairs_geodesic_distance(verts_np, faces_np, geodesic_cache_dir=None):
812
+ """
813
+ Return a gigantic VxV dense matrix containing the all-pairs geodesic distance matrix. Internally caches, recomputing only if necessary.
814
+
815
+ (numpy in, numpy out)
816
+ """
817
+
818
+ # need libigl for geodesic call
819
+ try:
820
+ import igl
821
+ except ImportError as e:
822
+ raise ImportError("Must have python libigl installed for all-pairs geodesics. `conda install -c conda-forge igl`")
823
+
824
+ # Check the cache
825
+ found = False
826
+ if geodesic_cache_dir is not None:
827
+ utils.ensure_dir_exists(geodesic_cache_dir)
828
+ hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
829
+ # print("Building operators for input with hash: " + hash_key_str)
830
+
831
+ # Search through buckets with matching hashes. When the loop exits, this
832
+ # is the bucket index of the file we should write to.
833
+ i_cache_search = 0
834
+ while True:
835
+
836
+ # Form the name of the file to check
837
+ search_path = os.path.join(geodesic_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
838
+
839
+ try:
840
+ npzfile = np.load(search_path, allow_pickle=True)
841
+ cache_verts = npzfile["verts"]
842
+ cache_faces = npzfile["faces"]
843
+
844
+ # If the cache doesn't match, keep looking
845
+ if (not np.array_equal(verts_np, cache_verts)) or (not np.array_equal(faces_np, cache_faces)):
846
+ i_cache_search += 1
847
+ continue
848
+
849
+ # This entry matches! Return it.
850
+ found = True
851
+ result_dists = npzfile["dist"]
852
+ break
853
+
854
+ except FileNotFoundError:
855
+ break
856
+
857
+ if not found:
858
+
859
+ print("Computing all-pairs geodesic distance (warning: SLOW!)")
860
+
861
+ # Not found, compute from scratch
862
+ # warning: slowwwwwww
863
+
864
+ N = verts_np.shape[0]
865
+
866
+ try:
867
+ pool = Pool(None) # on 8 processors
868
+ engine = AllPairsGeodesicEngine(verts_np, faces_np)
869
+ outputs = pool.map(engine, range(N))
870
+ finally: # To make sure processes are closed in the end, even if errors happen
871
+ pool.close()
872
+ pool.join()
873
+
874
+ result_dists = np.array(outputs)
875
+
876
+ # replace any failed values with nan
877
+ result_dists = np.nan_to_num(result_dists, nan=np.nan, posinf=np.nan, neginf=np.nan)
878
+
879
+ # we expect that this should be a symmetric matrix, but it might not be. Take the min of the symmetric values to make it symmetric
880
+ result_dists = np.fmin(result_dists, np.transpose(result_dists))
881
+
882
+ # on rare occaisions MMP fails, yielding nan/inf; set it to the largest non-failed value if so
883
+ max_dist = np.nanmax(result_dists)
884
+ result_dists = np.nan_to_num(result_dists, nan=max_dist, posinf=max_dist, neginf=max_dist)
885
+
886
+ print("...finished computing all-pairs geodesic distance")
887
+
888
+ # put it in the cache if possible
889
+ if geodesic_cache_dir is not None:
890
+
891
+ print("saving geodesic distances to cache: " + str(geodesic_cache_dir))
892
+
893
+ # TODO we're potentially saving a double precision but only using a single
894
+ # precision here; could save storage by always saving as floats
895
+ np.savez(search_path, verts=verts_np, faces=faces_np, dist=result_dists)
896
+
897
+ return result_dists
shape_models/layers.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import os.path
4
+ import random
5
+
6
+ import scipy
7
+ import scipy.sparse.linalg as sla
8
+ # ^^^ we NEED to import scipy before torch, or it crashes :(
9
+ # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+
17
+ def to_basis(values, basis, massvec):
18
+ """
19
+ Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
20
+ Inputs:
21
+ - values: (B,V,D)
22
+ - basis: (B,V,K)
23
+ - massvec: (B,V)
24
+ Outputs:
25
+ - (B,K,D) transformed values
26
+ """
27
+ basisT = basis.transpose(-2, -1)
28
+ return torch.matmul(basisT, values * massvec.unsqueeze(-1))
29
+
30
+
31
+ def from_basis(values, basis):
32
+ """
33
+ Transform data out of an orthonormal basis
34
+ Inputs:
35
+ - values: (K,D)
36
+ - basis: (V,K)
37
+ Outputs:
38
+ - (V,D) reconstructed values
39
+ """
40
+ if values.is_complex() or basis.is_complex():
41
+ return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
42
+ else:
43
+ return torch.matmul(basis, values)
44
+
45
+
46
+ class LearnedTimeDiffusion(nn.Module):
47
+ """
48
+ Applies diffusion with learned per-channel t.
49
+
50
+ In the spectral domain this becomes
51
+ f_out = e ^ (lambda_i t) f_in
52
+
53
+ Inputs:
54
+ - values: (V,C) in the spectral domain
55
+ - L: (V,V) sparse laplacian
56
+ - evals: (K) eigenvalues
57
+ - mass: (V) mass matrix diagonal
58
+
59
+ (note: L/evals may be omitted as None depending on method)
60
+ Outputs:
61
+ - (V,C) diffused values
62
+ """
63
+
64
+ def __init__(self, C_inout, method='spectral'):
65
+ super(LearnedTimeDiffusion, self).__init__()
66
+ self.C_inout = C_inout
67
+ self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C)
68
+ self.method = method # one of ['spectral', 'implicit_dense']
69
+
70
+ nn.init.constant_(self.diffusion_time, 0.0)
71
+
72
+ def forward(self, x, L, mass, evals, evecs):
73
+
74
+ # project times to the positive halfspace
75
+ # (and away from 0 in the incredibly rare chance that they get stuck)
76
+ with torch.no_grad():
77
+ self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
78
+
79
+ if x.shape[-1] != self.C_inout:
80
+ raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
81
+ x.shape, self.C_inout))
82
+
83
+ if self.method == 'spectral':
84
+
85
+ # Transform to spectral
86
+ x_spec = to_basis(x, evecs, mass)
87
+
88
+ # Diffuse
89
+ time = self.diffusion_time
90
+ diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
91
+ x_diffuse_spec = diffusion_coefs * x_spec
92
+
93
+ # Transform back to per-vertex
94
+ x_diffuse = from_basis(x_diffuse_spec, evecs)
95
+
96
+ elif self.method == 'implicit_dense':
97
+ V = x.shape[-2]
98
+
99
+ # Form the dense matrices (M + tL) with dims (B,C,V,V)
100
+ mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
101
+ mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
102
+ mat_dense += torch.diag_embed(mass).unsqueeze(1)
103
+
104
+ # Factor the system
105
+ cholesky_factors = torch.linalg.cholesky(mat_dense)
106
+
107
+ # Solve the system
108
+ rhs = x * mass.unsqueeze(-1)
109
+ rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
110
+ sols = torch.cholesky_solve(rhsT, cholesky_factors)
111
+ x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
112
+
113
+ else:
114
+ raise ValueError("unrecognized method")
115
+
116
+ return x_diffuse
117
+
118
+
119
+ class SpatialGradientFeatures(nn.Module):
120
+ """
121
+ Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
122
+
123
+ Input:
124
+ - vectors: (V,C,2)
125
+ Output:
126
+ - dots: (V,C) dots
127
+ """
128
+
129
+ def __init__(self, C_inout, with_gradient_rotations=True):
130
+ super(SpatialGradientFeatures, self).__init__()
131
+
132
+ self.C_inout = C_inout
133
+ self.with_gradient_rotations = with_gradient_rotations
134
+
135
+ if (self.with_gradient_rotations):
136
+ self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
137
+ self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
138
+ else:
139
+ self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
140
+
141
+ # self.norm = nn.InstanceNorm1d(C_inout)
142
+
143
+ def forward(self, vectors):
144
+
145
+ vectorsA = vectors # (V,C)
146
+
147
+ if self.with_gradient_rotations:
148
+ vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
149
+ vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
150
+ else:
151
+ vectorsBreal = self.A(vectors[..., 0])
152
+ vectorsBimag = self.A(vectors[..., 1])
153
+
154
+ dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
155
+
156
+ return torch.tanh(dots)
157
+
158
+
159
+ class MiniMLP(nn.Sequential):
160
+ '''
161
+ A simple MLP with configurable hidden layer sizes.
162
+ '''
163
+
164
+ def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
165
+ super(MiniMLP, self).__init__()
166
+
167
+ for i in range(len(layer_sizes) - 1):
168
+ is_last = (i + 2 == len(layer_sizes))
169
+
170
+ if dropout and i > 0:
171
+ self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
172
+
173
+ # Affine map
174
+ self.add_module(
175
+ name + "_mlp_layer_{:03d}".format(i),
176
+ nn.Linear(
177
+ layer_sizes[i],
178
+ layer_sizes[i + 1],
179
+ ),
180
+ )
181
+
182
+ # Nonlinearity
183
+ # (but not on the last layer)
184
+ if not is_last:
185
+ self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
186
+
187
+
188
+ class DiffusionNetBlock(nn.Module):
189
+ """
190
+ Inputs and outputs are defined at vertices
191
+ """
192
+
193
+ def __init__(self,
194
+ C_width,
195
+ mlp_hidden_dims,
196
+ dropout=True,
197
+ diffusion_method='spectral',
198
+ with_gradient_features=True,
199
+ with_gradient_rotations=True):
200
+ super(DiffusionNetBlock, self).__init__()
201
+
202
+ # Specified dimensions
203
+ self.C_width = C_width
204
+ self.mlp_hidden_dims = mlp_hidden_dims
205
+
206
+ self.dropout = dropout
207
+ self.with_gradient_features = with_gradient_features
208
+ self.with_gradient_rotations = with_gradient_rotations
209
+
210
+ # Diffusion block
211
+ self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
212
+
213
+ self.MLP_C = 2 * self.C_width
214
+
215
+ if self.with_gradient_features:
216
+ self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
217
+ self.MLP_C += self.C_width
218
+
219
+ # MLPs
220
+ self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
221
+
222
+ def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
223
+
224
+ # Manage dimensions
225
+ B = x_in.shape[0] # batch dimension
226
+ if x_in.shape[-1] != self.C_width:
227
+ raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
228
+ x_in.shape, self.C_width))
229
+
230
+ # Diffusion block
231
+ x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
232
+
233
+ # Compute gradient features, if using
234
+ if self.with_gradient_features:
235
+
236
+ # Compute gradients
237
+ x_grads = [
238
+ ] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
239
+ for b in range(B):
240
+ # gradient after diffusion
241
+ x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
242
+ x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
243
+
244
+ x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
245
+ x_grad = torch.stack(x_grads, dim=0)
246
+
247
+ # Evaluate gradient features
248
+ x_grad_features = self.gradient_features(x_grad)
249
+
250
+ # Stack inputs to mlp
251
+ feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
252
+ else:
253
+ # Stack inputs to mlp
254
+ feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
255
+
256
+ # Apply the mlp
257
+ x0_out = self.mlp(feature_combined)
258
+
259
+ # Skip connection
260
+ x0_out = x0_out + x_in
261
+
262
+ return x0_out
263
+
264
+
265
+ class DiffusionNet(nn.Module):
266
+
267
+ def __init__(self,
268
+ C_in,
269
+ C_out,
270
+ C_width=128,
271
+ N_block=4,
272
+ last_activation=None,
273
+ outputs_at='vertices',
274
+ mlp_hidden_dims=None,
275
+ dropout=True,
276
+ with_gradient_features=True,
277
+ with_gradient_rotations=True,
278
+ diffusion_method='spectral',
279
+ num_eigenbasis=128):
280
+ """
281
+ Construct a DiffusionNet.
282
+
283
+ Parameters:
284
+ C_in (int): input dimension
285
+ C_out (int): output dimension
286
+ last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
287
+ outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
288
+ C_width (int): dimension of internal DiffusionNet blocks (default: 128)
289
+ N_block (int): number of DiffusionNet blocks (default: 4)
290
+ mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width])
291
+ dropout (bool): if True, internal MLPs use dropout (default: True)
292
+ diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
293
+ with_gradient_features (bool): if True, use gradient features (default: True)
294
+ with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
295
+ num_eigenbasis (int): for trunking the eigenvalues eigenvectors
296
+ """
297
+
298
+ super(DiffusionNet, self).__init__()
299
+
300
+ ## Store parameters
301
+
302
+ # Basic parameters
303
+ self.C_in = C_in
304
+ self.C_out = C_out
305
+ self.C_width = C_width
306
+ self.N_block = N_block
307
+
308
+ # Outputs
309
+ self.last_activation = last_activation
310
+ self.outputs_at = outputs_at
311
+ if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
312
+ raise ValueError("invalid setting for outputs_at")
313
+
314
+ # MLP options
315
+ if mlp_hidden_dims == None:
316
+ mlp_hidden_dims = [C_width, C_width]
317
+ self.mlp_hidden_dims = mlp_hidden_dims
318
+ self.dropout = dropout
319
+
320
+ # Diffusion
321
+ self.diffusion_method = diffusion_method
322
+ if diffusion_method not in ['spectral', 'implicit_dense']:
323
+ raise ValueError("invalid setting for diffusion_method")
324
+ self.num_eigenbasis = num_eigenbasis
325
+
326
+ # Gradient features
327
+ self.with_gradient_features = with_gradient_features
328
+ self.with_gradient_rotations = with_gradient_rotations
329
+
330
+ ## Set up the network
331
+
332
+ # First and last affine layers
333
+ self.first_lin = nn.Linear(C_in, C_width)
334
+ self.last_lin = nn.Linear(C_width, C_out)
335
+
336
+ # DiffusionNet blocks
337
+ self.blocks = []
338
+ for i_block in range(self.N_block):
339
+ block = DiffusionNetBlock(C_width=C_width,
340
+ mlp_hidden_dims=mlp_hidden_dims,
341
+ dropout=dropout,
342
+ diffusion_method=diffusion_method,
343
+ with_gradient_features=with_gradient_features,
344
+ with_gradient_rotations=with_gradient_rotations)
345
+
346
+ self.blocks.append(block)
347
+ self.add_module("block_" + str(i_block), self.blocks[-1])
348
+
349
+ def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
350
+ """
351
+ A forward pass on the DiffusionNet.
352
+
353
+ In the notation below, dimension are:
354
+ - C is the input channel dimension (C_in on construction)
355
+ - C_OUT is the output channel dimension (C_out on construction)
356
+ - N is the number of vertices/points, which CAN be different for each forward pass
357
+ - B is an OPTIONAL batch dimension
358
+ - K_EIG is the number of eigenvalues used for spectral acceleration
359
+ Generally, our data layout it is [N,C] or [B,N,C].
360
+
361
+ Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
362
+
363
+ Parameters:
364
+ x_in (tensor): Input features, dimension [N,C] or [B,N,C]
365
+ mass (tensor): Mass vector, dimension [N] or [B,N]
366
+ L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
367
+ evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
368
+ evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
369
+ gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
370
+ gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
371
+
372
+ Returns:
373
+ x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out]
374
+ """
375
+
376
+ ## Check dimensions, and append batch dimension if not given
377
+ if x_in.shape[-1] != self.C_in:
378
+ raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
379
+ self.C_in, x_in.shape[-1]))
380
+ N = x_in.shape[-2]
381
+ if len(x_in.shape) == 2:
382
+ appended_batch_dim = True
383
+
384
+ # add a batch dim to all inputs
385
+ x_in = x_in.unsqueeze(0)
386
+ mass = mass.unsqueeze(0)
387
+ if L != None:
388
+ L = L.unsqueeze(0)
389
+ if evals != None:
390
+ evals = evals.unsqueeze(0)
391
+ if evecs != None:
392
+ evecs = evecs.unsqueeze(0)
393
+ if gradX != None:
394
+ gradX = gradX.unsqueeze(0)
395
+ if gradY != None:
396
+ gradY = gradY.unsqueeze(0)
397
+ if edges != None:
398
+ edges = edges.unsqueeze(0)
399
+ if faces != None:
400
+ faces = faces.unsqueeze(0)
401
+
402
+ elif len(x_in.shape) == 3:
403
+ appended_batch_dim = False
404
+
405
+ else:
406
+ raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
407
+
408
+ evals = evals[..., :self.num_eigenbasis]
409
+ evecs = evecs[..., :self.num_eigenbasis]
410
+
411
+ # Apply the first linear layer
412
+ x = self.first_lin(x_in)
413
+
414
+ # Apply each of the blocks
415
+ for b in self.blocks:
416
+ x = b(x, mass, L, evals, evecs, gradX, gradY)
417
+
418
+ # Apply the last linear layer
419
+ x = self.last_lin(x)
420
+
421
+ # Remap output to faces/edges if requested
422
+ if self.outputs_at == 'vertices':
423
+ x_out = x
424
+
425
+ elif self.outputs_at == 'edges':
426
+ # Remap to edges
427
+ x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
428
+ edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
429
+ xe = torch.gather(x_gather, 1, edges_gather)
430
+ x_out = torch.mean(xe, dim=-1)
431
+
432
+ elif self.outputs_at == 'faces':
433
+ # Remap to faces
434
+ x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
435
+ faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
436
+ xf = torch.gather(x_gather, 1, faces_gather)
437
+ x_out = torch.mean(xf, dim=-1)
438
+
439
+ elif self.outputs_at == 'global_mean':
440
+ # Produce a single global mean ouput.
441
+ # Using a weighted mean according to the point mass/area is discretization-invariant.
442
+ # (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
443
+ x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
444
+
445
+ # Apply last nonlinearity if specified
446
+ if self.last_activation != None:
447
+ x_out = self.last_activation(x_out)
448
+
449
+ # Remove batch dim if we added it
450
+ if appended_batch_dim:
451
+ x_out = x_out.squeeze(0)
452
+
453
+ return x_out
shape_models/utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import hashlib
7
+ import numpy as np
8
+ import scipy
9
+
10
+ # == Pytorch things
11
+
12
+
13
+ def toNP(x):
14
+ """
15
+ Really, definitely convert a torch tensor to a numpy array
16
+ """
17
+ return x.detach().to(torch.device('cpu')).numpy()
18
+
19
+
20
+ def label_smoothing_log_loss(pred, labels, smoothing=0.0):
21
+ n_class = pred.shape[-1]
22
+ one_hot = torch.zeros_like(pred)
23
+ one_hot[labels] = 1.
24
+ one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
25
+ loss = -(one_hot * pred).sum(dim=-1).mean()
26
+ return loss
27
+
28
+
29
+ # Randomly rotate points.
30
+ # Torch in, torch out
31
+ # Note fornow, builds rotation matrix on CPU.
32
+ def random_rotate_points(pts, randgen=None):
33
+ R = random_rotation_matrix(randgen)
34
+ R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
35
+ return torch.matmul(pts, R)
36
+
37
+
38
+ def random_rotate_points_y(pts):
39
+ angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
40
+ rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
41
+ rot_mats[0, 0] = torch.cos(angles)
42
+ rot_mats[0, 2] = torch.sin(angles)
43
+ rot_mats[2, 0] = -torch.sin(angles)
44
+ rot_mats[2, 2] = torch.cos(angles)
45
+ rot_mats[1, 1] = 1.
46
+
47
+ pts = torch.matmul(pts, rot_mats)
48
+ return pts
49
+
50
+
51
+ # Numpy things
52
+
53
+
54
+ # Numpy sparse matrix to pytorch
55
+ def sparse_np_to_torch(A):
56
+ Acoo = A.tocoo()
57
+ values = Acoo.data
58
+ indices = np.vstack((Acoo.row, Acoo.col))
59
+ shape = Acoo.shape
60
+ return torch.sparse.FloatTensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape)).coalesce()
61
+
62
+
63
+ # Pytorch sparse to numpy csc matrix
64
+ def sparse_torch_to_np(A):
65
+ if len(A.shape) != 2:
66
+ raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
67
+
68
+ indices = toNP(A.indices())
69
+ values = toNP(A.values())
70
+
71
+ mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
72
+
73
+ return mat
74
+
75
+
76
+ # Hash a list of numpy arrays
77
+ def hash_arrays(arrs):
78
+ running_hash = hashlib.sha1()
79
+ for arr in arrs:
80
+ binarr = arr.view(np.uint8)
81
+ running_hash.update(binarr)
82
+ return running_hash.hexdigest()
83
+
84
+
85
+ def random_rotation_matrix(randgen=None):
86
+ """
87
+ Creates a random rotation matrix.
88
+ randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
89
+ """
90
+ # adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
91
+
92
+ if randgen is None:
93
+ randgen = np.random.RandomState()
94
+
95
+ theta, phi, z = tuple(randgen.rand(3).tolist())
96
+
97
+ theta = theta * 2.0 * np.pi # Rotation about the pole (Z).
98
+ phi = phi * 2.0 * np.pi # For direction of pole deflection.
99
+ z = z * 2.0 # For magnitude of pole deflection.
100
+
101
+ # Compute a vector V used for distributing points over the sphere
102
+ # via the reflection I - V Transpose(V). This formulation of V
103
+ # will guarantee that if x[1] and x[2] are uniformly distributed,
104
+ # the reflected points will be uniform on the sphere. Note that V
105
+ # has length sqrt(2) to eliminate the 2 in the Householder matrix.
106
+
107
+ r = np.sqrt(z)
108
+ Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
109
+
110
+ st = np.sin(theta)
111
+ ct = np.cos(theta)
112
+
113
+ R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
114
+ # Construct the rotation matrix ( V Transpose(V) - I ) R.
115
+
116
+ M = (np.outer(V, V) - np.eye(3)).dot(R)
117
+ return M
118
+
119
+
120
+ # Python string/file utilities
121
+ def ensure_dir_exists(d):
122
+ if not os.path.exists(d):
123
+ os.makedirs(d)
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # from .utils_legacy import *
2
+ # from .geometry import *
3
+ # from .layers import *
utils/descriptors.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # https://github.com/RobinMagnet/SimplifiedFmapsLearning/blob/main/learn_zo/data/utils.py
4
+ # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/HKS_functions.py
5
+ def HKS(evals, evects, time_list,scaled=False):
6
+ """
7
+ Returns the Heat Kernel Signature for num_T different values.
8
+ The values of the time are interpolated in logscale between the limits
9
+ given in the HKS paper. These limits only depends on the eigenvalues.
10
+
11
+ Parameters
12
+ ------------------------
13
+ evals : (K,) array of the K eigenvalues
14
+ evecs : (N,K) array with the K eigenvectors
15
+ time_list : (num_T,) Time values to use
16
+ scaled : (bool) whether to scale for each time value
17
+
18
+ Output
19
+ ------------------------
20
+ HKS : (N,num_T) array where each line is the HKS for a given t
21
+ """
22
+ evals_s = np.asarray(evals).flatten()
23
+ t_list = np.asarray(time_list).flatten()
24
+
25
+ coefs = np.exp(-np.outer(t_list, evals_s)) # (num_T,K)
26
+ # weighted_evects = evects[None, :, :] * coefs[:, None,:] # (num_T,N,K)
27
+ # natural_HKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
28
+ natural_HKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
29
+
30
+ if scaled:
31
+ inv_scaling = coefs.sum(1) # (num_T)
32
+ return (1/inv_scaling)[None,:] * natural_HKS
33
+
34
+ return natural_HKS
35
+
36
+
37
+ def lm_HKS(evals, evects, landmarks, time_list, scaled=False):
38
+ """
39
+ Returns the Heat Kernel Signature for some landmarks and time values.
40
+
41
+
42
+ Parameters
43
+ ------------------------
44
+ evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
45
+ evals : (K,) array of the K corresponding eigenvalues
46
+ landmarks : (p,) indices of landmarks to compute
47
+ time_list : (num_T,) values of t to use
48
+
49
+ Output
50
+ ------------------------
51
+ landmarks_HKS : (N,num_E*p) array where each column is the HKS for a given t for some landmark
52
+ """
53
+
54
+ evals_s = np.asarray(evals).flatten()
55
+ t_list = np.asarray(time_list).flatten()
56
+
57
+ coefs = np.exp(-np.outer(t_list, evals_s)) # (num_T,K)
58
+ weighted_evects = evects[None, landmarks, :] * coefs[:,None,:] # (num_T,p,K)
59
+
60
+ landmarks_HKS = np.einsum('tpk,nk->ptn', weighted_evects, evects) # (p,num_T,N)
61
+ landmarks_HKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects) # (p,num_T,N)
62
+
63
+ if scaled:
64
+ inv_scaling = coefs.sum(1) # (num_T,)
65
+ landmarks_HKS = (1/inv_scaling)[None,:,None] * landmarks_HKS # (p,num_T,N)
66
+
67
+ return rearrange(landmarks_HKS, 'p T N -> N (p T)')
68
+
69
+
70
+ def auto_HKS(evals, evects, num_T, landmarks=None, scaled=True):
71
+ """
72
+ Compute HKS with an automatic choice of tile values
73
+
74
+ Parameters
75
+ ------------------------
76
+ evals : (K,) array of K eigenvalues
77
+ evects : (N,K) array with K eigenvectors
78
+ landmarks : (p,) if not None, indices of landmarks to compute.
79
+ num_T : (int) number values of t to use
80
+ Output
81
+ ------------------------
82
+ HKS or lm_HKS : (N,num_E) or (N,p*num_E) array where each column is the WKS for a given e
83
+ for some landmark
84
+ """
85
+
86
+ abs_ev = sorted(np.abs(evals))
87
+ t_list = np.geomspace(4*np.log(10)/abs_ev[-1], 4*np.log(10)/abs_ev[1], num_T)
88
+
89
+ if landmarks is None:
90
+ return HKS(abs_ev, evects, t_list, scaled=scaled)
91
+ else:
92
+ return lm_HKS(abs_ev, evects, landmarks, t_list, scaled=scaled)
93
+
94
+
95
+ # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/WKS_functions.py
96
+ def WKS(evals, evects, energy_list, sigma, scaled=False):
97
+ """
98
+ Returns the Wave Kernel Signature for some energy values.
99
+
100
+ Parameters
101
+ ------------------------
102
+ evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
103
+ evals : (K,) array of the K corresponding eigenvalues
104
+ energy_list : (num_E,) values of e to use
105
+ sigma : (float) [positive] standard deviation to use
106
+ scaled : (bool) Whether to scale each energy level
107
+
108
+ Output
109
+ ------------------------
110
+ WKS : (N,num_E) array where each column is the WKS for a given e
111
+ """
112
+ assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
113
+
114
+ evals = np.asarray(evals).flatten()
115
+ indices = np.where(evals > 1e-5)[0].flatten()
116
+ evals = evals[indices]
117
+ evects = evects[:, indices]
118
+
119
+ e_list = np.asarray(energy_list)
120
+ coefs = np.exp(-np.square(e_list[:,None] - np.log(np.abs(evals))[None,:])/(2*sigma**2)) # (num_E,K)
121
+
122
+ # weighted_evects = evects[None, :, :] * coefs[:,None, :] # (num_E,N,K)
123
+
124
+ # natural_WKS = np.einsum('tnk,nk->nt', weighted_evects, evects) # (N,num_E)
125
+ natural_WKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
126
+
127
+ if scaled:
128
+ inv_scaling = coefs.sum(1) # (num_E)
129
+ return (1/inv_scaling)[None,:] * natural_WKS
130
+
131
+ return natural_WKS
132
+
133
+
134
+ def lm_WKS(evals, evects, landmarks, energy_list, sigma, scaled=False):
135
+ """
136
+ Returns the Wave Kernel Signature for some landmarks and energy values.
137
+
138
+
139
+ Parameters
140
+ ------------------------
141
+ evects : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
142
+ evals : (K,) array of the K corresponding eigenvalues
143
+ landmarks : (p,) indices of landmarks to compute
144
+ energy_list : (num_E,) values of e to use
145
+ sigma : int - standard deviation
146
+
147
+ Output
148
+ ------------------------
149
+ landmarks_WKS : (N,num_E*p) array where each column is the WKS for a given e for some landmark
150
+ """
151
+ assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
152
+
153
+ evals = np.asarray(evals).flatten()
154
+ indices = np.where(evals > 1e-2)[0].flatten()
155
+ evals = evals[indices]
156
+ evects = evects[:,indices]
157
+
158
+ e_list = np.asarray(energy_list)
159
+ coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2*sigma**2)) # (num_E,K)
160
+ # weighted_evects = evects[None, landmarks, :] * coefs[:,None,:] # (num_E,p,K)
161
+
162
+ # landmarks_WKS = np.einsum('tpk,nk->ptn', weighted_evects, evects) # (p,num_E,N)
163
+ landmarks_WKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects) # (p,num_E,N)
164
+
165
+ if scaled:
166
+ inv_scaling = coefs.sum(1) # (num_E,)
167
+ landmarks_WKS = (1/inv_scaling)[None,:,None] * landmarks_WKS
168
+
169
+ # return landmarks_WKS.reshape(-1,evects.shape[0]).T # (N,p*num_E)
170
+ return rearrange(landmarks_WKS, 'p T N -> N (p T)')
171
+
172
+
173
+ def auto_WKS(evals, evects, num_E, landmarks=None, scaled=True):
174
+ """
175
+ Compute WKS with an automatic choice of scale and energy
176
+
177
+ Parameters
178
+ ------------------------
179
+ evals : (K,) array of K eigenvalues
180
+ evects : (N,K) array with K eigenvectors
181
+ landmarks : (p,) If not None, indices of landmarks to compute.
182
+ num_E : (int) number values of e to use
183
+ Output
184
+ ------------------------
185
+ WKS or lm_WKS : (N,num_E) or (N,p*num_E) array where each column is the WKS for a given e
186
+ and possibly for some landmarks
187
+ """
188
+ abs_ev = sorted(np.abs(evals))
189
+
190
+ e_min,e_max = np.log(abs_ev[1]),np.log(abs_ev[-1])
191
+ sigma = 7*(e_max-e_min)/num_E
192
+
193
+ e_min += 2*sigma
194
+ e_max -= 2*sigma
195
+
196
+ energy_list = np.linspace(e_min,e_max,num_E)
197
+
198
+ if landmarks is None:
199
+ return WKS(abs_ev, evects, energy_list, sigma, scaled=scaled)
200
+ else:
201
+ return lm_WKS(abs_ev, evects, landmarks, energy_list, sigma, scaled=scaled)
202
+
203
+
204
+ def compute_hks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
205
+ """
206
+ Compute Heat Kernel Signature (HKS) descriptors.
207
+
208
+ Args:
209
+ evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
210
+ evals: (K,) eigenvalues of the Laplace-Beltrami operator
211
+ mass: (N,) vertex masses
212
+ n_descr: (int) number of descriptors
213
+ subsample_step: (int) subsampling step
214
+ n_eig: (int) number of eigenvectors to use
215
+
216
+ Returns:
217
+ feats: (N, n_descr) HKS descriptors
218
+ """
219
+ feats = auto_HKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
220
+ feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
221
+ if normalize:
222
+ feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
223
+ feats /= np.sqrt(feats_norm2)[None, :]
224
+ return feats.astype(np.float32)
225
+
226
+
227
+ def compute_wks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
228
+ """
229
+ Compute Wave Kernel Signature (WKS) descriptors.
230
+
231
+ Args:
232
+ evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
233
+ evals: (K,) eigenvalues of the Laplace-Beltrami operator
234
+ mass: (N,) vertex masses
235
+ n_descr: (int) number of descriptors
236
+ subsample_step: (int) subsampling step
237
+ n_eig: (int) number of eigenvectors to use
238
+
239
+ Returns:
240
+ feats: (N, n_descr) WKS descriptors
241
+ """
242
+ feats = auto_WKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
243
+ feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
244
+ # print("wks_shape",feats.shape, mass.shape)
245
+ if normalize:
246
+ feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
247
+ feats /= np.sqrt(feats_norm2)[None, :]
248
+ # feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
249
+ # feats /= np.sqrt(feats_norm2)[None, :]
250
+ return feats.astype(np.float32)
utils/eval.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def accuracy(p2p, gt_p2p, D1_geod, return_all=False, sqrt_area=None):
2
+ """
3
+ Computes the geodesic accuracy of a vertex to vertex map. The map goes from
4
+ the target shape to the source shape.
5
+ Borrowed from Robin.
6
+ Parameters
7
+ ----------------------
8
+ p2p : (n2,) - vertex to vertex map giving the index of the matched vertex on the source shape
9
+ for each vertex on the target shape (from a functional map point of view)
10
+ gt_p2p : (n2,) - ground truth mapping between the pairs
11
+ D1_geod : (n1,n1) - geodesic distance between pairs of vertices on the source mesh
12
+ return_all : bool - whether to return all the distances or only the average geodesic distance
13
+
14
+ Output
15
+ -----------------------
16
+ acc : float - average accuracy of the vertex to vertex map
17
+ dists : (n2,) - if return_all is True, returns all the pairwise distances
18
+ """
19
+
20
+ dists = D1_geod[(p2p, gt_p2p)]
21
+ if sqrt_area is not None:
22
+ dists /= sqrt_area
23
+ if return_all:
24
+ return dists.mean(), dists
25
+ return dists.mean()
utils/fmap.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import sys
4
+ import numpy as np
5
+ import scipy.linalg
6
+ from tqdm import tqdm
7
+
8
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
9
+ if ROOT_DIR not in sys.path:
10
+ sys.path.append(ROOT_DIR)
11
+
12
+ from utils_fmaps.misc import KNNSearch
13
+
14
+ try:
15
+ import pynndescent
16
+ index = pynndescent.NNDescent(np.random.random((100, 3)), n_jobs=2)
17
+ del index
18
+ ANN = True
19
+ except ImportError:
20
+ ANN = False
21
+
22
+ # https://github.com/RobinMagnet/pyFM
23
+
24
+
25
+ def FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=False):
26
+ if use_ANN and not ANN:
27
+ raise ValueError('Please install pydescent to achieve Approximate Nearest Neighbor')
28
+
29
+ k2, k1 = FM.shape
30
+
31
+ assert k1 <= eigvects1.shape[1], \
32
+ f'At least {k1} should be provided, here only {eigvects1.shape[1]} are given'
33
+ assert k2 <= eigvects2.shape[1], \
34
+ f'At least {k2} should be provided, here only {eigvects2.shape[1]} are given'
35
+
36
+ if use_ANN:
37
+ index = pynndescent.NNDescent(eigvects1[:, :k1] @ FM.T, n_jobs=8)
38
+ matches, _ = index.query(eigvects2[:, :k2], k=1)
39
+ matches = matches.flatten()
40
+ else:
41
+ tree = KNNSearch(eigvects1[:, :k1] @ FM.T)
42
+ matches = tree.query(eigvects2[:, :k2], k=1).flatten()
43
+
44
+ return matches
45
+
46
+
47
+ def p2p_to_FM(p2p, eigvects1, eigvects2, A2=None):
48
+ if A2 is not None:
49
+ if A2.shape[0] != eigvects2.shape[0]:
50
+ raise ValueError("Can't compute pseudo inverse with subsampled eigenvectors")
51
+
52
+ if len(A2.shape) == 1:
53
+ return eigvects2.T @ (A2[:, None] * eigvects1[p2p, :])
54
+
55
+ return eigvects2.T @ A2 @ eigvects1[p2p, :]
56
+
57
+ return scipy.linalg.lstsq(eigvects2, eigvects1[p2p, :])[0]
58
+
59
+
60
+ def zoomout_iteration(eigvects1, eigvects2, FM, step=1, A2=None, use_ANN=False):
61
+ k2, k1 = FM.shape
62
+ try:
63
+ step1, step2 = step
64
+ except TypeError:
65
+ step1 = step
66
+ step2 = step
67
+ new_k1, new_k2 = k1 + step1, k2 + step2
68
+
69
+ p2p = FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=use_ANN)
70
+ FM_zo = p2p_to_FM(p2p, eigvects1[:, :new_k1], eigvects2[:, :new_k2], A2=A2)
71
+
72
+ return FM_zo
73
+
74
+
75
+ def zoomout_refine(eigvects1,
76
+ eigvects2,
77
+ FM,
78
+ nit=10,
79
+ step=1,
80
+ A2=None,
81
+ subsample=None,
82
+ use_ANN=False,
83
+ return_p2p=False,
84
+ verbose=False):
85
+ k2_0, k1_0 = FM.shape
86
+ try:
87
+ step1, step2 = step
88
+ except TypeError:
89
+ step1 = step
90
+ step2 = step
91
+
92
+ assert k1_0 + nit*step1 <= eigvects1.shape[1], \
93
+ f"Not enough eigenvectors on source : \
94
+ {k1_0 + nit*step1} are needed when {eigvects1.shape[1]} are provided"
95
+ assert k2_0 + nit*step2 <= eigvects2.shape[1], \
96
+ f"Not enough eigenvectors on target : \
97
+ {k2_0 + nit*step2} are needed when {eigvects2.shape[1]} are provided"
98
+
99
+ use_subsample = False
100
+ if subsample is not None:
101
+ use_subsample = True
102
+ sub1, sub2 = subsample
103
+
104
+ FM_zo = FM.copy()
105
+
106
+ ANN_adventage = False
107
+ iterable = range(nit) if not verbose else tqdm(range(nit))
108
+ for it in iterable:
109
+ ANN_adventage = use_ANN and (FM_zo.shape[0] > 90) and (FM_zo.shape[1] > 90)
110
+
111
+ if use_subsample:
112
+ FM_zo = zoomout_iteration(eigvects1[sub1], eigvects2[sub2], FM_zo, A2=None, step=step, use_ANN=ANN_adventage)
113
+
114
+ else:
115
+ FM_zo = zoomout_iteration(eigvects1, eigvects2, FM_zo, A2=A2, step=step, use_ANN=ANN_adventage)
116
+
117
+ if return_p2p:
118
+ p2p_zo = FM_to_p2p(FM_zo, eigvects1, eigvects2, use_ANN=False)
119
+ return FM_zo, p2p_zo
120
+
121
+ return FM_zo
utils/geometry.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import scipy.sparse.linalg as sla
3
+ # ^^^ we NEED to import scipy before torch, or it crashes :(
4
+ # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
5
+
6
+ import os.path
7
+ import sys
8
+ import random
9
+ from multiprocessing import Pool
10
+
11
+ import numpy as np
12
+ import scipy.spatial
13
+ import torch
14
+ import sklearn.neighbors
15
+
16
+ import robust_laplacian
17
+ import potpourri3d as pp3d
18
+
19
+ from utils.utils_legacy import toNP, ensure_dir_exists, sparse_np_to_torch, sparse_torch_to_np
20
+ from utils.descriptors import compute_hks, compute_wks
21
+
22
+
23
+ def norm(x, highdim=False):
24
+ """
25
+ Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
26
+ """
27
+ return torch.norm(x, dim=len(x.shape) - 1)
28
+
29
+
30
+ def norm2(x, highdim=False):
31
+ """
32
+ Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
33
+ """
34
+ return dot(x, x)
35
+
36
+
37
+ def normalize(x, divide_eps=1e-6, highdim=False):
38
+ """
39
+ Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
40
+ """
41
+ if (len(x.shape) == 1):
42
+ raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
43
+ if (not highdim and x.shape[-1] > 4):
44
+ raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
45
+ return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
46
+
47
+
48
+ def face_coords(verts, faces):
49
+ coords = verts[faces]
50
+ return coords
51
+
52
+
53
+ def cross(vec_A, vec_B):
54
+ return torch.cross(vec_A, vec_B, dim=-1)
55
+
56
+
57
+ def dot(vec_A, vec_B):
58
+ return torch.sum(vec_A * vec_B, dim=-1)
59
+
60
+
61
+ # Given (..., 3) vectors and normals, projects out any components of vecs
62
+ # which lies in the direction of normals. Normals are assumed to be unit.
63
+
64
+
65
+ def project_to_tangent(vecs, unit_normals):
66
+ dots = dot(vecs, unit_normals)
67
+ return vecs - unit_normals * dots.unsqueeze(-1)
68
+
69
+
70
+ def face_area(verts, faces):
71
+ coords = face_coords(verts, faces)
72
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
73
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
74
+
75
+ raw_normal = cross(vec_A, vec_B)
76
+ return 0.5 * norm(raw_normal)
77
+
78
+
79
+ def face_normals(verts, faces, normalized=True):
80
+ coords = face_coords(verts, faces)
81
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
82
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
83
+
84
+ raw_normal = cross(vec_A, vec_B)
85
+
86
+ if normalized:
87
+ return normalize(raw_normal)
88
+
89
+ return raw_normal
90
+
91
+
92
+ def neighborhood_normal(points):
93
+ # points: (N, K, 3) array of neighborhood psoitions
94
+ # points should be centered at origin
95
+ # out: (N,3) array of normals
96
+ # numpy in, numpy out
97
+ (u, s, vh) = np.linalg.svd(points, full_matrices=False)
98
+ normal = vh[:, 2, :]
99
+ return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
100
+
101
+
102
+ def mesh_vertex_normals(verts, faces):
103
+ # numpy in / out
104
+ face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces))) # ugly torch <---> numpy
105
+ eps = 1e-3
106
+ vertex_normals = np.zeros(verts.shape)
107
+ for i in range(3):
108
+ np.add.at(vertex_normals, faces[:, i], face_n)
109
+
110
+ vertex_normals = vertex_normals / (eps + np.linalg.norm(vertex_normals, axis=-1, keepdims=True))
111
+
112
+ return vertex_normals
113
+
114
+
115
+ def vertex_normals(verts, faces, n_neighbors_cloud=30):
116
+ verts_np = toNP(verts)
117
+
118
+ if faces.numel() == 0: # point cloud
119
+
120
+ _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
121
+ neigh_points = verts_np[neigh_inds, :]
122
+ neigh_points = neigh_points - verts_np[:, np.newaxis, :]
123
+ normals = neighborhood_normal(neigh_points)
124
+
125
+ else: # mesh
126
+
127
+ normals = mesh_vertex_normals(verts_np, toNP(faces))
128
+
129
+ # if any are NaN, wiggle slightly and recompute
130
+ bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
131
+ if bad_normals_mask.any():
132
+ bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
133
+ scale = np.linalg.norm(bbox) * 1e-4
134
+ wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
135
+ wiggle_verts = verts_np + bad_normals_mask * wiggle
136
+ normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
137
+
138
+ # if still NaN assign random normals (probably means unreferenced verts in mesh)
139
+ bad_normals_mask = np.isnan(normals).any(axis=1)
140
+ if bad_normals_mask.any():
141
+ normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
142
+ normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
143
+
144
+ normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
145
+
146
+ if torch.any(torch.isnan(normals)):
147
+ raise ValueError("NaN normals :(")
148
+
149
+ return normals
150
+
151
+
152
+ def build_tangent_frames(verts, faces, normals=None):
153
+
154
+ V = verts.shape[0]
155
+ dtype = verts.dtype
156
+ device = verts.device
157
+
158
+ if normals is None:
159
+ vert_normals = vertex_normals(verts, faces) # (V,3)
160
+ elif isinstance(normals, np.ndarray):
161
+ vert_normals = torch.from_numpy(normals).to(dtype=dtype, device=device)
162
+ else:
163
+ vert_normals = normals
164
+
165
+ # = find an orthogonal basis
166
+
167
+ basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
168
+ basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
169
+
170
+ basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
171
+ basisX = project_to_tangent(basisX, vert_normals)
172
+ basisX = normalize(basisX)
173
+ basisY = cross(vert_normals, basisX)
174
+ frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
175
+
176
+ if torch.any(torch.isnan(frames)):
177
+ raise ValueError("NaN coordinate frame! Must be very degenerate")
178
+
179
+ return frames
180
+
181
+
182
+ def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
183
+
184
+ verts_np = toNP(verts)
185
+ frames_np = toNP(frames)
186
+
187
+ _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
188
+ neigh_points = verts_np[neigh_inds, :]
189
+ neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
190
+
191
+ # TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
192
+
193
+ edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
194
+ edges = np.stack((edge_inds_from, neigh_inds.flatten()))
195
+ edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
196
+
197
+ return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
198
+
199
+
200
+ def edge_tangent_vectors(verts, frames, edges):
201
+ edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
202
+ basisX = frames[edges[0, :], 0, :]
203
+ basisY = frames[edges[0, :], 1, :]
204
+
205
+ compX = dot(edge_vecs, basisX)
206
+ compY = dot(edge_vecs, basisY)
207
+ edge_tangent = torch.stack((compX, compY), dim=-1)
208
+
209
+ return edge_tangent
210
+
211
+
212
+ def build_grad(verts, edges, edge_tangent_vectors):
213
+ """
214
+ Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
215
+ - edges: (2, E)
216
+ """
217
+
218
+ edges_np = toNP(edges)
219
+ edge_tangent_vectors_np = toNP(edge_tangent_vectors)
220
+
221
+ # TODO find a way to do this in pure numpy?
222
+
223
+ # Build outgoing neighbor lists
224
+ N = verts.shape[0]
225
+ vert_edge_outgoing = [[] for i in range(N)]
226
+ for iE in range(edges_np.shape[1]):
227
+ tail_ind = edges_np[0, iE]
228
+ tip_ind = edges_np[1, iE]
229
+ if tip_ind != tail_ind:
230
+ vert_edge_outgoing[tail_ind].append(iE)
231
+
232
+ # Build local inversion matrix for each vertex
233
+ row_inds = []
234
+ col_inds = []
235
+ data_vals = []
236
+ eps_reg = 1e-5
237
+ for iV in range(N):
238
+ n_neigh = len(vert_edge_outgoing[iV])
239
+
240
+ lhs_mat = np.zeros((n_neigh, 2))
241
+ rhs_mat = np.zeros((n_neigh, n_neigh + 1))
242
+ ind_lookup = [iV]
243
+ for i_neigh in range(n_neigh):
244
+ iE = vert_edge_outgoing[iV][i_neigh]
245
+ jV = edges_np[1, iE]
246
+ ind_lookup.append(jV)
247
+
248
+ edge_vec = edge_tangent_vectors[iE][:]
249
+ w_e = 1.
250
+
251
+ lhs_mat[i_neigh][:] = w_e * edge_vec
252
+ rhs_mat[i_neigh][0] = w_e * (-1)
253
+ rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
254
+
255
+ lhs_T = lhs_mat.T
256
+ lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
257
+
258
+ sol_mat = lhs_inv @ rhs_mat
259
+ sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
260
+
261
+ for i_neigh in range(n_neigh + 1):
262
+ i_glob = ind_lookup[i_neigh]
263
+
264
+ row_inds.append(iV)
265
+ col_inds.append(i_glob)
266
+ data_vals.append(sol_coefs[i_neigh])
267
+
268
+ # build the sparse matrix
269
+ row_inds = np.array(row_inds)
270
+ col_inds = np.array(col_inds)
271
+ data_vals = np.array(data_vals)
272
+ mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
273
+
274
+ return mat
275
+
276
+
277
+ def compute_operators(verts, faces, k_eig, normals=None, normalize_desc=True):
278
+ """
279
+ Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
280
+
281
+ See get_operators() for a similar routine that wraps this one with a layer of caching.
282
+
283
+ Torch in / torch out.
284
+
285
+ Arguments:
286
+ - vertices: (V,3) vertex positions
287
+ - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
288
+ - k_eig: number of eigenvectors to use
289
+
290
+ Returns:
291
+ - frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
292
+ - massvec: (V) real diagonal of lumped mass matrix
293
+ - L: (VxV) real sparse matrix of (weak) Laplacian
294
+ - evals: (k) list of eigenvalues of the Laplacian
295
+ - evecs: (V,k) list of eigenvectors of the Laplacian
296
+ - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
297
+ - gradY: same as gradX but for Y-component of gradient
298
+
299
+ PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
300
+
301
+ Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
302
+ """
303
+ device = verts.device
304
+ dtype = verts.dtype
305
+ V = verts.shape[0]
306
+ is_cloud = faces.numel() == 0
307
+
308
+ eps = 1e-8
309
+
310
+ verts_np = toNP(verts).astype(np.float64)
311
+ faces_np = toNP(faces)
312
+ frames = build_tangent_frames(verts, faces, normals=normals)
313
+ frames_np = toNP(frames)
314
+
315
+ # Build the scalar Laplacian
316
+ if is_cloud:
317
+ L, M = robust_laplacian.point_cloud_laplacian(verts_np)
318
+ massvec_np = M.diagonal()
319
+ else:
320
+ # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
321
+ # massvec_np = M.diagonal()
322
+ L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
323
+ massvec_np = pp3d.vertex_areas(verts_np, faces_np)
324
+ massvec_np += eps * np.mean(massvec_np)
325
+
326
+ if (np.isnan(L.data).any()):
327
+ raise RuntimeError("NaN Laplace matrix")
328
+ if (np.isnan(massvec_np).any()):
329
+ raise RuntimeError("NaN mass matrix")
330
+
331
+ # Read off neighbors & rotations from the Laplacian
332
+ L_coo = L.tocoo()
333
+ inds_row = L_coo.row
334
+ inds_col = L_coo.col
335
+
336
+ # === Compute the eigenbasis
337
+ if k_eig > 0:
338
+
339
+ # Prepare matrices
340
+ L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
341
+ massvec_eigsh = massvec_np
342
+ Mmat = scipy.sparse.diags(massvec_eigsh)
343
+ eigs_sigma = eps
344
+
345
+ failcount = 0
346
+ while True:
347
+ try:
348
+ # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
349
+ evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
350
+
351
+ # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
352
+ evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
353
+
354
+ break
355
+ except Exception as e:
356
+ print(e)
357
+ if (failcount > 3):
358
+ raise ValueError("failed to compute eigendecomp")
359
+ failcount += 1
360
+ print("--- decomp failed; adding eps ===> count: " + str(failcount))
361
+ L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
362
+
363
+ else: #k_eig == 0
364
+ evals_np = np.zeros((0))
365
+ evecs_np = np.zeros((verts.shape[0], 0))
366
+
367
+ # == Build gradient matrices
368
+
369
+ # For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
370
+ if is_cloud:
371
+ grad_mat_np = build_grad_point_cloud(verts, frames)
372
+ else:
373
+ edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
374
+ edge_vecs = edge_tangent_vectors(verts, frames, edges)
375
+ grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
376
+
377
+ # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
378
+ gradX_np = np.real(grad_mat_np)
379
+ gradY_np = np.imag(grad_mat_np)
380
+
381
+ # === Convert back to torch
382
+ massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
383
+ L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
384
+ evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
385
+ evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
386
+ gradX = sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
387
+ gradY = sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
388
+
389
+ hks = torch.from_numpy(compute_hks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
390
+ wks = torch.from_numpy(compute_wks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
391
+
392
+
393
+ return frames, massvec, L, evals, evecs, gradX, gradY, hks, wks
394
+
395
+
396
+ def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
397
+ N = len(verts_list)
398
+
399
+ frames = [None] * N
400
+ massvec = [None] * N
401
+ L = [None] * N
402
+ evals = [None] * N
403
+ evecs = [None] * N
404
+ gradX = [None] * N
405
+ gradY = [None] * N
406
+
407
+ inds = [i for i in range(N)]
408
+ # process in random order
409
+ # random.shuffle(inds)
410
+
411
+ for num, i in enumerate(inds):
412
+ print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
413
+ if normals is None:
414
+ outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
415
+ else:
416
+ outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
417
+ frames[i] = outputs[0]
418
+ massvec[i] = outputs[1]
419
+ L[i] = outputs[2]
420
+ evals[i] = outputs[3]
421
+ evecs[i] = outputs[4]
422
+ gradX[i] = outputs[5]
423
+ gradY[i] = outputs[6]
424
+
425
+ return frames, massvec, L, evals, evecs, gradX, gradY
426
+
427
+
428
+ def get_operators(verts, faces, k_eig=128, cache_path=None, normals=None, overwrite_cache=False):
429
+ """
430
+ See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
431
+ All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
432
+ """
433
+ if type(verts) == torch.Tensor:
434
+ device = verts.device
435
+ dtype = verts.dtype
436
+ verts_np = toNP(verts)
437
+ else:
438
+ device = "cpu"
439
+ dtype = torch.float32
440
+ verts_np = verts.copy()
441
+ verts = torch.from_numpy(verts).float()
442
+ if type(faces) == torch.Tensor:
443
+ faces_np = toNP(faces)
444
+ else:
445
+ faces_np = faces.copy()
446
+ faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
447
+ is_cloud = faces.numel() == 0
448
+
449
+ if (np.isnan(verts_np).any()):
450
+ raise RuntimeError("tried to construct operators from NaN verts")
451
+
452
+ # Check the cache directory
453
+ # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
454
+ # but for good measure we check values nonetheless.
455
+ # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
456
+ # entries in this cache. The good news is that that is totally fine, and at most slightly
457
+ # slows performance with rare extra cache misses.
458
+ found = False
459
+
460
+ if cache_path is not None:
461
+ op_cache_dir = os.path.dirname(cache_path)
462
+ ensure_dir_exists(op_cache_dir)
463
+ # print("Building operators for input with hash: " + hash_key_str)
464
+
465
+ # Search through buckets with matching hashes. When the loop exits, this
466
+ # is the bucket index of the file we should write to.
467
+ i_cache_search = 0
468
+ while True:
469
+ try:
470
+ # print('loading path: ' + str(search_path))
471
+ npzfile = np.load(cache_path, allow_pickle=True)
472
+ cache_verts = npzfile["verts"]
473
+ cache_faces = npzfile["faces"]
474
+ cache_k_eig = npzfile["k_eig"].item()
475
+
476
+ # If the cache doesn't match, keep looking
477
+ if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
478
+ i_cache_search += 1
479
+ print("hash collision! searching next.")
480
+ continue
481
+
482
+ # print(" cache hit!")
483
+
484
+ # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
485
+ # entry below more eigenvalues
486
+ if overwrite_cache:
487
+ print(" overwriting cache by request")
488
+ os.remove(cache_path)
489
+ break
490
+
491
+ if cache_k_eig < k_eig:
492
+ print(" overwriting cache --- not enough eigenvalues")
493
+ os.remove(cache_path)
494
+ break
495
+
496
+ if "L_data" not in npzfile:
497
+ print(" overwriting cache --- entries are absent")
498
+ os.remove(cache_path)
499
+ break
500
+
501
+ def read_sp_mat(prefix):
502
+ data = npzfile[prefix + "_data"]
503
+ indices = npzfile[prefix + "_indices"]
504
+ indptr = npzfile[prefix + "_indptr"]
505
+ shape = npzfile[prefix + "_shape"]
506
+ mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
507
+ return mat
508
+
509
+ # This entry matches! Return it.
510
+ frames = npzfile["frames"]
511
+ mass = npzfile["mass"]
512
+ L = read_sp_mat("L")
513
+ evals = npzfile["evals"][:k_eig]
514
+ evecs = npzfile["evecs"][:, :k_eig]
515
+ gradX = read_sp_mat("gradX")
516
+ gradY = read_sp_mat("gradY")
517
+ frames = npzfile["hks"]
518
+ mass = npzfile["wks"]
519
+
520
+ frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
521
+ mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
522
+ L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
523
+ evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
524
+ evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
525
+ gradX = sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
526
+ gradY = sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
527
+ hks = torch.from_numpy(hks).to(device=device, dtype=dtype)
528
+ wks = torch.from_numpy(wks).to(device=device, dtype=dtype)
529
+
530
+ found = True
531
+
532
+ break
533
+
534
+ except FileNotFoundError:
535
+ print(" cache miss -- constructing operators")
536
+ break
537
+
538
+ except Exception as E:
539
+ print("unexpected error loading file: " + str(E))
540
+ print("-- constructing operators")
541
+ break
542
+
543
+ if not found:
544
+
545
+ # No matching entry found; recompute.
546
+ frames, mass, L, evals, evecs, gradX, gradY, hks, wks = compute_operators(verts, faces, k_eig, normals=normals)
547
+
548
+ dtype_np = np.float32
549
+
550
+ # Store it in the cache
551
+ if op_cache_dir is not None:
552
+
553
+ L_np = sparse_torch_to_np(L).astype(dtype_np)
554
+ gradX_np = sparse_torch_to_np(gradX).astype(dtype_np)
555
+ gradY_np = sparse_torch_to_np(gradY).astype(dtype_np)
556
+
557
+ np.savez(
558
+ cache_path,
559
+ verts=verts_np.astype(dtype_np),
560
+ frames=toNP(frames).astype(dtype_np),
561
+ faces=faces_np,
562
+ k_eig=k_eig,
563
+ mass=toNP(mass).astype(dtype_np),
564
+ L_data=L_np.data.astype(dtype_np),
565
+ L_indices=L_np.indices,
566
+ L_indptr=L_np.indptr,
567
+ L_shape=L_np.shape,
568
+ evals=toNP(evals).astype(dtype_np),
569
+ evecs=toNP(evecs).astype(dtype_np),
570
+ gradX_data=gradX_np.data.astype(dtype_np),
571
+ gradX_indices=gradX_np.indices,
572
+ gradX_indptr=gradX_np.indptr,
573
+ gradX_shape=gradX_np.shape,
574
+ gradY_data=gradY_np.data.astype(dtype_np),
575
+ gradY_indices=gradY_np.indices,
576
+ gradY_indptr=gradY_np.indptr,
577
+ gradY_shape=gradY_np.shape,
578
+ hks=hks,
579
+ wks=wks
580
+ )
581
+
582
+ return frames, mass, L, evals, evecs, gradX, gradY, hks, wks
583
+
584
+
585
+ def load_operators(filepath):
586
+ npzfile = np.load(filepath, allow_pickle=False)
587
+
588
+ def read_sp_mat(prefix):
589
+ data = npzfile[prefix + '_data']
590
+ indices = npzfile[prefix + '_indices']
591
+ indptr = npzfile[prefix + '_indptr']
592
+ shape = npzfile[prefix + '_shape']
593
+ mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
594
+ return mat
595
+ if 'verts' in npzfile:
596
+ keyverts = 'verts'
597
+ else:
598
+ keyverts = 'vertices'
599
+ return dict(
600
+ vertices=npzfile[keyverts],
601
+ faces=npzfile['faces'],
602
+ frames=npzfile['frames'],
603
+ mass=npzfile['mass'],
604
+ L=read_sp_mat('L'),
605
+ evals=npzfile['evals'],
606
+ evecs=npzfile['evecs'],
607
+ gradX=read_sp_mat('gradX'),
608
+ gradY=read_sp_mat('gradY'),
609
+ hks=npzfile['hks'],
610
+ wks=npzfile['wks'],
611
+ )
612
+
613
+
614
+ def compute_operators_small(verts, faces, k_eig):
615
+ """
616
+ Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
617
+
618
+ See get_operators() for a similar routine that wraps this one with a layer of caching.
619
+
620
+ Torch in / torch out.
621
+
622
+ Arguments:
623
+ - vertices: (V,3) vertex positions
624
+ - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
625
+ - k_eig: number of eigenvectors to use
626
+
627
+ Returns:
628
+ - massvec: (V) real diagonal of lumped mass matrix
629
+ - L: (VxV) real sparse matrix of (weak) Laplacian
630
+ - evals: (k) list of eigenvalues of the Laplacian
631
+ - evecs: (V,k) list of eigenvectors of the Laplacian
632
+ - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
633
+ - gradY: same as gradX but for Y-component of gradient
634
+
635
+ PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
636
+
637
+ Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
638
+ """
639
+
640
+ device = verts.device
641
+ dtype = verts.dtype
642
+ V = verts.shape[0]
643
+ is_cloud = faces.numel() == 0
644
+
645
+ eps = 1e-8
646
+
647
+ verts_np = toNP(verts).astype(np.float64)
648
+ faces_np = toNP(faces)
649
+ # Build the scalar Laplacian
650
+ if is_cloud:
651
+ L, M = robust_laplacian.point_cloud_laplacian(verts_np)
652
+ massvec_np = M.diagonal()
653
+ else:
654
+ # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
655
+ # massvec_np = M.diagonal()
656
+ L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
657
+ massvec_np = pp3d.vertex_areas(verts_np, faces_np)
658
+ massvec_np += eps * np.mean(massvec_np)
659
+
660
+ if (np.isnan(L.data).any()):
661
+ raise RuntimeError("NaN Laplace matrix")
662
+ if (np.isnan(massvec_np).any()):
663
+ raise RuntimeError("NaN mass matrix")
664
+
665
+ # Read off neighbors & rotations from the Laplacian
666
+ L_coo = L.tocoo()
667
+ inds_row = L_coo.row
668
+ inds_col = L_coo.col
669
+
670
+ # === Compute the eigenbasis
671
+ if k_eig > 0:
672
+
673
+ # Prepare matrices
674
+ L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
675
+ massvec_eigsh = massvec_np
676
+ Mmat = scipy.sparse.diags(massvec_eigsh)
677
+ eigs_sigma = eps
678
+
679
+ failcount = 0
680
+ while True:
681
+ try:
682
+ # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
683
+ evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
684
+
685
+ # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
686
+ evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
687
+
688
+ break
689
+ except Exception as e:
690
+ print(e)
691
+ if (failcount > 3):
692
+ raise ValueError("failed to compute eigendecomp")
693
+ failcount += 1
694
+ print("--- decomp failed; adding eps ===> count: " + str(failcount))
695
+ L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
696
+
697
+ else: #k_eig == 0
698
+ evals_np = np.zeros((0))
699
+ evecs_np = np.zeros((verts.shape[0], 0))
700
+ # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
701
+
702
+ # === Convert back to torch
703
+ massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
704
+ L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
705
+ evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
706
+ evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
707
+
708
+ return massvec, L, evals, evecs
709
+
710
+
711
+ def get_operators_small(verts, faces, k_eig=128, cache_path=None, overwrite_cache=False):
712
+ """
713
+ See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
714
+ All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
715
+ """
716
+ if type(verts) == torch.Tensor:
717
+ device = verts.device
718
+ dtype = verts.dtype
719
+ verts_np = toNP(verts)
720
+ else:
721
+ device = "cpu"
722
+ dtype = torch.float32
723
+ verts_np = verts.copy()
724
+ verts = torch.from_numpy(verts).float()
725
+ if type(faces) == torch.Tensor:
726
+ faces_np = toNP(faces)
727
+ else:
728
+ faces_np = faces.copy()
729
+ faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
730
+ is_cloud = faces.numel() == 0
731
+
732
+ if (np.isnan(verts_np).any()):
733
+ raise RuntimeError("tried to construct operators from NaN verts")
734
+
735
+ # Check the cache directory
736
+ # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
737
+ # but for good measure we check values nonetheless.
738
+ # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
739
+ # entries in this cache. The good news is that that is totally fine, and at most slightly
740
+ # slows performance with rare extra cache misses.
741
+ found = False
742
+
743
+ if cache_path is not None:
744
+ op_cache_dir = os.path.dirname(cache_path)
745
+ ensure_dir_exists(op_cache_dir)
746
+ # print("Building operators for input with hash: " + hash_key_str)
747
+
748
+ # Search through buckets with matching hashes. When the loop exits, this
749
+ # is the bucket index of the file we should write to.
750
+ i_cache_search = 0
751
+ while True:
752
+ try:
753
+ # print('loading path: ' + str(search_path))
754
+ npzfile = np.load(cache_path, allow_pickle=True)
755
+ cache_verts = npzfile["verts"]
756
+ cache_faces = npzfile["faces"]
757
+ cache_k_eig = npzfile["k_eig"].item()
758
+
759
+ # If the cache doesn't match, keep looking
760
+ if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
761
+ i_cache_search += 1
762
+ print("hash collision! searching next.")
763
+ continue
764
+
765
+ # print(" cache hit!")
766
+
767
+ # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
768
+ # entry below more eigenvalues
769
+ if overwrite_cache:
770
+ print(" overwriting cache by request")
771
+ os.remove(cache_path)
772
+ break
773
+
774
+ if cache_k_eig < k_eig:
775
+ print(" overwriting cache --- not enough eigenvalues")
776
+ os.remove(cache_path)
777
+ break
778
+
779
+ if "L_data" not in npzfile:
780
+ print(" overwriting cache --- entries are absent")
781
+ os.remove(cache_path)
782
+ break
783
+
784
+ def read_sp_mat(prefix):
785
+ data = npzfile[prefix + "_data"]
786
+ indices = npzfile[prefix + "_indices"]
787
+ indptr = npzfile[prefix + "_indptr"]
788
+ shape = npzfile[prefix + "_shape"]
789
+ mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
790
+ return mat
791
+
792
+ # This entry matches! Return it.
793
+ mass = npzfile["mass"]
794
+ L = read_sp_mat("L")
795
+ evals = npzfile["evals"][:k_eig]
796
+ evecs = npzfile["evecs"][:, :k_eig]
797
+ mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
798
+ L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
799
+ evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
800
+ evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
801
+
802
+ found = True
803
+
804
+ break
805
+
806
+ except FileNotFoundError:
807
+ print(cache_path)
808
+ print(" cache miss -- constructing operators")
809
+ break
810
+
811
+ except Exception as E:
812
+ print("unexpected error loading file: " + str(E))
813
+ print("-- constructing operators")
814
+ break
815
+
816
+ if not found:
817
+
818
+ # No matching entry found; recompute.
819
+ mass, L, evals, evecs, = compute_operators_small(verts, faces, k_eig)
820
+
821
+ dtype_np = np.float32
822
+
823
+ # Store it in the cache
824
+ if op_cache_dir is not None:
825
+
826
+ L_np = sparse_torch_to_np(L).astype(dtype_np)
827
+
828
+ np.savez(
829
+ cache_path,
830
+ verts=verts_np.astype(dtype_np),
831
+ faces=faces_np,
832
+ k_eig=k_eig,
833
+ mass=toNP(mass).astype(dtype_np),
834
+ L_data=L_np.data.astype(dtype_np),
835
+ L_indices=L_np.indices,
836
+ L_indptr=L_np.indptr,
837
+ L_shape=L_np.shape,
838
+ evals=toNP(evals).astype(dtype_np),
839
+ evecs=toNP(evecs).astype(dtype_np),
840
+ )
841
+
842
+ return mass, L, evals, evecs
843
+
844
+
845
+ def to_basis(values, basis, massvec):
846
+ """
847
+ Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
848
+ Inputs:
849
+ - values: (B,V,D)
850
+ - basis: (B,V,K)
851
+ - massvec: (B,V)
852
+ Outputs:
853
+ - (B,K,D) transformed values
854
+ """
855
+ basisT = basis.transpose(-2, -1)
856
+ return torch.matmul(basisT, values * massvec.unsqueeze(-1))
857
+
858
+
859
+
860
+ def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
861
+ # center and unit-scale positions
862
+
863
+ if method == 'mean':
864
+ # center using the average point position
865
+ pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
866
+ elif method == 'bbox':
867
+ # center via the middle of the axis-aligned bounding box
868
+ bbox_min = torch.min(pos, dim=-2).values
869
+ bbox_max = torch.max(pos, dim=-2).values
870
+ center = (bbox_max + bbox_min) / 2.
871
+ pos -= center.unsqueeze(-2)
872
+ else:
873
+ raise ValueError("unrecognized method")
874
+
875
+ if scale_method == 'max_rad':
876
+ scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
877
+ pos = pos / scale
878
+ elif scale_method == 'area':
879
+ if faces is None:
880
+ raise ValueError("must pass faces for area normalization")
881
+ coords = pos[faces]
882
+ vec_A = coords[:, 1, :] - coords[:, 0, :]
883
+ vec_B = coords[:, 2, :] - coords[:, 0, :]
884
+ face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
885
+ total_area = torch.sum(face_areas)
886
+ scale = (1. / torch.sqrt(total_area))
887
+ pos = pos * scale
888
+ else:
889
+ raise ValueError("unrecognized scale method")
890
+ return pos
891
+
892
+
893
+ # Finds the k nearest neighbors of source on target.
894
+ # Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
895
+ def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
896
+
897
+ if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
898
+ raise ValueError("omit_diagonal can only be used when source and target are same shape")
899
+
900
+ if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
901
+ method = 'cpu_kd'
902
+ print("switching to cpu_kd knn")
903
+
904
+ if method == 'brute':
905
+
906
+ # Expand so both are NxMx3 tensor
907
+ points_source_expand = points_source.unsqueeze(1)
908
+ points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
909
+ points_target_expand = points_target.unsqueeze(0)
910
+ points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
911
+
912
+ diff_mat = points_source_expand - points_target_expand
913
+ dist_mat = norm(diff_mat)
914
+
915
+ if omit_diagonal:
916
+ torch.diagonal(dist_mat)[:] = float('inf')
917
+
918
+ result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
919
+ return result
920
+
921
+ elif method == 'cpu_kd':
922
+
923
+ if largest:
924
+ raise ValueError("can't do largest with cpu_kd")
925
+
926
+ points_source_np = toNP(points_source)
927
+ points_target_np = toNP(points_target)
928
+
929
+ # Build the tree
930
+ kd_tree = sklearn.neighbors.KDTree(points_target_np)
931
+
932
+ k_search = k + 1 if omit_diagonal else k
933
+ _, neighbors = kd_tree.query(points_source_np, k=k_search)
934
+
935
+ if omit_diagonal:
936
+ # Mask out self element
937
+ mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
938
+
939
+ # make sure we mask out exactly one element in each row, in rare case of many duplicate points
940
+ mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
941
+
942
+ neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
943
+
944
+ inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
945
+ dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
946
+
947
+ return dists, inds
948
+
949
+ else:
950
+ raise ValueError("unrecognized method")
951
+
utils/io.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import shutil
7
+
8
+
9
+ def is_number(s):
10
+ try:
11
+ float(s)
12
+ return True
13
+ except ValueError:
14
+ return False
15
+
16
+
17
+ def may_create_folder(folder_path):
18
+ if not osp.exists(folder_path):
19
+ oldmask = os.umask(000)
20
+ os.makedirs(folder_path, mode=0o777)
21
+ os.umask(oldmask)
22
+ return True
23
+ return False
24
+
25
+
26
+ def make_clean_folder(folder_path):
27
+ success = may_create_folder(folder_path)
28
+ if not success:
29
+ shutil.rmtree(folder_path)
30
+ may_create_folder(folder_path)
31
+
32
+
33
+ def sorted_alphanum(file_list_ordered):
34
+ convert = lambda text: int(text) if text.isdigit() else text
35
+ alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) if len(c) > 0]
36
+ return sorted(file_list_ordered, key=alphanum_key)
37
+
38
+
39
+ def list_files(folder_path, name_filter, alphanum_sort=False):
40
+ file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
41
+ if alphanum_sort:
42
+ return sorted_alphanum(file_list)
43
+ else:
44
+ return sorted(file_list)
45
+
46
+
47
+ def list_folders(folder_path, name_filter=None, alphanum_sort=False):
48
+ folders = list()
49
+ for subfolder in Path(folder_path).iterdir():
50
+ if subfolder.is_dir() and not subfolder.name.startswith('.'):
51
+ folder_name = subfolder.name
52
+ if name_filter is not None:
53
+ if name_filter in folder_name:
54
+ folders.append(folder_name)
55
+ else:
56
+ folders.append(folder_name)
57
+ if alphanum_sort:
58
+ return sorted_alphanum(folders)
59
+ else:
60
+ return sorted(folders)
61
+
62
+
63
+ def read_lines(file_path):
64
+ with open(file_path, 'r') as fin:
65
+ lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
66
+ return lines
67
+
68
+
69
+ def read_strings(file_path):
70
+ with open(file_path, 'r') as fin:
71
+ ret = fin.readlines()
72
+ return ''.join(ret)
73
+
74
+
75
+ def read_json(filepath):
76
+ with open(filepath, 'r') as fh:
77
+ ret = json.load(fh)
78
+ return ret
utils/layers.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import os.path
4
+ import random
5
+
6
+ import scipy
7
+ import scipy.sparse.linalg as sla
8
+ # ^^^ we NEED to import scipy before torch, or it crashes :(
9
+ # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../')
16
+ if ROOT_DIR not in sys.path:
17
+ sys.path.append(ROOT_DIR)
18
+
19
+ from diffusion_net.utils import toNP
20
+ from diffusion_net.geometry import to_basis, from_basis
21
+
22
+
23
+ class LearnedTimeDiffusion(nn.Module):
24
+ """
25
+ Applies diffusion with learned per-channel t.
26
+
27
+ In the spectral domain this becomes
28
+ f_out = e ^ (lambda_i t) f_in
29
+
30
+ Inputs:
31
+ - values: (V,C) in the spectral domain
32
+ - L: (V,V) sparse laplacian
33
+ - evals: (K) eigenvalues
34
+ - mass: (V) mass matrix diagonal
35
+
36
+ (note: L/evals may be omitted as None depending on method)
37
+ Outputs:
38
+ - (V,C) diffused values
39
+ """
40
+
41
+ def __init__(self, C_inout, method='spectral'):
42
+ super(LearnedTimeDiffusion, self).__init__()
43
+ self.C_inout = C_inout
44
+ self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C)
45
+ self.method = method # one of ['spectral', 'implicit_dense']
46
+
47
+ nn.init.constant_(self.diffusion_time, 0.0)
48
+
49
+ def forward(self, x, L, mass, evals, evecs):
50
+
51
+ # project times to the positive halfspace
52
+ # (and away from 0 in the incredibly rare chance that they get stuck)
53
+ with torch.no_grad():
54
+ self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
55
+
56
+ if x.shape[-1] != self.C_inout:
57
+ raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
58
+ x.shape, self.C_inout))
59
+
60
+ if self.method == 'spectral':
61
+
62
+ # Transform to spectral
63
+ x_spec = to_basis(x, evecs, mass)
64
+
65
+ # Diffuse
66
+ time = self.diffusion_time
67
+ diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
68
+ x_diffuse_spec = diffusion_coefs * x_spec
69
+
70
+ # Transform back to per-vertex
71
+ x_diffuse = from_basis(x_diffuse_spec, evecs)
72
+
73
+ elif self.method == 'implicit_dense':
74
+ V = x.shape[-2]
75
+
76
+ # Form the dense matrices (M + tL) with dims (B,C,V,V)
77
+ mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
78
+ mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
79
+ mat_dense += torch.diag_embed(mass).unsqueeze(1)
80
+
81
+ # Factor the system
82
+ cholesky_factors = torch.linalg.cholesky(mat_dense)
83
+
84
+ # Solve the system
85
+ rhs = x * mass.unsqueeze(-1)
86
+ rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
87
+ sols = torch.cholesky_solve(rhsT, cholesky_factors)
88
+ x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
89
+
90
+ else:
91
+ raise ValueError("unrecognized method")
92
+
93
+ return x_diffuse
94
+
95
+
96
+ class SpatialGradientFeatures(nn.Module):
97
+ """
98
+ Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
99
+
100
+ Input:
101
+ - vectors: (V,C,2)
102
+ Output:
103
+ - dots: (V,C) dots
104
+ """
105
+
106
+ def __init__(self, C_inout, with_gradient_rotations=True):
107
+ super(SpatialGradientFeatures, self).__init__()
108
+
109
+ self.C_inout = C_inout
110
+ self.with_gradient_rotations = with_gradient_rotations
111
+
112
+ if (self.with_gradient_rotations):
113
+ self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
114
+ self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
115
+ else:
116
+ self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
117
+
118
+ # self.norm = nn.InstanceNorm1d(C_inout)
119
+
120
+ def forward(self, vectors):
121
+
122
+ vectorsA = vectors # (V,C)
123
+
124
+ if self.with_gradient_rotations:
125
+ vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
126
+ vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
127
+ else:
128
+ vectorsBreal = self.A(vectors[..., 0])
129
+ vectorsBimag = self.A(vectors[..., 1])
130
+
131
+ dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
132
+
133
+ return torch.tanh(dots)
134
+
135
+
136
+ class MiniMLP(nn.Sequential):
137
+ '''
138
+ A simple MLP with configurable hidden layer sizes.
139
+ '''
140
+
141
+ def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
142
+ super(MiniMLP, self).__init__()
143
+
144
+ for i in range(len(layer_sizes) - 1):
145
+ is_last = (i + 2 == len(layer_sizes))
146
+
147
+ if dropout and i > 0:
148
+ self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
149
+
150
+ # Affine map
151
+ self.add_module(
152
+ name + "_mlp_layer_{:03d}".format(i),
153
+ nn.Linear(
154
+ layer_sizes[i],
155
+ layer_sizes[i + 1],
156
+ ),
157
+ )
158
+
159
+ # Nonlinearity
160
+ # (but not on the last layer)
161
+ if not is_last:
162
+ self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
163
+
164
+
165
+ class DiffusionNetBlock(nn.Module):
166
+ """
167
+ Inputs and outputs are defined at vertices
168
+ """
169
+
170
+ def __init__(self,
171
+ C_width,
172
+ mlp_hidden_dims,
173
+ dropout=True,
174
+ diffusion_method='spectral',
175
+ with_gradient_features=True,
176
+ with_gradient_rotations=True):
177
+ super(DiffusionNetBlock, self).__init__()
178
+
179
+ # Specified dimensions
180
+ self.C_width = C_width
181
+ self.mlp_hidden_dims = mlp_hidden_dims
182
+
183
+ self.dropout = dropout
184
+ self.with_gradient_features = with_gradient_features
185
+ self.with_gradient_rotations = with_gradient_rotations
186
+
187
+ # Diffusion block
188
+ self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
189
+
190
+ self.MLP_C = 2 * self.C_width
191
+
192
+ if self.with_gradient_features:
193
+ self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
194
+ self.MLP_C += self.C_width
195
+
196
+ # MLPs
197
+ self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
198
+
199
+ def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
200
+
201
+ # Manage dimensions
202
+ B = x_in.shape[0] # batch dimension
203
+ if x_in.shape[-1] != self.C_width:
204
+ raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
205
+ x_in.shape, self.C_width))
206
+
207
+ # Diffusion block
208
+ x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
209
+
210
+ # Compute gradient features, if using
211
+ if self.with_gradient_features:
212
+
213
+ # Compute gradients
214
+ x_grads = [
215
+ ] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
216
+ for b in range(B):
217
+ # gradient after diffusion
218
+ x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
219
+ x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
220
+
221
+ x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
222
+ x_grad = torch.stack(x_grads, dim=0)
223
+
224
+ # Evaluate gradient features
225
+ x_grad_features = self.gradient_features(x_grad)
226
+
227
+ # Stack inputs to mlp
228
+ feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
229
+ else:
230
+ # Stack inputs to mlp
231
+ feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
232
+
233
+ # Apply the mlp
234
+ x0_out = self.mlp(feature_combined)
235
+
236
+ # Skip connection
237
+ x0_out = x0_out + x_in
238
+
239
+ return x0_out
240
+
241
+
242
+ class DiffusionNet(nn.Module):
243
+
244
+ def __init__(self,
245
+ C_in,
246
+ C_out,
247
+ C_width=128,
248
+ N_block=4,
249
+ last_activation=None,
250
+ outputs_at='vertices',
251
+ mlp_hidden_dims=None,
252
+ dropout=True,
253
+ with_gradient_features=True,
254
+ with_gradient_rotations=True,
255
+ diffusion_method='spectral',
256
+ num_eigenbasis=128):
257
+ """
258
+ Construct a DiffusionNet.
259
+
260
+ Parameters:
261
+ C_in (int): input dimension
262
+ C_out (int): output dimension
263
+ last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
264
+ outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
265
+ C_width (int): dimension of internal DiffusionNet blocks (default: 128)
266
+ N_block (int): number of DiffusionNet blocks (default: 4)
267
+ mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width])
268
+ dropout (bool): if True, internal MLPs use dropout (default: True)
269
+ diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
270
+ with_gradient_features (bool): if True, use gradient features (default: True)
271
+ with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
272
+ num_eigenbasis (int): for trunking the eigenvalues eigenvectors
273
+ """
274
+
275
+ super(DiffusionNet, self).__init__()
276
+
277
+ ## Store parameters
278
+
279
+ # Basic parameters
280
+ self.C_in = C_in
281
+ self.C_out = C_out
282
+ self.C_width = C_width
283
+ self.N_block = N_block
284
+
285
+ # Outputs
286
+ self.last_activation = last_activation
287
+ self.outputs_at = outputs_at
288
+ if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
289
+ raise ValueError("invalid setting for outputs_at")
290
+
291
+ # MLP options
292
+ if mlp_hidden_dims == None:
293
+ mlp_hidden_dims = [C_width, C_width]
294
+ self.mlp_hidden_dims = mlp_hidden_dims
295
+ self.dropout = dropout
296
+
297
+ # Diffusion
298
+ self.diffusion_method = diffusion_method
299
+ if diffusion_method not in ['spectral', 'implicit_dense']:
300
+ raise ValueError("invalid setting for diffusion_method")
301
+ self.num_eigenbasis = num_eigenbasis
302
+
303
+ # Gradient features
304
+ self.with_gradient_features = with_gradient_features
305
+ self.with_gradient_rotations = with_gradient_rotations
306
+
307
+ ## Set up the network
308
+
309
+ # First and last affine layers
310
+ self.first_lin = nn.Linear(C_in, C_width)
311
+ self.last_lin = nn.Linear(C_width, C_out)
312
+
313
+ # DiffusionNet blocks
314
+ self.blocks = []
315
+ for i_block in range(self.N_block):
316
+ block = DiffusionNetBlock(C_width=C_width,
317
+ mlp_hidden_dims=mlp_hidden_dims,
318
+ dropout=dropout,
319
+ diffusion_method=diffusion_method,
320
+ with_gradient_features=with_gradient_features,
321
+ with_gradient_rotations=with_gradient_rotations)
322
+
323
+ self.blocks.append(block)
324
+ self.add_module("block_" + str(i_block), self.blocks[-1])
325
+
326
+ def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
327
+ """
328
+ A forward pass on the DiffusionNet.
329
+
330
+ In the notation below, dimension are:
331
+ - C is the input channel dimension (C_in on construction)
332
+ - C_OUT is the output channel dimension (C_out on construction)
333
+ - N is the number of vertices/points, which CAN be different for each forward pass
334
+ - B is an OPTIONAL batch dimension
335
+ - K_EIG is the number of eigenvalues used for spectral acceleration
336
+ Generally, our data layout it is [N,C] or [B,N,C].
337
+
338
+ Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
339
+
340
+ Parameters:
341
+ x_in (tensor): Input features, dimension [N,C] or [B,N,C]
342
+ mass (tensor): Mass vector, dimension [N] or [B,N]
343
+ L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
344
+ evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
345
+ evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
346
+ gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
347
+ gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
348
+
349
+ Returns:
350
+ x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out]
351
+ """
352
+
353
+ ## Check dimensions, and append batch dimension if not given
354
+ if x_in.shape[-1] != self.C_in:
355
+ raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
356
+ self.C_in, x_in.shape[-1]))
357
+ N = x_in.shape[-2]
358
+ if len(x_in.shape) == 2:
359
+ appended_batch_dim = True
360
+
361
+ # add a batch dim to all inputs
362
+ x_in = x_in.unsqueeze(0)
363
+ mass = mass.unsqueeze(0)
364
+ if L != None:
365
+ L = L.unsqueeze(0)
366
+ if evals != None:
367
+ evals = evals.unsqueeze(0)
368
+ if evecs != None:
369
+ evecs = evecs.unsqueeze(0)
370
+ if gradX != None:
371
+ gradX = gradX.unsqueeze(0)
372
+ if gradY != None:
373
+ gradY = gradY.unsqueeze(0)
374
+ if edges != None:
375
+ edges = edges.unsqueeze(0)
376
+ if faces != None:
377
+ faces = faces.unsqueeze(0)
378
+
379
+ elif len(x_in.shape) == 3:
380
+ appended_batch_dim = False
381
+
382
+ else:
383
+ raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
384
+
385
+ evals = evals[..., :self.num_eigenbasis]
386
+ evecs = evecs[..., :self.num_eigenbasis]
387
+
388
+ # Apply the first linear layer
389
+ x = self.first_lin(x_in)
390
+
391
+ # Apply each of the blocks
392
+ for b in self.blocks:
393
+ x = b(x, mass, L, evals, evecs, gradX, gradY)
394
+
395
+ # Apply the last linear layer
396
+ x = self.last_lin(x)
397
+
398
+ # Remap output to faces/edges if requested
399
+ if self.outputs_at == 'vertices':
400
+ x_out = x
401
+
402
+ elif self.outputs_at == 'edges':
403
+ # Remap to edges
404
+ x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
405
+ edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
406
+ xe = torch.gather(x_gather, 1, edges_gather)
407
+ x_out = torch.mean(xe, dim=-1)
408
+
409
+ elif self.outputs_at == 'faces':
410
+ # Remap to faces
411
+ x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
412
+ faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
413
+ xf = torch.gather(x_gather, 1, faces_gather)
414
+ x_out = torch.mean(xf, dim=-1)
415
+
416
+ elif self.outputs_at == 'global_mean':
417
+ # Produce a single global mean ouput.
418
+ # Using a weighted mean according to the point mass/area is discretization-invariant.
419
+ # (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
420
+ x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
421
+
422
+ # Apply last nonlinearity if specified
423
+ if self.last_activation != None:
424
+ x_out = self.last_activation(x_out)
425
+
426
+ # Remove batch dim if we added it
427
+ if appended_batch_dim:
428
+ x_out = x_out.squeeze(0)
429
+
430
+ return x_out
utils/mesh.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import os
4
+ from tqdm import tqdm
5
+ import potpourri3d as pp3d
6
+ import open3d as o3d
7
+ import scipy.io as sio
8
+ import numpy as np
9
+ import re
10
+ import sys
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12
+ from shape_data import get_data_dirs
13
+
14
+ # List of file extensions to consider as "mesh" files.
15
+ # Kudos to chatgpt!
16
+ # Add or remove extensions here as needed.
17
+ MESH_EXTENSIONS = {".ply", ".obj", ".off", ".stl", ".fbx", ".gltf", ".glb"}
18
+
19
+ def sorted_alphanum(file_list_ordered):
20
+ def convert(text):
21
+ return int(text) if text.isdigit() else text
22
+ def alphanum_key(key):
23
+ return [convert(c) for c in re.split('([0-9]+)', str(key)) if len(c) > 0]
24
+ return sorted(file_list_ordered, key=alphanum_key)
25
+
26
+
27
+ def list_files(folder_path, name_filter, alphanum_sort=False):
28
+ file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
29
+ if alphanum_sort:
30
+ return sorted_alphanum(file_list)
31
+ else:
32
+ return sorted(file_list)
33
+
34
+ def find_mesh_files(directory: Path, extensions: set=MESH_EXTENSIONS, alphanum_sort=False) -> list[Path]:
35
+ """
36
+ Recursively find all files in 'directory' whose suffix (lowercased) is in 'extensions'.
37
+ Returns a list of Path objects.
38
+ """
39
+ matches = []
40
+ for path in directory.rglob("*"):
41
+ if path.is_file() and path.suffix.lower() in extensions:
42
+ matches.append(path)
43
+ if alphanum_sort:
44
+ return sorted_alphanum(matches)
45
+ else:
46
+ return sorted(matches)
47
+
48
+ def save_ply(file_name, V, F, Rho=None, color=None):
49
+ """Save mesh information either as an ASCII ply file.
50
+ https://github.com/emmanuel-hartman/BaRe-ESA/blob/main/utils/input_output.py
51
+ Input:
52
+ - file_name: specified path for saving mesh [string]
53
+ - V: vertices of the triangulated surface [nVx3 numpy ndarray]
54
+ - F: faces of the triangulated surface [nFx3 numpy ndarray]
55
+ - Rho: weights defined on the vertices of the triangulated surface [nVx1 numpy ndarray, default=None]
56
+ - color: colormap [nVx3 numpy ndarray of RGB triples]
57
+
58
+ Output:
59
+ - file_name.mat or file_name.ply file containing mesh information
60
+ """
61
+
62
+ # Save as .ply file
63
+ nV = V.shape[0]
64
+ nF = F.shape[0]
65
+ if not ".ply" in file_name:
66
+ file_name += ".ply"
67
+ file = open(file_name, "w")
68
+ lines = ("ply","\n","format ascii 1.0","\n", "element vertex {}".format(nV),"\n","property float x","\n","property float y","\n","property float z","\n")
69
+
70
+ if color is not None:
71
+ lines += ("property uchar red","\n","property uchar green","\n","property uchar blue","\n")
72
+ if Rho is not None:
73
+ lines += ("property uchar alpha","\n")
74
+
75
+ lines += ("element face {}".format(nF),"\n","property list uchar int vertex_index","\n","end_header","\n")
76
+
77
+ file.writelines(lines)
78
+ lines = []
79
+ for i in range(0,nV):
80
+ for j in range(0,3):
81
+ lines.append(str(V[i][j]))
82
+ lines.append(" ")
83
+ if color is not None:
84
+ for j in range(0,3):
85
+ lines.append(str(color[i][j]))
86
+ lines.append(" ")
87
+ if Rho is not None:
88
+ lines.append(str(Rho[i]))
89
+ lines.append(" ")
90
+
91
+ lines.append("\n")
92
+ for i in range(0,nF):
93
+ l = len(F[i,:])
94
+ lines.append(str(l))
95
+ lines.append(" ")
96
+
97
+ for j in range(0,l):
98
+ lines.append(str(F[i,j]))
99
+ lines.append(" ")
100
+ lines.append("\n")
101
+
102
+ file.writelines(lines)
103
+ file.close()
104
+
105
+ def numpy_to_open3d_mesh(V, F):
106
+ # Create an empty TriangleMesh object
107
+ mesh = o3d.geometry.TriangleMesh()
108
+ # Set vertices
109
+ mesh.vertices = o3d.utility.Vector3dVector(V)
110
+ # Set triangles
111
+ mesh.triangles = o3d.utility.Vector3iVector(F)
112
+ return mesh
113
+
114
+
115
+
116
+ def load_mesh(filepath, scale=True, return_vnormals=False):
117
+ V, F = pp3d.read_mesh(filepath)
118
+ mesh = numpy_to_open3d_mesh(V, F)
119
+
120
+ tmat = np.identity(4, dtype=np.float32)
121
+ center = mesh.get_center()
122
+ tmat[:3, 3] = -center
123
+ area = mesh.get_surface_area()
124
+ if scale:
125
+ smat = np.identity(4, dtype=np.float32)
126
+
127
+ smat[:3, :3] = np.identity(3, dtype=np.float32) / np.sqrt(area)
128
+ tmat = smat @ tmat
129
+ mesh.transform(tmat)
130
+
131
+ vertices = np.asarray(mesh.vertices, dtype=np.float32)
132
+ faces = np.asarray(mesh.triangles, dtype=np.int32)
133
+ if return_vnormals:
134
+ mesh.compute_vertex_normals()
135
+ vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
136
+ if scale:
137
+ return vertices, faces, vnormals, area, center
138
+ return vertices, faces, vnormals, area, center
139
+ else:
140
+ return vertices, faces
141
+
142
+
143
+ def mesh_geod_matrix(vertices, faces, do_tqdm=False, verbose=False):
144
+ if verbose:
145
+ print("Setting Geodesic matrix bw vertices")
146
+ n_vertices = vertices.shape[0]
147
+ distmat = np.zeros((n_vertices, n_vertices))
148
+ solver = pp3d.MeshHeatMethodDistanceSolver(vertices, faces)
149
+ if do_tqdm:
150
+ iterable = tqdm(range(n_vertices))
151
+ else:
152
+ iterable = range(n_vertices)
153
+ for vertind in iterable:
154
+ distmat[vertind] = np.maximum(solver.compute_distance(vertind), 0)
155
+ geod_mat = distmat
156
+ return geod_mat
157
+
158
+
159
+ def prepare_geod_mats(shapes_folder, out, basename=None):
160
+ if basename is None:
161
+ basename = os.path.basename(os.path.dirname(shapes_folder)) #+ "_" + os.path.basename(shapes_folder)
162
+ case = basename
163
+ case_folder_out = os.path.join(out, case)
164
+ os.makedirs(case_folder_out, exist_ok=True)
165
+ all_shapes = [f for f in os.listdir(shapes_folder) if (".ply" in f) or (".off" in f) or ('.obj' in f)]
166
+ for shape in tqdm(all_shapes, "Processing " + os.path.basename(shapes_folder)):
167
+ vertices, faces = pp3d.read_mesh(os.path.join(shapes_folder, shape))
168
+ areas = pp3d.face_areas(vertices, faces)
169
+ mat = mesh_geod_matrix(vertices, faces, verbose=False)
170
+ dict_save = {
171
+ 'geod_dist': mat,
172
+ 'areas_f': areas
173
+ }
174
+ sio.savemat(os.path.join(case_folder_out, shape[:-4]+'.mat'), dict_save)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ parser = argparse.ArgumentParser(description="What to do.")
179
+ parser.add_argument('--make_geods', required=True, type=int, default=0, help='launch computation of geod matrices')
180
+ parser.add_argument('--data', required=False, type=str, default=None)
181
+ parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
182
+ parser.add_argument('--basename', required=False, type=str, default=None)
183
+ args = parser.parse_args()
184
+ if args.make_geods:
185
+ # from config import get_geod_path, get_dataset_path, get_template_path
186
+ # output = get_geod_path()
187
+ output = os.path.join(args.datadir, "geomats")
188
+ if args.data == "humans":
189
+ all_folders = [get_data_dirs(args.datadir, "faust", 'test')[0], get_data_dirs(args.datadir, "scape", 'test')[0], get_data_dirs(args.datadir, "shrec19", 'test')[0]]
190
+ for folder in all_folders:
191
+ prepare_geod_mats(folder, output)
192
+ elif args.data == "dt4d":
193
+ data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
194
+ all_folders = sorted([f for f in os.listdir(data_dir) if "cross" not in f])
195
+ for folder in all_folders:
196
+ prepare_geod_mats(os.path.join(data_dir, folder), os.path.join(output, "DT4D"), basename=folder)
197
+ elif args.data is not None:
198
+ data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
199
+ prepare_geod_mats(data_dir, output, args.basename)
200
+ # parser = argparse.ArgumentParser(description="Find all mesh files in a folder and list their paths.")
201
+ # parser.add_argument("--folder", type=Path, help="Path to the folder to search (will search recursively).")
202
+ # args = parser.parse_args()
203
+
204
+ # search_folder = args.folder
205
+ # if not search_folder.is_dir():
206
+ # print(f"Error: '{search_folder}' is not a valid directory.")
207
+ # else:
208
+ # # Find all matching files
209
+ # mesh_files = find_mesh_files(search_folder)
210
+
211
+ # # Sort the results for consistency
212
+ # mesh_files.sort()
213
+ # for p in mesh_files:
214
+ # print(p.resolve())
utils/meshplot.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import potpourri3d as pp3d
3
+ import torch
4
+ import meshplot as mp
5
+ from ipygany import PolyMesh, Scene
6
+
7
+
8
+
9
+ colors = np.array([[255, 119, 0], #orange
10
+ [163, 51, 219], #violet
11
+ [0, 246, 205], #bleu clair
12
+ [0, 131, 246], #bleu floncé
13
+ [246, 234, 0], #jaune
14
+ [143, 188, 143], #rouge
15
+ [255, 0, 0]])
16
+
17
+
18
+ def double_plot_surf(surf_1, surf_2,cmap1=None,cmap2=None):
19
+ d = mp.subplot(surf_1.vertices, surf_1.faces, c=cmap1, s=[2, 2, 0])
20
+ mp.subplot(surf_2.vertices, surf_2.faces, c=cmap2, s=[2, 2, 1], data=d)
21
+
22
+ def visu_pts(surf, colors=colors, idx_points=None, n_kpts=5):
23
+ areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
24
+ area = np.sqrt(areas.sum()/2)
25
+ if idx_points is None:
26
+ center = (surf.centers*(areas)).sum(axis=0)/areas.sum()
27
+ surf.updateVertices((surf.vertices - center)/area)
28
+ surf.cotanLaplacian()
29
+ idx_points = surf.get_keypoints(n_points=n_kpts)
30
+ solver = pp3d.MeshHeatMethodDistanceSolver(surf.vertices, surf.faces)
31
+ norm_center = solver.compute_distance(idx_points[-1])
32
+ color_array = np.zeros(surf.vertices.shape)
33
+ for i in range(len(idx_points)):
34
+ coeff = 0.1 + 1.5*norm_center[idx_points[i]]
35
+ i_v = idx_points[i]
36
+ dist = solver.compute_distance(i_v)*2
37
+ #color_array += np.clip(1-dist, 0, np.inf)[:, None]*colors[i][None, :]
38
+ color_array += np.exp(-dist**2/coeff)[:, None]*colors[i][None, :]
39
+ color_array = np.clip(color_array, 0, 255.)
40
+ return color_array
41
+
42
+ def toNP(tens):
43
+ if isinstance(tens, torch.Tensor):
44
+ return tens.detach().squeeze().cpu().numpy()
45
+ return tens
46
+
47
+
48
+ def overlay_surf(shape_vertices, shape_faces, target_vertices, target_faces, colors=["tomato", "darksalmon"]):
49
+ shape_vertices = toNP(shape_vertices)
50
+ target_vertices = toNP(target_vertices)
51
+ mesh_1 = PolyMesh(
52
+ vertices=shape_vertices,
53
+ triangle_indices=shape_faces
54
+ )
55
+ mesh_1.default_color = colors[0]
56
+
57
+
58
+ mesh_2 = PolyMesh(
59
+ vertices=target_vertices,
60
+ triangle_indices=target_faces
61
+ )
62
+ mesh_2.default_color = colors[1]
63
+
64
+
65
+ scene = Scene([mesh_1, mesh_2])
66
+ #scene = Scene([mesh_5])
67
+ return scene, [mesh_1, mesh_2]
utils/misc.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ import omegaconf
7
+ from omegaconf import OmegaConf
8
+ from scipy.spatial import cKDTree
9
+ from pathlib import Path
10
+
11
+
12
+ class KNNSearch(object):
13
+ DTYPE = np.float32
14
+ NJOBS = 4
15
+
16
+ def __init__(self, data):
17
+ self.data = np.asarray(data, dtype=self.DTYPE)
18
+ self.kdtree = cKDTree(self.data)
19
+
20
+ def query(self, kpts, k, return_dists=False):
21
+ kpts = np.asarray(kpts, dtype=self.DTYPE)
22
+ nndists, nnindices = self.kdtree.query(kpts, k=k, workers=self.NJOBS)
23
+ if return_dists:
24
+ return nnindices, nndists
25
+ else:
26
+ return nnindices
27
+
28
+ def query_ball(self, kpt, radius):
29
+ kpt = np.asarray(kpt, dtype=self.DTYPE)
30
+ assert kpt.ndim == 1
31
+ nnindices = self.kdtree.query_ball_point(kpt, radius, n_jobs=self.NJOBS)
32
+ return nnindices
33
+
34
+
35
+ def validate_str(x):
36
+ return x is not None and x != ''
37
+
38
+
39
+ def hashing(arr, M):
40
+ assert isinstance(arr, np.ndarray) and arr.ndim == 2
41
+ N, D = arr.shape
42
+
43
+ hash_vec = np.zeros(N, dtype=np.int64)
44
+ for d in range(D):
45
+ hash_vec += arr[:, d] * M**d
46
+ return hash_vec
47
+
48
+
49
+ def omegaconf_to_dotdict(hparams):
50
+
51
+ def _to_dot_dict(cfg):
52
+ res = {}
53
+ for k, v in cfg.items():
54
+ if v is None:
55
+ res[k] = v
56
+ elif isinstance(v, omegaconf.DictConfig):
57
+ res.update({k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()})
58
+ elif isinstance(v, (str, int, float, bool)):
59
+ res[k] = v
60
+ elif isinstance(v, omegaconf.ListConfig):
61
+ res[k] = omegaconf.OmegaConf.to_container(v, resolve=True)
62
+ else:
63
+ raise RuntimeError('The type of {} is not supported.'.format(type(v)))
64
+ return res
65
+
66
+ return _to_dot_dict(hparams)
67
+
68
+
69
+ def incrange(start, end, step):
70
+ assert step > 0
71
+ res = [start]
72
+ if start <= end:
73
+ while res[-1] + step <= end:
74
+ res.append(res[-1] + step)
75
+ else:
76
+ while res[-1] - step >= end:
77
+ res.append(res[-1] - step)
78
+ return res
79
+
80
+
81
+ def seeding(seed=0):
82
+ torch.manual_seed(seed)
83
+ torch.cuda.manual_seed_all(seed)
84
+ np.random.seed(seed)
85
+ random.seed(seed)
86
+
87
+ torch.backends.cudnn.enabled = True
88
+ torch.backends.cudnn.benchmark = True
89
+ torch.backends.cudnn.deterministic = True
90
+
91
+
92
+ def run_trainer(trainer_cls):
93
+ cfg_cli = OmegaConf.from_cli()
94
+
95
+ assert cfg_cli.run_mode is not None
96
+ if cfg_cli.run_mode == 'train':
97
+ assert cfg_cli.run_cfg is not None
98
+ cfg = OmegaConf.merge(
99
+ OmegaConf.load(cfg_cli.run_cfg),
100
+ cfg_cli,
101
+ )
102
+ OmegaConf.resolve(cfg)
103
+ cfg = omegaconf_to_dotdict(cfg)
104
+ seeding(cfg['seed'])
105
+ trainer = trainer_cls(cfg)
106
+ trainer.train()
107
+ trainer.test()
108
+ elif cfg_cli.run_mode == 'test':
109
+ assert cfg_cli.run_ckpt is not None
110
+ log_dir = str(Path(cfg_cli.run_ckpt).parent)
111
+ cfg = OmegaConf.merge(
112
+ OmegaConf.load(osp.join(log_dir, 'config.yml')),
113
+ cfg_cli,
114
+ )
115
+ OmegaConf.resolve(cfg)
116
+ cfg = omegaconf_to_dotdict(cfg)
117
+ cfg['test_ckpt'] = cfg_cli.run_ckpt
118
+ seeding(cfg['seed'])
119
+ trainer = trainer_cls(cfg)
120
+ trainer.test()
121
+ else:
122
+ raise RuntimeError(f'Mode {cfg_cli.run_mode} is not supported.')
utils/pickle_stuff.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import io
3
+ import importlib
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+
9
+ # Mapping of old module names to new module names
10
+ MODULE_RENAME_MAP = {
11
+ 'module': 'diffu_models',
12
+ 'module.model': 'diffu_models.precond',
13
+ 'module.dit_models': 'diffu_models.dit_models',
14
+ 'module.model': 'diffu_models.precond',
15
+ # add more as needed
16
+ }
17
+
18
+ class RenameUnpickler(pickle.Unpickler):
19
+ def find_class(self, module, name):
20
+ if module in MODULE_RENAME_MAP:
21
+ module = MODULE_RENAME_MAP[module]
22
+ try:
23
+ return super().find_class(module, name)
24
+ except ModuleNotFoundError as e:
25
+ raise ModuleNotFoundError(f"Could not find module '{module}'. You may need to update MODULE_RENAME_MAP.") from e
26
+
27
+ # Usage
28
+ def load_renamed_pickle(file_path):
29
+ with open(file_path, 'rb') as f:
30
+ return RenameUnpickler(f).load()
31
+
32
+ def safe_load_with_fallback(file_path):
33
+ try:
34
+ with open(file_path, 'rb') as f:
35
+ return pickle.load(f)
36
+ except ModuleNotFoundError:
37
+ with open(file_path, 'rb') as f:
38
+ return RenameUnpickler(f).load()
utils/surfaces.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy as sp
3
+ import scipy.linalg as spLA
4
+ import scipy.sparse.linalg as sla
5
+ import os
6
+ import logging
7
+ import copy
8
+ import potpourri3d as pp3d
9
+ from sklearn.cluster import KMeans#, DBSCAN, SpectralClustering
10
+ from scipy.spatial import cKDTree
11
+ import trimesh as tm
12
+
13
+
14
+ try:
15
+ from vtk import *
16
+ import vtk.util.numpy_support as v2n
17
+
18
+ gotVTK = True
19
+ except ImportError:
20
+ print('could not import VTK functions')
21
+ gotVTK = False
22
+
23
+
24
+ # import kernelFunctions as kfun
25
+
26
+ class vtkFields:
27
+ def __init__(self):
28
+ self.scalars = []
29
+ self.vectors = []
30
+ self.normals = []
31
+ self.tensors = []
32
+
33
+
34
+ # General surface class, fibers possible. The fiber is a vector field of norm 1, defined on each vertex.
35
+ # It must be an array of the same size (number of vertices)x3.
36
+
37
+ class Surface:
38
+
39
+ # Fibers: a list of vectors, with i-th element corresponding to the value of the vector field at vertex i
40
+ # Contains as object:
41
+ # vertices : all the vertices
42
+ # centers: the centers of each face
43
+ # faces: faces along with the id of the faces
44
+ # surfel: surface element of each face (area*normal)
45
+ # List of methods:
46
+ # read : from filename type, call readFILENAME and set all surface attributes
47
+ # updateVertices: update the whole surface after a modification of the vertices
48
+ # computeVertexArea and Normals
49
+ # getEdges
50
+ # LocalSignedDistance distance function in the neighborhood of a shape
51
+ # toPolyData: convert the surface to a polydata vtk object
52
+ # fromPolyDate: guess what
53
+ # Simplify : simplify the meshes
54
+ # flipfaces: invert the indices of faces from [a, b, c] to [a, c, b]
55
+ # smooth: get a smoother surface
56
+ # Isosurface: compute isosurface
57
+ # edgeRecove: ensure that orientation is correct
58
+ # remove isolated: if isolated vertice, remove it
59
+ # laplacianMatrix: compute Laplacian Matrix of the surface graph
60
+ # graphLaplacianMatrix: ???
61
+ # laplacianSegmentation: segment the surface using laplacian properties
62
+ # surfVolume: compute volume inscribed in the surface (+ inscribed infinitesimal volume for each face)
63
+ # surfCenter: compute surface Center
64
+ # surfMoments: compute surface second order moments
65
+ # surfU: compute the informations for surface rigid alignement
66
+ # surfEllipsoid: compute ellipsoid representing the surface
67
+ # savebyu of vitk or vtk2: save the surface in a file
68
+ # concatenate: concatenate to another surface
69
+
70
+ def __init__(self, surf=None, filename=None, FV=None):
71
+ if surf == None:
72
+ if FV == None:
73
+ if filename == None:
74
+ self.vertices = np.empty(0)
75
+ self.centers = np.empty(0)
76
+ self.faces = np.empty(0)
77
+ self.surfel = np.empty(0)
78
+ else:
79
+ if type(filename) is list:
80
+ fvl = []
81
+ for name in filename:
82
+ fvl.append(Surface(filename=name))
83
+ self.concatenate(fvl)
84
+ else:
85
+ self.read(filename)
86
+ else:
87
+ self.vertices = np.copy(FV[1])
88
+ self.faces = np.int_(FV[0])
89
+ self.computeCentersAreas()
90
+
91
+ else:
92
+ self.vertices = np.copy(surf.vertices)
93
+ self.faces = np.copy(surf.faces)
94
+ self.surfel = np.copy(surf.surfel)
95
+ self.centers = np.copy(surf.centers)
96
+ self.computeCentersAreas()
97
+ self.volume, self.vols = self.surfVolume()
98
+ self.center = self.surfCenter()
99
+ self.cotanLaplacian()
100
+
101
+ def read(self, filename):
102
+ (mainPart, ext) = os.path.splitext(filename)
103
+ if ext == '.byu':
104
+ self.readbyu(filename)
105
+ elif ext == '.off':
106
+ self.readOFF(filename)
107
+ elif ext == '.vtk':
108
+ self.readVTK(filename)
109
+ elif ext == '.obj':
110
+ self.readOBJ(filename)
111
+ elif ext == '.ply':
112
+ self.readPLY(filename)
113
+ elif ext == '.tri' or ext == ".ntri":
114
+ self.readTRI(filename)
115
+ else:
116
+ print('Unknown Surface Extension:', ext)
117
+ self.vertices = np.empty(0)
118
+ self.centers = np.empty(0)
119
+ self.faces = np.empty(0)
120
+ self.surfel = np.empty(0)
121
+
122
+ # face centers and area weighted normal
123
+ def computeCentersAreas(self):
124
+ xDef1 = self.vertices[self.faces[:, 0], :]
125
+ xDef2 = self.vertices[self.faces[:, 1], :]
126
+ xDef3 = self.vertices[self.faces[:, 2], :]
127
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
128
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
129
+
130
+ # modify vertices without toplogical change
131
+ def updateVertices(self, x0):
132
+ self.vertices = np.copy(x0)
133
+ xDef1 = self.vertices[self.faces[:, 0], :]
134
+ xDef2 = self.vertices[self.faces[:, 1], :]
135
+ xDef3 = self.vertices[self.faces[:, 2], :]
136
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
137
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
138
+ self.volume, self.vols = self.surfVolume()
139
+
140
+ def computeVertexArea(self):
141
+ # compute areas of faces and vertices
142
+ V = self.vertices
143
+ F = self.faces
144
+ nv = V.shape[0]
145
+ nf = F.shape[0]
146
+ AF = np.zeros([nf, 1])
147
+ AV = np.zeros([nv, 1])
148
+ for k in range(nf):
149
+ # determining if face is obtuse
150
+ x12 = V[F[k, 1], :] - V[F[k, 0], :]
151
+ x13 = V[F[k, 2], :] - V[F[k, 0], :]
152
+ n12 = np.sqrt((x12 ** 2).sum())
153
+ n13 = np.sqrt((x13 ** 2).sum())
154
+ c1 = (x12 * x13).sum() / (n12 * n13)
155
+ x23 = V[F[k, 2], :] - V[F[k, 1], :]
156
+ n23 = np.sqrt((x23 ** 2).sum())
157
+ # n23 = norm(x23) ;
158
+ c2 = -(x12 * x23).sum() / (n12 * n23)
159
+ c3 = (x13 * x23).sum() / (n13 * n23)
160
+ AF[k] = np.sqrt((np.cross(x12, x13) ** 2).sum()) / 2
161
+ if (c1 < 0):
162
+ # face obtuse at vertex 1
163
+ AV[F[k, 0]] += AF[k] / 2
164
+ AV[F[k, 1]] += AF[k] / 4
165
+ AV[F[k, 2]] += AF[k] / 4
166
+ elif (c2 < 0):
167
+ # face obuse at vertex 2
168
+ AV[F[k, 0]] += AF[k] / 4
169
+ AV[F[k, 1]] += AF[k] / 2
170
+ AV[F[k, 2]] += AF[k] / 4
171
+ elif (c3 < 0):
172
+ # face obtuse at vertex 3
173
+ AV[F[k, 0]] += AF[k] / 4
174
+ AV[F[k, 1]] += AF[k] / 4
175
+ AV[F[k, 2]] += AF[k] / 2
176
+ else:
177
+ # non obtuse face
178
+ cot1 = c1 / np.sqrt(1 - c1 ** 2)
179
+ cot2 = c2 / np.sqrt(1 - c2 ** 2)
180
+ cot3 = c3 / np.sqrt(1 - c3 ** 2)
181
+ AV[F[k, 0]] += ((x12 ** 2).sum() * cot3 + (x13 ** 2).sum() * cot2) / 8
182
+ AV[F[k, 1]] += ((x12 ** 2).sum() * cot3 + (x23 ** 2).sum() * cot1) / 8
183
+ AV[F[k, 2]] += ((x13 ** 2).sum() * cot2 + (x23 ** 2).sum() * cot1) / 8
184
+
185
+ for k in range(nv):
186
+ if (np.fabs(AV[k]) < 1e-10):
187
+ print('Warning: vertex ', k, 'has no face; use removeIsolated')
188
+ # print('sum check area:', AF.sum(), AV.sum()
189
+ return AV, AF
190
+
191
+ def computeVertexNormals(self):
192
+ self.computeCentersAreas()
193
+ normals = np.zeros(self.vertices.shape)
194
+ F = self.faces
195
+ for k in range(F.shape[0]):
196
+ normals[F[k, 0]] += self.surfel[k]
197
+ normals[F[k, 1]] += self.surfel[k]
198
+ normals[F[k, 2]] += self.surfel[k]
199
+ af = np.sqrt((normals ** 2).sum(axis=1))
200
+ # logging.info('min area = %.4f'%(af.min()))
201
+ normals /= af.reshape([self.vertices.shape[0], 1])
202
+
203
+ return normals
204
+
205
+ # Computes edges from vertices/faces
206
+ def getEdges(self):
207
+ self.edges = []
208
+
209
+ for k in range(self.faces.shape[0]):
210
+ for kj in (0, 1, 2):
211
+ u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
212
+
213
+ if (u not in self.edges) & (u.reverse() not in self.edges):
214
+ self.edges.append(u)
215
+
216
+ self.edgeFaces = []
217
+
218
+ for u in self.edges:
219
+ self.edgeFaces.append([])
220
+
221
+ for k in range(self.faces.shape[0]):
222
+ for kj in (0, 1, 2):
223
+ u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
224
+
225
+ if u in self.edges:
226
+ kk = self.edges.index(u)
227
+ else:
228
+ u.reverse()
229
+ kk = self.edges.index(u)
230
+
231
+ self.edgeFaces[kk].append(k)
232
+
233
+ self.edges = np.int_(np.array(self.edges))
234
+ self.bdry = np.int_(np.zeros(self.edges.shape[0]))
235
+
236
+ for k in range(self.edges.shape[0]):
237
+ if len(self.edgeFaces[k]) < 2:
238
+ self.bdry[k] = 1
239
+
240
+ # computes the signed distance function in a small neighborhood of a shape
241
+ def LocalSignedDistance(self, data, value):
242
+ d2 = 2 * np.array(data >= value) - 1
243
+ c2 = np.cumsum(d2, axis=0)
244
+ for j in range(2):
245
+ c2 = np.cumsum(c2, axis=j + 1)
246
+ (n0, n1, n2) = c2.shape
247
+
248
+ rad = 3
249
+ diam = 2 * rad + 1
250
+ (x, y, z) = np.mgrid[-rad:rad + 1, -rad:rad + 1, -rad:rad + 1]
251
+ cube = (x ** 2 + y ** 2 + z ** 2)
252
+ maxval = (diam) ** 3
253
+ s = 3.0 * rad ** 2
254
+ res = d2 * s
255
+ u = maxval * np.ones(c2.shape)
256
+ u[rad + 1:n0 - rad, rad + 1:n1 - rad, rad + 1:n2 - rad] = (c2[diam:n0, diam:n1, diam:n2]
257
+ - c2[0:n0 - diam, diam:n1, diam:n2] - c2[diam:n0,
258
+ 0:n1 - diam,
259
+ diam:n2] - c2[
260
+ diam:n0,
261
+ diam:n1,
262
+ 0:n2 - diam]
263
+ + c2[0:n0 - diam, 0:n1 - diam, diam:n2] + c2[diam:n0,
264
+ 0:n1 - diam,
265
+ 0:n2 - diam] + c2[
266
+ 0:n0 - diam,
267
+ diam:n1,
268
+ 0:n2 - diam]
269
+ - c2[0:n0 - diam, 0:n1 - diam, 0:n2 - diam])
270
+
271
+ I = np.nonzero(np.fabs(u) < maxval)
272
+ # print(len(I[0]))
273
+
274
+ for k in range(len(I[0])):
275
+ p = np.array((I[0][k], I[1][k], I[2][k]))
276
+ bmin = p - rad
277
+ bmax = p + rad + 1
278
+ # print(p, bmin, bmax)
279
+ if (d2[p[0], p[1], p[2]] > 0):
280
+ # print(u[p[0],p[1], p[2]])
281
+ # print(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]].sum())
282
+ res[p[0], p[1], p[2]] = min(
283
+ cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] < 0)]) - .25
284
+ else:
285
+ res[p[0], p[1], p[2]] = - min(
286
+ cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] > 0)]) - .25
287
+
288
+ return res
289
+
290
+ def toPolyData(self):
291
+ if gotVTK:
292
+ points = vtkPoints()
293
+ for k in range(self.vertices.shape[0]):
294
+ points.InsertNextPoint(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
295
+ polys = vtkCellArray()
296
+ for k in range(self.faces.shape[0]):
297
+ polys.InsertNextCell(3)
298
+ for kk in range(3):
299
+ polys.InsertCellPoint(self.faces[k, kk])
300
+ polydata = vtkPolyData()
301
+ polydata.SetPoints(points)
302
+ polydata.SetPolys(polys)
303
+ return polydata
304
+ else:
305
+ raise Exception('Cannot run toPolyData without VTK')
306
+
307
+ def fromPolyData(self, g, scales=[1., 1., 1.]):
308
+ npoints = int(g.GetNumberOfPoints())
309
+ nfaces = int(g.GetNumberOfPolys())
310
+ logging.info('Dimensions: %d %d %d' % (npoints, nfaces, g.GetNumberOfCells()))
311
+ V = np.zeros([npoints, 3])
312
+ for kk in range(npoints):
313
+ V[kk, :] = np.array(g.GetPoint(kk))
314
+ # print(kk, V[kk])
315
+ # print(kk, np.array(g.GetPoint(kk)))
316
+ F = np.zeros([nfaces, 3])
317
+ gf = 0
318
+ for kk in range(g.GetNumberOfCells()):
319
+ c = g.GetCell(kk)
320
+ if (c.GetNumberOfPoints() == 3):
321
+ for ll in range(3):
322
+ F[gf, ll] = c.GetPointId(ll)
323
+ # print(kk, gf, F[gf])
324
+ gf += 1
325
+
326
+ # self.vertices = np.multiply(data.shape-V-1, scales)
327
+ self.vertices = np.multiply(V, scales)
328
+ self.faces = np.int_(F[0:gf, :])
329
+ self.computeCentersAreas()
330
+
331
+ def Simplify(self, target=1000.0):
332
+ if gotVTK:
333
+ polydata = self.toPolyData()
334
+ dc = vtkQuadricDecimation()
335
+ red = 1 - min(np.float(target) / polydata.GetNumberOfPoints(), 1)
336
+ dc.SetTargetReduction(red)
337
+ dc.SetInput(polydata)
338
+ dc.Update()
339
+ g = dc.GetOutput()
340
+ self.fromPolyData(g)
341
+ z = self.surfVolume()
342
+ if (z > 0):
343
+ self.flipFaces()
344
+ print('flipping volume', z, self.surfVolume())
345
+ else:
346
+ raise Exception('Cannot run Simplify without VTK')
347
+
348
+ def flipFaces(self):
349
+ self.faces = self.faces[:, [0, 2, 1]]
350
+ self.computeCentersAreas()
351
+
352
+ def smooth(self, n=30, smooth=0.1):
353
+ if gotVTK:
354
+ g = self.toPolyData()
355
+ print(g)
356
+ smoother = vtkWindowedSincPolyDataFilter()
357
+ smoother.SetInput(g)
358
+ smoother.SetNumberOfIterations(n)
359
+ smoother.SetPassBand(smooth)
360
+ smoother.NonManifoldSmoothingOn()
361
+ smoother.NormalizeCoordinatesOn()
362
+ smoother.GenerateErrorScalarsOn()
363
+ # smoother.GenerateErrorVectorsOn()
364
+ smoother.Update()
365
+ g = smoother.GetOutput()
366
+ self.fromPolyData(g)
367
+ else:
368
+ raise Exception('Cannot run smooth without VTK')
369
+
370
+ # Computes isosurfaces using vtk
371
+ def Isosurface(self, data, value=0.5, target=1000.0, scales=[1., 1., 1.], smooth=0.1, fill_holes=1.):
372
+ if gotVTK:
373
+ # data = self.LocalSignedDistance(data0, value)
374
+ if isinstance(data, vtkImageData):
375
+ img = data
376
+ else:
377
+ img = vtkImageData()
378
+ img.SetDimensions(data.shape)
379
+ img.SetOrigin(0, 0, 0)
380
+ if vtkVersion.GetVTKMajorVersion() >= 6:
381
+ img.AllocateScalars(VTK_FLOAT, 1)
382
+ else:
383
+ img.SetNumberOfScalarComponents(1)
384
+ v = vtkDoubleArray()
385
+ v.SetNumberOfValues(data.size)
386
+ v.SetNumberOfComponents(1)
387
+ for ii, tmp in enumerate(np.ravel(data, order='F')):
388
+ v.SetValue(ii, tmp)
389
+ img.GetPointData().SetScalars(v)
390
+
391
+ cf = vtkContourFilter()
392
+ if vtkVersion.GetVTKMajorVersion() >= 6:
393
+ cf.SetInputData(img)
394
+ else:
395
+ cf.SetInput(img)
396
+ cf.SetValue(0, value)
397
+ cf.SetNumberOfContours(1)
398
+ cf.Update()
399
+ # print(cf
400
+ connectivity = vtkPolyDataConnectivityFilter()
401
+ connectivity.ScalarConnectivityOff()
402
+ connectivity.SetExtractionModeToLargestRegion()
403
+ if vtkVersion.GetVTKMajorVersion() >= 6:
404
+ connectivity.SetInputData(cf.GetOutput())
405
+ else:
406
+ connectivity.SetInput(cf.GetOutput())
407
+ connectivity.Update()
408
+ g = connectivity.GetOutput()
409
+
410
+ if smooth > 0:
411
+ smoother = vtkWindowedSincPolyDataFilter()
412
+ if vtkVersion.GetVTKMajorVersion() >= 6:
413
+ smoother.SetInputData(g)
414
+ else:
415
+ smoother.SetInput(g)
416
+ # else:
417
+ # smoother.SetInputConnection(contour.GetOutputPort())
418
+ smoother.SetNumberOfIterations(30)
419
+ # this has little effect on the error!
420
+ # smoother.BoundarySmoothingOff()
421
+ # smoother.FeatureEdgeSmoothingOff()
422
+ # smoother.SetFeatureAngle(120.0)
423
+ smoother.SetPassBand(smooth) # this increases the error a lot!
424
+ smoother.NonManifoldSmoothingOn()
425
+ # smoother.NormalizeCoordinatesOn()
426
+ # smoother.GenerateErrorScalarsOn()
427
+ # smoother.GenerateErrorVectorsOn()
428
+ smoother.Update()
429
+ g = smoother.GetOutput()
430
+
431
+ # dc = vtkDecimatePro()
432
+ red = 1 - min(np.float(target) / g.GetNumberOfPoints(), 1)
433
+ # print('Reduction: ', red)
434
+ dc = vtkQuadricDecimation()
435
+ dc.SetTargetReduction(red)
436
+ # dc.AttributeErrorMetricOn()
437
+ # dc.SetDegree(10)
438
+ # dc.SetSplitting(0)
439
+ if vtkVersion.GetVTKMajorVersion() >= 6:
440
+ dc.SetInputData(g)
441
+ else:
442
+ dc.SetInput(g)
443
+ # dc.SetInput(g)
444
+ # print(dc)
445
+ dc.Update()
446
+ g = dc.GetOutput()
447
+ # print('points:', g.GetNumberOfPoints())
448
+ cp = vtkCleanPolyData()
449
+ if vtkVersion.GetVTKMajorVersion() >= 6:
450
+ cp.SetInputData(dc.GetOutput())
451
+ else:
452
+ cp.SetInput(dc.GetOutput())
453
+ # cp.SetInput(dc.GetOutput())
454
+ # cp.SetPointMerging(1)
455
+ cp.ConvertPolysToLinesOn()
456
+ cp.SetAbsoluteTolerance(1e-5)
457
+ cp.Update()
458
+ g = cp.GetOutput()
459
+ self.fromPolyData(g, scales)
460
+ z = self.surfVolume()
461
+ if (z > 0):
462
+ self.flipFaces()
463
+ # print('flipping volume', z, self.surfVolume())
464
+ logging.info('flipping volume %.2f %.2f' % (z, self.surfVolume()))
465
+
466
+ # print(g)
467
+ # npoints = int(g.GetNumberOfPoints())
468
+ # nfaces = int(g.GetNumberOfPolys())
469
+ # print('Dimensions:', npoints, nfaces, g.GetNumberOfCells())
470
+ # V = np.zeros([npoints, 3])
471
+ # for kk in range(npoints):
472
+ # V[kk, :] = np.array(g.GetPoint(kk))
473
+ # #print(kk, V[kk])
474
+ # #print(kk, np.array(g.GetPoint(kk)))
475
+ # F = np.zeros([nfaces, 3])
476
+ # gf = 0
477
+ # for kk in range(g.GetNumberOfCells()):
478
+ # c = g.GetCell(kk)
479
+ # if(c.GetNumberOfPoints() == 3):
480
+ # for ll in range(3):
481
+ # F[gf,ll] = c.GetPointId(ll)
482
+ # #print(kk, gf, F[gf])
483
+ # gf += 1
484
+
485
+ # #self.vertices = np.multiply(data.shape-V-1, scales)
486
+ # self.vertices = np.multiply(V, scales)
487
+ # self.faces = np.int_(F[0:gf, :])
488
+ # self.computeCentersAreas()
489
+ else:
490
+ raise Exception('Cannot run Isosurface without VTK')
491
+
492
+ # Ensures that orientation is correct
493
+ def edgeRecover(self):
494
+ v = self.vertices
495
+ f = self.faces
496
+ nv = v.shape[0]
497
+ nf = f.shape[0]
498
+ # faces containing each oriented edge
499
+ edg0 = np.int_(np.zeros((nv, nv)))
500
+ # number of edges between each vertex
501
+ edg = np.int_(np.zeros((nv, nv)))
502
+ # contiguous faces
503
+ edgF = np.int_(np.zeros((nf, nf)))
504
+ for (kf, c) in enumerate(f):
505
+ if (edg0[c[0], c[1]] > 0):
506
+ edg0[c[1], c[0]] = kf + 1
507
+ else:
508
+ edg0[c[0], c[1]] = kf + 1
509
+
510
+ if (edg0[c[1], c[2]] > 0):
511
+ edg0[c[2], c[1]] = kf + 1
512
+ else:
513
+ edg0[c[1], c[2]] = kf + 1
514
+
515
+ if (edg0[c[2], c[0]] > 0):
516
+ edg0[c[0], c[2]] = kf + 1
517
+ else:
518
+ edg0[c[2], c[0]] = kf + 1
519
+
520
+ edg[c[0], c[1]] += 1
521
+ edg[c[1], c[2]] += 1
522
+ edg[c[2], c[0]] += 1
523
+
524
+ for kv in range(nv):
525
+ I2 = np.nonzero(edg0[kv, :])
526
+ for kkv in I2[0].tolist():
527
+ edgF[edg0[kkv, kv] - 1, edg0[kv, kkv] - 1] = kv + 1
528
+
529
+ isOriented = np.int_(np.zeros(f.shape[0]))
530
+ isActive = np.int_(np.zeros(f.shape[0]))
531
+ I = np.nonzero(np.squeeze(edgF[0, :]))
532
+ # list of faces to be oriented
533
+ activeList = [0] + I[0].tolist()
534
+ lastOriented = 0
535
+ isOriented[0] = True
536
+ for k in activeList:
537
+ isActive[k] = True
538
+
539
+ while lastOriented < len(activeList) - 1:
540
+ i = activeList[lastOriented]
541
+ j = activeList[lastOriented + 1]
542
+ I = np.nonzero(edgF[j, :])
543
+ foundOne = False
544
+ for kk in I[0].tolist():
545
+ if (foundOne == False) & (isOriented[kk]):
546
+ foundOne = True
547
+ u1 = edgF[j, kk] - 1
548
+ u2 = edgF[kk, j] - 1
549
+ if not ((edg[u1, u2] == 1) & (edg[u2, u1] == 1)):
550
+ # reorient face j
551
+ edg[f[j, 0], f[j, 1]] -= 1
552
+ edg[f[j, 1], f[j, 2]] -= 1
553
+ edg[f[j, 2], f[j, 0]] -= 1
554
+ a = f[j, 1]
555
+ f[j, 1] = f[j, 2]
556
+ f[j, 2] = a
557
+ edg[f[j, 0], f[j, 1]] += 1
558
+ edg[f[j, 1], f[j, 2]] += 1
559
+ edg[f[j, 2], f[j, 0]] += 1
560
+ elif (not isActive[kk]):
561
+ activeList.append(kk)
562
+ isActive[kk] = True
563
+ if foundOne:
564
+ lastOriented = lastOriented + 1
565
+ isOriented[j] = True
566
+ # print('oriented face', j, lastOriented, 'out of', nf, '; total active', len(activeList))
567
+ else:
568
+ print('Unable to orient face', j)
569
+ return
570
+ self.vertices = v;
571
+ self.faces = f;
572
+
573
+ z, _ = self.surfVolume()
574
+ if (z > 0):
575
+ self.flipFaces()
576
+
577
+ def removeIsolated(self):
578
+ N = self.vertices.shape[0]
579
+ inFace = np.int_(np.zeros(N))
580
+ for k in range(3):
581
+ inFace[self.faces[:, k]] = 1
582
+ J = np.nonzero(inFace)
583
+ self.vertices = self.vertices[J[0], :]
584
+ logging.info('Found %d isolated vertices' % (J[0].shape[0]))
585
+ Q = -np.ones(N)
586
+ for k, j in enumerate(J[0]):
587
+ Q[j] = k
588
+ self.faces = np.int_(Q[self.faces])
589
+
590
+ def laplacianMatrix(self):
591
+ F = self.faces
592
+ V = self.vertices;
593
+ nf = F.shape[0]
594
+ nv = V.shape[0]
595
+
596
+ AV, AF = self.computeVertexArea()
597
+
598
+ # compute edges and detect boundary
599
+ # edm = sp.lil_matrix((nv,nv))
600
+ edm = -np.ones([nv, nv]).astype(np.int32)
601
+ E = np.zeros([3 * nf, 2]).astype(np.int32)
602
+ j = 0
603
+ for k in range(nf):
604
+ if (edm[F[k, 0], F[k, 1]] == -1):
605
+ edm[F[k, 0], F[k, 1]] = j
606
+ edm[F[k, 1], F[k, 0]] = j
607
+ E[j, :] = [F[k, 0], F[k, 1]]
608
+ j = j + 1
609
+ if (edm[F[k, 1], F[k, 2]] == -1):
610
+ edm[F[k, 1], F[k, 2]] = j
611
+ edm[F[k, 2], F[k, 1]] = j
612
+ E[j, :] = [F[k, 1], F[k, 2]]
613
+ j = j + 1
614
+ if (edm[F[k, 0], F[k, 2]] == -1):
615
+ edm[F[k, 2], F[k, 0]] = j
616
+ edm[F[k, 0], F[k, 2]] = j
617
+ E[j, :] = [F[k, 2], F[k, 0]]
618
+ j = j + 1
619
+ E = E[0:j, :]
620
+
621
+ edgeFace = np.zeros([j, nf])
622
+ ne = j
623
+ # print(E)
624
+ for k in range(nf):
625
+ edgeFace[edm[F[k, 0], F[k, 1]], k] = 1
626
+ edgeFace[edm[F[k, 1], F[k, 2]], k] = 1
627
+ edgeFace[edm[F[k, 2], F[k, 0]], k] = 1
628
+
629
+ bEdge = np.zeros([ne, 1])
630
+ bVert = np.zeros([nv, 1])
631
+ edgeAngles = np.zeros([ne, 2])
632
+ for k in range(ne):
633
+ I = np.flatnonzero(edgeFace[k, :])
634
+ # print('I=', I, F[I, :], E.shape)
635
+ # print('E[k, :]=', k, E[k, :])
636
+ # print(k, edgeFace[k, :])
637
+ for u in range(len(I)):
638
+ f = I[u]
639
+ i1l = np.flatnonzero(F[f, :] == E[k, 0])
640
+ i2l = np.flatnonzero(F[f, :] == E[k, 1])
641
+ # print(f, F[f, :])
642
+ # print(i1l, i2l)
643
+ i1 = i1l[0]
644
+ i2 = i2l[0]
645
+ s = i1 + i2
646
+ if s == 1:
647
+ i3 = 2
648
+ elif s == 2:
649
+ i3 = 1
650
+ elif s == 3:
651
+ i3 = 0
652
+ x1 = V[F[f, i1], :] - V[F[f, i3], :]
653
+ x2 = V[F[f, i2], :] - V[F[f, i3], :]
654
+ a = (np.cross(x1, x2) * np.cross(V[F[f, 1], :] - V[F[f, 0], :], V[F[f, 2], :] - V[F[f, 0], :])).sum()
655
+ b = (x1 * x2).sum()
656
+ if (a > 0):
657
+ edgeAngles[k, u] = b / np.sqrt(a)
658
+ else:
659
+ edgeAngles[k, u] = b / np.sqrt(-a)
660
+ if (len(I) == 1):
661
+ # boundary edge
662
+ bEdge[k] = 1
663
+ bVert[E[k, 0]] = 1
664
+ bVert[E[k, 1]] = 1
665
+ edgeAngles[k, 1] = 0
666
+
667
+ # Compute Laplacian matrix
668
+ L = np.zeros([nv, nv])
669
+
670
+ for k in range(ne):
671
+ L[E[k, 0], E[k, 1]] = (edgeAngles[k, 0] + edgeAngles[k, 1]) / 2
672
+ L[E[k, 1], E[k, 0]] = L[E[k, 0], E[k, 1]]
673
+
674
+ for k in range(nv):
675
+ L[k, k] = - L[k, :].sum()
676
+
677
+ A = np.zeros([nv, nv])
678
+ for k in range(nv):
679
+ A[k, k] = AV[k]
680
+
681
+ return L, A
682
+
683
+ def graphLaplacianMatrix(self):
684
+ F = self.faces
685
+ V = self.vertices
686
+ nf = F.shape[0]
687
+ nv = V.shape[0]
688
+
689
+ # compute edges and detect boundary
690
+ # edm = sp.lil_matrix((nv,nv))
691
+ edm = -np.ones([nv, nv])
692
+ E = np.zeros([3 * nf, 2])
693
+ j = 0
694
+ for k in range(nf):
695
+ if (edm[F[k, 0], F[k, 1]] == -1):
696
+ edm[F[k, 0], F[k, 1]] = j
697
+ edm[F[k, 1], F[k, 0]] = j
698
+ E[j, :] = [F[k, 0], F[k, 1]]
699
+ j = j + 1
700
+ if (edm[F[k, 1], F[k, 2]] == -1):
701
+ edm[F[k, 1], F[k, 2]] = j
702
+ edm[F[k, 2], F[k, 1]] = j
703
+ E[j, :] = [F[k, 1], F[k, 2]]
704
+ j = j + 1
705
+ if (edm[F[k, 0], F[k, 2]] == -1):
706
+ edm[F[k, 2], F[k, 0]] = j
707
+ edm[F[k, 0], F[k, 2]] = j
708
+ E[j, :] = [F[k, 2], F[k, 0]]
709
+ j = j + 1
710
+ E = E[0:j, :]
711
+
712
+ edgeFace = np.zeros([j, nf])
713
+ ne = j
714
+ # print(E)
715
+
716
+ # Compute Laplacian matrix
717
+ L = np.zeros([nv, nv])
718
+
719
+ for k in range(ne):
720
+ L[E[k, 0], E[k, 1]] = 1
721
+ L[E[k, 1], E[k, 0]] = 1
722
+
723
+ for k in range(nv):
724
+ L[k, k] = - L[k, :].sum()
725
+
726
+ return L
727
+
728
+ def cotanLaplacian(self, eps=1e-8):
729
+ L = pp3d.cotan_laplacian(self.vertices, self.faces, denom_eps=1e-10)
730
+ massvec_np = pp3d.vertex_areas(self.vertices, self.faces)
731
+ massvec_np += eps * np.mean(massvec_np)
732
+
733
+ if (np.isnan(L.data).any()):
734
+ raise RuntimeError("NaN Laplace matrix")
735
+ if (np.isnan(massvec_np).any()):
736
+ raise RuntimeError("NaN mass matrix")
737
+ self.L = L
738
+ self.massvec_np = massvec_np
739
+ return L, massvec_np
740
+
741
+ def lapEigen(self, n_eig=10, eps=1e-8):
742
+ L_eigsh = (self.L + sp.sparse.identity(self.L.shape[0]) * eps).tocsc()
743
+ massvec_eigsh = self.massvec_np
744
+ Mmat = sp.sparse.diags(massvec_eigsh)
745
+ eigs_sigma = eps#-0.01
746
+ evals, evecs = sla.eigsh(L_eigsh, k=n_eig, M=Mmat, sigma=eigs_sigma)
747
+ return evals, evecs
748
+
749
+ def laplacianSegmentation(self, k, verbose=False):
750
+ # (L, AA) = self.laplacianMatrix()
751
+ # # print((L.shape[0]-k-1, L.shape[0]-2))
752
+ # (D, y) = spLA.eigh(L, AA, eigvals=(L.shape[0] - k, L.shape[0] - 1))
753
+ evals, evecs = self.lapEigen(k*2)
754
+ y = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) #**2
755
+ # N = y.shape[0]
756
+ # d = y.shape[1]
757
+ # I = np.argsort(y.sum(axis=1))
758
+ # I0 = np.floor((N - 1) * sp.linspace(0, 1, num=k)).astype(int)
759
+ # # print(y.shape, L.shape, N, k, d)
760
+ # C = y[I0, :].copy()
761
+ #
762
+ # eps = 1e-20
763
+ # Cold = C.copy()
764
+ # u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
765
+ # T = u.min(axis=0).sum() / (N)
766
+ # # print(T)
767
+ # j = 0
768
+ # while j < 5000:
769
+ # u0 = u - u.min(axis=0).reshape([1, N])
770
+ # w = np.exp(-u0 / T);
771
+ # w = w / (eps + w.sum(axis=0).reshape([1, N]))
772
+ # # print(w.min(), w.max())
773
+ # cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
774
+ # C = np.dot(w, y) / (eps + w.sum(axis=1).reshape([k, 1]))
775
+ # # print(j, 'cost0 ', cost)
776
+ #
777
+ # u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
778
+ # cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
779
+ # err = np.sqrt(((C - Cold) ** 2).sum(axis=1)).sum()
780
+ # # print(j, 'cost ', cost, err, T)
781
+ # if (j > 100) & (err < 1e-4):
782
+ # break
783
+ # j = j + 1
784
+ # Cold = C.copy()
785
+ # T = T * 0.99
786
+ #
787
+ # # print(k, d, C.shape)
788
+ # dst = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
789
+ # md = dst.min(axis=0)
790
+ # idx = np.zeros(N).astype(int)
791
+ # for j in range(N):
792
+ # I = np.flatnonzero(dst[:, j] < md[j] + 1e-10)
793
+ # idx[j] = I[0]
794
+ # I = -np.ones(k).astype(int)
795
+ # kk = 0
796
+ # for j in range(k):
797
+ # if True in (idx == j):
798
+ # I[j] = kk
799
+ # kk += 1
800
+ # idx = I[idx]
801
+ # if idx.max() < (k - 1):
802
+ # print('Warning: kmeans convergence with %d clusters instead of %d' % (idx.max(), k))
803
+ # # ml = w.sum(axis=1)/N
804
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(y)
805
+ # kmeans = DBSCAN().fit(y)
806
+ idx = kmeans.labels_
807
+ nc = idx.max() + 1
808
+ C = np.zeros([nc, self.vertices.shape[1]])
809
+ a, foo = self.computeVertexArea()
810
+ for k in range(nc):
811
+ I = np.flatnonzero(idx == k)
812
+ nI = len(I)
813
+ # print(a.shape, nI)
814
+ aI = a[I]
815
+ ak = aI.sum()
816
+ C[k, :] = (self.vertices[I, :] * aI).sum(axis=0) / ak;
817
+ mean_eigen = (y*a).sum(axis=0)/a
818
+ tree = cKDTree(y)
819
+ _, indices = tree.query(mean_eigen, k=1)
820
+ index_center = indices[0]
821
+ _, indices = tree.query(C, k=1)
822
+ index_C = indices[0]
823
+
824
+ mesh = tm.Trimesh(vertices=self.vertices, faces=self.faces)
825
+ nv = self.computeVertexNormals()
826
+ rori = self.vertices[index] + 0.001 * (-nv[index, :])
827
+ rdir = -nv[index, :]
828
+
829
+ locations, index_ray, index_tri = mesh.ray.intersects_location(ray_origins=rori[None, :], ray_directions=rdir[None, :])
830
+ if verbose:
831
+ print(locations, index_ray, index_tri)
832
+ print(index_center, index_C)
833
+ return idx, C, index_center, locations[0]
834
+
835
+ def get_keypoints(self, n_eig=30, n_points=5, sym=False):
836
+ evals, evecs = self.lapEigen(n_eig+1)
837
+ pts_evecs = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) # approximate geod distance
838
+ tree = cKDTree(pts_evecs)
839
+ _, center_index = tree.query(np.zeros(pts_evecs.shape[-1]), k=1)
840
+ indices = fps_gpt_geod(self.vertices, self.faces, 10, center_index)[1:]
841
+ areas = np.linalg.norm(self.surfel, axis=-1, keepdims=True)
842
+ area = np.sqrt(areas.sum()/2)
843
+ print(area)
844
+ ## Looking for center + provide distances to the center
845
+
846
+ print(center_index)
847
+ solver = pp3d.MeshHeatMethodDistanceSolver(self.vertices, self.faces)
848
+ norm_center = solver.compute_distance(center_index)
849
+
850
+ ## Dist matrix just between selected points
851
+ dist_mat_indices = np.zeros((len(indices), len(indices)))
852
+ all_ii = np.arange(len(indices)) # just to easy code
853
+ for ii, index in enumerate(indices):
854
+ dist_mat_indices[ii, ii!=all_ii] = solver.compute_distance(index)[indices[indices != index]]
855
+
856
+ ## Select points which : farthest from the center, delete their neighbors with dist < 0.5*area
857
+ keep = []
858
+ max_norm = norm_center.max()
859
+ for ii, index in enumerate(indices[:n_points]):
860
+ # distances_compar = (dist_mat_indices[ii, ii!=all_ii] + norm_center[ii]) / norm_center[indices[ii!=all_ii]]
861
+ # min_dist = np.amin(np.abs(distances_compar - 1))
862
+ # print(min_dist/norm_center[ii], min_dist)
863
+ # if min_dist > 0.3:
864
+ # #print(ii, min_dist/norm_center[ii], norm_center[ii], dist_mat_indices[ii, ii!=all_ii], norm_center[indices[indices != index]], indices[7])
865
+ keep.append(index)
866
+ ## Add the index of the center
867
+ ## Add the index of the center
868
+ if sym:
869
+ mesh = tm.base.Trimesh(self.vertices, self.faces, process=False) # Load your mesh
870
+ # Given data
871
+ index = 123 # Your starting point index
872
+ direction = -mesh.vertex_normals[center_index]
873
+
874
+ # Get the starting vertex position
875
+ start_point = mesh.vertices[center_index]
876
+
877
+ # Perform ray intersection
878
+ locations, index_ray, index_tri = mesh.ray.intersects_location(
879
+ ray_origins=start_point[None, :], # Starting point
880
+ ray_directions=direction[None, :] # Direction of travel
881
+ )
882
+
883
+ # If intersections exist
884
+ if len(locations) > 0:
885
+ # Sort intersections by distance from start_point
886
+ distances = np.linalg.norm(locations - start_point, axis=1)
887
+ sorted_indices = np.argsort(distances)
888
+
889
+ # Get the first valid intersection point beyond the start point
890
+ #next_intersection = locations[sorted_indices]
891
+ #print(f"Next intersection at: {next_intersection}")
892
+ intersected_triangle = index_tri[sorted_indices[1]]
893
+ print(f"Intersected Triangle Index: {mesh.faces[intersected_triangle]}")
894
+ center_index = mesh.faces[intersected_triangle][0]
895
+ else:
896
+ print("No intersection found in the given direction.")
897
+ keep.append(center_index)
898
+ return keep
899
+
900
+
901
+ # Computes surface volume
902
+ def surfVolume(self):
903
+ f = self.faces
904
+ v = self.vertices
905
+ t = v[f, :]
906
+ vols = np.linalg.det(t) / 6
907
+ return vols.sum(), vols
908
+
909
+ def surfCenter(self):
910
+ f = self.faces
911
+ v = self.vertices
912
+ center_infs = (v[f, :].sum(axis=1) / 4) * self.vols[:, np.newaxis]
913
+ center = center_infs.sum(axis=0)
914
+ return center
915
+
916
+ def surfMoments(self):
917
+ f = self.faces
918
+ v = self.vertices
919
+ vec_0 = v[f[:, 0], :] + v[f[:, 1], :]
920
+ s_0 = vec_0[:, :, np.newaxis] * vec_0[:, np.newaxis, :]
921
+ vec_1 = v[f[:, 0], :] + v[f[:, 2], :]
922
+ s_1 = vec_1[:, :, np.newaxis] * vec_1[:, np.newaxis, :]
923
+ vec_2 = v[f[:, 1], :] + v[f[:, 2], :]
924
+ s_2 = vec_2[:, :, np.newaxis] * vec_2[:, np.newaxis, :]
925
+ moments_inf = self.vols[:, np.newaxis, np.newaxis] * (1. / 20) * (s_0 + s_1 + s_2)
926
+ return moments_inf.sum(axis=0)
927
+
928
+ def surfF(self):
929
+ f = self.faces
930
+ v = self.vertices
931
+ cent = (v[f, :].sum(axis=1)) / 4.
932
+ F = self.vols[:, np.newaxis] * np.sign(cent) * (cent ** 2)
933
+ return F.sum(axis=0)
934
+
935
+ def surfU(self):
936
+ vol = self.volume
937
+ vertices = self.vertices / pow(vol, 1. / 3)
938
+ vertices -= self.center
939
+ self.updateVertices(vertices)
940
+ moments = self.surfMoments()
941
+ u, s, vh = np.linalg.svd(moments)
942
+ F = self.surfF()
943
+ return np.diag(np.sign(F)) @ u, s
944
+
945
+ def surfEllipsoid(self, u, s, moments):
946
+ coeff = pow(4 * np.pi / 15, 1. / 5) * pow(np.linalg.det(moments), -1. / 10)
947
+ A = coeff * ((u * np.sqrt(s)) @ u.T)
948
+ return u, A
949
+
950
+ # Reads from .off file
951
+ def readOFF(self, offfile):
952
+ with open(offfile, 'r') as f:
953
+ all_lines = f.readlines()
954
+
955
+
956
+ n_vertices = int(all_lines[1].split()[0])
957
+ n_faces = int(all_lines[1].split()[1])
958
+
959
+ vertices_list = []
960
+ for i in range(n_vertices):
961
+ vertices_list.append([float(x) for x in all_lines[2+i].split()[:3]])
962
+
963
+ faces_list = []
964
+ for i in range(n_faces):
965
+ # Be careful to convert to int. Otherwise, you can use np.array(faces_list).astype(np.int32)
966
+ faces_list.append([int(x) for x in all_lines[2+i+n_vertices].split()[1:4]])
967
+ self.faces = np.array(faces_list)
968
+ self.vertices = np.array(vertices_list)
969
+ self.computeCentersAreas() # Reads from .byu file
970
+ def readbyu(self, byufile):
971
+ with open(byufile, 'r') as fbyu:
972
+ ln0 = fbyu.readline()
973
+ ln = ln0.split()
974
+ # read header
975
+ ncomponents = int(ln[0]) # number of components
976
+ npoints = int(ln[1]) # number of vertices
977
+ nfaces = int(ln[2]) # number of faces
978
+ # fscanf(fbyu,'%d',1); % number of edges
979
+ # %ntest = fscanf(fbyu,'%d',1); % number of edges
980
+ for k in range(ncomponents):
981
+ fbyu.readline() # components (ignored)
982
+ # read data
983
+ self.vertices = np.empty([npoints, 3])
984
+ k = -1
985
+ while k < npoints - 1:
986
+ ln = fbyu.readline().split()
987
+ k = k + 1;
988
+ self.vertices[k, 0] = float(ln[0])
989
+ self.vertices[k, 1] = float(ln[1])
990
+ self.vertices[k, 2] = float(ln[2])
991
+ if len(ln) > 3:
992
+ k = k + 1;
993
+ self.vertices[k, 0] = float(ln[3])
994
+ self.vertices[k, 1] = float(ln[4])
995
+ self.vertices[k, 2] = float(ln[5])
996
+
997
+ self.faces = np.empty([nfaces, 3])
998
+ ln = fbyu.readline().split()
999
+ kf = 0
1000
+ j = 0
1001
+ while ln:
1002
+ if kf >= nfaces:
1003
+ break
1004
+ # print(nfaces, kf, ln)
1005
+ for s in ln:
1006
+ self.faces[kf, j] = int(sp.fabs(int(s)))
1007
+ j = j + 1
1008
+ if j == 3:
1009
+ kf = kf + 1
1010
+ j = 0
1011
+ ln = fbyu.readline().split()
1012
+ self.faces = np.int_(self.faces) - 1
1013
+ xDef1 = self.vertices[self.faces[:, 0], :]
1014
+ xDef2 = self.vertices[self.faces[:, 1], :]
1015
+ xDef3 = self.vertices[self.faces[:, 2], :]
1016
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
1017
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
1018
+
1019
+ # Saves in .byu format
1020
+ def savebyu(self, byufile):
1021
+ # FV = readbyu(byufile)
1022
+ # reads from a .byu file into matlab's face vertex structure FV
1023
+
1024
+ with open(byufile, 'w') as fbyu:
1025
+ # copy header
1026
+ ncomponents = 1 # number of components
1027
+ npoints = self.vertices.shape[0] # number of vertices
1028
+ nfaces = self.faces.shape[0] # number of faces
1029
+ nedges = 3 * nfaces # number of edges
1030
+
1031
+ str = '{0: d} {1: d} {2: d} {3: d} 0\n'.format(ncomponents, npoints, nfaces, nedges)
1032
+ fbyu.write(str)
1033
+ str = '1 {0: d}\n'.format(nfaces)
1034
+ fbyu.write(str)
1035
+
1036
+ k = -1
1037
+ while k < (npoints - 1):
1038
+ k = k + 1
1039
+ str = '{0: f} {1: f} {2: f} '.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
1040
+ fbyu.write(str)
1041
+ if k < (npoints - 1):
1042
+ k = k + 1
1043
+ str = '{0: f} {1: f} {2: f}\n'.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
1044
+ fbyu.write(str)
1045
+ else:
1046
+ fbyu.write('\n')
1047
+
1048
+ j = 0
1049
+ for k in range(nfaces):
1050
+ for kk in (0, 1):
1051
+ fbyu.write('{0: d} '.format(self.faces[k, kk] + 1))
1052
+ j = j + 1
1053
+ if j == 16:
1054
+ fbyu.write('\n')
1055
+ j = 0
1056
+
1057
+ fbyu.write('{0: d} '.format(-self.faces[k, 2] - 1))
1058
+ j = j + 1
1059
+ if j == 16:
1060
+ fbyu.write('\n')
1061
+ j = 0
1062
+
1063
+ def saveVTK(self, fileName, scalars=None, normals=None, tensors=None, scal_name='scalars', vectors=None,
1064
+ vect_name='vectors'):
1065
+ vf = vtkFields()
1066
+ # print(scalars)
1067
+ if not (scalars == None):
1068
+ vf.scalars.append(scal_name)
1069
+ vf.scalars.append(scalars)
1070
+ if not (vectors == None):
1071
+ vf.vectors.append(vect_name)
1072
+ vf.vectors.append(vectors)
1073
+ if not (normals == None):
1074
+ vf.normals.append('normals')
1075
+ vf.normals.append(normals)
1076
+ if not (tensors == None):
1077
+ vf.tensors.append('tensors')
1078
+ vf.tensors.append(tensors)
1079
+ self.saveVTK2(fileName, vf)
1080
+
1081
+ # Saves in .vtk format
1082
+ def saveVTK2(self, fileName, vtkFields=None):
1083
+ F = self.faces;
1084
+ V = self.vertices;
1085
+
1086
+ with open(fileName, 'w') as fvtkout:
1087
+ fvtkout.write('# vtk DataFile Version 3.0\nSurface Data\nASCII\nDATASET POLYDATA\n')
1088
+ fvtkout.write('\nPOINTS {0: d} float'.format(V.shape[0]))
1089
+ for ll in range(V.shape[0]):
1090
+ fvtkout.write('\n{0: f} {1: f} {2: f}'.format(V[ll, 0], V[ll, 1], V[ll, 2]))
1091
+ fvtkout.write('\nPOLYGONS {0:d} {1:d}'.format(F.shape[0], 4 * F.shape[0]))
1092
+ for ll in range(F.shape[0]):
1093
+ fvtkout.write('\n3 {0: d} {1: d} {2: d}'.format(F[ll, 0], F[ll, 1], F[ll, 2]))
1094
+ if not (vtkFields == None):
1095
+ wrote_pd_hdr = False
1096
+ if len(vtkFields.scalars) > 0:
1097
+ if not wrote_pd_hdr:
1098
+ fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
1099
+ wrote_pd_hdr = True
1100
+ nf = len(vtkFields.scalars) / 2
1101
+ for k in range(nf):
1102
+ fvtkout.write('\nSCALARS ' + vtkFields.scalars[2 * k] + ' float 1\nLOOKUP_TABLE default')
1103
+ for ll in range(V.shape[0]):
1104
+ # print(scalars[ll])
1105
+ fvtkout.write('\n {0: .5f}'.format(vtkFields.scalars[2 * k + 1][ll]))
1106
+ if len(vtkFields.vectors) > 0:
1107
+ if not wrote_pd_hdr:
1108
+ fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
1109
+ wrote_pd_hdr = True
1110
+ nf = len(vtkFields.vectors) / 2
1111
+ for k in range(nf):
1112
+ fvtkout.write('\nVECTORS ' + vtkFields.vectors[2 * k] + ' float')
1113
+ vectors = vtkFields.vectors[2 * k + 1]
1114
+ for ll in range(V.shape[0]):
1115
+ fvtkout.write(
1116
+ '\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
1117
+ if len(vtkFields.normals) > 0:
1118
+ if not wrote_pd_hdr:
1119
+ fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
1120
+ wrote_pd_hdr = True
1121
+ nf = len(vtkFields.normals) / 2
1122
+ for k in range(nf):
1123
+ fvtkout.write('\nNORMALS ' + vtkFields.normals[2 * k] + ' float')
1124
+ vectors = vtkFields.normals[2 * k + 1]
1125
+ for ll in range(V.shape[0]):
1126
+ fvtkout.write(
1127
+ '\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
1128
+ if len(vtkFields.tensors) > 0:
1129
+ if not wrote_pd_hdr:
1130
+ fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
1131
+ wrote_pd_hdr = True
1132
+ nf = len(vtkFields.tensors) / 2
1133
+ for k in range(nf):
1134
+ fvtkout.write('\nTENSORS ' + vtkFields.tensors[2 * k] + ' float')
1135
+ tensors = vtkFields.tensors[2 * k + 1]
1136
+ for ll in range(V.shape[0]):
1137
+ for kk in range(2):
1138
+ fvtkout.write(
1139
+ '\n {0: .5f} {1: .5f} {2: .5f}'.format(tensors[ll, kk, 0], tensors[ll, kk, 1],
1140
+ tensors[ll, kk, 2]))
1141
+ fvtkout.write('\n')
1142
+
1143
+ # Reads .vtk file
1144
+ def readVTK(self, fileName):
1145
+ if gotVTK:
1146
+ u = vtkPolyDataReader()
1147
+ u.SetFileName(fileName)
1148
+ u.Update()
1149
+ v = u.GetOutput()
1150
+ # print(v)
1151
+ npoints = int(v.GetNumberOfPoints())
1152
+ nfaces = int(v.GetNumberOfPolys())
1153
+ V = np.zeros([npoints, 3])
1154
+ for kk in range(npoints):
1155
+ V[kk, :] = np.array(v.GetPoint(kk))
1156
+
1157
+ F = np.zeros([nfaces, 3])
1158
+ for kk in range(nfaces):
1159
+ c = v.GetCell(kk)
1160
+ for ll in range(3):
1161
+ F[kk, ll] = c.GetPointId(ll)
1162
+
1163
+ self.vertices = V
1164
+ self.faces = np.int_(F)
1165
+ xDef1 = self.vertices[self.faces[:, 0], :]
1166
+ xDef2 = self.vertices[self.faces[:, 1], :]
1167
+ xDef3 = self.vertices[self.faces[:, 2], :]
1168
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
1169
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
1170
+ else:
1171
+ raise Exception('Cannot run readVTK without VTK')
1172
+
1173
+ # Reads .obj file
1174
+ def readOBJ(self, fileName):
1175
+ if gotVTK:
1176
+ u = vtkOBJReader()
1177
+ u.SetFileName(fileName)
1178
+ u.Update()
1179
+ v = u.GetOutput()
1180
+ # print(v)
1181
+ npoints = int(v.GetNumberOfPoints())
1182
+ nfaces = int(v.GetNumberOfPolys())
1183
+ V = np.zeros([npoints, 3])
1184
+ for kk in range(npoints):
1185
+ V[kk, :] = np.array(v.GetPoint(kk))
1186
+
1187
+ F = np.zeros([nfaces, 3])
1188
+ for kk in range(nfaces):
1189
+ c = v.GetCell(kk)
1190
+ for ll in range(3):
1191
+ F[kk, ll] = c.GetPointId(ll)
1192
+
1193
+ self.vertices = V
1194
+ self.faces = np.int_(F)
1195
+ xDef1 = self.vertices[self.faces[:, 0], :]
1196
+ xDef2 = self.vertices[self.faces[:, 1], :]
1197
+ xDef3 = self.vertices[self.faces[:, 2], :]
1198
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
1199
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
1200
+ else:
1201
+ raise Exception('Cannot run readOBJ without VTK')
1202
+
1203
+ def readPLY(self, fileName):
1204
+ if gotVTK:
1205
+ u = vtkPLYReader()
1206
+ u.SetFileName(fileName)
1207
+ u.Update()
1208
+ v = u.GetOutput()
1209
+ # print(v)
1210
+ npoints = int(v.GetNumberOfPoints())
1211
+ nfaces = int(v.GetNumberOfPolys())
1212
+ V = np.zeros([npoints, 3])
1213
+ for kk in range(npoints):
1214
+ V[kk, :] = np.array(v.GetPoint(kk))
1215
+
1216
+ F = np.zeros([nfaces, 3])
1217
+ for kk in range(nfaces):
1218
+ c = v.GetCell(kk)
1219
+ for ll in range(3):
1220
+ F[kk, ll] = c.GetPointId(ll)
1221
+
1222
+ self.vertices = V
1223
+ self.faces = np.int_(F)
1224
+ xDef1 = self.vertices[self.faces[:, 0], :]
1225
+ xDef2 = self.vertices[self.faces[:, 1], :]
1226
+ xDef3 = self.vertices[self.faces[:, 2], :]
1227
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
1228
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
1229
+ else:
1230
+ raise Exception('Cannot run readPLY without VTK')
1231
+
1232
+ def readTRI(self, fileName):
1233
+ vertices = []
1234
+ faces = []
1235
+ with open(fileName, "r") as f:
1236
+ pithc = f.readlines()
1237
+ line_1 = pithc[0]
1238
+ n_pts, n_tri = line_1.split(' ')
1239
+ for i, line in enumerate(pithc[1:]):
1240
+ pouetage = line.split(' ')
1241
+ if i <= int(n_pts):
1242
+ vertices.append([float(poupout) for poupout in pouetage[:3]])
1243
+ # vertices.append([float(poupout) for poupout in pouetage[4:7]])
1244
+ else:
1245
+ faces.append([int(poupout.split("\n")[0]) for poupout in pouetage[1:]])
1246
+ self.vertices = np.array(vertices)
1247
+ self.faces = np.array(faces)
1248
+ xDef1 = self.vertices[self.faces[:, 0], :]
1249
+ xDef2 = self.vertices[self.faces[:, 1], :]
1250
+ xDef3 = self.vertices[self.faces[:, 2], :]
1251
+ self.centers = (xDef1 + xDef2 + xDef3) / 3
1252
+ self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
1253
+
1254
+ def concatenate(self, fvl):
1255
+ nv = 0
1256
+ nf = 0
1257
+ for fv in fvl:
1258
+ nv += fv.vertices.shape[0]
1259
+ nf += fv.faces.shape[0]
1260
+ self.vertices = np.zeros([nv, 3])
1261
+ self.faces = np.zeros([nf, 3], dtype='int')
1262
+
1263
+ nv0 = 0
1264
+ nf0 = 0
1265
+ for fv in fvl:
1266
+ nv = nv0 + fv.vertices.shape[0]
1267
+ nf = nf0 + fv.faces.shape[0]
1268
+ self.vertices[nv0:nv, :] = fv.vertices
1269
+ self.faces[nf0:nf, :] = fv.faces + nv0
1270
+ nv0 = nv
1271
+ nf0 = nf
1272
+ self.computeCentersAreas()
1273
+
1274
+ def get_surf_area(surf):
1275
+ areas = np.linalg.norm(surf.surfel, axis=-1)
1276
+ return areas.sum()/2
1277
+
1278
+ def centroid(surf):
1279
+ areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
1280
+ center = (surf.centers * areas).sum(axis=0) / areas.sum()
1281
+ return center, np.sqrt(areas.sum()/2)
1282
+
1283
+
1284
+ def do_bbox_vertices(vertices):
1285
+ new_verts = np.zeros(vertices.shape)
1286
+ new_verts[:, 0] = (vertices[:, 0] - np.amin(vertices[:, 0]))/(np.amax(vertices[:, 0]) - np.amin(vertices[:, 0]))
1287
+ new_verts[:, 1] = (vertices[:, 1] - np.amin(vertices[:, 1]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 1]))
1288
+ new_verts[:, 2] = (vertices[:, 2] - np.amin(vertices[:, 2]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 2]))*0.5
1289
+ return new_verts
1290
+
1291
+ def opt_rot_surf(surf_1, surf_2, areas_c):
1292
+ center_1, _ = centroid(surf_1)
1293
+ pts_c1 = surf_1.centers - center_1
1294
+ center_2, _ = centroid(surf_2)
1295
+ pts_c2 = surf_2.centers - center_2
1296
+ to_sum = pts_c1[:, :, np.newaxis] * pts_c2[:, np.newaxis, :]
1297
+ A = (to_sum * areas_c[:, :, np.newaxis]).sum(axis=0)
1298
+ #A = to_sum.sum(axis=0)
1299
+ u, _, v = np.linalg.svd(A)
1300
+ a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, np.sign(np.linalg.det(A))]])
1301
+ O = u @ a @ v
1302
+ return O.T
1303
+
1304
+ def srnf(surf):
1305
+ areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True) + 1e-8
1306
+ return surf.surfel/np.sqrt(areas)
1307
+
1308
+ def fps_gpt(points, num_samples):
1309
+ """
1310
+ Selects a subset of points using the Farthest Point Sampling algorithm.
1311
+
1312
+ :param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
1313
+ :param num_samples: number of points to select
1314
+ :return: numpy array of indices of the selected points
1315
+ """
1316
+ if num_samples > len(points):
1317
+ raise ValueError("num_samples must be less than or equal to the number of points")
1318
+
1319
+ # Initialize an array to hold the indices of the selected points
1320
+ selected_indices = np.zeros(num_samples, dtype=int)
1321
+
1322
+ center = np.mean(points, axis=0)
1323
+
1324
+ # Array to store the shortest distance of each point to a selected poin
1325
+ shortest_distances = np.full(len(points), np.inf)
1326
+ last_selected_point = center
1327
+ for i in range(1, num_samples):
1328
+
1329
+
1330
+ # Update the shortest distance of each point to the selected points
1331
+ for j, point in enumerate(points):
1332
+ distance = np.linalg.norm(point - last_selected_point)
1333
+ shortest_distances[j] = min(shortest_distances[j], distance)
1334
+
1335
+ # Select the point that is farthest from all selected points
1336
+ selected_indices[i] = np.argmax(shortest_distances)
1337
+ last_selected_point = points[selected_indices[i]]
1338
+ return selected_indices
1339
+
1340
+
1341
+ def fps_gpt_geod(points, faces, num_samples, index_start):
1342
+ """
1343
+ Selects a subset of points using the Farthest Point Sampling algorithm.
1344
+
1345
+ :param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
1346
+ :param num_samples: number of points to select
1347
+ :return: numpy array of indices of the selected points
1348
+ """
1349
+ if num_samples > len(points):
1350
+ raise ValueError("num_samples must be less than or equal to the number of points")
1351
+
1352
+ # Initialize an array to hold the indices of the selected points
1353
+ selected_indices = np.zeros(num_samples, dtype=int)
1354
+
1355
+ # Array to store the shortest distance of each point to a selected poin
1356
+ shortest_distances = np.full(len(points), np.inf)
1357
+ selected_indices[0] = index_start
1358
+ distances = np.zeros((num_samples, len(points)))
1359
+ solver = pp3d.MeshHeatMethodDistanceSolver(points, faces)
1360
+ for i in range(1, num_samples):
1361
+ distances[i] = (solver.compute_distance(selected_indices[i-1]))
1362
+
1363
+ # Update the shortest distance of each point to the selected points
1364
+ for j, point in enumerate(points):
1365
+ distance = distances[i][j]
1366
+ shortest_distances[j] = min(shortest_distances[j], distance)
1367
+
1368
+ # Select the point that is farthest from all selected points
1369
+ selected_indices[i] = np.argmax(shortest_distances)
1370
+ return selected_indices
1371
+
1372
+ # class MultiSurface:
1373
+ # def __init__(self, pattern):
1374
+ # self.surf = []
1375
+ # files = glob.glob(pattern)
1376
+ # for f in files:
1377
+ # self.surf.append(Surface(filename=f))
utils/torch_fmap.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from pykeops.torch import LazyTensor
4
+
5
+
6
+ def euclidean_dist(x, y):
7
+ """
8
+ Args:
9
+ x: pytorch Variable, with shape [m, d]
10
+ y: pytorch Variable, with shape [n, d]
11
+ Returns:
12
+ dist: pytorch Variable, with shape [m, n]
13
+ """
14
+ #bs, m, n = x.size(0), x.size(1), y.size(1)
15
+ xx = torch.pow(x.squeeze(), 2).sum(1, keepdim=True)
16
+ yy = torch.pow(y.squeeze(), 2).sum(1, keepdim=True).t()
17
+ dist = xx + yy - 2 * torch.inner(x.squeeze(), y.squeeze())
18
+ dist = dist.clamp(min=1e-12).sqrt()
19
+ return dist
20
+
21
+ def knnsearch(x, y, alpha=1./0.07, prod=False):
22
+ if prod:
23
+ prods = torch.inner(x.squeeze(), y.squeeze())#/( torch.norm(x.squeeze(), dim=-1)[:, None]*torch.norm(y.squeeze(), dim=-1)[None, :])
24
+ output = F.softmax(alpha*prods, dim=1)
25
+ else:
26
+ distance = euclidean_dist(x, y[None,:])
27
+ output = F.softmax(-alpha*distance, dim=1)
28
+ return output.squeeze()
29
+
30
+ def extract_p2p_torch(reps_shape, reps_template):
31
+ n_ev = reps_shape.shape[-1]
32
+ with torch.no_grad():
33
+ # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
34
+ # print(evecs1_dzo.shape)
35
+ reps_shape_torch = torch.from_numpy(reps_shape).float().cuda()
36
+ G_i = LazyTensor(reps_shape_torch[:, None, :].contiguous()) # (M**2, 1, 2)
37
+ reps_template_torch = torch.from_numpy(reps_template).float().cuda()
38
+ X_j = LazyTensor(reps_template_torch[None, :, :n_ev].contiguous()) # (1, N, 2)
39
+ D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
40
+ indKNN = D_ij.argKmin(1, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
41
+ # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
42
+ # print(indKNN[:10], pmap10_ref[:10])
43
+ indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
44
+ return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
45
+
46
+ def extract_p2p_torch_fmap(fmap_shape_template, evecs_shape, evecs_template):
47
+ n_ev = fmap_shape_template.shape[-1]
48
+ with torch.no_grad():
49
+ # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
50
+ # print(evecs1_dzo.shape)
51
+ G_i = LazyTensor((evecs_shape[:, :n_ev] @ fmap_shape_template.squeeze().T)[:, None, :].contiguous()) # (M**2, 1, 2)
52
+ X_j = LazyTensor(evecs_template[None, :, :n_ev].contiguous()) # (1, N, 2)
53
+ D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
54
+ indKNN = D_ij.argKmin(1, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
55
+ # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
56
+ # print(indKNN[:10], pmap10_ref[:10])
57
+ indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
58
+ return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
59
+
60
+ def wlstsq(A, B, w):
61
+ if w is None:
62
+ return torch.linalg.lstsq(A, B).solution
63
+ else:
64
+ assert w.dim() + 1 == A.dim() and w.shape[-1] == A.shape[-2]
65
+ W = torch.diag_embed(w)
66
+ return torch.linalg.lstsq(W @ A, W @ B).solution
67
+
68
+ def torch_zoomout(evecs0, evecs1, evecs_1_trans, fmap01, target_size, step=1):
69
+ assert fmap01.shape[-2] == fmap01.shape[-1], f"square fmap needed, got {fmap01.shape[-2]} and {fmap01.shape[-1]}"
70
+ fs = fmap01.shape[0]
71
+ for i in range(fs, target_size+1, step):
72
+ indKNN, _ = extract_p2p_torch_fmap(fmap01, evecs0, evecs1)
73
+ #fmap01 = wlstsq(evecs1[..., :i], evecs0[indKNN, :i], None)
74
+ fmap01 = evecs_1_trans[:i, :] @ evecs0[indKNN, :i]
75
+ if fmap01.shape[0] < target_size:
76
+ fmap01 = evecs_1_trans[:target_size, :] @ evecs0[indKNN, :target_size]
77
+ return fmap01
utils/utils_func.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import scipy.sparse as sp
4
+ import shutil
5
+ from pathlib import Path
6
+ import torch
7
+ import numpy as np
8
+ import re
9
+ import os
10
+ import requests
11
+
12
+ def ensure_pretrained_file(hf_url: str, save_dir: str = "pretrained", filename: str = "pretrained.pkl", token: str = None):
13
+ """
14
+ Ensure that a pretrained file exists locally.
15
+ If the folder is empty, download from Hugging Face.
16
+
17
+ Args:
18
+ hf_url (str): Hugging Face file URL (resolve/main/...).
19
+ save_dir (str): Directory to store pretrained file.
20
+ filename (str): Name of file to save.
21
+ token (str): Optional Hugging Face token (for gated/private repos).
22
+ """
23
+ os.makedirs(save_dir, exist_ok=True)
24
+ save_path = os.path.join(save_dir, filename)
25
+
26
+ if os.path.exists(save_path):
27
+ print(f"✅ Found pretrained file: {save_path}")
28
+ return save_path
29
+
30
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
31
+
32
+ print(f"⬇️ Downloading pretrained file from {hf_url} to {save_path} ...")
33
+ response = requests.get(hf_url, headers=headers, stream=True)
34
+ response.raise_for_status()
35
+
36
+ with open(save_path, "wb") as f:
37
+ for chunk in response.iter_content(chunk_size=8192):
38
+ f.write(chunk)
39
+
40
+ print("✅ Download complete.")
41
+ return save_path
42
+
43
+
44
+ def may_create_folder(folder_path):
45
+ if not osp.exists(folder_path):
46
+ oldmask = os.umask(000)
47
+ os.makedirs(folder_path, mode=0o777)
48
+ os.umask(oldmask)
49
+ return True
50
+ return False
51
+
52
+
53
+ def make_clean_folder(folder_path):
54
+ success = may_create_folder(folder_path)
55
+ if not success:
56
+ shutil.rmtree(folder_path)
57
+ may_create_folder(folder_path)
58
+
59
+ def str_delta(delta):
60
+ s = delta.total_seconds()
61
+ hours, remainder = divmod(s, 3600)
62
+ minutes, seconds = divmod(remainder, 60)
63
+ return '{:02}:{:02}:{:02}'.format(int(hours), int(minutes), int(seconds))
64
+
65
+ def desc_from_config(config, world_size):
66
+ cond_str = 'uncond' #'cond' if c.dataset_kwargs.use_labels else 'uncond'
67
+ dtype_str = 'fp16' if config["perfs"]["fp16"] else 'fp32'
68
+ name_data = config["data"]["name"]
69
+ if "abs" in config["data"]:
70
+ name_data += "abs" if config["data"]["abs"] else ""
71
+ desc = (f'{name_data:s}-{cond_str:s}-{config["architecture"]["model"]:s}-'
72
+ f'gpus{world_size:d}-batch{config["hyper_params"]["batch_size"]:d}-{dtype_str:s}')
73
+ if config["misc"]["precond"]:
74
+ desc += f'-{config["misc"]["precond"]:s}'
75
+
76
+ if "desc" in config["misc"]:
77
+ if config["misc"]["desc"] is not None:
78
+ desc += f'-{config["misc"]["desc"]}'
79
+ return desc
80
+
81
+ def get_dataset_path(config):
82
+ name_exp = config["data"]["name"]
83
+ if "add_name" in config:
84
+ if config["add_name"]["do"]:
85
+ name_exp += "_" + config["add_name"]["name"]
86
+ dataset_path = os.path.join(config["data"]["root_dir"], name_exp)
87
+ return name_exp, dataset_path
88
+
89
+ def sparse_np_to_torch(A):
90
+ Acoo = A.tocoo()
91
+ values = Acoo.data
92
+ indices = np.vstack((Acoo.row, Acoo.col))
93
+ shape = Acoo.shape
94
+ return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
95
+
96
+ def convert_dict(np_dict, device="cpu"):
97
+ torch_dict = {}
98
+ for k, value in np_dict.items():
99
+ if sp.issparse(value):
100
+ torch_dict[k] = sparse_np_to_torch(value).to(device)
101
+ if torch_dict[k].dtype == torch.int32:
102
+ torch_dict[k] = torch_dict[k].long().to(device)
103
+ elif isinstance(value, np.ndarray):
104
+ torch_dict[k] = torch.from_numpy(value).to(device)
105
+ if torch_dict[k].dtype == torch.int32:
106
+ torch_dict[k] = torch_dict[k].squeeze().long().to(device)
107
+ else:
108
+ torch_dict[k] = value
109
+ return torch_dict
110
+
111
+
112
+ def convert_dict_torch(in_dict, device="cpu"):
113
+ torch_dict = {}
114
+ for k, value in in_dict.items():
115
+ if isinstance(value, torch.Tensor):
116
+ torch_dict[k] = in_dict[k].to(device)
117
+ return torch_dict
118
+
119
+ def batchify_dict(torch_dict):
120
+ for k, value in torch_dict.items():
121
+ if isinstance(value, torch.Tensor):
122
+ if torch_dict[k].dtype != torch.int64:
123
+ torch_dict[k] = torch_dict[k].unsqueeze(0)
utils/utils_legacy.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import hashlib
7
+ import numpy as np
8
+ import scipy
9
+
10
+ def read_lines(file_path):
11
+ with open(file_path, 'r') as fin:
12
+ lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
13
+ return lines
14
+
15
+ # == Pytorch things
16
+
17
+
18
+ def toNP(x):
19
+ """
20
+ Really, definitely convert a torch tensor to a numpy array
21
+ """
22
+ return x.detach().to(torch.device('cpu')).numpy()
23
+
24
+
25
+ def label_smoothing_log_loss(pred, labels, smoothing=0.0):
26
+ n_class = pred.shape[-1]
27
+ one_hot = torch.zeros_like(pred)
28
+ one_hot[labels] = 1.
29
+ one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
30
+ loss = -(one_hot * pred).sum(dim=-1).mean()
31
+ return loss
32
+
33
+
34
+ # Randomly rotate points.
35
+ # Torch in, torch out
36
+ # Note fornow, builds rotation matrix on CPU.
37
+ def random_rotate_points(pts, randgen=None):
38
+ R = random_rotation_matrix(randgen)
39
+ R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
40
+ return torch.matmul(pts, R)
41
+
42
+
43
+ def random_rotate_points_y(pts):
44
+ angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
45
+ rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
46
+ rot_mats[0, 0] = torch.cos(angles)
47
+ rot_mats[0, 2] = torch.sin(angles)
48
+ rot_mats[2, 0] = -torch.sin(angles)
49
+ rot_mats[2, 2] = torch.cos(angles)
50
+ rot_mats[1, 1] = 1.
51
+
52
+ pts = torch.matmul(pts, rot_mats)
53
+ return pts
54
+
55
+
56
+ # Numpy things
57
+
58
+
59
+ # Numpy sparse matrix to pytorch
60
+ def sparse_np_to_torch(A):
61
+ Acoo = A.tocoo()
62
+ values = Acoo.data
63
+ indices = np.vstack((Acoo.row, Acoo.col))
64
+ shape = Acoo.shape
65
+ return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
66
+
67
+
68
+ # Pytorch sparse to numpy csc matrix
69
+ def sparse_torch_to_np(A):
70
+ if len(A.shape) != 2:
71
+ raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
72
+
73
+ indices = toNP(A.indices())
74
+ values = toNP(A.values())
75
+
76
+ mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
77
+
78
+ return mat
79
+
80
+
81
+ # Hash a list of numpy arrays
82
+ def hash_arrays(arrs):
83
+ running_hash = hashlib.sha1()
84
+ for arr in arrs:
85
+ binarr = arr.view(np.uint8)
86
+ running_hash.update(binarr)
87
+ return running_hash.hexdigest()
88
+
89
+
90
+ def random_rotation_matrix(randgen=None):
91
+ """
92
+ Creates a random rotation matrix.
93
+ randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
94
+ """
95
+ # adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
96
+
97
+ if randgen is None:
98
+ randgen = np.random.RandomState()
99
+
100
+ theta, phi, z = tuple(randgen.rand(3).tolist())
101
+
102
+ theta = theta * 2.0 * np.pi # Rotation about the pole (Z).
103
+ phi = phi * 2.0 * np.pi # For direction of pole deflection.
104
+ z = z * 2.0 # For magnitude of pole deflection.
105
+
106
+ # Compute a vector V used for distributing points over the sphere
107
+ # via the reflection I - V Transpose(V). This formulation of V
108
+ # will guarantee that if x[1] and x[2] are uniformly distributed,
109
+ # the reflected points will be uniform on the sphere. Note that V
110
+ # has length sqrt(2) to eliminate the 2 in the Householder matrix.
111
+
112
+ r = np.sqrt(z)
113
+ Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
114
+
115
+ st = np.sin(theta)
116
+ ct = np.cos(theta)
117
+
118
+ R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
119
+ # Construct the rotation matrix ( V Transpose(V) - I ) R.
120
+
121
+ M = (np.outer(V, V) - np.eye(3)).dot(R)
122
+ return M
123
+
124
+
125
+ # Python string/file utilities
126
+ def ensure_dir_exists(d):
127
+ if not os.path.exists(d):
128
+ os.makedirs(d)
129
+
130
+
zero_shot.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import scipy
4
+ import os
5
+ import sys
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import time
10
+ from datetime import datetime
11
+ import importlib
12
+ import json
13
+ import argparse
14
+ from omegaconf import OmegaConf
15
+ from snk.loss import PrismRegularizationLoss
16
+ from snk.prism_decoder import PrismDecoder
17
+ from shape_models.fmap import DFMNet
18
+ from shape_models.encoder import Encoder
19
+ from diffu_models.losses import VELoss, VPLoss, EDMLoss
20
+ from diffu_models.sds import guidance_grad
21
+ from utils.torch_fmap import torch_zoomout, knnsearch, extract_p2p_torch_fmap
22
+ from utils.utils_func import convert_dict, str_delta, ensure_pretrained_file
23
+ from utils.eval import accuracy
24
+ from utils.mesh import save_ply, load_mesh
25
+ from shape_data import get_data_dirs
26
+ from utils.pickle_stuff import safe_load_with_fallback
27
+ from utils.geometry import compute_operators, load_operators
28
+ from utils.surfaces import Surface
29
+ import sys
30
+ try:
31
+ import google.colab
32
+ print("Running Colab")
33
+ from tqdm import tqdm
34
+ except ImportError:
35
+ print("Running local")
36
+ from tqdm.auto import tqdm
37
+
38
+
39
+ def seed_everything(seed=42):
40
+ random.seed(seed)
41
+ os.environ['PYTHONHASHSEED'] = str(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = False
46
+
47
+ seed_everything()
48
+
49
+ class Tee:
50
+ def __init__(self, *outputs):
51
+ self.outputs = outputs
52
+
53
+ def write(self, message):
54
+ for output in self.outputs:
55
+ output.write(message)
56
+ output.flush() # ensure it's written immediately
57
+
58
+ def flush(self):
59
+ for output in self.outputs:
60
+ output.flush()
61
+
62
+ class DiffModel:
63
+
64
+ def __init__(self, cfg, device="cuda:0"):
65
+ if cfg["train_dir"] == "pretrained":
66
+ url = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/network-snapshot-041216.pkl"
67
+ network_pkl = ensure_pretrained_file(url, "pretrained")
68
+ url_json = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/training_options.json"
69
+ json_filename = ensure_pretrained_file(url_json, "pretrained", filename="training_options.json")
70
+ train_cfg = json.load(open(json_filename))
71
+ else:
72
+ num_exp = cfg["diff_num_exp"]
73
+ files = os.listdir(cfg["train_dir"])
74
+ for file in files:
75
+ if file[:5] == f"{num_exp:05d}":
76
+ netdir = os.path.join(cfg["train_dir"], file)
77
+ train_cfg = json.load(open(os.path.join(netdir, "training_options.json")))
78
+ pkls = [f for f in os.listdir(netdir) if ".pkl" in f]
79
+ nice_pkls = sorted(pkls, key=lambda x: int(x.split(".")[0].split("-")[-1]))
80
+ chosen_pkl = nice_pkls[-1]
81
+ network_pkl = os.path.join(netdir, chosen_pkl)
82
+ print(f'Loading network from "{network_pkl}"...')
83
+ self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
84
+
85
+ print('Done!')
86
+ loss_name = train_cfg['hyper_params']['loss_name']
87
+ self.loss_sde = None
88
+ if loss_name == "EDMLoss":
89
+ self.loss_sde = EDMLoss()
90
+ elif loss_name == "VPLoss":
91
+ self.loss_sde = VPLoss()
92
+
93
+
94
+ class Matcher(object):
95
+
96
+ def __init__(self, cfg):
97
+ self.cfg = cfg
98
+ self.device = torch.device(f'cuda:{cfg["gpu"]}' if torch.cuda.is_available() else 'cpu')
99
+ self.diffusion_model = None
100
+ if self.cfg.get("sds", False):
101
+ self.diffusion_model = DiffModel(cfg["sds_conf"])
102
+ self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
103
+ self.n_loop = 0
104
+ if self.cfg.get("optimize", False):
105
+ self.n_loop = self.cfg.opt.get("n_loop", 0)
106
+ self.snk = self.cfg.get("snk", False)
107
+ self.fmap_cfg = self.cfg.deepfeat_conf.fmap
108
+ self.dataloaders = dict()
109
+
110
+ def reconf(self, cfg):
111
+ self.cfg = cfg
112
+ self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
113
+ self.n_loop = 0
114
+ if self.cfg.get("optimize", False):
115
+ self.n_loop = self.cfg.opt.get("n_loop", 0)
116
+ self.fmap_cfg = self.cfg.deepfeat_conf.fmap
117
+ self.dataloaders = dict()
118
+
119
+ def _init(self):
120
+ cfg = self.cfg
121
+ self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
122
+ if self.snk:
123
+ self.encoder = Encoder().to(self.device)
124
+ self.decoder = PrismDecoder(dim_in=515).to(self.device)
125
+ self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
126
+ self.soft_p2p = True
127
+ params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(self.decoder.parameters())
128
+ else:
129
+ params_to_opt = self.fmap_model.parameters()
130
+ self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
131
+ self.eye = torch.eye(self.n_fmap).float().to(self.device)
132
+ self.eye.requires_grad = False
133
+
134
+ def fmap(self, shape_dict, target_dict):
135
+ if self.fmap_cfg.get("use_diff", False):
136
+ C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model, scale=self.fmap_cfg.diffusion.time)
137
+ C12_pred, C12_obj, mask_12 = C12_pred
138
+ C21_pred, C21_obj, mask_21 = C21_pred
139
+ else:
140
+ C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict})
141
+ C12_obj, C21_obj = C12_pred, C21_pred
142
+ mask_12, mask_21 = None, None
143
+ return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
144
+
145
+
146
+ def zo_shot(self, shape_dict, target_dict):
147
+ self._init()
148
+ evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
149
+ _, C12_mask_init, _, _, _, _, _ , _, _, _ = self.fmap(shape_dict, target_dict)
150
+ evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
151
+ new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
152
+ indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
153
+ return new_FM, indKNN_new
154
+
155
+
156
+ def optimize(self, shape_dict, target_dict, target_normals):
157
+ self._init()
158
+ evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
159
+ C12_pred_init, _, _, _ , _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
160
+ evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
161
+ evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
162
+ n_verts_target = target_dict["vertices"].shape[-2]
163
+
164
+ loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [], "proper": []}
165
+ snk_rec = None
166
+ for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
167
+ C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
168
+ if self.cfg.opt.soft_p2p:
169
+ ### A la SNK
170
+ ## P2P 2 -> 1
171
+ soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap], prod=True)
172
+ C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
173
+ soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
174
+
175
+ ## P2P 1 -> 2
176
+ soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap], prod=True)
177
+ C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
178
+ soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
179
+
180
+ l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
181
+ else:
182
+ C12_new, C21_new = C12_pred, C21_pred
183
+
184
+ l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye)**2).mean() + ((C21_new.squeeze() @ C21_new.squeeze().T - self.eye)**2).mean()
185
+ l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye)**2).mean() + ((C21_new.squeeze() @ C12_new.squeeze() - self.eye)**2).mean()
186
+ l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(target_dict["evals"][:self.n_fmap]) @ C12_new)**2).mean()
187
+ l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(shape_dict["evals"][:self.n_fmap]) @ C21_new)**2).mean()
188
+
189
+
190
+ l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
191
+ if self.snk:
192
+ # Latent vector
193
+ latents = self.encoder(shape_dict)
194
+ latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
195
+
196
+ # Prism decoder
197
+ feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
198
+ snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
199
+ l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
200
+ l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec)**2).sum(dim=-1).mean()
201
+ l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
202
+ l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
203
+ if self.fmap_cfg.get("use_diff", False):
204
+ if self.fmap_cfg.diffusion.get("abs", False):
205
+ C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
206
+ else:
207
+ C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
208
+ grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
209
+ scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
210
+ with torch.no_grad():
211
+ denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
212
+ targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
213
+
214
+ l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
215
+
216
+ grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
217
+ scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
218
+ #denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
219
+ with torch.no_grad():
220
+ denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
221
+ targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(), self.cfg.sds_conf.zoomout)#, step=10)
222
+ l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
223
+ l_proper = l_proper_12 + l_proper_21
224
+
225
+ l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
226
+ l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
227
+ loss = torch.as_tensor(0.).float().to(self.device)
228
+ if self.cfg.loss.get("ortho", 0) > 0:
229
+ loss += self.cfg.loss.get("ortho", 0) * l_ortho
230
+ if self.cfg.loss.get("bij", 0) > 0:
231
+ loss += self.cfg.loss.get("bij", 0) * l_bij
232
+ if self.cfg.loss.get("lap", 0) > 0:
233
+ loss += self.cfg.loss.get("lap", 0) * l_lap
234
+ if self.cfg.loss.get("cycle", 0) > 0:
235
+ loss += self.cfg.loss.get("cycle", 0) * l_cycle
236
+ if self.cfg.loss.get("mse_rec", 0) > 0:
237
+ loss += self.cfg.loss.get("mse_rec", 0) * l_mse
238
+ if self.cfg.loss.get("prism_rec", 0) > 0:
239
+ loss += self.cfg.loss.get("prism_rec", 0) * l_prism
240
+ if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
241
+ loss += self.cfg.loss.get("sds", 0) * l_sds
242
+ if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
243
+ loss += self.cfg.loss.get("proper", 0) * l_proper
244
+
245
+ loss.backward()
246
+ self.optim.step()
247
+ self.optim.zero_grad()
248
+
249
+ loss_save["cycle"].append(l_cycle.item())
250
+ loss_save["ortho"].append(l_ortho.item())
251
+ loss_save["bij"].append(l_bij.item())
252
+ loss_save["sds"].append(l_sds.item())
253
+ loss_save["proper"].append(l_proper.item())
254
+ loss_save["mse"].append(l_mse.item())
255
+ loss_save["prism"].append(l_prism.item())
256
+ indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
257
+ indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
258
+ return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
259
+
260
+
261
+
262
+ def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
263
+ shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
264
+ shape_dict_device = convert_dict(shape_dict, self.device)
265
+ target_dict_device = convert_dict(target_dict, self.device)
266
+ print(shape_dict_device["vertices"].device)
267
+ os.makedirs(output_pair, exist_ok=True)
268
+
269
+
270
+ if self.cfg["optimize"]:
271
+ C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device, target_normals.to(self.device))
272
+ np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
273
+ np.save(os.path.join(output_pair, "losses.npy"), loss_save)
274
+ else:
275
+ C12_new, p2p = self.zo_shot(shape_dict_device, target_dict_device)
276
+ snk_rec, loss_save = None, None
277
+ np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
278
+ np.save(os.path.join(output_pair, "p2p.npy"), p2p)
279
+ if snk_rec is not None:
280
+ save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(), target_dict["faces"])
281
+
282
+ if refine:
283
+ evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
284
+ evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
285
+ new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128)#, step=10)
286
+ p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
287
+ np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
288
+ if eval:
289
+ file_i, vts_1, vts_2 = mapinfo
290
+ mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
291
+ A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
292
+ _, dist = accuracy(p2p[vts_2], vts_1, A_geod,
293
+ sqrt_area=sqrt_area,
294
+ return_all=True)
295
+ if refine:
296
+ _, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
297
+ sqrt_area=sqrt_area,
298
+ return_all=True)
299
+ np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
300
+ return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
301
+ return p2p, loss_save, dist.mean()
302
+ return p2p, loss_save
303
+
304
+
305
+
306
+
307
+ def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
308
+ os.makedirs(save_dir, exist_ok=True)
309
+ # dloader = DataLoader(dataset, collate_fn=collate_default, batch_size=1)
310
+ num_pairs = len(dataset)
311
+ id_pair = 0
312
+ all_accs = []
313
+ all_accs_zo = []
314
+ t1 = datetime.now()
315
+ save_txt = os.path.join(save_dir, "log.txt")
316
+ # Open a file for writing
317
+ log_file = open(save_txt, 'w')
318
+ # Replace sys.stdout with Tee that writes to both console and file
319
+ sys.stdout = Tee(sys.__stdout__, log_file)
320
+
321
+ for batch in dset:
322
+ shape_dict, _, target_dict, _, _, _ = batch
323
+ print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
324
+ name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
325
+ if self.cfg.get("refine", False):
326
+ _, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=True)
327
+ else:
328
+ _, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=False)
329
+ delta = datetime.now() - t1
330
+ fm_delta = str_delta(delta)
331
+ remains = ((delta/(id_pair+1))*num_pairs) - delta
332
+ fm_remains = str_delta(remains)
333
+ all_accs.append(dist)
334
+ accs_mean = np.mean(all_accs)
335
+ if self.cfg.get("refine", False):
336
+ all_accs_zo.append(dist_zo)
337
+ accs_zo = np.mean(all_accs_zo)
338
+ print(f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
339
+ else:
340
+ print(f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
341
+ id_pair += 1
342
+ if self.cfg.get("refine", False):
343
+ print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
344
+ else:
345
+ print(f"mean error : {np.mean(all_accs)}")
346
+ sys.stdout = sys.__stdout__
347
+
348
+ def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
349
+ name = os.path.basename(os.path.splitext(file)[0])
350
+ cache_file = "single_" + name + ".npz"
351
+ verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
352
+ cache_path = os.path.join(self.cfg.cache, cache_file)
353
+ print("Cache is: ", cache_path)
354
+ if not os.path.exists(cache_path) or make_cache:
355
+ print("Computing operators ...")
356
+ compute_operators(verts_shape, faces, vnormals, num_evecs, cache_path, force_save=make_cache)
357
+ data_dict = load_operators(cache_path)
358
+ data_dict['name'] = name
359
+ data_dict_torch = convert_dict(data_dict, self.device)
360
+ #batchify_dict(data_dict_torch)
361
+ return data_dict_torch, area_shape
362
+
363
+ def match_files(self, file_shape, file_target):
364
+ batch_shape, _ = self.load_data(file_shape)
365
+ batch_target, _ = self.load_data(file_target)
366
+ target_surf = Surface(filename=file_target)
367
+ target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
368
+ batch = batch_shape, None, batch_target, target_normals, None, None
369
+ output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
370
+ p2p, _ = self.match(batch, output_folder, None)
371
+ return batch_shape, batch_target, p2p
372
+
373
+
374
+
375
+
376
+ if __name__ == '__main__':
377
+ parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
378
+ parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
379
+ parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
380
+ parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
381
+ parser.add_argument('--output', type=str, default="results", help="where to store experience results")
382
+ args = parser.parse_args()
383
+
384
+ arg_cfg = OmegaConf.from_dotlist(
385
+ [f"{k}={v}" for k, v in vars(args).items() if v is not None]
386
+ )
387
+ yaml_cfg = OmegaConf.load(args.config)
388
+ cfg = OmegaConf.merge(yaml_cfg, arg_cfg)
389
+ dataset_name = args.dataset.lower()
390
+ if cfg.get("oriented", False):
391
+ dataset_name += "_ori"
392
+ shape_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapeDataset')
393
+ pair_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapePairDataset')
394
+ data_dir, name_data_geo, corr_dir = get_data_dirs(args.datadir, dataset_name, 'test')
395
+ name_data_geo = "_".join(name_data_geo.split("_")[:2])
396
+ dset_shape = shape_cls(data_dir, "cache/fmaps", "test", oriented=cfg.get("oriented", False))
397
+ print("Preprocessing shapes done.")
398
+ dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
399
+ exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
400
+ output_logs = os.path.join(args.output, name_data_geo, exp_time)
401
+ matcher = Matcher(cfg)
402
+ matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)