Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		daidedou
		
	commited on
		
		
					Commit 
							
							·
						
						458efe2
	
1
								Parent(s):
							
							cf4ac70
								
first_try
Browse files- app.py +272 -0
- notebook_helpers.py +88 -0
- requirements.txt +45 -0
- shape_models/__init__.py +3 -0
- shape_models/encoder.py +24 -0
- shape_models/fmap.py +186 -0
- shape_models/geometry.py +897 -0
- shape_models/layers.py +453 -0
- shape_models/utils.py +123 -0
- utils/__init__.py +3 -0
- utils/descriptors.py +250 -0
- utils/eval.py +25 -0
- utils/fmap.py +121 -0
- utils/geometry.py +951 -0
- utils/io.py +78 -0
- utils/layers.py +430 -0
- utils/mesh.py +214 -0
- utils/meshplot.py +67 -0
- utils/misc.py +122 -0
- utils/pickle_stuff.py +38 -0
- utils/surfaces.py +1377 -0
- utils/torch_fmap.py +77 -0
- utils/utils_func.py +123 -0
- utils/utils_legacy.py +130 -0
- zero_shot.py +402 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,272 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Simple Gradio app for two-mesh initialization and run phases.
         | 
| 3 | 
            +
            - Upload two meshes (.ply, .obj, .off)
         | 
| 4 | 
            +
            - (Optional) upload a YAML config to override defaults
         | 
| 5 | 
            +
            - Adjust a few numeric settings (sane ranges). Defaults pulled from the provided YAML when present.
         | 
| 6 | 
            +
            - Click **Init** to generate "initialization maps" (here: position/normal-based vertex colors) for both meshes.
         | 
| 7 | 
            +
            - Click **Run** to simulate an iterative evolution with a progress bar, then output another pair of colored meshes.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Replace the bodies of `make_initialization_maps` and `run_evolution` with your real pipeline as needed.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Tested with: gradio >= 4.0, trimesh, pyyaml, numpy.
         | 
| 12 | 
            +
            """
         | 
| 13 | 
            +
            from __future__ import annotations
         | 
| 14 | 
            +
            import os
         | 
| 15 | 
            +
            import io
         | 
| 16 | 
            +
            import time
         | 
| 17 | 
            +
            import json
         | 
| 18 | 
            +
            import tempfile
         | 
| 19 | 
            +
            from typing import Dict, Tuple, Optional
         | 
| 20 | 
            +
            from omegaconf import OmegaConf
         | 
| 21 | 
            +
            import gradio as gr
         | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
            import trimesh
         | 
| 24 | 
            +
            import zero_shot
         | 
| 25 | 
            +
            import yaml
         | 
| 26 | 
            +
            from utils.surfaces import Surface
         | 
| 27 | 
            +
            import notebook_helpers as helper
         | 
| 28 | 
            +
            from utils.meshplot import visu_pts
         | 
| 29 | 
            +
            from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
         | 
| 30 | 
            +
            import torch
         | 
| 31 | 
            +
            import argparse
         | 
| 32 | 
            +
            # -----------------------------
         | 
| 33 | 
            +
            # Utils
         | 
| 34 | 
            +
            # -----------------------------
         | 
| 35 | 
            +
            SUPPORTED_EXTS = {".ply", ".obj", ".off", ".stl", ".glb", ".gltf"}
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def _safe_ext(path: str) -> str:
         | 
| 38 | 
            +
                for ext in SUPPORTED_EXTS:
         | 
| 39 | 
            +
                    if path.lower().endswith(ext):
         | 
| 40 | 
            +
                        return ext
         | 
| 41 | 
            +
                return os.path.splitext(path)[1].lower()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def normalize_vertices(vertices: np.ndarray) -> np.ndarray:
         | 
| 45 | 
            +
                v = vertices.astype(np.float64)
         | 
| 46 | 
            +
                v = v - v.mean(axis=0, keepdims=True)
         | 
| 47 | 
            +
                scale = np.linalg.norm(v, axis=1).max()
         | 
| 48 | 
            +
                if scale == 0:
         | 
| 49 | 
            +
                    scale = 1.0
         | 
| 50 | 
            +
                v = v / scale
         | 
| 51 | 
            +
                return v.astype(np.float32)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def ensure_vertex_colors(mesh: trimesh.Trimesh, colors: np.ndarray) -> trimesh.Trimesh:
         | 
| 56 | 
            +
                out = mesh.copy()
         | 
| 57 | 
            +
                if colors.shape[1] == 3:
         | 
| 58 | 
            +
                    rgba = np.concatenate([colors, 255*np.ones((colors.shape[0],1), dtype=np.uint8)], axis=1)
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    rgba = colors
         | 
| 61 | 
            +
                out.visual.vertex_colors = rgba
         | 
| 62 | 
            +
                return out
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def export_for_view(surf: Surface, colors: np.ndarray, basename: str, outdir: str) -> Tuple[str, str]:
         | 
| 66 | 
            +
                """Export to PLY (with vertex colors) and GLB for Model3D preview."""
         | 
| 67 | 
            +
                glb_path = os.path.join(outdir, f"{basename}.glb")
         | 
| 68 | 
            +
                mesh = trimesh.Trimesh(surf.vertices, surf.faces, process=False)
         | 
| 69 | 
            +
                colored_mesh = ensure_vertex_colors(mesh, colors)
         | 
| 70 | 
            +
                colored_mesh.export(glb_path)
         | 
| 71 | 
            +
                return glb_path
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            # -----------------------------
         | 
| 75 | 
            +
            # Algorithm placeholders (replace with your real pipeline)
         | 
| 76 | 
            +
            # -----------------------------
         | 
| 77 | 
            +
            DEFAULT_SETTINGS = {
         | 
| 78 | 
            +
                "deepfeat_conf.fmap.lambda_": 1,
         | 
| 79 | 
            +
                "sds_conf.zoomout": 40.0,
         | 
| 80 | 
            +
                "diffusion.time": 1.0,
         | 
| 81 | 
            +
                "opt.n_loop": 300,
         | 
| 82 | 
            +
                "loss.sds": 1.0,
         | 
| 83 | 
            +
                "loss.proper": 1.0,
         | 
| 84 | 
            +
            }
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            FLOAT_SLIDERS = {
         | 
| 87 | 
            +
                # name: (min, max, step)
         | 
| 88 | 
            +
                "deepfeat_conf.fmap.lambda_": (1e-3, 10.0, 1e-3),
         | 
| 89 | 
            +
                "sds_conf.zoomout": (1e-3, 10.0, 1e-3),
         | 
| 90 | 
            +
                "diffusion.time": (1e-3, 10.0, 1e-3),
         | 
| 91 | 
            +
                "loss.sds": (1e-3, 10.0, 1e-3),
         | 
| 92 | 
            +
                "loss.proper": (1e-3, 10.0, 1e-3),
         | 
| 93 | 
            +
            }
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            INT_SLIDERS = {
         | 
| 96 | 
            +
                "opt.n_loop": (1, 5000, 1),
         | 
| 97 | 
            +
            }
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def flatten_yaml_floats(d: Dict, prefix: str = "") -> Dict[str, float]:
         | 
| 101 | 
            +
                flat = {}
         | 
| 102 | 
            +
                for k, v in d.items():
         | 
| 103 | 
            +
                    key = f"{prefix}.{k}" if prefix else str(k)
         | 
| 104 | 
            +
                    if isinstance(v, dict):
         | 
| 105 | 
            +
                        flat.update(flatten_yaml_floats(v, key))
         | 
| 106 | 
            +
                    elif isinstance(v, (int, float)):
         | 
| 107 | 
            +
                        flat[key] = float(v)
         | 
| 108 | 
            +
                return flat
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def read_yaml_defaults(yaml_path: Optional[str]) -> Dict[str, float]:
         | 
| 112 | 
            +
                if yaml_path and os.path.exists(yaml_path):
         | 
| 113 | 
            +
                    with open(yaml_path, "r") as f:
         | 
| 114 | 
            +
                        data = yaml.safe_load(f)
         | 
| 115 | 
            +
                    flat = flatten_yaml_floats(data)
         | 
| 116 | 
            +
                    # Only keep known keys we expose as controls
         | 
| 117 | 
            +
                    defaults = DEFAULT_SETTINGS.copy()
         | 
| 118 | 
            +
                    for k in list(DEFAULT_SETTINGS.keys()):
         | 
| 119 | 
            +
                        if k in flat:
         | 
| 120 | 
            +
                            defaults[k] = float(flat[k])
         | 
| 121 | 
            +
                    return defaults
         | 
| 122 | 
            +
                return DEFAULT_SETTINGS.copy()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            class Datadicts:
         | 
| 128 | 
            +
                def __init__(self, shape_path, target_path):
         | 
| 129 | 
            +
                    self.shape_path = shape_path
         | 
| 130 | 
            +
                    basename_1 = os.path.basename(shape_path)
         | 
| 131 | 
            +
                    self.shape_dict, _ = helper.load_data(shape_path, "tmp/" + os.path.splitext(basename_1)[0]+".npz", "source", make_cache=True)
         | 
| 132 | 
            +
                    self.shape_surf = Surface(filename=shape_path)
         | 
| 133 | 
            +
                    self.target_path = target_path
         | 
| 134 | 
            +
                    basename_2 = os.path.basename(target_path)
         | 
| 135 | 
            +
                    self.target_dict, _ = helper.load_data(target_path, "tmp/" + os.path.splitext(basename_2)[0]+".npz", "target", make_cache=True)
         | 
| 136 | 
            +
                    self.target_surf = Surface(filename=target_path)
         | 
| 137 | 
            +
                    self.cmap1 = visu_pts(self.shape_surf)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            # -----------------------------
         | 
| 140 | 
            +
            # Gradio UI
         | 
| 141 | 
            +
            # -----------------------------
         | 
| 142 | 
            +
            TMP_ROOT = tempfile.mkdtemp(prefix="meshapp_")
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            def save_array_txt(arr):
         | 
| 145 | 
            +
                # Create a temporary file with .txt suffix
         | 
| 146 | 
            +
                with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f:
         | 
| 147 | 
            +
                    np.savetxt(f, arr.astype(int), fmt="%d")  # save as text
         | 
| 148 | 
            +
                    return f.name
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            def build_outputs(surf_a: Surface, surf_b: Surface, cmap_a: np.ndarray, p2p: np.ndarray, tag: str) -> Tuple[str, str, str, str]:
         | 
| 151 | 
            +
                outdir = os.path.join(TMP_ROOT, tag)
         | 
| 152 | 
            +
                os.makedirs(outdir, exist_ok=True)
         | 
| 153 | 
            +
                glb_a = export_for_view(surf_a, cmap_a, f"A_{tag}", outdir)
         | 
| 154 | 
            +
                cmap_b = cmap_a[p2p]
         | 
| 155 | 
            +
                glb_b = export_for_view(surf_b, cmap_b, f"B_{tag}", outdir)
         | 
| 156 | 
            +
                out_file = save_array_txt(p2p)
         | 
| 157 | 
            +
                return glb_a, glb_b, out_file
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def init_clicked(mesh1_path, mesh2_path,
         | 
| 161 | 
            +
                             lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val):
         | 
| 162 | 
            +
                cfg.deepfeat_conf.fmap.lambda_ = lambda_val
         | 
| 163 | 
            +
                cfg.sds_conf.zoomout = zoomout_val
         | 
| 164 | 
            +
                cfg.deepfeat_conf.fmap.diffusion.time = time_val
         | 
| 165 | 
            +
                cfg.opt.n_loop = nloop_val
         | 
| 166 | 
            +
                cfg.loss.sds = sds_val
         | 
| 167 | 
            +
                cfg.loss.proper = proper_val
         | 
| 168 | 
            +
                matcher.reconf(cfg)
         | 
| 169 | 
            +
                if not mesh1_path or not mesh2_path:
         | 
| 170 | 
            +
                    raise gr.Error("Please upload both meshes.")
         | 
| 171 | 
            +
                matcher._init()
         | 
| 172 | 
            +
                global datadicts
         | 
| 173 | 
            +
                datadicts = Datadicts(mesh1_path, mesh2_path)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = matcher.fmap_model({"shape1": datadicts.shape_dict, "shape2": datadicts.target_dict}, diff_model=matcher.diffusion_model, scale=matcher.fmap_cfg.diffusion.time)
         | 
| 176 | 
            +
                C12_pred, C12_obj, mask_12 = C12_pred_init
         | 
| 177 | 
            +
                p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
         | 
| 178 | 
            +
                return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            def run_clicked(mesh1_path, mesh2_path, yaml_path, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val, progress=gr.Progress(track_tqdm=True)):
         | 
| 182 | 
            +
                if not mesh1_path or not mesh2_path:
         | 
| 183 | 
            +
                    raise gr.Error("Please upload both meshes.")
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                cfg.deepfeat_conf.fmap.lambda_ = lambda_val
         | 
| 186 | 
            +
                cfg.sds_conf.zoomout = zoomout_val
         | 
| 187 | 
            +
                cfg.deepfeat_conf.fmap.diffusion.time = time_val
         | 
| 188 | 
            +
                cfg.opt.n_loop = nloop_val
         | 
| 189 | 
            +
                cfg.loss.sds = sds_val
         | 
| 190 | 
            +
                cfg.loss.proper = proper_val
         | 
| 191 | 
            +
                matcher.reconf(cfg)
         | 
| 192 | 
            +
                if not mesh1_path or not mesh2_path:
         | 
| 193 | 
            +
                    raise gr.Error("Please upload both meshes.")
         | 
| 194 | 
            +
                matcher._init()
         | 
| 195 | 
            +
                global datadicts
         | 
| 196 | 
            +
                if datadicts is None:
         | 
| 197 | 
            +
                    datadicts = Datadicts(mesh1_path, mesh2_path)
         | 
| 198 | 
            +
                elif datadicts is not None:
         | 
| 199 | 
            +
                    if not (datadicts.shape_path == mesh1_path and datadicts.target_path == mesh2_path):
         | 
| 200 | 
            +
                        datadicts = Datadicts(mesh1_path, mesh2_path)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                target_normals = torch.from_numpy(datadicts.target_surf.surfel/np.linalg.norm(datadicts.target_surf.surfel, axis=-1, keepdims=True)).float().to("cuda")
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                C12_new, p2p, p2p_init, _, loss_save = matcher.optimize(datadicts.shape_dict, datadicts.target_dict, target_normals)
         | 
| 205 | 
            +
                evecs1, evecs2 = datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"]
         | 
| 206 | 
            +
                evecs_2trans = evecs2.t() @ torch.diag(datadicts.target_dict["mass"])
         | 
| 207 | 
            +
                C12_end_zo = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze()[:15, :15], 150)# matcher.cfg.sds_conf.zoomout)
         | 
| 208 | 
            +
                p2p_zo, _ = extract_p2p_torch_fmap(C12_end_zo, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
         | 
| 209 | 
            +
                return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_zo, tag="run")
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            with gr.Blocks(title="DiffuMatch demo") as demo:
         | 
| 213 | 
            +
                text_in = "Upload two meshes and try our ICCV zero-shot method DiffuMatch! \n"
         | 
| 214 | 
            +
                text_in += "*Init* will give you a rough correspondence, and you can click on *Run* to see if our method is able to match the two shapes! \n"
         | 
| 215 | 
            +
                text_in += "*Recommended*: The method requires that the meshes are aligned (rotation-wise) to work well. Also might not work with scans (but try it out!)."
         | 
| 216 | 
            +
                gr.Markdown(text_in)
         | 
| 217 | 
            +
                with gr.Row():
         | 
| 218 | 
            +
                    mesh1 = gr.File(label="Mesh A (.ply/.obj/.off)")
         | 
| 219 | 
            +
                    mesh2 = gr.File(label="Mesh B (.ply/.obj/.off)")
         | 
| 220 | 
            +
                    yaml_file = gr.File(label="Optional YAML config", file_types=[".yaml", ".yml"], visible=True)
         | 
| 221 | 
            +
                # except Exception:
         | 
| 222 | 
            +
                with gr.Accordion("Settings", open=True):
         | 
| 223 | 
            +
                    with gr.Row():
         | 
| 224 | 
            +
                        lambda_val = gr.Slider(minimum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][0], maximum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][1], step=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][2], value=1, label="deepfeat_conf.fmap.lambda_")
         | 
| 225 | 
            +
                        zoomout_val = gr.Slider(minimum=FLOAT_SLIDERS["sds_conf.zoomout"][0], maximum=FLOAT_SLIDERS["sds_conf.zoomout"][1], step=FLOAT_SLIDERS["sds_conf.zoomout"][2], value=40, label="sds_conf.zoomout")
         | 
| 226 | 
            +
                        time_val = gr.Slider(minimum=FLOAT_SLIDERS["diffusion.time"][0], maximum=FLOAT_SLIDERS["diffusion.time"][1], step=FLOAT_SLIDERS["diffusion.time"][2], value=1, label="diffusion.time")
         | 
| 227 | 
            +
                    with gr.Row():
         | 
| 228 | 
            +
                        nloop_val = gr.Slider(minimum=INT_SLIDERS["opt.n_loop"][0], maximum=INT_SLIDERS["opt.n_loop"][1], step=INT_SLIDERS["opt.n_loop"][2], value=300, label="opt.n_loop")
         | 
| 229 | 
            +
                        sds_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.sds"][0], maximum=FLOAT_SLIDERS["loss.sds"][1], step=FLOAT_SLIDERS["loss.sds"][2], value=1, label="loss.sds")
         | 
| 230 | 
            +
                        proper_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.proper"][0], maximum=FLOAT_SLIDERS["loss.proper"][1], step=FLOAT_SLIDERS["loss.proper"][2], value=1, label="loss.proper")
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                with gr.Row():
         | 
| 233 | 
            +
                    init_btn = gr.Button("Init", variant="primary")
         | 
| 234 | 
            +
                    run_btn = gr.Button("Run", variant="secondary")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                gr.Markdown("### Outputs\nEach stage exports both **GLB** (preview below) and **PLY** (download links) with per‑vertex colors.")
         | 
| 237 | 
            +
                with gr.Tab("Init"):
         | 
| 238 | 
            +
                    with gr.Row():
         | 
| 239 | 
            +
                        init_view_a = gr.Model3D(label="Shape")
         | 
| 240 | 
            +
                        init_view_b = gr.Model3D(label="Target correspondence (init)")
         | 
| 241 | 
            +
                    with gr.Row():
         | 
| 242 | 
            +
                        out_file_init = gr.File(label="Download correspondences TXT")
         | 
| 243 | 
            +
                with gr.Tab("Run"):
         | 
| 244 | 
            +
                    with gr.Row():
         | 
| 245 | 
            +
                        run_view_a = gr.Model3D(label="Shape")
         | 
| 246 | 
            +
                        run_view_b = gr.Model3D(label="Target correspondence (run)")
         | 
| 247 | 
            +
                    with gr.Row():
         | 
| 248 | 
            +
                        out_file = gr.File(label="Download correspondences TXT")
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                init_btn.click(
         | 
| 251 | 
            +
                    fn=init_clicked,
         | 
| 252 | 
            +
                    inputs=[mesh1, mesh2, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
         | 
| 253 | 
            +
                    outputs=[init_view_a, init_view_b, out_file_init],
         | 
| 254 | 
            +
                    api_name="init",
         | 
| 255 | 
            +
                )
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                run_btn.click(
         | 
| 258 | 
            +
                    fn=run_clicked,
         | 
| 259 | 
            +
                    inputs=[mesh1, mesh2, yaml_file, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val],
         | 
| 260 | 
            +
                    outputs=[run_view_a, run_view_b, out_file],
         | 
| 261 | 
            +
                    api_name="run",
         | 
| 262 | 
            +
                )
         | 
| 263 | 
            +
             | 
| 264 | 
            +
            if __name__ == "__main__":
         | 
| 265 | 
            +
                parser = argparse.ArgumentParser(description="Launch the gradio demo")
         | 
| 266 | 
            +
                parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
         | 
| 267 | 
            +
                parser.add_argument('--share', action="store_true")
         | 
| 268 | 
            +
                args = parser.parse_args()
         | 
| 269 | 
            +
                cfg = OmegaConf.load(args.config)
         | 
| 270 | 
            +
                matcher = zero_shot.Matcher(cfg)
         | 
| 271 | 
            +
                datadicts = None
         | 
| 272 | 
            +
                demo.launch(share=args.share)
         | 
    	
        notebook_helpers.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from utils.mesh import load_mesh
         | 
| 2 | 
            +
            from utils.geometry import get_operators, load_operators
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from utils.utils_func import convert_dict
         | 
| 5 | 
            +
            from utils.surfaces import Surface
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            device = "cuda:0"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def load_data(file, cache_path, name, num_evecs=128, make_cache=False, factor=None):
         | 
| 12 | 
            +
                if factor is None:
         | 
| 13 | 
            +
                    verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
         | 
| 14 | 
            +
                else:
         | 
| 15 | 
            +
                    verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True, scale=False)
         | 
| 16 | 
            +
                    verts_shape = verts_shape/factor
         | 
| 17 | 
            +
                    area_shape /= factor**2
         | 
| 18 | 
            +
                # print("Cache is: ", cache_path)
         | 
| 19 | 
            +
                if not os.path.exists(cache_path) or make_cache:
         | 
| 20 | 
            +
                    print("Computing operators ...")
         | 
| 21 | 
            +
                    get_operators(verts_shape, faces, num_evecs, cache_path, vnormals)
         | 
| 22 | 
            +
                data_dict = load_operators(cache_path)
         | 
| 23 | 
            +
                data_dict['name'] = name
         | 
| 24 | 
            +
                data_dict['normals'] = vnormals
         | 
| 25 | 
            +
                data_dict['vertices'] = verts_shape
         | 
| 26 | 
            +
                data_dict_torch = convert_dict(data_dict, device)
         | 
| 27 | 
            +
                #batchify_dict(data_dict_torch)
         | 
| 28 | 
            +
                return data_dict_torch, area_shape
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def get_map_info(file_1, file_2, dict_1, dict_2, dataset):
         | 
| 31 | 
            +
                shape_dict, target_dict = dict_1, dict_2
         | 
| 32 | 
            +
                name_1, name_2 = shape_dict["name"], target_dict["name"]
         | 
| 33 | 
            +
                if dataset is None:
         | 
| 34 | 
            +
                    vts_1, vts_2 = np.arange(shape_dict['vertices'].shape[0]), np.arange(target_dict['vertices'].shape[0])
         | 
| 35 | 
            +
                    map_info = (file_1, file_2, vts_1, vts_2)
         | 
| 36 | 
            +
                file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
         | 
| 37 | 
            +
                vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
         | 
| 40 | 
            +
                vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
         | 
| 41 | 
            +
                if "DT4D" in dataset:    
         | 
| 42 | 
            +
                    file_vts_1 = file_1.replace("shapes", "correspondences")[:-4] + ".vts"
         | 
| 43 | 
            +
                    vts_1 = np.loadtxt(file_vts_1).astype(np.int32) - 1
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                    file_vts_2 = file_2.replace("shapes", "correspondences")[:-4] + ".vts"
         | 
| 46 | 
            +
                    vts_2 = np.loadtxt(file_vts_2).astype(np.int32) - 1
         | 
| 47 | 
            +
                    if name_1 == name_2:
         | 
| 48 | 
            +
                        map_info = (file_1, file_2, vts_1, vts_2)
         | 
| 49 | 
            +
                    elif ("crypto" in name_1) or ("crypto" in name_2):
         | 
| 50 | 
            +
                        name_cat_1, name_cat_2 = name_1.split(os.sep)[0], name_2.split(os.sep)[0]
         | 
| 51 | 
            +
                        data_path = os.path.dirname(os.path.dirname(os.path.dirname(file_1)))
         | 
| 52 | 
            +
                        map_file = os.path.join(data_path, "correspondences/cross_category_corres", f"{name_cat_1}_{name_cat_2}.vts")
         | 
| 53 | 
            +
                        if os.path.exists(map_file):
         | 
| 54 | 
            +
                            map_idx = np.loadtxt(map_file).astype(np.int32) - 1
         | 
| 55 | 
            +
                            map_info = (file_1, file_2, vts_1, vts_2[map_idx])
         | 
| 56 | 
            +
                    else:
         | 
| 57 | 
            +
                        print("NO GROUND TRUTH PAIRS")
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    map_info = (file_1,  file_2, vts_1, vts_2)
         | 
| 60 | 
            +
                return map_info
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def load_pair(cache, id_1, id_2, name_1, name_2, dataset):
         | 
| 64 | 
            +
                if "SCAPE" in dataset:
         | 
| 65 | 
            +
                    os.makedirs(cache, exist_ok=True) 
         | 
| 66 | 
            +
                    cache_file = os.path.join(cache, f"mesh{id_1:03d}_mesh_256k_0n.npz")
         | 
| 67 | 
            +
                    shape_file = f"data/{dataset}/shapes/mesh{id_1:03d}.ply"
         | 
| 68 | 
            +
                    shape_surf = Surface(filename=shape_file)
         | 
| 69 | 
            +
                    shape_dict, _  = load_data(shape_file, cache_file, str(id_1))
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                    cache_file = os.path.join(cache, f"mesh{id_2:03d}_mesh_256k_0n.npz")
         | 
| 72 | 
            +
                    target_file = f"data/{dataset}/shapes/mesh{id_2:03d}.ply"
         | 
| 73 | 
            +
                    target_surf = Surface(filename=target_file)
         | 
| 74 | 
            +
                    target_dict, _  = load_data(target_file, cache_file, str(id_2))
         | 
| 75 | 
            +
                    map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "SCAPE")
         | 
| 76 | 
            +
                elif "DT4D" in dataset:
         | 
| 77 | 
            +
                    cache_file = os.path.join(cache, f"{name_1}_mesh_256k_0n.npz")
         | 
| 78 | 
            +
                    shape_file = f"data/DT4D_r_ori/shapes/{name_1}.ply"
         | 
| 79 | 
            +
                    shape_surf = Surface(filename=shape_file)
         | 
| 80 | 
            +
                    shape_dict, _  = load_data(shape_file, cache_file, name_1)
         | 
| 81 | 
            +
                    cache_file = os.path.join(cache, f"{name_2}_mesh_256k_0n.npz")
         | 
| 82 | 
            +
                    # cache_file = f"../nonrigiddiff/cache/snk/{name_2}.npz"
         | 
| 83 | 
            +
                    cache_file = f"../nonrigiddiff/cache/attentive/{name_2}_mesh_256k_0n.npz"
         | 
| 84 | 
            +
                    target_file = f"data/DT4D_r_ori/shapes/{name_2}.ply"
         | 
| 85 | 
            +
                    target_surf = Surface(filename=target_file)
         | 
| 86 | 
            +
                    target_dict, _  = load_data(target_file, cache_file, name_2)
         | 
| 87 | 
            +
                    map_info = get_map_info(shape_file, target_file, shape_dict, target_dict, "DT4D")
         | 
| 88 | 
            +
                return shape_surf, target_surf, shape_dict, target_dict, map_info
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            comm==0.2.1
         | 
| 2 | 
            +
            configer
         | 
| 3 | 
            +
            configparser==6.0.0
         | 
| 4 | 
            +
            contourpy==1.1.1
         | 
| 5 | 
            +
            cycler==0.12.1
         | 
| 6 | 
            +
            fonttools==4.47.2
         | 
| 7 | 
            +
            freetype-py==2.4.0
         | 
| 8 | 
            +
            imageio==2.33.1
         | 
| 9 | 
            +
            ipygany==0.5.0
         | 
| 10 | 
            +
            ipywidgets==8.1.7
         | 
| 11 | 
            +
            nbformat==5.10.4
         | 
| 12 | 
            +
            jupyter
         | 
| 13 | 
            +
            kiwisolver==1.4.5
         | 
| 14 | 
            +
            lxml==5.1.0
         | 
| 15 | 
            +
            matplotlib==3.7.4
         | 
| 16 | 
            +
            networkx==3.1
         | 
| 17 | 
            +
            open3d>=0.18.0
         | 
| 18 | 
            +
            pyglet==2.0.10
         | 
| 19 | 
            +
            pyopengl==3.1.0
         | 
| 20 | 
            +
            pyparsing==3.1.1
         | 
| 21 | 
            +
            pyrender==0.1.45
         | 
| 22 | 
            +
            scipy>=1.10.1
         | 
| 23 | 
            +
            shapely==2.0.2
         | 
| 24 | 
            +
            smplx[all]
         | 
| 25 | 
            +
            tqdm==4.66.1
         | 
| 26 | 
            +
            trimesh==4.0.10
         | 
| 27 | 
            +
            vtk==9.3.0
         | 
| 28 | 
            +
            timm==1.0.11
         | 
| 29 | 
            +
            potpourri3d 
         | 
| 30 | 
            +
            pythreejs==2.4.2
         | 
| 31 | 
            +
            lapy
         | 
| 32 | 
            +
            roma
         | 
| 33 | 
            +
            webdataset
         | 
| 34 | 
            +
            lmdb
         | 
| 35 | 
            +
            robust_laplacian
         | 
| 36 | 
            +
            termcolor
         | 
| 37 | 
            +
            omegaconf
         | 
| 38 | 
            +
            pykeops==2.2.3
         | 
| 39 | 
            +
            scikit-image
         | 
| 40 | 
            +
            pyfmaps
         | 
| 41 | 
            +
            wandb
         | 
| 42 | 
            +
            gradio==4.44.1
         | 
| 43 | 
            +
            pydantic==2.10.6
         | 
| 44 | 
            +
            traitlets==5.7.1
         | 
| 45 | 
            +
            https://github.com/skoch9/meshplot/archive/0.4.0.tar.gz
         | 
    	
        shape_models/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .utils import *
         | 
| 2 | 
            +
            from .geometry import *
         | 
| 3 | 
            +
            from .layers import *
         | 
    	
        shape_models/encoder.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch 
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from .layers import DiffusionNet
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class Encoder(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, with_grad=True, key_verts="vertices"):
         | 
| 7 | 
            +
                    super(Encoder, self).__init__()
         | 
| 8 | 
            +
                    self.diff_net = DiffusionNet(
         | 
| 9 | 
            +
                         C_in=3,
         | 
| 10 | 
            +
                         C_out=512,
         | 
| 11 | 
            +
                         C_width=128,
         | 
| 12 | 
            +
                         N_block=4,
         | 
| 13 | 
            +
                         dropout=True,
         | 
| 14 | 
            +
                         with_gradient_features=with_grad,
         | 
| 15 | 
            +
                         with_gradient_rotations=with_grad,
         | 
| 16 | 
            +
                    )
         | 
| 17 | 
            +
                    self.key_verts = key_verts
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
                def forward(self, shape_dict):
         | 
| 21 | 
            +
                    feats = self.diff_net(shape_dict[self.key_verts], shape_dict["mass"], shape_dict["L"], evals=shape_dict["evals"], 
         | 
| 22 | 
            +
                                           evecs=shape_dict["evecs"], gradX=shape_dict["gradX"], gradY=shape_dict["gradY"], faces=shape_dict["faces"])
         | 
| 23 | 
            +
                    x_out = torch.max(feats, dim=0).values
         | 
| 24 | 
            +
                    return x_out
         | 
    	
        shape_models/fmap.py
    ADDED
    
    | @@ -0,0 +1,186 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from copy import deepcopy
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # feature extractor
         | 
| 7 | 
            +
            from .layers import DiffusionNet
         | 
| 8 | 
            +
            from omegaconf import OmegaConf
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_mask(evals1, evals2, gamma=0.5, device="cpu"):
         | 
| 12 | 
            +
                scaling_factor = max(torch.max(evals1), torch.max(evals2))
         | 
| 13 | 
            +
                evals1, evals2 = evals1.to(device) / scaling_factor, evals2.to(device) / scaling_factor
         | 
| 14 | 
            +
                evals_gamma1, evals_gamma2 = (evals1 ** gamma)[None, :], (evals2 ** gamma)[:, None]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                M_re = evals_gamma2 / (evals_gamma2.square() + 1) - evals_gamma1 / (evals_gamma1.square() + 1)
         | 
| 17 | 
            +
                M_im = 1 / (evals_gamma2.square() + 1) - 1 / (evals_gamma1.square() + 1)
         | 
| 18 | 
            +
                return M_re.square() + M_im.square()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y):
         | 
| 21 | 
            +
                # compute linear operator matrix representation C1 and C2
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                F_hat = torch.bmm(evecs_trans_x, feat_x)
         | 
| 24 | 
            +
                G_hat = torch.bmm(evecs_trans_y, feat_y)
         | 
| 25 | 
            +
                A, B = F_hat, G_hat
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
         | 
| 28 | 
            +
                A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
         | 
| 29 | 
            +
                B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                C12 = torch.bmm(B_A_t, torch.inverse(A_A_t))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                C21 = torch.bmm(A_B_t, torch.inverse(B_B_t))
         | 
| 34 | 
            +
                return [C12, C21]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def get_mask_noise(C, diff_model, scale=1, N_est=200, device="cuda", normalize=True, absolute=False):
         | 
| 37 | 
            +
                with torch.no_grad():
         | 
| 38 | 
            +
                    sig = torch.ones([1, 1, 1, 1], device=device) * scale
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    noise = torch.randn((N_est, 1, 30, 30), device=device)
         | 
| 41 | 
            +
                    if absolute:
         | 
| 42 | 
            +
                        #noisy_new = torch.abs(C[None, :, :] + noise *scale)
         | 
| 43 | 
            +
                        noisy_new = torch.abs(torch.abs(C[None, :, :]) + noise * scale)
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        noisy_new = C[None, :, :] + noise *scale
         | 
| 46 | 
            +
                    denoised = diff_model.net(noisy_new, sig)
         | 
| 47 | 
            +
                    mask_squared = torch.mean(torch.abs(noisy_new - denoised)/(2*scale), dim=0)/torch.mean(torch.abs(noisy_new), dim=0)
         | 
| 48 | 
            +
                    mask_median = torch.median(mask_squared).item()
         | 
| 49 | 
            +
                    if normalize:
         | 
| 50 | 
            +
                        mask_median = torch.median(mask_squared).item()
         | 
| 51 | 
            +
                        M_denoised = torch.sqrt(torch.clamp(mask_squared-mask_median/2, 0, mask_median*2))
         | 
| 52 | 
            +
                    else:
         | 
| 53 | 
            +
                        M_denoised = torch.sqrt(mask_squared)
         | 
| 54 | 
            +
                return M_denoised-M_denoised.min()#, mask_median*2)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class RegularizedFMNet(nn.Module):
         | 
| 58 | 
            +
                """Compute the functional map matrix representation."""
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self, lambda_=1e-3, resolvant_gamma=0.5, use_resolvent=False):
         | 
| 61 | 
            +
                    super().__init__()
         | 
| 62 | 
            +
                    self.lambda_ = lambda_
         | 
| 63 | 
            +
                    self.use_resolvent = use_resolvent
         | 
| 64 | 
            +
                    self.resolvant_gamma = resolvant_gamma
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def forward(self, feat_x, feat_y, evals_x, evals_y, evecs_trans_x, evecs_trans_y, diff_conf=None):
         | 
| 67 | 
            +
                    # compute linear operator matrix representation C1 and C2
         | 
| 68 | 
            +
                    evecs_trans_x, evecs_trans_y = evecs_trans_x.unsqueeze(0), evecs_trans_y.unsqueeze(0)
         | 
| 69 | 
            +
                    evals_x, evals_y = evals_x.unsqueeze(0), evals_y.unsqueeze(0)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    F_hat = torch.bmm(evecs_trans_x, feat_x)
         | 
| 72 | 
            +
                    G_hat = torch.bmm(evecs_trans_y, feat_y)
         | 
| 73 | 
            +
                    A, B = F_hat, G_hat
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if diff_conf is not None:
         | 
| 78 | 
            +
                        diff_model, scale, normalize, absolute, N_est = diff_conf
         | 
| 79 | 
            +
                        C12_raw, C21_raw = get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y)
         | 
| 80 | 
            +
                        D12 = get_mask_noise(C12_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
         | 
| 81 | 
            +
                        D21 = get_mask_noise(C21_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
         | 
| 82 | 
            +
                    elif self.use_resolvent:
         | 
| 83 | 
            +
                        D12 = get_mask(evals_x.flatten(), evals_y.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
         | 
| 84 | 
            +
                        D21 = get_mask(evals_y.flatten(), evals_x.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        D12 = (torch.unsqueeze(evals_y, 2) - torch.unsqueeze(evals_x, 1))**2
         | 
| 87 | 
            +
                        D21 = (torch.unsqueeze(evals_x, 2) - torch.unsqueeze(evals_y, 1))**2
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
         | 
| 90 | 
            +
                    A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
         | 
| 91 | 
            +
                    B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    C12_i = []
         | 
| 94 | 
            +
                    for i in range(evals_x.size(1)):
         | 
| 95 | 
            +
                        D12_i = torch.cat([torch.diag(D12[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0)
         | 
| 96 | 
            +
                        C12 = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D12_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2))
         | 
| 97 | 
            +
                        C12_i.append(C12.transpose(1, 2))
         | 
| 98 | 
            +
                    C12 = torch.cat(C12_i, dim=1)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    C21_i = []
         | 
| 101 | 
            +
                    for i in range(evals_y.size(1)):
         | 
| 102 | 
            +
                        D21_i = torch.cat([torch.diag(D21[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_y.size(0))], dim=0)
         | 
| 103 | 
            +
                        C21 = torch.bmm(torch.inverse(B_B_t + self.lambda_ * D21_i), A_B_t[:, i, :].unsqueeze(1).transpose(1, 2))
         | 
| 104 | 
            +
                        C21_i.append(C21.transpose(1, 2))
         | 
| 105 | 
            +
                    C21 = torch.cat(C21_i, dim=1)
         | 
| 106 | 
            +
                    if diff_conf is not None:
         | 
| 107 | 
            +
                        return [C12_raw, C12, D12], [C21_raw, C21, D21]
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        return [C12, C21]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class DFMNet(nn.Module):
         | 
| 113 | 
            +
                """
         | 
| 114 | 
            +
                Compilation of the global model :
         | 
| 115 | 
            +
                - diffusion net as feature extractor
         | 
| 116 | 
            +
                - fmap + q-fmap
         | 
| 117 | 
            +
                - unsupervised loss
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def __init__(self, cfg):
         | 
| 121 | 
            +
                    super().__init__()
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # feature extractor #
         | 
| 124 | 
            +
                    with_grad=True
         | 
| 125 | 
            +
                    self.feat = cfg["feat"]
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    self.feature_extractor = DiffusionNet(
         | 
| 128 | 
            +
                         C_in=cfg["C_in"],
         | 
| 129 | 
            +
                         C_out=cfg["n_feat"],
         | 
| 130 | 
            +
                         C_width=128,
         | 
| 131 | 
            +
                         N_block=4,
         | 
| 132 | 
            +
                         dropout=True,
         | 
| 133 | 
            +
                         with_gradient_features=with_grad,
         | 
| 134 | 
            +
                         with_gradient_rotations=with_grad,
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # regularized fmap
         | 
| 138 | 
            +
                    self.fmreg_net = RegularizedFMNet(lambda_=cfg["lambda_"],
         | 
| 139 | 
            +
                                                      resolvant_gamma=cfg.get("resolvent_gamma", 0.5), use_resolvent=cfg.get("use_resolvent", False))
         | 
| 140 | 
            +
                    # parameters
         | 
| 141 | 
            +
                    self.n_fmap = cfg["n_fmap"]
         | 
| 142 | 
            +
                    if cfg.get("diffusion", None) is not None:
         | 
| 143 | 
            +
                        self.normalize = cfg["diffusion"]["normalize"]
         | 
| 144 | 
            +
                        self.abs = cfg["diffusion"]["abs"]
         | 
| 145 | 
            +
                        self.N_est = cfg.diffusion.get("batch_mask", 100)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def forward(self, batch, diff_model=None, scale=1):
         | 
| 148 | 
            +
                    if self.feat == "xyz":
         | 
| 149 | 
            +
                        feat_1, feat_2 = batch["shape1"]["vertices"], batch["shape2"]["vertices"]
         | 
| 150 | 
            +
                    elif self.feat == "wks":
         | 
| 151 | 
            +
                        feat_1, feat_2 = batch["shape1"]["wks"], batch["shape2"]["wks"]
         | 
| 152 | 
            +
                    elif self.feat == "hks":
         | 
| 153 | 
            +
                        feat_1, feat_2 = batch["shape1"]["hks"], batch["shape2"]["hks"]
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        raise Exception("Unknow Feature")
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
                    verts1, faces1, mass1, L1, evals1, evecs1, gradX1, gradY1 = (feat_1, batch["shape1"]["faces"],
         | 
| 159 | 
            +
                                                                                 batch["shape1"]["mass"], batch["shape1"]["L"],
         | 
| 160 | 
            +
                                                                                 batch["shape1"]["evals"], batch["shape1"]["evecs"],
         | 
| 161 | 
            +
                                                                                 batch["shape1"]["gradX"], batch["shape1"]["gradY"])
         | 
| 162 | 
            +
                    verts2, faces2, mass2, L2, evals2, evecs2, gradX2, gradY2 = (feat_2, batch["shape2"]["faces"],
         | 
| 163 | 
            +
                                                                                 batch["shape2"]["mass"], batch["shape2"]["L"],
         | 
| 164 | 
            +
                                                                                 batch["shape2"]["evals"], batch["shape2"]["evecs"],
         | 
| 165 | 
            +
                                                                                 batch["shape2"]["gradX"], batch["shape2"]["gradY"])
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # set features to vertices
         | 
| 168 | 
            +
                    features1, features2 = verts1, verts2
         | 
| 169 | 
            +
                    # print(features1.shape, features2.shape)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    feat1 = self.feature_extractor(features1, mass1, L=L1, evals=evals1, evecs=evecs1,
         | 
| 172 | 
            +
                                                   gradX=gradX1, gradY=gradY1, faces=faces1).unsqueeze(0)
         | 
| 173 | 
            +
                    feat2 = self.feature_extractor(features2, mass2, L=L2, evals=evals2, evecs=evecs2,
         | 
| 174 | 
            +
                                                   gradX=gradX2, gradY=gradY2, faces=faces2).unsqueeze(0)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    evecs_trans1, evecs_trans2 = evecs1.t()[:self.n_fmap] @ torch.diag(mass1.squeeze()), evecs2.squeeze().t()[:self.n_fmap].squeeze() @ torch.diag(mass2.squeeze())
         | 
| 177 | 
            +
                    evals1, evals2 = evals1[:self.n_fmap], evals2[:self.n_fmap]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    #
         | 
| 180 | 
            +
                    if diff_model is not None:
         | 
| 181 | 
            +
                        C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2,
         | 
| 182 | 
            +
                                                        diff_conf=[diff_model, scale, self.normalize, self.abs, self.N_est])
         | 
| 183 | 
            +
                    else:
         | 
| 184 | 
            +
                        C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2
         | 
    	
        shape_models/geometry.py
    ADDED
    
    | @@ -0,0 +1,897 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import scipy
         | 
| 2 | 
            +
            import scipy.sparse.linalg as sla
         | 
| 3 | 
            +
            # ^^^ we NEED to import scipy before torch, or it crashes :(
         | 
| 4 | 
            +
            # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os.path
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            from multiprocessing import Pool
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import scipy.spatial
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import sklearn.neighbors
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import robust_laplacian
         | 
| 17 | 
            +
            import potpourri3d as pp3d
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def norm(x, highdim=False):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                return torch.norm(x, dim=len(x.shape) - 1)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def norm2(x, highdim=False):
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                return dot(x, x)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def normalize(x, divide_eps=1e-6, highdim=False):
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                if (len(x.shape) == 1):
         | 
| 39 | 
            +
                    raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
         | 
| 40 | 
            +
                if (not highdim and x.shape[-1] > 4):
         | 
| 41 | 
            +
                    raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
         | 
| 42 | 
            +
                return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def face_coords(verts, faces):
         | 
| 46 | 
            +
                coords = verts[faces]
         | 
| 47 | 
            +
                return coords
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def cross(vec_A, vec_B):
         | 
| 51 | 
            +
                return torch.cross(vec_A, vec_B, dim=-1)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def dot(vec_A, vec_B):
         | 
| 55 | 
            +
                return torch.sum(vec_A * vec_B, dim=-1)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Given (..., 3) vectors and normals, projects out any components of vecs
         | 
| 59 | 
            +
            # which lies in the direction of normals. Normals are assumed to be unit.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def project_to_tangent(vecs, unit_normals):
         | 
| 63 | 
            +
                dots = dot(vecs, unit_normals)
         | 
| 64 | 
            +
                return vecs - unit_normals * dots.unsqueeze(-1)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def face_area(verts, faces):
         | 
| 68 | 
            +
                coords = face_coords(verts, faces)
         | 
| 69 | 
            +
                vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 70 | 
            +
                vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                raw_normal = cross(vec_A, vec_B)
         | 
| 73 | 
            +
                return 0.5 * norm(raw_normal)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def face_normals(verts, faces, normalized=True):
         | 
| 77 | 
            +
                coords = face_coords(verts, faces)
         | 
| 78 | 
            +
                vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 79 | 
            +
                vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                raw_normal = cross(vec_A, vec_B)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if normalized:
         | 
| 84 | 
            +
                    return normalize(raw_normal)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return raw_normal
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def neighborhood_normal(points):
         | 
| 90 | 
            +
                # points: (N, K, 3) array of neighborhood psoitions
         | 
| 91 | 
            +
                # points should be centered at origin
         | 
| 92 | 
            +
                # out: (N,3) array of normals
         | 
| 93 | 
            +
                # numpy in, numpy out
         | 
| 94 | 
            +
                (u, s, vh) = np.linalg.svd(points, full_matrices=False)
         | 
| 95 | 
            +
                normal = vh[:, 2, :]
         | 
| 96 | 
            +
                return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def mesh_vertex_normals(verts, faces):
         | 
| 100 | 
            +
                # numpy in / out
         | 
| 101 | 
            +
                face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces)))  # ugly torch <---> numpy
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                vertex_normals = np.zeros(verts.shape)
         | 
| 104 | 
            +
                for i in range(3):
         | 
| 105 | 
            +
                    np.add.at(vertex_normals, faces[:, i], face_n)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=-1, keepdims=True)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                return vertex_normals
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def vertex_normals(verts, faces, n_neighbors_cloud=30):
         | 
| 113 | 
            +
                verts_np = toNP(verts)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if faces.numel() == 0:  # point cloud
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
         | 
| 118 | 
            +
                    neigh_points = verts_np[neigh_inds, :]
         | 
| 119 | 
            +
                    neigh_points = neigh_points - verts_np[:, np.newaxis, :]
         | 
| 120 | 
            +
                    normals = neighborhood_normal(neigh_points)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                else:  # mesh
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    normals = mesh_vertex_normals(verts_np, toNP(faces))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    # if any are NaN, wiggle slightly and recompute
         | 
| 127 | 
            +
                    bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
         | 
| 128 | 
            +
                    if bad_normals_mask.any():
         | 
| 129 | 
            +
                        bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
         | 
| 130 | 
            +
                        scale = np.linalg.norm(bbox) * 1e-4
         | 
| 131 | 
            +
                        wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
         | 
| 132 | 
            +
                        wiggle_verts = verts_np + bad_normals_mask * wiggle
         | 
| 133 | 
            +
                        normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # if still NaN assign random normals (probably means unreferenced verts in mesh)
         | 
| 136 | 
            +
                    bad_normals_mask = np.isnan(normals).any(axis=1)
         | 
| 137 | 
            +
                    if bad_normals_mask.any():
         | 
| 138 | 
            +
                        normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
         | 
| 139 | 
            +
                        normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                if torch.any(torch.isnan(normals)):
         | 
| 144 | 
            +
                    raise ValueError("NaN normals :(")
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                return normals
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def build_tangent_frames(verts, faces, normals=None):
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                V = verts.shape[0]
         | 
| 152 | 
            +
                dtype = verts.dtype
         | 
| 153 | 
            +
                device = verts.device
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                if normals == None:
         | 
| 156 | 
            +
                    vert_normals = vertex_normals(verts, faces)  # (V,3)
         | 
| 157 | 
            +
                else:
         | 
| 158 | 
            +
                    vert_normals = normals
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                # = find an orthogonal basis
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
         | 
| 163 | 
            +
                basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
         | 
| 166 | 
            +
                basisX = project_to_tangent(basisX, vert_normals)
         | 
| 167 | 
            +
                basisX = normalize(basisX)
         | 
| 168 | 
            +
                basisY = cross(vert_normals, basisX)
         | 
| 169 | 
            +
                frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                if torch.any(torch.isnan(frames)):
         | 
| 172 | 
            +
                    raise ValueError("NaN coordinate frame! Must be very degenerate")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                return frames
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                verts_np = toNP(verts)
         | 
| 180 | 
            +
                frames_np = toNP(frames)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
         | 
| 183 | 
            +
                neigh_points = verts_np[neigh_inds, :]
         | 
| 184 | 
            +
                neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                # TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
         | 
| 189 | 
            +
                edges = np.stack((edge_inds_from, neigh_inds.flatten()))
         | 
| 190 | 
            +
                edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def edge_tangent_vectors(verts, frames, edges):
         | 
| 196 | 
            +
                edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
         | 
| 197 | 
            +
                basisX = frames[edges[0, :], 0, :]
         | 
| 198 | 
            +
                basisY = frames[edges[0, :], 1, :]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                compX = dot(edge_vecs, basisX)
         | 
| 201 | 
            +
                compY = dot(edge_vecs, basisY)
         | 
| 202 | 
            +
                edge_tangent = torch.stack((compX, compY), dim=-1)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                return edge_tangent
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def build_grad(verts, edges, edge_tangent_vectors):
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
         | 
| 210 | 
            +
                - edges: (2, E)
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                edges_np = toNP(edges)
         | 
| 214 | 
            +
                edge_tangent_vectors_np = toNP(edge_tangent_vectors)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                # TODO find a way to do this in pure numpy?
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                # Build outgoing neighbor lists
         | 
| 219 | 
            +
                N = verts.shape[0]
         | 
| 220 | 
            +
                vert_edge_outgoing = [[] for i in range(N)]
         | 
| 221 | 
            +
                for iE in range(edges_np.shape[1]):
         | 
| 222 | 
            +
                    tail_ind = edges_np[0, iE]
         | 
| 223 | 
            +
                    tip_ind = edges_np[1, iE]
         | 
| 224 | 
            +
                    if tip_ind != tail_ind:
         | 
| 225 | 
            +
                        vert_edge_outgoing[tail_ind].append(iE)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                # Build local inversion matrix for each vertex
         | 
| 228 | 
            +
                row_inds = []
         | 
| 229 | 
            +
                col_inds = []
         | 
| 230 | 
            +
                data_vals = []
         | 
| 231 | 
            +
                eps_reg = 1e-5
         | 
| 232 | 
            +
                for iV in range(N):
         | 
| 233 | 
            +
                    n_neigh = len(vert_edge_outgoing[iV])
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    lhs_mat = np.zeros((n_neigh, 2))
         | 
| 236 | 
            +
                    rhs_mat = np.zeros((n_neigh, n_neigh + 1))
         | 
| 237 | 
            +
                    ind_lookup = [iV]
         | 
| 238 | 
            +
                    for i_neigh in range(n_neigh):
         | 
| 239 | 
            +
                        iE = vert_edge_outgoing[iV][i_neigh]
         | 
| 240 | 
            +
                        jV = edges_np[1, iE]
         | 
| 241 | 
            +
                        ind_lookup.append(jV)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        edge_vec = edge_tangent_vectors[iE][:]
         | 
| 244 | 
            +
                        w_e = 1.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        lhs_mat[i_neigh][:] = w_e * edge_vec
         | 
| 247 | 
            +
                        rhs_mat[i_neigh][0] = w_e * (-1)
         | 
| 248 | 
            +
                        rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    lhs_T = lhs_mat.T
         | 
| 251 | 
            +
                    lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    sol_mat = lhs_inv @ rhs_mat
         | 
| 254 | 
            +
                    sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    for i_neigh in range(n_neigh + 1):
         | 
| 257 | 
            +
                        i_glob = ind_lookup[i_neigh]
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                        row_inds.append(iV)
         | 
| 260 | 
            +
                        col_inds.append(i_glob)
         | 
| 261 | 
            +
                        data_vals.append(sol_coefs[i_neigh])
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                # build the sparse matrix
         | 
| 264 | 
            +
                row_inds = np.array(row_inds)
         | 
| 265 | 
            +
                col_inds = np.array(col_inds)
         | 
| 266 | 
            +
                data_vals = np.array(data_vals)
         | 
| 267 | 
            +
                mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                return mat
         | 
| 270 | 
            +
             | 
| 271 | 
            +
             | 
| 272 | 
            +
            def compute_operators(verts, faces, k_eig, normals=None):
         | 
| 273 | 
            +
                """
         | 
| 274 | 
            +
                Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                See get_operators() for a similar routine that wraps this one with a layer of caching.
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                Torch in / torch out.
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                Arguments:
         | 
| 281 | 
            +
                  - vertices: (V,3) vertex positions
         | 
| 282 | 
            +
                  - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
         | 
| 283 | 
            +
                  - k_eig: number of eigenvectors to use
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                Returns:
         | 
| 286 | 
            +
                  - frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
         | 
| 287 | 
            +
                  - massvec: (V) real diagonal of lumped mass matrix
         | 
| 288 | 
            +
                  - L: (VxV) real sparse matrix of (weak) Laplacian
         | 
| 289 | 
            +
                  - evals: (k) list of eigenvalues of the Laplacian
         | 
| 290 | 
            +
                  - evecs: (V,k) list of eigenvectors of the Laplacian 
         | 
| 291 | 
            +
                  - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
         | 
| 292 | 
            +
                  - gradY: same as gradX but for Y-component of gradient
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
         | 
| 297 | 
            +
                """
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                device = verts.device
         | 
| 300 | 
            +
                dtype = verts.dtype
         | 
| 301 | 
            +
                V = verts.shape[0]
         | 
| 302 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                eps = 1e-8
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                verts_np = toNP(verts).astype(np.float64)
         | 
| 307 | 
            +
                faces_np = toNP(faces)
         | 
| 308 | 
            +
                frames = build_tangent_frames(verts, faces, normals=normals)
         | 
| 309 | 
            +
                frames_np = toNP(frames)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                # Build the scalar Laplacian
         | 
| 312 | 
            +
                if is_cloud:
         | 
| 313 | 
            +
                    L, M = robust_laplacian.point_cloud_laplacian(verts_np)
         | 
| 314 | 
            +
                    massvec_np = M.diagonal()
         | 
| 315 | 
            +
                else:
         | 
| 316 | 
            +
                    # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
         | 
| 317 | 
            +
                    # massvec_np = M.diagonal()
         | 
| 318 | 
            +
                    L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
         | 
| 319 | 
            +
                    massvec_np = pp3d.vertex_areas(verts_np, faces_np)
         | 
| 320 | 
            +
                    massvec_np += eps * np.mean(massvec_np)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                if (np.isnan(L.data).any()):
         | 
| 323 | 
            +
                    raise RuntimeError("NaN Laplace matrix")
         | 
| 324 | 
            +
                if (np.isnan(massvec_np).any()):
         | 
| 325 | 
            +
                    raise RuntimeError("NaN mass matrix")
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                # Read off neighbors & rotations from the Laplacian
         | 
| 328 | 
            +
                L_coo = L.tocoo()
         | 
| 329 | 
            +
                inds_row = L_coo.row
         | 
| 330 | 
            +
                inds_col = L_coo.col
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                # === Compute the eigenbasis
         | 
| 333 | 
            +
                if k_eig > 0:
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    # Prepare matrices
         | 
| 336 | 
            +
                    L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
         | 
| 337 | 
            +
                    massvec_eigsh = massvec_np
         | 
| 338 | 
            +
                    Mmat = scipy.sparse.diags(massvec_eigsh)
         | 
| 339 | 
            +
                    eigs_sigma = eps
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    failcount = 0
         | 
| 342 | 
            +
                    while True:
         | 
| 343 | 
            +
                        try:
         | 
| 344 | 
            +
                            # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
         | 
| 345 | 
            +
                            evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                            # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
         | 
| 348 | 
            +
                            evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                            break
         | 
| 351 | 
            +
                        except Exception as e:
         | 
| 352 | 
            +
                            print(e)
         | 
| 353 | 
            +
                            if (failcount > 3):
         | 
| 354 | 
            +
                                raise ValueError("failed to compute eigendecomp")
         | 
| 355 | 
            +
                            failcount += 1
         | 
| 356 | 
            +
                            print("--- decomp failed; adding eps ===> count: " + str(failcount))
         | 
| 357 | 
            +
                            L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                else:  #k_eig == 0
         | 
| 360 | 
            +
                    evals_np = np.zeros((0))
         | 
| 361 | 
            +
                    evecs_np = np.zeros((verts.shape[0], 0))
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                # == Build gradient matrices
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                # For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
         | 
| 366 | 
            +
                if is_cloud:
         | 
| 367 | 
            +
                    grad_mat_np = build_grad_point_cloud(verts, frames)
         | 
| 368 | 
            +
                else:
         | 
| 369 | 
            +
                    edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
         | 
| 370 | 
            +
                    edge_vecs = edge_tangent_vectors(verts, frames, edges)
         | 
| 371 | 
            +
                    grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
         | 
| 374 | 
            +
                gradX_np = np.real(grad_mat_np)
         | 
| 375 | 
            +
                gradY_np = np.imag(grad_mat_np)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                # === Convert back to torch
         | 
| 378 | 
            +
                massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
         | 
| 379 | 
            +
                L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 380 | 
            +
                evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
         | 
| 381 | 
            +
                evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
         | 
| 382 | 
            +
                gradX = utils.sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
         | 
| 383 | 
            +
                gradY = utils.sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                return frames, massvec, L, evals, evecs, gradX, gradY
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
            def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
         | 
| 389 | 
            +
                N = len(verts_list)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                frames = [None] * N
         | 
| 392 | 
            +
                massvec = [None] * N
         | 
| 393 | 
            +
                L = [None] * N
         | 
| 394 | 
            +
                evals = [None] * N
         | 
| 395 | 
            +
                evecs = [None] * N
         | 
| 396 | 
            +
                gradX = [None] * N
         | 
| 397 | 
            +
                gradY = [None] * N
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                inds = [i for i in range(N)]
         | 
| 400 | 
            +
                # process in random order
         | 
| 401 | 
            +
                # random.shuffle(inds)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                for num, i in enumerate(inds):
         | 
| 404 | 
            +
                    print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
         | 
| 405 | 
            +
                    if normals is None:
         | 
| 406 | 
            +
                        outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
         | 
| 407 | 
            +
                    else:
         | 
| 408 | 
            +
                        outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
         | 
| 409 | 
            +
                    frames[i] = outputs[0]
         | 
| 410 | 
            +
                    massvec[i] = outputs[1]
         | 
| 411 | 
            +
                    L[i] = outputs[2]
         | 
| 412 | 
            +
                    evals[i] = outputs[3]
         | 
| 413 | 
            +
                    evecs[i] = outputs[4]
         | 
| 414 | 
            +
                    gradX[i] = outputs[5]
         | 
| 415 | 
            +
                    gradY[i] = outputs[6]
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                return frames, massvec, L, evals, evecs, gradX, gradY
         | 
| 418 | 
            +
             | 
| 419 | 
            +
             | 
| 420 | 
            +
            def get_operators(verts, faces, k_eig=128, op_cache_dir=None, normals=None, overwrite_cache=False):
         | 
| 421 | 
            +
                """
         | 
| 422 | 
            +
                See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
         | 
| 423 | 
            +
                All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
         | 
| 424 | 
            +
                """
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                device = verts.device
         | 
| 427 | 
            +
                dtype = verts.dtype
         | 
| 428 | 
            +
                verts_np = toNP(verts)
         | 
| 429 | 
            +
                faces_np = toNP(faces)
         | 
| 430 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                if (np.isnan(verts_np).any()):
         | 
| 433 | 
            +
                    raise RuntimeError("tried to construct operators from NaN verts")
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                # Check the cache directory
         | 
| 436 | 
            +
                # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
         | 
| 437 | 
            +
                #         but for good measure we check values nonetheless.
         | 
| 438 | 
            +
                # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
         | 
| 439 | 
            +
                #         entries in this cache. The good news is that that is totally fine, and at most slightly
         | 
| 440 | 
            +
                #         slows performance with rare extra cache misses.
         | 
| 441 | 
            +
                found = False
         | 
| 442 | 
            +
                if op_cache_dir is not None:
         | 
| 443 | 
            +
                    utils.ensure_dir_exists(op_cache_dir)
         | 
| 444 | 
            +
                    hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
         | 
| 445 | 
            +
                    # print("Building operators for input with hash: " + hash_key_str)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    # Search through buckets with matching hashes.  When the loop exits, this
         | 
| 448 | 
            +
                    # is the bucket index of the file we should write to.
         | 
| 449 | 
            +
                    i_cache_search = 0
         | 
| 450 | 
            +
                    while True:
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        # Form the name of the file to check
         | 
| 453 | 
            +
                        search_path = os.path.join(op_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                        try:
         | 
| 456 | 
            +
                            # print('loading path: ' + str(search_path))
         | 
| 457 | 
            +
                            npzfile = np.load(search_path, allow_pickle=True)
         | 
| 458 | 
            +
                            cache_verts = npzfile["verts"]
         | 
| 459 | 
            +
                            cache_faces = npzfile["faces"]
         | 
| 460 | 
            +
                            cache_k_eig = npzfile["k_eig"].item()
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                            # If the cache doesn't match, keep looking
         | 
| 463 | 
            +
                            if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
         | 
| 464 | 
            +
                                i_cache_search += 1
         | 
| 465 | 
            +
                                print("hash collision! searching next.")
         | 
| 466 | 
            +
                                continue
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                            # print("  cache hit!")
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                            # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
         | 
| 471 | 
            +
                            # entry below more eigenvalues
         | 
| 472 | 
            +
                            if overwrite_cache:
         | 
| 473 | 
            +
                                print("  overwriting cache by request")
         | 
| 474 | 
            +
                                os.remove(search_path)
         | 
| 475 | 
            +
                                break
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                            if cache_k_eig < k_eig:
         | 
| 478 | 
            +
                                print("  overwriting cache --- not enough eigenvalues")
         | 
| 479 | 
            +
                                os.remove(search_path)
         | 
| 480 | 
            +
                                break
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                            if "L_data" not in npzfile:
         | 
| 483 | 
            +
                                print("  overwriting cache --- entries are absent")
         | 
| 484 | 
            +
                                os.remove(search_path)
         | 
| 485 | 
            +
                                break
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                            def read_sp_mat(prefix):
         | 
| 488 | 
            +
                                data = npzfile[prefix + "_data"]
         | 
| 489 | 
            +
                                indices = npzfile[prefix + "_indices"]
         | 
| 490 | 
            +
                                indptr = npzfile[prefix + "_indptr"]
         | 
| 491 | 
            +
                                shape = npzfile[prefix + "_shape"]
         | 
| 492 | 
            +
                                mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
         | 
| 493 | 
            +
                                return mat
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                            # This entry matches! Return it.
         | 
| 496 | 
            +
                            frames = npzfile["frames"]
         | 
| 497 | 
            +
                            mass = npzfile["mass"]
         | 
| 498 | 
            +
                            L = read_sp_mat("L")
         | 
| 499 | 
            +
                            evals = npzfile["evals"][:k_eig]
         | 
| 500 | 
            +
                            evecs = npzfile["evecs"][:, :k_eig]
         | 
| 501 | 
            +
                            gradX = read_sp_mat("gradX")
         | 
| 502 | 
            +
                            gradY = read_sp_mat("gradY")
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                            frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
         | 
| 505 | 
            +
                            mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
         | 
| 506 | 
            +
                            L = utils.sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 507 | 
            +
                            evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
         | 
| 508 | 
            +
                            evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
         | 
| 509 | 
            +
                            gradX = utils.sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
         | 
| 510 | 
            +
                            gradY = utils.sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                            found = True
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                            break
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                        except FileNotFoundError:
         | 
| 517 | 
            +
                            print("  cache miss -- constructing operators")
         | 
| 518 | 
            +
                            break
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                        except Exception as E:
         | 
| 521 | 
            +
                            print("unexpected error loading file: " + str(E))
         | 
| 522 | 
            +
                            print("-- constructing operators")
         | 
| 523 | 
            +
                            break
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                if not found:
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    # No matching entry found; recompute.
         | 
| 528 | 
            +
                    frames, mass, L, evals, evecs, gradX, gradY = compute_operators(verts, faces, k_eig, normals=normals)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    dtype_np = np.float32
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    # Store it in the cache
         | 
| 533 | 
            +
                    if op_cache_dir is not None:
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                        L_np = utils.sparse_torch_to_np(L).astype(dtype_np)
         | 
| 536 | 
            +
                        gradX_np = utils.sparse_torch_to_np(gradX).astype(dtype_np)
         | 
| 537 | 
            +
                        gradY_np = utils.sparse_torch_to_np(gradY).astype(dtype_np)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                        np.savez(
         | 
| 540 | 
            +
                            search_path,
         | 
| 541 | 
            +
                            verts=verts_np.astype(dtype_np),
         | 
| 542 | 
            +
                            frames=toNP(frames).astype(dtype_np),
         | 
| 543 | 
            +
                            faces=faces_np,
         | 
| 544 | 
            +
                            k_eig=k_eig,
         | 
| 545 | 
            +
                            mass=toNP(mass).astype(dtype_np),
         | 
| 546 | 
            +
                            L_data=L_np.data.astype(dtype_np),
         | 
| 547 | 
            +
                            L_indices=L_np.indices,
         | 
| 548 | 
            +
                            L_indptr=L_np.indptr,
         | 
| 549 | 
            +
                            L_shape=L_np.shape,
         | 
| 550 | 
            +
                            evals=toNP(evals).astype(dtype_np),
         | 
| 551 | 
            +
                            evecs=toNP(evecs).astype(dtype_np),
         | 
| 552 | 
            +
                            gradX_data=gradX_np.data.astype(dtype_np),
         | 
| 553 | 
            +
                            gradX_indices=gradX_np.indices,
         | 
| 554 | 
            +
                            gradX_indptr=gradX_np.indptr,
         | 
| 555 | 
            +
                            gradX_shape=gradX_np.shape,
         | 
| 556 | 
            +
                            gradY_data=gradY_np.data.astype(dtype_np),
         | 
| 557 | 
            +
                            gradY_indices=gradY_np.indices,
         | 
| 558 | 
            +
                            gradY_indptr=gradY_np.indptr,
         | 
| 559 | 
            +
                            gradY_shape=gradY_np.shape,
         | 
| 560 | 
            +
                        )
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                return frames, mass, L, evals, evecs, gradX, gradY
         | 
| 563 | 
            +
             | 
| 564 | 
            +
             | 
| 565 | 
            +
            def to_basis(values, basis, massvec):
         | 
| 566 | 
            +
                """
         | 
| 567 | 
            +
                Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
         | 
| 568 | 
            +
                Inputs:
         | 
| 569 | 
            +
                  - values: (B,V,D)
         | 
| 570 | 
            +
                  - basis: (B,V,K)
         | 
| 571 | 
            +
                  - massvec: (B,V)
         | 
| 572 | 
            +
                Outputs:
         | 
| 573 | 
            +
                  - (B,K,D) transformed values
         | 
| 574 | 
            +
                """
         | 
| 575 | 
            +
                basisT = basis.transpose(-2, -1)
         | 
| 576 | 
            +
                return torch.matmul(basisT, values * massvec.unsqueeze(-1))
         | 
| 577 | 
            +
             | 
| 578 | 
            +
             | 
| 579 | 
            +
            def from_basis(values, basis):
         | 
| 580 | 
            +
                """
         | 
| 581 | 
            +
                Transform data out of an orthonormal basis
         | 
| 582 | 
            +
                Inputs:
         | 
| 583 | 
            +
                  - values: (K,D)
         | 
| 584 | 
            +
                  - basis: (V,K)
         | 
| 585 | 
            +
                Outputs:
         | 
| 586 | 
            +
                  - (V,D) reconstructed values
         | 
| 587 | 
            +
                """
         | 
| 588 | 
            +
                if values.is_complex() or basis.is_complex():
         | 
| 589 | 
            +
                    return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
         | 
| 590 | 
            +
                else:
         | 
| 591 | 
            +
                    return torch.matmul(basis, values)
         | 
| 592 | 
            +
             | 
| 593 | 
            +
             | 
| 594 | 
            +
            def compute_hks(evals, evecs, scales):
         | 
| 595 | 
            +
                """
         | 
| 596 | 
            +
                Inputs:
         | 
| 597 | 
            +
                  - evals: (K) eigenvalues
         | 
| 598 | 
            +
                  - evecs: (V,K) values
         | 
| 599 | 
            +
                  - scales: (S) times
         | 
| 600 | 
            +
                Outputs:
         | 
| 601 | 
            +
                  - (V,S) hks values
         | 
| 602 | 
            +
                """
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                # expand batch
         | 
| 605 | 
            +
                if len(evals.shape) == 1:
         | 
| 606 | 
            +
                    expand_batch = True
         | 
| 607 | 
            +
                    evals = evals.unsqueeze(0)
         | 
| 608 | 
            +
                    evecs = evecs.unsqueeze(0)
         | 
| 609 | 
            +
                    scales = scales.unsqueeze(0)
         | 
| 610 | 
            +
                else:
         | 
| 611 | 
            +
                    expand_batch = False
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                # TODO could be a matmul
         | 
| 614 | 
            +
                power_coefs = torch.exp(-evals.unsqueeze(1) * scales.unsqueeze(-1)).unsqueeze(1)  # (B,1,S,K)
         | 
| 615 | 
            +
                terms = power_coefs * (evecs * evecs).unsqueeze(2)  # (B,V,S,K)
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                out = torch.sum(terms, dim=-1)  # (B,V,S)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                if expand_batch:
         | 
| 620 | 
            +
                    return out.squeeze(0)
         | 
| 621 | 
            +
                else:
         | 
| 622 | 
            +
                    return out
         | 
| 623 | 
            +
             | 
| 624 | 
            +
             | 
| 625 | 
            +
            def compute_hks_autoscale(evals, evecs, count):
         | 
| 626 | 
            +
                # these scales roughly approximate those suggested in the hks paper
         | 
| 627 | 
            +
                scales = torch.logspace(-2, 0., steps=count, device=evals.device, dtype=evals.dtype)
         | 
| 628 | 
            +
                return compute_hks(evals, evecs, scales)
         | 
| 629 | 
            +
             | 
| 630 | 
            +
             | 
| 631 | 
            +
            def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
         | 
| 632 | 
            +
                # center and unit-scale positions
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                if method == 'mean':
         | 
| 635 | 
            +
                    # center using the average point position
         | 
| 636 | 
            +
                    pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
         | 
| 637 | 
            +
                elif method == 'bbox':
         | 
| 638 | 
            +
                    # center via the middle of the axis-aligned bounding box
         | 
| 639 | 
            +
                    bbox_min = torch.min(pos, dim=-2).values
         | 
| 640 | 
            +
                    bbox_max = torch.max(pos, dim=-2).values
         | 
| 641 | 
            +
                    center = (bbox_max + bbox_min) / 2.
         | 
| 642 | 
            +
                    pos -= center.unsqueeze(-2)
         | 
| 643 | 
            +
                else:
         | 
| 644 | 
            +
                    raise ValueError("unrecognized method")
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                if scale_method == 'max_rad':
         | 
| 647 | 
            +
                    scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
         | 
| 648 | 
            +
                    pos = pos / scale
         | 
| 649 | 
            +
                elif scale_method == 'area':
         | 
| 650 | 
            +
                    if faces is None:
         | 
| 651 | 
            +
                        raise ValueError("must pass faces for area normalization")
         | 
| 652 | 
            +
                    coords = pos[faces]
         | 
| 653 | 
            +
                    vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 654 | 
            +
                    vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 655 | 
            +
                    face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
         | 
| 656 | 
            +
                    total_area = torch.sum(face_areas)
         | 
| 657 | 
            +
                    scale = (1. / torch.sqrt(total_area))
         | 
| 658 | 
            +
                    pos = pos * scale
         | 
| 659 | 
            +
                else:
         | 
| 660 | 
            +
                    raise ValueError("unrecognized scale method")
         | 
| 661 | 
            +
                return pos
         | 
| 662 | 
            +
             | 
| 663 | 
            +
             | 
| 664 | 
            +
            # Finds the k nearest neighbors of source on target.
         | 
| 665 | 
            +
            # Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
         | 
| 666 | 
            +
            def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
         | 
| 669 | 
            +
                    raise ValueError("omit_diagonal can only be used when source and target are same shape")
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
         | 
| 672 | 
            +
                    method = 'cpu_kd'
         | 
| 673 | 
            +
                    print("switching to cpu_kd knn")
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                if method == 'brute':
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                    # Expand so both are NxMx3 tensor
         | 
| 678 | 
            +
                    points_source_expand = points_source.unsqueeze(1)
         | 
| 679 | 
            +
                    points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
         | 
| 680 | 
            +
                    points_target_expand = points_target.unsqueeze(0)
         | 
| 681 | 
            +
                    points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    diff_mat = points_source_expand - points_target_expand
         | 
| 684 | 
            +
                    dist_mat = norm(diff_mat)
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    if omit_diagonal:
         | 
| 687 | 
            +
                        torch.diagonal(dist_mat)[:] = float('inf')
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
         | 
| 690 | 
            +
                    return result
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                elif method == 'cpu_kd':
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    if largest:
         | 
| 695 | 
            +
                        raise ValueError("can't do largest with cpu_kd")
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    points_source_np = toNP(points_source)
         | 
| 698 | 
            +
                    points_target_np = toNP(points_target)
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    # Build the tree
         | 
| 701 | 
            +
                    kd_tree = sklearn.neighbors.KDTree(points_target_np)
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                    k_search = k + 1 if omit_diagonal else k
         | 
| 704 | 
            +
                    _, neighbors = kd_tree.query(points_source_np, k=k_search)
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                    if omit_diagonal:
         | 
| 707 | 
            +
                        # Mask out self element
         | 
| 708 | 
            +
                        mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                        # make sure we mask out exactly one element in each row, in rare case of many duplicate points
         | 
| 711 | 
            +
                        mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                        neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
         | 
| 716 | 
            +
                    dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                    return dists, inds
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                else:
         | 
| 721 | 
            +
                    raise ValueError("unrecognized method")
         | 
| 722 | 
            +
             | 
| 723 | 
            +
             | 
| 724 | 
            +
            def farthest_point_sampling(points, n_sample):
         | 
| 725 | 
            +
                # Torch in, torch out. Returns a |V| mask with n_sample elements set to true.
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                N = points.shape[0]
         | 
| 728 | 
            +
                if (n_sample > N):
         | 
| 729 | 
            +
                    raise ValueError("not enough points to sample")
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                chosen_mask = torch.zeros(N, dtype=torch.bool, device=points.device)
         | 
| 732 | 
            +
                min_dists = torch.ones(N, dtype=points.dtype, device=points.device) * float('inf')
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                # pick the centermost first point
         | 
| 735 | 
            +
                points = normalize_positions(points)
         | 
| 736 | 
            +
                i = torch.min(norm2(points), dim=0).indices
         | 
| 737 | 
            +
                chosen_mask[i] = True
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                for _ in range(n_sample - 1):
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    # update distance
         | 
| 742 | 
            +
                    dists = norm2(points[i, :].unsqueeze(0) - points)
         | 
| 743 | 
            +
                    min_dists = torch.minimum(dists, min_dists)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    # take the farthest
         | 
| 746 | 
            +
                    i = torch.max(min_dists, dim=0).indices.item()
         | 
| 747 | 
            +
                    chosen_mask[i] = True
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                return chosen_mask
         | 
| 750 | 
            +
             | 
| 751 | 
            +
             | 
| 752 | 
            +
            def geodesic_label_errors(target_verts,
         | 
| 753 | 
            +
                                      target_faces,
         | 
| 754 | 
            +
                                      pred_labels,
         | 
| 755 | 
            +
                                      gt_labels,
         | 
| 756 | 
            +
                                      normalization='diameter',
         | 
| 757 | 
            +
                                      geodesic_cache_dir=None):
         | 
| 758 | 
            +
                """
         | 
| 759 | 
            +
                Return a vector of distances between predicted and ground-truth lables (normalized by geodesic diameter or area)
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                This method is SLOW when it needs to recompute geodesic distances.
         | 
| 762 | 
            +
                """
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                # move all to numpy cpu
         | 
| 765 | 
            +
                target_verts = toNP(target_verts)
         | 
| 766 | 
            +
                target_faces = toNP(target_faces)
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                pred_labels = toNP(pred_labels)
         | 
| 769 | 
            +
                gt_labels = toNP(gt_labels)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                dists = get_all_pairs_geodesic_distance(target_verts, target_faces, geodesic_cache_dir)
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                result_dists = dists[pred_labels, gt_labels]
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                if normalization == 'diameter':
         | 
| 776 | 
            +
                    geodesic_diameter = np.max(dists)
         | 
| 777 | 
            +
                    normalized_result_dists = result_dists / geodesic_diameter
         | 
| 778 | 
            +
                elif normalization == 'area':
         | 
| 779 | 
            +
                    total_area = torch.sum(face_area(torch.tensor(target_verts), torch.tensor(target_faces)))
         | 
| 780 | 
            +
                    normalized_result_dists = result_dists / torch.sqrt(total_area)
         | 
| 781 | 
            +
                else:
         | 
| 782 | 
            +
                    raise ValueError('unrecognized normalization')
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                return normalized_result_dists
         | 
| 785 | 
            +
             | 
| 786 | 
            +
             | 
| 787 | 
            +
            # This function and the helper class below are to support parallel computation of all-pairs geodesic distance
         | 
| 788 | 
            +
            def all_pairs_geodesic_worker(verts, faces, i):
         | 
| 789 | 
            +
                import igl
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                N = verts.shape[0]
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                # TODO: this re-does a ton of work, since it is called independently each time. Some custom C++ code could surely make it faster.
         | 
| 794 | 
            +
                sources = np.array([i])[:, np.newaxis]
         | 
| 795 | 
            +
                targets = np.arange(N)[:, np.newaxis]
         | 
| 796 | 
            +
                dist_vec = igl.exact_geodesic(verts, faces, sources, targets)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                return dist_vec
         | 
| 799 | 
            +
             | 
| 800 | 
            +
             | 
| 801 | 
            +
            class AllPairsGeodesicEngine(object):
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                def __init__(self, verts, faces):
         | 
| 804 | 
            +
                    self.verts = verts
         | 
| 805 | 
            +
                    self.faces = faces
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                def __call__(self, i):
         | 
| 808 | 
            +
                    return all_pairs_geodesic_worker(self.verts, self.faces, i)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
             | 
| 811 | 
            +
            def get_all_pairs_geodesic_distance(verts_np, faces_np, geodesic_cache_dir=None):
         | 
| 812 | 
            +
                """
         | 
| 813 | 
            +
                Return a gigantic VxV dense matrix containing the all-pairs geodesic distance matrix. Internally caches, recomputing only if necessary.
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                (numpy in, numpy out)
         | 
| 816 | 
            +
                """
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                # need libigl for geodesic call
         | 
| 819 | 
            +
                try:
         | 
| 820 | 
            +
                    import igl
         | 
| 821 | 
            +
                except ImportError as e:
         | 
| 822 | 
            +
                    raise ImportError("Must have python libigl installed for all-pairs geodesics. `conda install -c conda-forge igl`")
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                # Check the cache
         | 
| 825 | 
            +
                found = False
         | 
| 826 | 
            +
                if geodesic_cache_dir is not None:
         | 
| 827 | 
            +
                    utils.ensure_dir_exists(geodesic_cache_dir)
         | 
| 828 | 
            +
                    hash_key_str = str(utils.hash_arrays((verts_np, faces_np)))
         | 
| 829 | 
            +
                    # print("Building operators for input with hash: " + hash_key_str)
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                    # Search through buckets with matching hashes.  When the loop exits, this
         | 
| 832 | 
            +
                    # is the bucket index of the file we should write to.
         | 
| 833 | 
            +
                    i_cache_search = 0
         | 
| 834 | 
            +
                    while True:
         | 
| 835 | 
            +
             | 
| 836 | 
            +
                        # Form the name of the file to check
         | 
| 837 | 
            +
                        search_path = os.path.join(geodesic_cache_dir, hash_key_str + "_" + str(i_cache_search) + ".npz")
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                        try:
         | 
| 840 | 
            +
                            npzfile = np.load(search_path, allow_pickle=True)
         | 
| 841 | 
            +
                            cache_verts = npzfile["verts"]
         | 
| 842 | 
            +
                            cache_faces = npzfile["faces"]
         | 
| 843 | 
            +
             | 
| 844 | 
            +
                            # If the cache doesn't match, keep looking
         | 
| 845 | 
            +
                            if (not np.array_equal(verts_np, cache_verts)) or (not np.array_equal(faces_np, cache_faces)):
         | 
| 846 | 
            +
                                i_cache_search += 1
         | 
| 847 | 
            +
                                continue
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                            # This entry matches! Return it.
         | 
| 850 | 
            +
                            found = True
         | 
| 851 | 
            +
                            result_dists = npzfile["dist"]
         | 
| 852 | 
            +
                            break
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                        except FileNotFoundError:
         | 
| 855 | 
            +
                            break
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                if not found:
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                    print("Computing all-pairs geodesic distance (warning: SLOW!)")
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    # Not found, compute from scratch
         | 
| 862 | 
            +
                    # warning: slowwwwwww
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    N = verts_np.shape[0]
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                    try:
         | 
| 867 | 
            +
                        pool = Pool(None)  # on 8 processors
         | 
| 868 | 
            +
                        engine = AllPairsGeodesicEngine(verts_np, faces_np)
         | 
| 869 | 
            +
                        outputs = pool.map(engine, range(N))
         | 
| 870 | 
            +
                    finally:  # To make sure processes are closed in the end, even if errors happen
         | 
| 871 | 
            +
                        pool.close()
         | 
| 872 | 
            +
                        pool.join()
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    result_dists = np.array(outputs)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    # replace any failed values with nan
         | 
| 877 | 
            +
                    result_dists = np.nan_to_num(result_dists, nan=np.nan, posinf=np.nan, neginf=np.nan)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                    # we expect that this should be a symmetric matrix, but it might not be. Take the min of the symmetric values to make it symmetric
         | 
| 880 | 
            +
                    result_dists = np.fmin(result_dists, np.transpose(result_dists))
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                    # on rare occaisions MMP fails, yielding nan/inf; set it to the largest non-failed value if so
         | 
| 883 | 
            +
                    max_dist = np.nanmax(result_dists)
         | 
| 884 | 
            +
                    result_dists = np.nan_to_num(result_dists, nan=max_dist, posinf=max_dist, neginf=max_dist)
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                    print("...finished computing all-pairs geodesic distance")
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                    # put it in the cache if possible
         | 
| 889 | 
            +
                    if geodesic_cache_dir is not None:
         | 
| 890 | 
            +
             | 
| 891 | 
            +
                        print("saving geodesic distances to cache: " + str(geodesic_cache_dir))
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                        # TODO we're potentially saving a double precision but only using a single
         | 
| 894 | 
            +
                        # precision here; could save storage by always saving as floats
         | 
| 895 | 
            +
                        np.savez(search_path, verts=verts_np, faces=faces_np, dist=result_dists)
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                return result_dists
         | 
    	
        shape_models/layers.py
    ADDED
    
    | @@ -0,0 +1,453 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import os.path
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import scipy
         | 
| 7 | 
            +
            import scipy.sparse.linalg as sla
         | 
| 8 | 
            +
            # ^^^ we NEED to import scipy before torch, or it crashes :(
         | 
| 9 | 
            +
            # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def to_basis(values, basis, massvec):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
         | 
| 20 | 
            +
                Inputs:
         | 
| 21 | 
            +
                  - values: (B,V,D)
         | 
| 22 | 
            +
                  - basis: (B,V,K)
         | 
| 23 | 
            +
                  - massvec: (B,V)
         | 
| 24 | 
            +
                Outputs:
         | 
| 25 | 
            +
                  - (B,K,D) transformed values
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                basisT = basis.transpose(-2, -1)
         | 
| 28 | 
            +
                return torch.matmul(basisT, values * massvec.unsqueeze(-1))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def from_basis(values, basis):
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                Transform data out of an orthonormal basis
         | 
| 34 | 
            +
                Inputs:
         | 
| 35 | 
            +
                  - values: (K,D)
         | 
| 36 | 
            +
                  - basis: (V,K)
         | 
| 37 | 
            +
                Outputs:
         | 
| 38 | 
            +
                  - (V,D) reconstructed values
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                if values.is_complex() or basis.is_complex():
         | 
| 41 | 
            +
                    return utils.cmatmul(utils.ensure_complex(basis), utils.ensure_complex(values))
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    return torch.matmul(basis, values)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class LearnedTimeDiffusion(nn.Module):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Applies diffusion with learned per-channel t.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                In the spectral domain this becomes 
         | 
| 51 | 
            +
                    f_out = e ^ (lambda_i t) f_in
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                Inputs:
         | 
| 54 | 
            +
                  - values: (V,C) in the spectral domain
         | 
| 55 | 
            +
                  - L: (V,V) sparse laplacian
         | 
| 56 | 
            +
                  - evals: (K) eigenvalues
         | 
| 57 | 
            +
                  - mass: (V) mass matrix diagonal
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                  (note: L/evals may be omitted as None depending on method)
         | 
| 60 | 
            +
                Outputs:
         | 
| 61 | 
            +
                  - (V,C) diffused values 
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __init__(self, C_inout, method='spectral'):
         | 
| 65 | 
            +
                    super(LearnedTimeDiffusion, self).__init__()
         | 
| 66 | 
            +
                    self.C_inout = C_inout
         | 
| 67 | 
            +
                    self.diffusion_time = nn.Parameter(torch.Tensor(C_inout))  # (C)
         | 
| 68 | 
            +
                    self.method = method  # one of ['spectral', 'implicit_dense']
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    nn.init.constant_(self.diffusion_time, 0.0)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def forward(self, x, L, mass, evals, evecs):
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # project times to the positive halfspace
         | 
| 75 | 
            +
                    # (and away from 0 in the incredibly rare chance that they get stuck)
         | 
| 76 | 
            +
                    with torch.no_grad():
         | 
| 77 | 
            +
                        self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if x.shape[-1] != self.C_inout:
         | 
| 80 | 
            +
                        raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
         | 
| 81 | 
            +
                            x.shape, self.C_inout))
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    if self.method == 'spectral':
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        # Transform to spectral
         | 
| 86 | 
            +
                        x_spec = to_basis(x, evecs, mass)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                        # Diffuse
         | 
| 89 | 
            +
                        time = self.diffusion_time
         | 
| 90 | 
            +
                        diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
         | 
| 91 | 
            +
                        x_diffuse_spec = diffusion_coefs * x_spec
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        # Transform back to per-vertex
         | 
| 94 | 
            +
                        x_diffuse = from_basis(x_diffuse_spec, evecs)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    elif self.method == 'implicit_dense':
         | 
| 97 | 
            +
                        V = x.shape[-2]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        # Form the dense matrices (M + tL) with dims (B,C,V,V)
         | 
| 100 | 
            +
                        mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
         | 
| 101 | 
            +
                        mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
         | 
| 102 | 
            +
                        mat_dense += torch.diag_embed(mass).unsqueeze(1)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        # Factor the system
         | 
| 105 | 
            +
                        cholesky_factors = torch.linalg.cholesky(mat_dense)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                        # Solve the system
         | 
| 108 | 
            +
                        rhs = x * mass.unsqueeze(-1)
         | 
| 109 | 
            +
                        rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
         | 
| 110 | 
            +
                        sols = torch.cholesky_solve(rhsT, cholesky_factors)
         | 
| 111 | 
            +
                        x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        raise ValueError("unrecognized method")
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    return x_diffuse
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class SpatialGradientFeatures(nn.Module):
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                Input:
         | 
| 124 | 
            +
                    - vectors: (V,C,2)
         | 
| 125 | 
            +
                Output:
         | 
| 126 | 
            +
                    - dots: (V,C) dots 
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __init__(self, C_inout, with_gradient_rotations=True):
         | 
| 130 | 
            +
                    super(SpatialGradientFeatures, self).__init__()
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    self.C_inout = C_inout
         | 
| 133 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    if (self.with_gradient_rotations):
         | 
| 136 | 
            +
                        self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 137 | 
            +
                        self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # self.norm = nn.InstanceNorm1d(C_inout)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def forward(self, vectors):
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    vectorsA = vectors  # (V,C)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if self.with_gradient_rotations:
         | 
| 148 | 
            +
                        vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
         | 
| 149 | 
            +
                        vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        vectorsBreal = self.A(vectors[..., 0])
         | 
| 152 | 
            +
                        vectorsBimag = self.A(vectors[..., 1])
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    return torch.tanh(dots)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            class MiniMLP(nn.Sequential):
         | 
| 160 | 
            +
                '''
         | 
| 161 | 
            +
                A simple MLP with configurable hidden layer sizes.
         | 
| 162 | 
            +
                '''
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
         | 
| 165 | 
            +
                    super(MiniMLP, self).__init__()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    for i in range(len(layer_sizes) - 1):
         | 
| 168 | 
            +
                        is_last = (i + 2 == len(layer_sizes))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        if dropout and i > 0:
         | 
| 171 | 
            +
                            self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        # Affine map
         | 
| 174 | 
            +
                        self.add_module(
         | 
| 175 | 
            +
                            name + "_mlp_layer_{:03d}".format(i),
         | 
| 176 | 
            +
                            nn.Linear(
         | 
| 177 | 
            +
                                layer_sizes[i],
         | 
| 178 | 
            +
                                layer_sizes[i + 1],
         | 
| 179 | 
            +
                            ),
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                        # Nonlinearity
         | 
| 183 | 
            +
                        # (but not on the last layer)
         | 
| 184 | 
            +
                        if not is_last:
         | 
| 185 | 
            +
                            self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            class DiffusionNetBlock(nn.Module):
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
                Inputs and outputs are defined at vertices
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def __init__(self,
         | 
| 194 | 
            +
                             C_width,
         | 
| 195 | 
            +
                             mlp_hidden_dims,
         | 
| 196 | 
            +
                             dropout=True,
         | 
| 197 | 
            +
                             diffusion_method='spectral',
         | 
| 198 | 
            +
                             with_gradient_features=True,
         | 
| 199 | 
            +
                             with_gradient_rotations=True):
         | 
| 200 | 
            +
                    super(DiffusionNetBlock, self).__init__()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    # Specified dimensions
         | 
| 203 | 
            +
                    self.C_width = C_width
         | 
| 204 | 
            +
                    self.mlp_hidden_dims = mlp_hidden_dims
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    self.dropout = dropout
         | 
| 207 | 
            +
                    self.with_gradient_features = with_gradient_features
         | 
| 208 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # Diffusion block
         | 
| 211 | 
            +
                    self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.MLP_C = 2 * self.C_width
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if self.with_gradient_features:
         | 
| 216 | 
            +
                        self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
         | 
| 217 | 
            +
                        self.MLP_C += self.C_width
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # MLPs
         | 
| 220 | 
            +
                    self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # Manage dimensions
         | 
| 225 | 
            +
                    B = x_in.shape[0]  # batch dimension
         | 
| 226 | 
            +
                    if x_in.shape[-1] != self.C_width:
         | 
| 227 | 
            +
                        raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
         | 
| 228 | 
            +
                            x_in.shape, self.C_width))
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Diffusion block
         | 
| 231 | 
            +
                    x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Compute gradient features, if using
         | 
| 234 | 
            +
                    if self.with_gradient_features:
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                        # Compute gradients
         | 
| 237 | 
            +
                        x_grads = [
         | 
| 238 | 
            +
                        ]  # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
         | 
| 239 | 
            +
                        for b in range(B):
         | 
| 240 | 
            +
                            # gradient after diffusion
         | 
| 241 | 
            +
                            x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
         | 
| 242 | 
            +
                            x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                            x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
         | 
| 245 | 
            +
                        x_grad = torch.stack(x_grads, dim=0)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                        # Evaluate gradient features
         | 
| 248 | 
            +
                        x_grad_features = self.gradient_features(x_grad)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                        # Stack inputs to mlp
         | 
| 251 | 
            +
                        feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
         | 
| 252 | 
            +
                    else:
         | 
| 253 | 
            +
                        # Stack inputs to mlp
         | 
| 254 | 
            +
                        feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # Apply the mlp
         | 
| 257 | 
            +
                    x0_out = self.mlp(feature_combined)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # Skip connection
         | 
| 260 | 
            +
                    x0_out = x0_out + x_in
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    return x0_out
         | 
| 263 | 
            +
             | 
| 264 | 
            +
             | 
| 265 | 
            +
            class DiffusionNet(nn.Module):
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def __init__(self,
         | 
| 268 | 
            +
                             C_in,
         | 
| 269 | 
            +
                             C_out,
         | 
| 270 | 
            +
                             C_width=128,
         | 
| 271 | 
            +
                             N_block=4,
         | 
| 272 | 
            +
                             last_activation=None,
         | 
| 273 | 
            +
                             outputs_at='vertices',
         | 
| 274 | 
            +
                             mlp_hidden_dims=None,
         | 
| 275 | 
            +
                             dropout=True,
         | 
| 276 | 
            +
                             with_gradient_features=True,
         | 
| 277 | 
            +
                             with_gradient_rotations=True,
         | 
| 278 | 
            +
                             diffusion_method='spectral',
         | 
| 279 | 
            +
                             num_eigenbasis=128):
         | 
| 280 | 
            +
                    """
         | 
| 281 | 
            +
                    Construct a DiffusionNet.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    Parameters:
         | 
| 284 | 
            +
                        C_in (int):                     input dimension 
         | 
| 285 | 
            +
                        C_out (int):                    output dimension 
         | 
| 286 | 
            +
                        last_activation (func)          a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
         | 
| 287 | 
            +
                        outputs_at (string)             produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
         | 
| 288 | 
            +
                        C_width (int):                  dimension of internal DiffusionNet blocks (default: 128)
         | 
| 289 | 
            +
                        N_block (int):                  number of DiffusionNet blocks (default: 4)
         | 
| 290 | 
            +
                        mlp_hidden_dims (list of int):  a list of hidden layer sizes for MLPs (default: [C_width, C_width])
         | 
| 291 | 
            +
                        dropout (bool):                 if True, internal MLPs use dropout (default: True)
         | 
| 292 | 
            +
                        diffusion_method (string):      how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
         | 
| 293 | 
            +
                        with_gradient_features (bool):  if True, use gradient features (default: True)
         | 
| 294 | 
            +
                        with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
         | 
| 295 | 
            +
                        num_eigenbasis (int):           for trunking the eigenvalues eigenvectors
         | 
| 296 | 
            +
                    """
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    super(DiffusionNet, self).__init__()
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    ## Store parameters
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # Basic parameters
         | 
| 303 | 
            +
                    self.C_in = C_in
         | 
| 304 | 
            +
                    self.C_out = C_out
         | 
| 305 | 
            +
                    self.C_width = C_width
         | 
| 306 | 
            +
                    self.N_block = N_block
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # Outputs
         | 
| 309 | 
            +
                    self.last_activation = last_activation
         | 
| 310 | 
            +
                    self.outputs_at = outputs_at
         | 
| 311 | 
            +
                    if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
         | 
| 312 | 
            +
                        raise ValueError("invalid setting for outputs_at")
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    # MLP options
         | 
| 315 | 
            +
                    if mlp_hidden_dims == None:
         | 
| 316 | 
            +
                        mlp_hidden_dims = [C_width, C_width]
         | 
| 317 | 
            +
                    self.mlp_hidden_dims = mlp_hidden_dims
         | 
| 318 | 
            +
                    self.dropout = dropout
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    # Diffusion
         | 
| 321 | 
            +
                    self.diffusion_method = diffusion_method
         | 
| 322 | 
            +
                    if diffusion_method not in ['spectral', 'implicit_dense']:
         | 
| 323 | 
            +
                        raise ValueError("invalid setting for diffusion_method")
         | 
| 324 | 
            +
                    self.num_eigenbasis = num_eigenbasis
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    # Gradient features
         | 
| 327 | 
            +
                    self.with_gradient_features = with_gradient_features
         | 
| 328 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    ## Set up the network
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    # First and last affine layers
         | 
| 333 | 
            +
                    self.first_lin = nn.Linear(C_in, C_width)
         | 
| 334 | 
            +
                    self.last_lin = nn.Linear(C_width, C_out)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    # DiffusionNet blocks
         | 
| 337 | 
            +
                    self.blocks = []
         | 
| 338 | 
            +
                    for i_block in range(self.N_block):
         | 
| 339 | 
            +
                        block = DiffusionNetBlock(C_width=C_width,
         | 
| 340 | 
            +
                                                  mlp_hidden_dims=mlp_hidden_dims,
         | 
| 341 | 
            +
                                                  dropout=dropout,
         | 
| 342 | 
            +
                                                  diffusion_method=diffusion_method,
         | 
| 343 | 
            +
                                                  with_gradient_features=with_gradient_features,
         | 
| 344 | 
            +
                                                  with_gradient_rotations=with_gradient_rotations)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                        self.blocks.append(block)
         | 
| 347 | 
            +
                        self.add_module("block_" + str(i_block), self.blocks[-1])
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
         | 
| 350 | 
            +
                    """
         | 
| 351 | 
            +
                    A forward pass on the DiffusionNet.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    In the notation below, dimension are:
         | 
| 354 | 
            +
                        - C is the input channel dimension (C_in on construction)
         | 
| 355 | 
            +
                        - C_OUT is the output channel dimension (C_out on construction)
         | 
| 356 | 
            +
                        - N is the number of vertices/points, which CAN be different for each forward pass
         | 
| 357 | 
            +
                        - B is an OPTIONAL batch dimension
         | 
| 358 | 
            +
                        - K_EIG is the number of eigenvalues used for spectral acceleration
         | 
| 359 | 
            +
                    Generally, our data layout it is [N,C] or [B,N,C].
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    Parameters:
         | 
| 364 | 
            +
                        x_in (tensor):      Input features, dimension [N,C] or [B,N,C]
         | 
| 365 | 
            +
                        mass (tensor):      Mass vector, dimension [N] or [B,N]
         | 
| 366 | 
            +
                        L (tensor):         Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
         | 
| 367 | 
            +
                        evals (tensor):     Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
         | 
| 368 | 
            +
                        evecs (tensor):     Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
         | 
| 369 | 
            +
                        gradX (tensor):     Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
         | 
| 370 | 
            +
                        gradY (tensor):     Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    Returns:
         | 
| 373 | 
            +
                        x_out (tensor):    Output with dimension [N,C_out] or [B,N,C_out]
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    ## Check dimensions, and append batch dimension if not given
         | 
| 377 | 
            +
                    if x_in.shape[-1] != self.C_in:
         | 
| 378 | 
            +
                        raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
         | 
| 379 | 
            +
                            self.C_in, x_in.shape[-1]))
         | 
| 380 | 
            +
                    N = x_in.shape[-2]
         | 
| 381 | 
            +
                    if len(x_in.shape) == 2:
         | 
| 382 | 
            +
                        appended_batch_dim = True
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                        # add a batch dim to all inputs
         | 
| 385 | 
            +
                        x_in = x_in.unsqueeze(0)
         | 
| 386 | 
            +
                        mass = mass.unsqueeze(0)
         | 
| 387 | 
            +
                        if L != None:
         | 
| 388 | 
            +
                            L = L.unsqueeze(0)
         | 
| 389 | 
            +
                        if evals != None:
         | 
| 390 | 
            +
                            evals = evals.unsqueeze(0)
         | 
| 391 | 
            +
                        if evecs != None:
         | 
| 392 | 
            +
                            evecs = evecs.unsqueeze(0)
         | 
| 393 | 
            +
                        if gradX != None:
         | 
| 394 | 
            +
                            gradX = gradX.unsqueeze(0)
         | 
| 395 | 
            +
                        if gradY != None:
         | 
| 396 | 
            +
                            gradY = gradY.unsqueeze(0)
         | 
| 397 | 
            +
                        if edges != None:
         | 
| 398 | 
            +
                            edges = edges.unsqueeze(0)
         | 
| 399 | 
            +
                        if faces != None:
         | 
| 400 | 
            +
                            faces = faces.unsqueeze(0)
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    elif len(x_in.shape) == 3:
         | 
| 403 | 
            +
                        appended_batch_dim = False
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    else:
         | 
| 406 | 
            +
                        raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    evals = evals[..., :self.num_eigenbasis]
         | 
| 409 | 
            +
                    evecs = evecs[..., :self.num_eigenbasis]
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    # Apply the first linear layer
         | 
| 412 | 
            +
                    x = self.first_lin(x_in)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    # Apply each of the blocks
         | 
| 415 | 
            +
                    for b in self.blocks:
         | 
| 416 | 
            +
                        x = b(x, mass, L, evals, evecs, gradX, gradY)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    # Apply the last linear layer
         | 
| 419 | 
            +
                    x = self.last_lin(x)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    # Remap output to faces/edges if requested
         | 
| 422 | 
            +
                    if self.outputs_at == 'vertices':
         | 
| 423 | 
            +
                        x_out = x
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    elif self.outputs_at == 'edges':
         | 
| 426 | 
            +
                        # Remap to edges
         | 
| 427 | 
            +
                        x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
         | 
| 428 | 
            +
                        edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
         | 
| 429 | 
            +
                        xe = torch.gather(x_gather, 1, edges_gather)
         | 
| 430 | 
            +
                        x_out = torch.mean(xe, dim=-1)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    elif self.outputs_at == 'faces':
         | 
| 433 | 
            +
                        # Remap to faces
         | 
| 434 | 
            +
                        x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
         | 
| 435 | 
            +
                        faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
         | 
| 436 | 
            +
                        xf = torch.gather(x_gather, 1, faces_gather)
         | 
| 437 | 
            +
                        x_out = torch.mean(xf, dim=-1)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    elif self.outputs_at == 'global_mean':
         | 
| 440 | 
            +
                        # Produce a single global mean ouput.
         | 
| 441 | 
            +
                        # Using a weighted mean according to the point mass/area is discretization-invariant.
         | 
| 442 | 
            +
                        # (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
         | 
| 443 | 
            +
                        x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    # Apply last nonlinearity if specified
         | 
| 446 | 
            +
                    if self.last_activation != None:
         | 
| 447 | 
            +
                        x_out = self.last_activation(x_out)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # Remove batch dim if we added it
         | 
| 450 | 
            +
                    if appended_batch_dim:
         | 
| 451 | 
            +
                        x_out = x_out.squeeze(0)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    return x_out
         | 
    	
        shape_models/utils.py
    ADDED
    
    | @@ -0,0 +1,123 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import hashlib
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import scipy
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # == Pytorch things
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def toNP(x):
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Really, definitely convert a torch tensor to a numpy array
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                return x.detach().to(torch.device('cpu')).numpy()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def label_smoothing_log_loss(pred, labels, smoothing=0.0):
         | 
| 21 | 
            +
                n_class = pred.shape[-1]
         | 
| 22 | 
            +
                one_hot = torch.zeros_like(pred)
         | 
| 23 | 
            +
                one_hot[labels] = 1.
         | 
| 24 | 
            +
                one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
         | 
| 25 | 
            +
                loss = -(one_hot * pred).sum(dim=-1).mean()
         | 
| 26 | 
            +
                return loss
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            # Randomly rotate points.
         | 
| 30 | 
            +
            # Torch in, torch out
         | 
| 31 | 
            +
            # Note fornow, builds rotation matrix on CPU.
         | 
| 32 | 
            +
            def random_rotate_points(pts, randgen=None):
         | 
| 33 | 
            +
                R = random_rotation_matrix(randgen)
         | 
| 34 | 
            +
                R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
         | 
| 35 | 
            +
                return torch.matmul(pts, R)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def random_rotate_points_y(pts):
         | 
| 39 | 
            +
                angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
         | 
| 40 | 
            +
                rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
         | 
| 41 | 
            +
                rot_mats[0, 0] = torch.cos(angles)
         | 
| 42 | 
            +
                rot_mats[0, 2] = torch.sin(angles)
         | 
| 43 | 
            +
                rot_mats[2, 0] = -torch.sin(angles)
         | 
| 44 | 
            +
                rot_mats[2, 2] = torch.cos(angles)
         | 
| 45 | 
            +
                rot_mats[1, 1] = 1.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                pts = torch.matmul(pts, rot_mats)
         | 
| 48 | 
            +
                return pts
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            # Numpy things
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Numpy sparse matrix to pytorch
         | 
| 55 | 
            +
            def sparse_np_to_torch(A):
         | 
| 56 | 
            +
                Acoo = A.tocoo()
         | 
| 57 | 
            +
                values = Acoo.data
         | 
| 58 | 
            +
                indices = np.vstack((Acoo.row, Acoo.col))
         | 
| 59 | 
            +
                shape = Acoo.shape
         | 
| 60 | 
            +
                return torch.sparse.FloatTensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape)).coalesce()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            # Pytorch sparse to numpy csc matrix
         | 
| 64 | 
            +
            def sparse_torch_to_np(A):
         | 
| 65 | 
            +
                if len(A.shape) != 2:
         | 
| 66 | 
            +
                    raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                indices = toNP(A.indices())
         | 
| 69 | 
            +
                values = toNP(A.values())
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                return mat
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            # Hash a list of numpy arrays
         | 
| 77 | 
            +
            def hash_arrays(arrs):
         | 
| 78 | 
            +
                running_hash = hashlib.sha1()
         | 
| 79 | 
            +
                for arr in arrs:
         | 
| 80 | 
            +
                    binarr = arr.view(np.uint8)
         | 
| 81 | 
            +
                    running_hash.update(binarr)
         | 
| 82 | 
            +
                return running_hash.hexdigest()
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def random_rotation_matrix(randgen=None):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                Creates a random rotation matrix.
         | 
| 88 | 
            +
                randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                # adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                if randgen is None:
         | 
| 93 | 
            +
                    randgen = np.random.RandomState()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                theta, phi, z = tuple(randgen.rand(3).tolist())
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                theta = theta * 2.0 * np.pi  # Rotation about the pole (Z).
         | 
| 98 | 
            +
                phi = phi * 2.0 * np.pi  # For direction of pole deflection.
         | 
| 99 | 
            +
                z = z * 2.0  # For magnitude of pole deflection.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # Compute a vector V used for distributing points over the sphere
         | 
| 102 | 
            +
                # via the reflection I - V Transpose(V).  This formulation of V
         | 
| 103 | 
            +
                # will guarantee that if x[1] and x[2] are uniformly distributed,
         | 
| 104 | 
            +
                # the reflected points will be uniform on the sphere.  Note that V
         | 
| 105 | 
            +
                # has length sqrt(2) to eliminate the 2 in the Householder matrix.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                r = np.sqrt(z)
         | 
| 108 | 
            +
                Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                st = np.sin(theta)
         | 
| 111 | 
            +
                ct = np.cos(theta)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
         | 
| 114 | 
            +
                # Construct the rotation matrix  ( V Transpose(V) - I ) R.
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                M = (np.outer(V, V) - np.eye(3)).dot(R)
         | 
| 117 | 
            +
                return M
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            # Python string/file utilities
         | 
| 121 | 
            +
            def ensure_dir_exists(d):
         | 
| 122 | 
            +
                if not os.path.exists(d):
         | 
| 123 | 
            +
                    os.makedirs(d)
         | 
    	
        utils/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # from .utils_legacy import *
         | 
| 2 | 
            +
            # from .geometry import *
         | 
| 3 | 
            +
            # from .layers import *
         | 
    	
        utils/descriptors.py
    ADDED
    
    | @@ -0,0 +1,250 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # https://github.com/RobinMagnet/SimplifiedFmapsLearning/blob/main/learn_zo/data/utils.py
         | 
| 4 | 
            +
            # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/HKS_functions.py
         | 
| 5 | 
            +
            def HKS(evals, evects, time_list,scaled=False):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Returns the Heat Kernel Signature for num_T different values.
         | 
| 8 | 
            +
                The values of the time are interpolated in logscale between the limits
         | 
| 9 | 
            +
                given in the HKS paper. These limits only depends on the eigenvalues.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                Parameters
         | 
| 12 | 
            +
                ------------------------
         | 
| 13 | 
            +
                evals     : (K,) array of the K eigenvalues
         | 
| 14 | 
            +
                evecs     : (N,K) array with the K eigenvectors
         | 
| 15 | 
            +
                time_list : (num_T,) Time values to use
         | 
| 16 | 
            +
                scaled    : (bool) whether to scale for each time value
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Output
         | 
| 19 | 
            +
                ------------------------
         | 
| 20 | 
            +
                HKS : (N,num_T) array where each line is the HKS for a given t
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                evals_s = np.asarray(evals).flatten()
         | 
| 23 | 
            +
                t_list = np.asarray(time_list).flatten()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                coefs = np.exp(-np.outer(t_list, evals_s))  # (num_T,K)
         | 
| 26 | 
            +
                # weighted_evects = evects[None, :, :] * coefs[:, None,:]  # (num_T,N,K)
         | 
| 27 | 
            +
                # natural_HKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
         | 
| 28 | 
            +
                natural_HKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if scaled:
         | 
| 31 | 
            +
                    inv_scaling = coefs.sum(1)  # (num_T)
         | 
| 32 | 
            +
                    return (1/inv_scaling)[None,:] * natural_HKS
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                return natural_HKS
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def lm_HKS(evals, evects, landmarks, time_list, scaled=False):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Returns the Heat Kernel Signature for some landmarks and time values.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
                Parameters
         | 
| 43 | 
            +
                ------------------------
         | 
| 44 | 
            +
                evects      : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
         | 
| 45 | 
            +
                evals       : (K,) array of the K corresponding eigenvalues
         | 
| 46 | 
            +
                landmarks   : (p,) indices of landmarks to compute
         | 
| 47 | 
            +
                time_list   : (num_T,) values of t to use
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                Output
         | 
| 50 | 
            +
                ------------------------
         | 
| 51 | 
            +
                landmarks_HKS : (N,num_E*p) array where each column is the HKS for a given t for some landmark
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                evals_s = np.asarray(evals).flatten()
         | 
| 55 | 
            +
                t_list = np.asarray(time_list).flatten()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                coefs = np.exp(-np.outer(t_list, evals_s))  # (num_T,K)
         | 
| 58 | 
            +
                weighted_evects = evects[None, landmarks, :] * coefs[:,None,:]  # (num_T,p,K)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                landmarks_HKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)  # (p,num_T,N)
         | 
| 61 | 
            +
                landmarks_HKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects)  # (p,num_T,N)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if scaled:
         | 
| 64 | 
            +
                    inv_scaling = coefs.sum(1)  # (num_T,)
         | 
| 65 | 
            +
                    landmarks_HKS = (1/inv_scaling)[None,:,None] * landmarks_HKS  # (p,num_T,N)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                return rearrange(landmarks_HKS, 'p T N -> N (p T)')
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def auto_HKS(evals, evects, num_T, landmarks=None, scaled=True):
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                Compute HKS with an automatic choice of tile values
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Parameters
         | 
| 75 | 
            +
                ------------------------
         | 
| 76 | 
            +
                evals       : (K,) array of  K eigenvalues
         | 
| 77 | 
            +
                evects      : (N,K) array with K eigenvectors
         | 
| 78 | 
            +
                landmarks   : (p,) if not None, indices of landmarks to compute.
         | 
| 79 | 
            +
                num_T       : (int) number values of t to use
         | 
| 80 | 
            +
                Output
         | 
| 81 | 
            +
                ------------------------
         | 
| 82 | 
            +
                HKS or lm_HKS : (N,num_E) or (N,p*num_E)  array where each column is the WKS for a given e
         | 
| 83 | 
            +
                                for some landmark
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                abs_ev = sorted(np.abs(evals))
         | 
| 87 | 
            +
                t_list = np.geomspace(4*np.log(10)/abs_ev[-1], 4*np.log(10)/abs_ev[1], num_T)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                if landmarks is None:
         | 
| 90 | 
            +
                    return HKS(abs_ev, evects, t_list, scaled=scaled)
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    return lm_HKS(abs_ev, evects, landmarks, t_list, scaled=scaled)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/WKS_functions.py
         | 
| 96 | 
            +
            def WKS(evals, evects, energy_list, sigma, scaled=False):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Returns the Wave Kernel Signature for some energy values.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                Parameters
         | 
| 101 | 
            +
                ------------------------
         | 
| 102 | 
            +
                evects      : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
         | 
| 103 | 
            +
                evals       : (K,) array of the K corresponding eigenvalues
         | 
| 104 | 
            +
                energy_list : (num_E,) values of e to use
         | 
| 105 | 
            +
                sigma       : (float) [positive] standard deviation to use
         | 
| 106 | 
            +
                scaled      : (bool) Whether to scale each energy level
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Output
         | 
| 109 | 
            +
                ------------------------
         | 
| 110 | 
            +
                WKS : (N,num_E) array where each column is the WKS for a given e
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                evals = np.asarray(evals).flatten()
         | 
| 115 | 
            +
                indices = np.where(evals > 1e-5)[0].flatten()
         | 
| 116 | 
            +
                evals = evals[indices]
         | 
| 117 | 
            +
                evects = evects[:, indices]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                e_list = np.asarray(energy_list)
         | 
| 120 | 
            +
                coefs = np.exp(-np.square(e_list[:,None] - np.log(np.abs(evals))[None,:])/(2*sigma**2))  # (num_E,K)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # weighted_evects = evects[None, :, :] * coefs[:,None, :]  # (num_E,N,K)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # natural_WKS = np.einsum('tnk,nk->nt', weighted_evects, evects)  # (N,num_E)
         | 
| 125 | 
            +
                natural_WKS = np.einsum('tk,nk,nk->nt', coefs, evects, evects)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if scaled:
         | 
| 128 | 
            +
                    inv_scaling = coefs.sum(1)  # (num_E)
         | 
| 129 | 
            +
                    return (1/inv_scaling)[None,:] * natural_WKS
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                return natural_WKS
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def lm_WKS(evals, evects, landmarks, energy_list, sigma, scaled=False):
         | 
| 135 | 
            +
                """
         | 
| 136 | 
            +
                Returns the Wave Kernel Signature for some landmarks and energy values.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
                Parameters
         | 
| 140 | 
            +
                ------------------------
         | 
| 141 | 
            +
                evects      : (N,K) array with the K eigenvectors of the Laplace Beltrami operator
         | 
| 142 | 
            +
                evals       : (K,) array of the K corresponding eigenvalues
         | 
| 143 | 
            +
                landmarks   : (p,) indices of landmarks to compute
         | 
| 144 | 
            +
                energy_list : (num_E,) values of e to use
         | 
| 145 | 
            +
                sigma       : int - standard deviation
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                Output
         | 
| 148 | 
            +
                ------------------------
         | 
| 149 | 
            +
                landmarks_WKS : (N,num_E*p) array where each column is the WKS for a given e for some landmark
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                evals = np.asarray(evals).flatten()
         | 
| 154 | 
            +
                indices = np.where(evals > 1e-2)[0].flatten()
         | 
| 155 | 
            +
                evals = evals[indices]
         | 
| 156 | 
            +
                evects = evects[:,indices]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                e_list = np.asarray(energy_list)
         | 
| 159 | 
            +
                coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2*sigma**2))  # (num_E,K)
         | 
| 160 | 
            +
                # weighted_evects = evects[None, landmarks, :] * coefs[:,None,:]  # (num_E,p,K)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                # landmarks_WKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)  # (p,num_E,N)
         | 
| 163 | 
            +
                landmarks_WKS = np.einsum('tk,pk,nk->ptn', coefs, evects[landmarks, :], evects)  # (p,num_E,N)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                if scaled:
         | 
| 166 | 
            +
                    inv_scaling = coefs.sum(1)  # (num_E,)
         | 
| 167 | 
            +
                    landmarks_WKS = (1/inv_scaling)[None,:,None] * landmarks_WKS
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # return landmarks_WKS.reshape(-1,evects.shape[0]).T  # (N,p*num_E)
         | 
| 170 | 
            +
                return rearrange(landmarks_WKS, 'p T N -> N (p T)')
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def auto_WKS(evals, evects, num_E, landmarks=None, scaled=True):
         | 
| 174 | 
            +
                """
         | 
| 175 | 
            +
                Compute WKS with an automatic choice of scale and energy
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                Parameters
         | 
| 178 | 
            +
                ------------------------
         | 
| 179 | 
            +
                evals       : (K,) array of  K eigenvalues
         | 
| 180 | 
            +
                evects      : (N,K) array with K eigenvectors
         | 
| 181 | 
            +
                landmarks   : (p,) If not None, indices of landmarks to compute.
         | 
| 182 | 
            +
                num_E       : (int) number values of e to use
         | 
| 183 | 
            +
                Output
         | 
| 184 | 
            +
                ------------------------
         | 
| 185 | 
            +
                WKS or lm_WKS : (N,num_E) or (N,p*num_E)  array where each column is the WKS for a given e
         | 
| 186 | 
            +
                                and possibly for some landmarks
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                abs_ev = sorted(np.abs(evals))
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                e_min,e_max = np.log(abs_ev[1]),np.log(abs_ev[-1])
         | 
| 191 | 
            +
                sigma = 7*(e_max-e_min)/num_E
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                e_min += 2*sigma
         | 
| 194 | 
            +
                e_max -= 2*sigma
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                energy_list = np.linspace(e_min,e_max,num_E)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                if landmarks is None:
         | 
| 199 | 
            +
                    return WKS(abs_ev, evects, energy_list, sigma, scaled=scaled)
         | 
| 200 | 
            +
                else:
         | 
| 201 | 
            +
                    return lm_WKS(abs_ev, evects, landmarks, energy_list, sigma, scaled=scaled)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def compute_hks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
         | 
| 205 | 
            +
                """
         | 
| 206 | 
            +
                Compute Heat Kernel Signature (HKS) descriptors.
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
         | 
| 210 | 
            +
                    evals: (K,) eigenvalues of the Laplace-Beltrami operator
         | 
| 211 | 
            +
                    mass: (N,) vertex masses
         | 
| 212 | 
            +
                    n_descr: (int) number of descriptors
         | 
| 213 | 
            +
                    subsample_step: (int) subsampling step
         | 
| 214 | 
            +
                    n_eig: (int) number of eigenvectors to use
         | 
| 215 | 
            +
                
         | 
| 216 | 
            +
                Returns:
         | 
| 217 | 
            +
                    feats: (N, n_descr) HKS descriptors
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                feats = auto_HKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
         | 
| 220 | 
            +
                feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
         | 
| 221 | 
            +
                if normalize:
         | 
| 222 | 
            +
                    feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
         | 
| 223 | 
            +
                    feats /= np.sqrt(feats_norm2)[None, :]
         | 
| 224 | 
            +
                return feats.astype(np.float32)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            def compute_wks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35, normalize=True):
         | 
| 228 | 
            +
                """
         | 
| 229 | 
            +
                Compute Wave Kernel Signature (WKS) descriptors.
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                Args:
         | 
| 232 | 
            +
                    evecs: (N, K) eigenvectors of the Laplace-Beltrami operator
         | 
| 233 | 
            +
                    evals: (K,) eigenvalues of the Laplace-Beltrami operator
         | 
| 234 | 
            +
                    mass: (N,) vertex masses
         | 
| 235 | 
            +
                    n_descr: (int) number of descriptors
         | 
| 236 | 
            +
                    subsample_step: (int) subsampling step
         | 
| 237 | 
            +
                    n_eig: (int) number of eigenvectors to use
         | 
| 238 | 
            +
                
         | 
| 239 | 
            +
                Returns:
         | 
| 240 | 
            +
                    feats: (N, n_descr) WKS descriptors
         | 
| 241 | 
            +
                """
         | 
| 242 | 
            +
                feats = auto_WKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
         | 
| 243 | 
            +
                feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
         | 
| 244 | 
            +
                # print("wks_shape",feats.shape, mass.shape)
         | 
| 245 | 
            +
                if normalize:
         | 
| 246 | 
            +
                    feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
         | 
| 247 | 
            +
                    feats /= np.sqrt(feats_norm2)[None, :]
         | 
| 248 | 
            +
                # feats_norm2 = np.einsum('np,n->p', feats**2, mass).flatten()
         | 
| 249 | 
            +
                # feats /= np.sqrt(feats_norm2)[None, :]
         | 
| 250 | 
            +
                return feats.astype(np.float32)
         | 
    	
        utils/eval.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def accuracy(p2p, gt_p2p, D1_geod, return_all=False, sqrt_area=None):
         | 
| 2 | 
            +
                """
         | 
| 3 | 
            +
                Computes the geodesic accuracy of a vertex to vertex map. The map goes from
         | 
| 4 | 
            +
                the target shape to the source shape.
         | 
| 5 | 
            +
                Borrowed from Robin.
         | 
| 6 | 
            +
                Parameters
         | 
| 7 | 
            +
                ----------------------
         | 
| 8 | 
            +
                p2p        : (n2,) - vertex to vertex map giving the index of the matched vertex on the source shape
         | 
| 9 | 
            +
                             for each vertex on the target shape (from a functional map point of view)
         | 
| 10 | 
            +
                gt_p2p     : (n2,) - ground truth mapping between the pairs
         | 
| 11 | 
            +
                D1_geod    : (n1,n1) - geodesic distance between pairs of vertices on the source mesh
         | 
| 12 | 
            +
                return_all : bool - whether to return all the distances or only the average geodesic distance
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Output
         | 
| 15 | 
            +
                -----------------------
         | 
| 16 | 
            +
                acc   : float - average accuracy of the vertex to vertex map
         | 
| 17 | 
            +
                dists : (n2,) - if return_all is True, returns all the pairwise distances
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                dists = D1_geod[(p2p, gt_p2p)]
         | 
| 21 | 
            +
                if sqrt_area is not None:
         | 
| 22 | 
            +
                    dists /= sqrt_area
         | 
| 23 | 
            +
                if return_all:
         | 
| 24 | 
            +
                    return dists.mean(), dists
         | 
| 25 | 
            +
                return dists.mean()
         | 
    	
        utils/fmap.py
    ADDED
    
    | @@ -0,0 +1,121 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import os.path as osp
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import scipy.linalg
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
         | 
| 9 | 
            +
            if ROOT_DIR not in sys.path:
         | 
| 10 | 
            +
                sys.path.append(ROOT_DIR)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from utils_fmaps.misc import KNNSearch
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            try:
         | 
| 15 | 
            +
                import pynndescent
         | 
| 16 | 
            +
                index = pynndescent.NNDescent(np.random.random((100, 3)), n_jobs=2)
         | 
| 17 | 
            +
                del index
         | 
| 18 | 
            +
                ANN = True
         | 
| 19 | 
            +
            except ImportError:
         | 
| 20 | 
            +
                ANN = False
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # https://github.com/RobinMagnet/pyFM
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=False):
         | 
| 26 | 
            +
                if use_ANN and not ANN:
         | 
| 27 | 
            +
                    raise ValueError('Please install pydescent to achieve Approximate Nearest Neighbor')
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                k2, k1 = FM.shape
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                assert k1 <= eigvects1.shape[1], \
         | 
| 32 | 
            +
                    f'At least {k1} should be provided, here only {eigvects1.shape[1]} are given'
         | 
| 33 | 
            +
                assert k2 <= eigvects2.shape[1], \
         | 
| 34 | 
            +
                    f'At least {k2} should be provided, here only {eigvects2.shape[1]} are given'
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                if use_ANN:
         | 
| 37 | 
            +
                    index = pynndescent.NNDescent(eigvects1[:, :k1] @ FM.T, n_jobs=8)
         | 
| 38 | 
            +
                    matches, _ = index.query(eigvects2[:, :k2], k=1)
         | 
| 39 | 
            +
                    matches = matches.flatten()
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    tree = KNNSearch(eigvects1[:, :k1] @ FM.T)
         | 
| 42 | 
            +
                    matches = tree.query(eigvects2[:, :k2], k=1).flatten()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                return matches
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def p2p_to_FM(p2p, eigvects1, eigvects2, A2=None):
         | 
| 48 | 
            +
                if A2 is not None:
         | 
| 49 | 
            +
                    if A2.shape[0] != eigvects2.shape[0]:
         | 
| 50 | 
            +
                        raise ValueError("Can't compute pseudo inverse with subsampled eigenvectors")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if len(A2.shape) == 1:
         | 
| 53 | 
            +
                        return eigvects2.T @ (A2[:, None] * eigvects1[p2p, :])
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    return eigvects2.T @ A2 @ eigvects1[p2p, :]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                return scipy.linalg.lstsq(eigvects2, eigvects1[p2p, :])[0]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def zoomout_iteration(eigvects1, eigvects2, FM, step=1, A2=None, use_ANN=False):
         | 
| 61 | 
            +
                k2, k1 = FM.shape
         | 
| 62 | 
            +
                try:
         | 
| 63 | 
            +
                    step1, step2 = step
         | 
| 64 | 
            +
                except TypeError:
         | 
| 65 | 
            +
                    step1 = step
         | 
| 66 | 
            +
                    step2 = step
         | 
| 67 | 
            +
                new_k1, new_k2 = k1 + step1, k2 + step2
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                p2p = FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=use_ANN)
         | 
| 70 | 
            +
                FM_zo = p2p_to_FM(p2p, eigvects1[:, :new_k1], eigvects2[:, :new_k2], A2=A2)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                return FM_zo
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def zoomout_refine(eigvects1,
         | 
| 76 | 
            +
                               eigvects2,
         | 
| 77 | 
            +
                               FM,
         | 
| 78 | 
            +
                               nit=10,
         | 
| 79 | 
            +
                               step=1,
         | 
| 80 | 
            +
                               A2=None,
         | 
| 81 | 
            +
                               subsample=None,
         | 
| 82 | 
            +
                               use_ANN=False,
         | 
| 83 | 
            +
                               return_p2p=False,
         | 
| 84 | 
            +
                               verbose=False):
         | 
| 85 | 
            +
                k2_0, k1_0 = FM.shape
         | 
| 86 | 
            +
                try:
         | 
| 87 | 
            +
                    step1, step2 = step
         | 
| 88 | 
            +
                except TypeError:
         | 
| 89 | 
            +
                    step1 = step
         | 
| 90 | 
            +
                    step2 = step
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                assert k1_0 + nit*step1 <= eigvects1.shape[1], \
         | 
| 93 | 
            +
                    f"Not enough eigenvectors on source : \
         | 
| 94 | 
            +
                    {k1_0 + nit*step1} are needed when {eigvects1.shape[1]} are provided"
         | 
| 95 | 
            +
                assert k2_0 + nit*step2 <= eigvects2.shape[1], \
         | 
| 96 | 
            +
                    f"Not enough eigenvectors on target : \
         | 
| 97 | 
            +
                    {k2_0 + nit*step2} are needed when {eigvects2.shape[1]} are provided"
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                use_subsample = False
         | 
| 100 | 
            +
                if subsample is not None:
         | 
| 101 | 
            +
                    use_subsample = True
         | 
| 102 | 
            +
                    sub1, sub2 = subsample
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                FM_zo = FM.copy()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                ANN_adventage = False
         | 
| 107 | 
            +
                iterable = range(nit) if not verbose else tqdm(range(nit))
         | 
| 108 | 
            +
                for it in iterable:
         | 
| 109 | 
            +
                    ANN_adventage = use_ANN and (FM_zo.shape[0] > 90) and (FM_zo.shape[1] > 90)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if use_subsample:
         | 
| 112 | 
            +
                        FM_zo = zoomout_iteration(eigvects1[sub1], eigvects2[sub2], FM_zo, A2=None, step=step, use_ANN=ANN_adventage)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        FM_zo = zoomout_iteration(eigvects1, eigvects2, FM_zo, A2=A2, step=step, use_ANN=ANN_adventage)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                if return_p2p:
         | 
| 118 | 
            +
                    p2p_zo = FM_to_p2p(FM_zo, eigvects1, eigvects2, use_ANN=False)
         | 
| 119 | 
            +
                    return FM_zo, p2p_zo
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                return FM_zo
         | 
    	
        utils/geometry.py
    ADDED
    
    | @@ -0,0 +1,951 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import scipy
         | 
| 2 | 
            +
            import scipy.sparse.linalg as sla
         | 
| 3 | 
            +
            # ^^^ we NEED to import scipy before torch, or it crashes :(
         | 
| 4 | 
            +
            # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os.path
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            from multiprocessing import Pool
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import scipy.spatial
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import sklearn.neighbors
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import robust_laplacian
         | 
| 17 | 
            +
            import potpourri3d as pp3d
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from utils.utils_legacy import toNP, ensure_dir_exists, sparse_np_to_torch, sparse_torch_to_np
         | 
| 20 | 
            +
            from utils.descriptors import compute_hks, compute_wks
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def norm(x, highdim=False):
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                Computes norm of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                return torch.norm(x, dim=len(x.shape) - 1)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def norm2(x, highdim=False):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                return dot(x, x)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def normalize(x, divide_eps=1e-6, highdim=False):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Computes norm^2 of an array of vectors. Given (shape,d), returns (shape) after norm along last dimension
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                if (len(x.shape) == 1):
         | 
| 42 | 
            +
                    raise ValueError("called normalize() on single vector of dim " + str(x.shape) + " are you sure?")
         | 
| 43 | 
            +
                if (not highdim and x.shape[-1] > 4):
         | 
| 44 | 
            +
                    raise ValueError("called normalize() with large last dimension " + str(x.shape) + " are you sure?")
         | 
| 45 | 
            +
                return x / (norm(x, highdim=highdim) + divide_eps).unsqueeze(-1)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def face_coords(verts, faces):
         | 
| 49 | 
            +
                coords = verts[faces]
         | 
| 50 | 
            +
                return coords
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def cross(vec_A, vec_B):
         | 
| 54 | 
            +
                return torch.cross(vec_A, vec_B, dim=-1)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def dot(vec_A, vec_B):
         | 
| 58 | 
            +
                return torch.sum(vec_A * vec_B, dim=-1)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            # Given (..., 3) vectors and normals, projects out any components of vecs
         | 
| 62 | 
            +
            # which lies in the direction of normals. Normals are assumed to be unit.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def project_to_tangent(vecs, unit_normals):
         | 
| 66 | 
            +
                dots = dot(vecs, unit_normals)
         | 
| 67 | 
            +
                return vecs - unit_normals * dots.unsqueeze(-1)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def face_area(verts, faces):
         | 
| 71 | 
            +
                coords = face_coords(verts, faces)
         | 
| 72 | 
            +
                vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 73 | 
            +
                vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                raw_normal = cross(vec_A, vec_B)
         | 
| 76 | 
            +
                return 0.5 * norm(raw_normal)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def face_normals(verts, faces, normalized=True):
         | 
| 80 | 
            +
                coords = face_coords(verts, faces)
         | 
| 81 | 
            +
                vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 82 | 
            +
                vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                raw_normal = cross(vec_A, vec_B)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if normalized:
         | 
| 87 | 
            +
                    return normalize(raw_normal)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return raw_normal
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def neighborhood_normal(points):
         | 
| 93 | 
            +
                # points: (N, K, 3) array of neighborhood psoitions
         | 
| 94 | 
            +
                # points should be centered at origin
         | 
| 95 | 
            +
                # out: (N,3) array of normals
         | 
| 96 | 
            +
                # numpy in, numpy out
         | 
| 97 | 
            +
                (u, s, vh) = np.linalg.svd(points, full_matrices=False)
         | 
| 98 | 
            +
                normal = vh[:, 2, :]
         | 
| 99 | 
            +
                return normal / np.linalg.norm(normal, axis=-1, keepdims=True)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def mesh_vertex_normals(verts, faces):
         | 
| 103 | 
            +
                # numpy in / out
         | 
| 104 | 
            +
                face_n = toNP(face_normals(torch.tensor(verts), torch.tensor(faces)))  # ugly torch <---> numpy
         | 
| 105 | 
            +
                eps = 1e-3
         | 
| 106 | 
            +
                vertex_normals = np.zeros(verts.shape)
         | 
| 107 | 
            +
                for i in range(3):
         | 
| 108 | 
            +
                    np.add.at(vertex_normals, faces[:, i], face_n)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                vertex_normals = vertex_normals / (eps + np.linalg.norm(vertex_normals, axis=-1, keepdims=True))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                return vertex_normals
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def vertex_normals(verts, faces, n_neighbors_cloud=30):
         | 
| 116 | 
            +
                verts_np = toNP(verts)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if faces.numel() == 0:  # point cloud
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
         | 
| 121 | 
            +
                    neigh_points = verts_np[neigh_inds, :]
         | 
| 122 | 
            +
                    neigh_points = neigh_points - verts_np[:, np.newaxis, :]
         | 
| 123 | 
            +
                    normals = neighborhood_normal(neigh_points)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                else:  # mesh
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    normals = mesh_vertex_normals(verts_np, toNP(faces))
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # if any are NaN, wiggle slightly and recompute
         | 
| 130 | 
            +
                    bad_normals_mask = np.isnan(normals).any(axis=1, keepdims=True)
         | 
| 131 | 
            +
                    if bad_normals_mask.any():
         | 
| 132 | 
            +
                        bbox = np.amax(verts_np, axis=0) - np.amin(verts_np, axis=0)
         | 
| 133 | 
            +
                        scale = np.linalg.norm(bbox) * 1e-4
         | 
| 134 | 
            +
                        wiggle = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5) * scale
         | 
| 135 | 
            +
                        wiggle_verts = verts_np + bad_normals_mask * wiggle
         | 
| 136 | 
            +
                        normals = mesh_vertex_normals(wiggle_verts, toNP(faces))
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # if still NaN assign random normals (probably means unreferenced verts in mesh)
         | 
| 139 | 
            +
                    bad_normals_mask = np.isnan(normals).any(axis=1)
         | 
| 140 | 
            +
                    if bad_normals_mask.any():
         | 
| 141 | 
            +
                        normals[bad_normals_mask, :] = (np.random.RandomState(seed=777).rand(*verts.shape) - 0.5)[bad_normals_mask, :]
         | 
| 142 | 
            +
                        normals = normals / np.linalg.norm(normals, axis=-1)[:, np.newaxis]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                normals = torch.from_numpy(normals).to(device=verts.device, dtype=verts.dtype)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                if torch.any(torch.isnan(normals)):
         | 
| 147 | 
            +
                    raise ValueError("NaN normals :(")
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                return normals
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            def build_tangent_frames(verts, faces, normals=None):
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                V = verts.shape[0]
         | 
| 155 | 
            +
                dtype = verts.dtype
         | 
| 156 | 
            +
                device = verts.device
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                if normals is None:
         | 
| 159 | 
            +
                    vert_normals = vertex_normals(verts, faces)  # (V,3)
         | 
| 160 | 
            +
                elif isinstance(normals, np.ndarray):
         | 
| 161 | 
            +
                    vert_normals = torch.from_numpy(normals).to(dtype=dtype, device=device)
         | 
| 162 | 
            +
                else:
         | 
| 163 | 
            +
                    vert_normals = normals
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                # = find an orthogonal basis
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                basis_cand1 = torch.tensor([1, 0, 0]).to(device=device, dtype=dtype).expand(V, -1)
         | 
| 168 | 
            +
                basis_cand2 = torch.tensor([0, 1, 0]).to(device=device, dtype=dtype).expand(V, -1)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                basisX = torch.where((torch.abs(dot(vert_normals, basis_cand1)) < 0.9).unsqueeze(-1), basis_cand1, basis_cand2)
         | 
| 171 | 
            +
                basisX = project_to_tangent(basisX, vert_normals)
         | 
| 172 | 
            +
                basisX = normalize(basisX)
         | 
| 173 | 
            +
                basisY = cross(vert_normals, basisX)
         | 
| 174 | 
            +
                frames = torch.stack((basisX, basisY, vert_normals), dim=-2)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                if torch.any(torch.isnan(frames)):
         | 
| 177 | 
            +
                    raise ValueError("NaN coordinate frame! Must be very degenerate")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                return frames
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def build_grad_point_cloud(verts, frames, n_neighbors_cloud=30):
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                verts_np = toNP(verts)
         | 
| 185 | 
            +
                frames_np = toNP(frames)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                _, neigh_inds = find_knn(verts, verts, n_neighbors_cloud, omit_diagonal=True, method='cpu_kd')
         | 
| 188 | 
            +
                neigh_points = verts_np[neigh_inds, :]
         | 
| 189 | 
            +
                neigh_vecs = neigh_points - verts_np[:, np.newaxis, :]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                # TODO this could easily be way faster. For instance we could avoid the weird edges format and the corresponding pure-python loop via some numpy broadcasting of the same logic. The way it works right now is just to share code with the mesh version. But its low priority since its preprocessing code.
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                edge_inds_from = np.repeat(np.arange(verts.shape[0]), n_neighbors_cloud)
         | 
| 194 | 
            +
                edges = np.stack((edge_inds_from, neigh_inds.flatten()))
         | 
| 195 | 
            +
                edge_tangent_vecs = edge_tangent_vectors(verts, frames, edges)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                return build_grad(verts_np, torch.tensor(edges), edge_tangent_vecs)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            def edge_tangent_vectors(verts, frames, edges):
         | 
| 201 | 
            +
                edge_vecs = verts[edges[1, :], :] - verts[edges[0, :], :]
         | 
| 202 | 
            +
                basisX = frames[edges[0, :], 0, :]
         | 
| 203 | 
            +
                basisY = frames[edges[0, :], 1, :]
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                compX = dot(edge_vecs, basisX)
         | 
| 206 | 
            +
                compY = dot(edge_vecs, basisY)
         | 
| 207 | 
            +
                edge_tangent = torch.stack((compX, compY), dim=-1)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                return edge_tangent
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def build_grad(verts, edges, edge_tangent_vectors):
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
                Build a (V, V) complex sparse matrix grad operator. Given real inputs at vertices, produces a complex (vector value) at vertices giving the gradient. All values pointwise.
         | 
| 215 | 
            +
                - edges: (2, E)
         | 
| 216 | 
            +
                """
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                edges_np = toNP(edges)
         | 
| 219 | 
            +
                edge_tangent_vectors_np = toNP(edge_tangent_vectors)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # TODO find a way to do this in pure numpy?
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # Build outgoing neighbor lists
         | 
| 224 | 
            +
                N = verts.shape[0]
         | 
| 225 | 
            +
                vert_edge_outgoing = [[] for i in range(N)]
         | 
| 226 | 
            +
                for iE in range(edges_np.shape[1]):
         | 
| 227 | 
            +
                    tail_ind = edges_np[0, iE]
         | 
| 228 | 
            +
                    tip_ind = edges_np[1, iE]
         | 
| 229 | 
            +
                    if tip_ind != tail_ind:
         | 
| 230 | 
            +
                        vert_edge_outgoing[tail_ind].append(iE)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                # Build local inversion matrix for each vertex
         | 
| 233 | 
            +
                row_inds = []
         | 
| 234 | 
            +
                col_inds = []
         | 
| 235 | 
            +
                data_vals = []
         | 
| 236 | 
            +
                eps_reg = 1e-5
         | 
| 237 | 
            +
                for iV in range(N):
         | 
| 238 | 
            +
                    n_neigh = len(vert_edge_outgoing[iV])
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    lhs_mat = np.zeros((n_neigh, 2))
         | 
| 241 | 
            +
                    rhs_mat = np.zeros((n_neigh, n_neigh + 1))
         | 
| 242 | 
            +
                    ind_lookup = [iV]
         | 
| 243 | 
            +
                    for i_neigh in range(n_neigh):
         | 
| 244 | 
            +
                        iE = vert_edge_outgoing[iV][i_neigh]
         | 
| 245 | 
            +
                        jV = edges_np[1, iE]
         | 
| 246 | 
            +
                        ind_lookup.append(jV)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        edge_vec = edge_tangent_vectors[iE][:]
         | 
| 249 | 
            +
                        w_e = 1.
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        lhs_mat[i_neigh][:] = w_e * edge_vec
         | 
| 252 | 
            +
                        rhs_mat[i_neigh][0] = w_e * (-1)
         | 
| 253 | 
            +
                        rhs_mat[i_neigh][i_neigh + 1] = w_e * 1
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    lhs_T = lhs_mat.T
         | 
| 256 | 
            +
                    lhs_inv = np.linalg.inv(lhs_T @ lhs_mat + eps_reg * np.identity(2)) @ lhs_T
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    sol_mat = lhs_inv @ rhs_mat
         | 
| 259 | 
            +
                    sol_coefs = (sol_mat[0, :] + 1j * sol_mat[1, :]).T
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    for i_neigh in range(n_neigh + 1):
         | 
| 262 | 
            +
                        i_glob = ind_lookup[i_neigh]
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        row_inds.append(iV)
         | 
| 265 | 
            +
                        col_inds.append(i_glob)
         | 
| 266 | 
            +
                        data_vals.append(sol_coefs[i_neigh])
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # build the sparse matrix
         | 
| 269 | 
            +
                row_inds = np.array(row_inds)
         | 
| 270 | 
            +
                col_inds = np.array(col_inds)
         | 
| 271 | 
            +
                data_vals = np.array(data_vals)
         | 
| 272 | 
            +
                mat = scipy.sparse.coo_matrix((data_vals, (row_inds, col_inds)), shape=(N, N)).tocsc()
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                return mat
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            def compute_operators(verts, faces, k_eig, normals=None, normalize_desc=True):
         | 
| 278 | 
            +
                """
         | 
| 279 | 
            +
                Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                See get_operators() for a similar routine that wraps this one with a layer of caching.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                Torch in / torch out.
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                Arguments:
         | 
| 286 | 
            +
                  - vertices: (V,3) vertex positions
         | 
| 287 | 
            +
                  - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
         | 
| 288 | 
            +
                  - k_eig: number of eigenvectors to use
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                Returns:
         | 
| 291 | 
            +
                  - frames: (V,3,3) X/Y/Z coordinate frame at each vertex. Z coordinate is normal (e.g. [:,2,:] for normals)
         | 
| 292 | 
            +
                  - massvec: (V) real diagonal of lumped mass matrix
         | 
| 293 | 
            +
                  - L: (VxV) real sparse matrix of (weak) Laplacian
         | 
| 294 | 
            +
                  - evals: (k) list of eigenvalues of the Laplacian
         | 
| 295 | 
            +
                  - evecs: (V,k) list of eigenvectors of the Laplacian 
         | 
| 296 | 
            +
                  - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
         | 
| 297 | 
            +
                  - gradY: same as gradX but for Y-component of gradient
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
         | 
| 302 | 
            +
                """
         | 
| 303 | 
            +
                device = verts.device
         | 
| 304 | 
            +
                dtype = verts.dtype
         | 
| 305 | 
            +
                V = verts.shape[0]
         | 
| 306 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                eps = 1e-8
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                verts_np = toNP(verts).astype(np.float64)
         | 
| 311 | 
            +
                faces_np = toNP(faces)
         | 
| 312 | 
            +
                frames = build_tangent_frames(verts, faces, normals=normals)
         | 
| 313 | 
            +
                frames_np = toNP(frames)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                # Build the scalar Laplacian
         | 
| 316 | 
            +
                if is_cloud:
         | 
| 317 | 
            +
                    L, M = robust_laplacian.point_cloud_laplacian(verts_np)
         | 
| 318 | 
            +
                    massvec_np = M.diagonal()
         | 
| 319 | 
            +
                else:
         | 
| 320 | 
            +
                    # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
         | 
| 321 | 
            +
                    # massvec_np = M.diagonal()
         | 
| 322 | 
            +
                    L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
         | 
| 323 | 
            +
                    massvec_np = pp3d.vertex_areas(verts_np, faces_np)
         | 
| 324 | 
            +
                    massvec_np += eps * np.mean(massvec_np)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                if (np.isnan(L.data).any()):
         | 
| 327 | 
            +
                    raise RuntimeError("NaN Laplace matrix")
         | 
| 328 | 
            +
                if (np.isnan(massvec_np).any()):
         | 
| 329 | 
            +
                    raise RuntimeError("NaN mass matrix")
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                # Read off neighbors & rotations from the Laplacian
         | 
| 332 | 
            +
                L_coo = L.tocoo()
         | 
| 333 | 
            +
                inds_row = L_coo.row
         | 
| 334 | 
            +
                inds_col = L_coo.col
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                # === Compute the eigenbasis
         | 
| 337 | 
            +
                if k_eig > 0:
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    # Prepare matrices
         | 
| 340 | 
            +
                    L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
         | 
| 341 | 
            +
                    massvec_eigsh = massvec_np
         | 
| 342 | 
            +
                    Mmat = scipy.sparse.diags(massvec_eigsh)
         | 
| 343 | 
            +
                    eigs_sigma = eps
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    failcount = 0
         | 
| 346 | 
            +
                    while True:
         | 
| 347 | 
            +
                        try:
         | 
| 348 | 
            +
                            # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
         | 
| 349 | 
            +
                            evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                            # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
         | 
| 352 | 
            +
                            evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                            break
         | 
| 355 | 
            +
                        except Exception as e:
         | 
| 356 | 
            +
                            print(e)
         | 
| 357 | 
            +
                            if (failcount > 3):
         | 
| 358 | 
            +
                                raise ValueError("failed to compute eigendecomp")
         | 
| 359 | 
            +
                            failcount += 1
         | 
| 360 | 
            +
                            print("--- decomp failed; adding eps ===> count: " + str(failcount))
         | 
| 361 | 
            +
                            L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                else:  #k_eig == 0
         | 
| 364 | 
            +
                    evals_np = np.zeros((0))
         | 
| 365 | 
            +
                    evecs_np = np.zeros((verts.shape[0], 0))
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                # == Build gradient matrices
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                # For meshes, we use the same edges as were used to build the Laplacian. For point clouds, use a whole local neighborhood
         | 
| 370 | 
            +
                if is_cloud:
         | 
| 371 | 
            +
                    grad_mat_np = build_grad_point_cloud(verts, frames)
         | 
| 372 | 
            +
                else:
         | 
| 373 | 
            +
                    edges = torch.tensor(np.stack((inds_row, inds_col), axis=0), device=device, dtype=faces.dtype)
         | 
| 374 | 
            +
                    edge_vecs = edge_tangent_vectors(verts, frames, edges)
         | 
| 375 | 
            +
                    grad_mat_np = build_grad(verts.cpu(), edges.cpu(), edge_vecs.cpu())
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
         | 
| 378 | 
            +
                gradX_np = np.real(grad_mat_np)
         | 
| 379 | 
            +
                gradY_np = np.imag(grad_mat_np)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                # === Convert back to torch
         | 
| 382 | 
            +
                massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
         | 
| 383 | 
            +
                L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 384 | 
            +
                evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
         | 
| 385 | 
            +
                evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
         | 
| 386 | 
            +
                gradX = sparse_np_to_torch(gradX_np).to(device=device, dtype=dtype)
         | 
| 387 | 
            +
                gradY = sparse_np_to_torch(gradY_np).to(device=device, dtype=dtype)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                hks = torch.from_numpy(compute_hks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
         | 
| 390 | 
            +
                wks = torch.from_numpy(compute_wks(evecs_np, evals_np, massvec_np, n_descr=128, subsample_step=1, n_eig=128, normalize=normalize_desc)).to(device=device, dtype=dtype)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
                return frames, massvec, L, evals, evecs, gradX, gradY, hks, wks
         | 
| 394 | 
            +
             | 
| 395 | 
            +
             | 
| 396 | 
            +
            def get_all_operators(verts_list, faces_list, k_eig, op_cache_dir=None, normals=None):
         | 
| 397 | 
            +
                N = len(verts_list)
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                frames = [None] * N
         | 
| 400 | 
            +
                massvec = [None] * N
         | 
| 401 | 
            +
                L = [None] * N
         | 
| 402 | 
            +
                evals = [None] * N
         | 
| 403 | 
            +
                evecs = [None] * N
         | 
| 404 | 
            +
                gradX = [None] * N
         | 
| 405 | 
            +
                gradY = [None] * N
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                inds = [i for i in range(N)]
         | 
| 408 | 
            +
                # process in random order
         | 
| 409 | 
            +
                # random.shuffle(inds)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                for num, i in enumerate(inds):
         | 
| 412 | 
            +
                    print("get_all_operators() processing {} / {} {:.3f}%".format(num, N, num / N * 100))
         | 
| 413 | 
            +
                    if normals is None:
         | 
| 414 | 
            +
                        outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir)
         | 
| 415 | 
            +
                    else:
         | 
| 416 | 
            +
                        outputs = get_operators(verts_list[i], faces_list[i], k_eig, op_cache_dir, normals=normals[i])
         | 
| 417 | 
            +
                    frames[i] = outputs[0]
         | 
| 418 | 
            +
                    massvec[i] = outputs[1]
         | 
| 419 | 
            +
                    L[i] = outputs[2]
         | 
| 420 | 
            +
                    evals[i] = outputs[3]
         | 
| 421 | 
            +
                    evecs[i] = outputs[4]
         | 
| 422 | 
            +
                    gradX[i] = outputs[5]
         | 
| 423 | 
            +
                    gradY[i] = outputs[6]
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                return frames, massvec, L, evals, evecs, gradX, gradY
         | 
| 426 | 
            +
             | 
| 427 | 
            +
             | 
| 428 | 
            +
            def get_operators(verts, faces, k_eig=128, cache_path=None, normals=None, overwrite_cache=False):
         | 
| 429 | 
            +
                """
         | 
| 430 | 
            +
                See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
         | 
| 431 | 
            +
                All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
         | 
| 432 | 
            +
                """
         | 
| 433 | 
            +
                if type(verts) == torch.Tensor:
         | 
| 434 | 
            +
                    device = verts.device
         | 
| 435 | 
            +
                    dtype = verts.dtype
         | 
| 436 | 
            +
                    verts_np = toNP(verts)
         | 
| 437 | 
            +
                else:
         | 
| 438 | 
            +
                    device = "cpu"
         | 
| 439 | 
            +
                    dtype = torch.float32
         | 
| 440 | 
            +
                    verts_np = verts.copy()
         | 
| 441 | 
            +
                    verts = torch.from_numpy(verts).float()
         | 
| 442 | 
            +
                if type(faces) == torch.Tensor:
         | 
| 443 | 
            +
                    faces_np = toNP(faces)
         | 
| 444 | 
            +
                else:
         | 
| 445 | 
            +
                    faces_np = faces.copy()
         | 
| 446 | 
            +
                    faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
         | 
| 447 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                if (np.isnan(verts_np).any()):
         | 
| 450 | 
            +
                    raise RuntimeError("tried to construct operators from NaN verts")
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                # Check the cache directory
         | 
| 453 | 
            +
                # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
         | 
| 454 | 
            +
                #         but for good measure we check values nonetheless.
         | 
| 455 | 
            +
                # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
         | 
| 456 | 
            +
                #         entries in this cache. The good news is that that is totally fine, and at most slightly
         | 
| 457 | 
            +
                #         slows performance with rare extra cache misses.
         | 
| 458 | 
            +
                found = False
         | 
| 459 | 
            +
                
         | 
| 460 | 
            +
                if cache_path is not None:
         | 
| 461 | 
            +
                    op_cache_dir = os.path.dirname(cache_path)
         | 
| 462 | 
            +
                    ensure_dir_exists(op_cache_dir)
         | 
| 463 | 
            +
                    # print("Building operators for input with hash: " + hash_key_str)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    # Search through buckets with matching hashes.  When the loop exits, this
         | 
| 466 | 
            +
                    # is the bucket index of the file we should write to.
         | 
| 467 | 
            +
                    i_cache_search = 0
         | 
| 468 | 
            +
                    while True:
         | 
| 469 | 
            +
                        try:
         | 
| 470 | 
            +
                            # print('loading path: ' + str(search_path))
         | 
| 471 | 
            +
                            npzfile = np.load(cache_path, allow_pickle=True)
         | 
| 472 | 
            +
                            cache_verts = npzfile["verts"]
         | 
| 473 | 
            +
                            cache_faces = npzfile["faces"]
         | 
| 474 | 
            +
                            cache_k_eig = npzfile["k_eig"].item()
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                            # If the cache doesn't match, keep looking
         | 
| 477 | 
            +
                            if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
         | 
| 478 | 
            +
                                i_cache_search += 1
         | 
| 479 | 
            +
                                print("hash collision! searching next.")
         | 
| 480 | 
            +
                                continue
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                            # print("  cache hit!")
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                            # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
         | 
| 485 | 
            +
                            # entry below more eigenvalues
         | 
| 486 | 
            +
                            if overwrite_cache:
         | 
| 487 | 
            +
                                print("  overwriting cache by request")
         | 
| 488 | 
            +
                                os.remove(cache_path)
         | 
| 489 | 
            +
                                break
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                            if cache_k_eig < k_eig:
         | 
| 492 | 
            +
                                print("  overwriting cache --- not enough eigenvalues")
         | 
| 493 | 
            +
                                os.remove(cache_path)
         | 
| 494 | 
            +
                                break
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                            if "L_data" not in npzfile:
         | 
| 497 | 
            +
                                print("  overwriting cache --- entries are absent")
         | 
| 498 | 
            +
                                os.remove(cache_path)
         | 
| 499 | 
            +
                                break
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                            def read_sp_mat(prefix):
         | 
| 502 | 
            +
                                data = npzfile[prefix + "_data"]
         | 
| 503 | 
            +
                                indices = npzfile[prefix + "_indices"]
         | 
| 504 | 
            +
                                indptr = npzfile[prefix + "_indptr"]
         | 
| 505 | 
            +
                                shape = npzfile[prefix + "_shape"]
         | 
| 506 | 
            +
                                mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
         | 
| 507 | 
            +
                                return mat
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                            # This entry matches! Return it.
         | 
| 510 | 
            +
                            frames = npzfile["frames"]
         | 
| 511 | 
            +
                            mass = npzfile["mass"]
         | 
| 512 | 
            +
                            L = read_sp_mat("L")
         | 
| 513 | 
            +
                            evals = npzfile["evals"][:k_eig]
         | 
| 514 | 
            +
                            evecs = npzfile["evecs"][:, :k_eig]
         | 
| 515 | 
            +
                            gradX = read_sp_mat("gradX")
         | 
| 516 | 
            +
                            gradY = read_sp_mat("gradY")
         | 
| 517 | 
            +
                            frames = npzfile["hks"]
         | 
| 518 | 
            +
                            mass = npzfile["wks"]
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                            frames = torch.from_numpy(frames).to(device=device, dtype=dtype)
         | 
| 521 | 
            +
                            mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
         | 
| 522 | 
            +
                            L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 523 | 
            +
                            evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
         | 
| 524 | 
            +
                            evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
         | 
| 525 | 
            +
                            gradX = sparse_np_to_torch(gradX).to(device=device, dtype=dtype)
         | 
| 526 | 
            +
                            gradY = sparse_np_to_torch(gradY).to(device=device, dtype=dtype)
         | 
| 527 | 
            +
                            hks = torch.from_numpy(hks).to(device=device, dtype=dtype)
         | 
| 528 | 
            +
                            wks = torch.from_numpy(wks).to(device=device, dtype=dtype)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                            found = True
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                            break
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                        except FileNotFoundError:
         | 
| 535 | 
            +
                            print("  cache miss -- constructing operators")
         | 
| 536 | 
            +
                            break
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                        except Exception as E:
         | 
| 539 | 
            +
                            print("unexpected error loading file: " + str(E))
         | 
| 540 | 
            +
                            print("-- constructing operators")
         | 
| 541 | 
            +
                            break
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                if not found:
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    # No matching entry found; recompute.
         | 
| 546 | 
            +
                    frames, mass, L, evals, evecs, gradX, gradY, hks, wks = compute_operators(verts, faces, k_eig, normals=normals)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    dtype_np = np.float32
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    # Store it in the cache
         | 
| 551 | 
            +
                    if op_cache_dir is not None:
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                        L_np = sparse_torch_to_np(L).astype(dtype_np)
         | 
| 554 | 
            +
                        gradX_np = sparse_torch_to_np(gradX).astype(dtype_np)
         | 
| 555 | 
            +
                        gradY_np = sparse_torch_to_np(gradY).astype(dtype_np)
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                        np.savez(
         | 
| 558 | 
            +
                            cache_path,
         | 
| 559 | 
            +
                            verts=verts_np.astype(dtype_np),
         | 
| 560 | 
            +
                            frames=toNP(frames).astype(dtype_np),
         | 
| 561 | 
            +
                            faces=faces_np,
         | 
| 562 | 
            +
                            k_eig=k_eig,
         | 
| 563 | 
            +
                            mass=toNP(mass).astype(dtype_np),
         | 
| 564 | 
            +
                            L_data=L_np.data.astype(dtype_np),
         | 
| 565 | 
            +
                            L_indices=L_np.indices,
         | 
| 566 | 
            +
                            L_indptr=L_np.indptr,
         | 
| 567 | 
            +
                            L_shape=L_np.shape,
         | 
| 568 | 
            +
                            evals=toNP(evals).astype(dtype_np),
         | 
| 569 | 
            +
                            evecs=toNP(evecs).astype(dtype_np),
         | 
| 570 | 
            +
                            gradX_data=gradX_np.data.astype(dtype_np),
         | 
| 571 | 
            +
                            gradX_indices=gradX_np.indices,
         | 
| 572 | 
            +
                            gradX_indptr=gradX_np.indptr,
         | 
| 573 | 
            +
                            gradX_shape=gradX_np.shape,
         | 
| 574 | 
            +
                            gradY_data=gradY_np.data.astype(dtype_np),
         | 
| 575 | 
            +
                            gradY_indices=gradY_np.indices,
         | 
| 576 | 
            +
                            gradY_indptr=gradY_np.indptr,
         | 
| 577 | 
            +
                            gradY_shape=gradY_np.shape,
         | 
| 578 | 
            +
                            hks=hks,
         | 
| 579 | 
            +
                            wks=wks
         | 
| 580 | 
            +
                        )
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                return frames, mass, L, evals, evecs, gradX, gradY, hks, wks
         | 
| 583 | 
            +
             | 
| 584 | 
            +
             | 
| 585 | 
            +
            def load_operators(filepath):
         | 
| 586 | 
            +
                npzfile = np.load(filepath, allow_pickle=False)
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                def read_sp_mat(prefix):
         | 
| 589 | 
            +
                    data = npzfile[prefix + '_data']
         | 
| 590 | 
            +
                    indices = npzfile[prefix + '_indices']
         | 
| 591 | 
            +
                    indptr = npzfile[prefix + '_indptr']
         | 
| 592 | 
            +
                    shape = npzfile[prefix + '_shape']
         | 
| 593 | 
            +
                    mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
         | 
| 594 | 
            +
                    return mat
         | 
| 595 | 
            +
                if 'verts' in npzfile:
         | 
| 596 | 
            +
                    keyverts = 'verts'
         | 
| 597 | 
            +
                else:
         | 
| 598 | 
            +
                    keyverts = 'vertices'
         | 
| 599 | 
            +
                return dict(
         | 
| 600 | 
            +
                    vertices=npzfile[keyverts],
         | 
| 601 | 
            +
                    faces=npzfile['faces'],
         | 
| 602 | 
            +
                    frames=npzfile['frames'],
         | 
| 603 | 
            +
                    mass=npzfile['mass'],
         | 
| 604 | 
            +
                    L=read_sp_mat('L'),
         | 
| 605 | 
            +
                    evals=npzfile['evals'],
         | 
| 606 | 
            +
                    evecs=npzfile['evecs'],
         | 
| 607 | 
            +
                    gradX=read_sp_mat('gradX'),
         | 
| 608 | 
            +
                    gradY=read_sp_mat('gradY'),
         | 
| 609 | 
            +
                    hks=npzfile['hks'],
         | 
| 610 | 
            +
                    wks=npzfile['wks'], 
         | 
| 611 | 
            +
                )
         | 
| 612 | 
            +
             | 
| 613 | 
            +
             | 
| 614 | 
            +
            def compute_operators_small(verts, faces, k_eig):
         | 
| 615 | 
            +
                """
         | 
| 616 | 
            +
                Builds spectral operators for a mesh/point cloud. Constructs mass matrix, eigenvalues/vectors for Laplacian, and gradient matrix.
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                See get_operators() for a similar routine that wraps this one with a layer of caching.
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                Torch in / torch out.
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                Arguments:
         | 
| 623 | 
            +
                  - vertices: (V,3) vertex positions
         | 
| 624 | 
            +
                  - faces: (F,3) list of triangular faces. If empty, assumed to be a point cloud.
         | 
| 625 | 
            +
                  - k_eig: number of eigenvectors to use
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                Returns:
         | 
| 628 | 
            +
                  - massvec: (V) real diagonal of lumped mass matrix
         | 
| 629 | 
            +
                  - L: (VxV) real sparse matrix of (weak) Laplacian
         | 
| 630 | 
            +
                  - evals: (k) list of eigenvalues of the Laplacian
         | 
| 631 | 
            +
                  - evecs: (V,k) list of eigenvectors of the Laplacian 
         | 
| 632 | 
            +
                  - gradX: (VxV) sparse matrix which gives X-component of gradient in the local basis at the vertex
         | 
| 633 | 
            +
                  - gradY: same as gradX but for Y-component of gradient
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                PyTorch doesn't seem to like complex sparse matrices, so we store the "real" and "imaginary" (aka X and Y) gradient matrices separately, rather than as one complex sparse matrix.
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                Note: for a generalized eigenvalue problem, the mass matrix matters! The eigenvectors are only othrthonormal with respect to the mass matrix, like v^H M v, so the mass (given as the diagonal vector massvec) needs to be used in projections, etc.
         | 
| 638 | 
            +
                """
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                device = verts.device
         | 
| 641 | 
            +
                dtype = verts.dtype
         | 
| 642 | 
            +
                V = verts.shape[0]
         | 
| 643 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                eps = 1e-8
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                verts_np = toNP(verts).astype(np.float64)
         | 
| 648 | 
            +
                faces_np = toNP(faces)
         | 
| 649 | 
            +
                # Build the scalar Laplacian
         | 
| 650 | 
            +
                if is_cloud:
         | 
| 651 | 
            +
                    L, M = robust_laplacian.point_cloud_laplacian(verts_np)
         | 
| 652 | 
            +
                    massvec_np = M.diagonal()
         | 
| 653 | 
            +
                else:
         | 
| 654 | 
            +
                    # L, M = robust_laplacian.mesh_laplacian(verts_np, faces_np)
         | 
| 655 | 
            +
                    # massvec_np = M.diagonal()
         | 
| 656 | 
            +
                    L = pp3d.cotan_laplacian(verts_np, faces_np, denom_eps=1e-10)
         | 
| 657 | 
            +
                    massvec_np = pp3d.vertex_areas(verts_np, faces_np)
         | 
| 658 | 
            +
                    massvec_np += eps * np.mean(massvec_np)
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                if (np.isnan(L.data).any()):
         | 
| 661 | 
            +
                    raise RuntimeError("NaN Laplace matrix")
         | 
| 662 | 
            +
                if (np.isnan(massvec_np).any()):
         | 
| 663 | 
            +
                    raise RuntimeError("NaN mass matrix")
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                # Read off neighbors & rotations from the Laplacian
         | 
| 666 | 
            +
                L_coo = L.tocoo()
         | 
| 667 | 
            +
                inds_row = L_coo.row
         | 
| 668 | 
            +
                inds_col = L_coo.col
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                # === Compute the eigenbasis
         | 
| 671 | 
            +
                if k_eig > 0:
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    # Prepare matrices
         | 
| 674 | 
            +
                    L_eigsh = (L + scipy.sparse.identity(L.shape[0]) * eps).tocsc()
         | 
| 675 | 
            +
                    massvec_eigsh = massvec_np
         | 
| 676 | 
            +
                    Mmat = scipy.sparse.diags(massvec_eigsh)
         | 
| 677 | 
            +
                    eigs_sigma = eps
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    failcount = 0
         | 
| 680 | 
            +
                    while True:
         | 
| 681 | 
            +
                        try:
         | 
| 682 | 
            +
                            # We would be happy here to lower tol or maxiter since we don't need these to be super precise, but for some reason those parameters seem to have no effect
         | 
| 683 | 
            +
                            evals_np, evecs_np = sla.eigsh(L_eigsh, k=k_eig, M=Mmat, sigma=eigs_sigma)
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                            # Clip off any eigenvalues that end up slightly negative due to numerical weirdness
         | 
| 686 | 
            +
                            evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                            break
         | 
| 689 | 
            +
                        except Exception as e:
         | 
| 690 | 
            +
                            print(e)
         | 
| 691 | 
            +
                            if (failcount > 3):
         | 
| 692 | 
            +
                                raise ValueError("failed to compute eigendecomp")
         | 
| 693 | 
            +
                            failcount += 1
         | 
| 694 | 
            +
                            print("--- decomp failed; adding eps ===> count: " + str(failcount))
         | 
| 695 | 
            +
                            L_eigsh = L_eigsh + scipy.sparse.identity(L.shape[0]) * (eps * 10**failcount)
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                else:  #k_eig == 0
         | 
| 698 | 
            +
                    evals_np = np.zeros((0))
         | 
| 699 | 
            +
                    evecs_np = np.zeros((verts.shape[0], 0))
         | 
| 700 | 
            +
                # Split complex gradient in to two real sparse mats (torch doesn't like complex sparse matrices)
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                # === Convert back to torch
         | 
| 703 | 
            +
                massvec = torch.from_numpy(massvec_np).to(device=device, dtype=dtype)
         | 
| 704 | 
            +
                L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 705 | 
            +
                evals = torch.from_numpy(evals_np).to(device=device, dtype=dtype)
         | 
| 706 | 
            +
                evecs = torch.from_numpy(evecs_np).to(device=device, dtype=dtype)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                return massvec, L, evals, evecs
         | 
| 709 | 
            +
             | 
| 710 | 
            +
             | 
| 711 | 
            +
            def get_operators_small(verts, faces, k_eig=128, cache_path=None, overwrite_cache=False):
         | 
| 712 | 
            +
                """
         | 
| 713 | 
            +
                See documentation for compute_operators(). This essentailly just wraps a call to compute_operators, using a cache if possible.
         | 
| 714 | 
            +
                All arrays are always computed using double precision for stability, then truncated to single precision floats to store on disk, and finally returned as a tensor with dtype/device matching the `verts` input.
         | 
| 715 | 
            +
                """
         | 
| 716 | 
            +
                if type(verts) == torch.Tensor:
         | 
| 717 | 
            +
                    device = verts.device
         | 
| 718 | 
            +
                    dtype = verts.dtype
         | 
| 719 | 
            +
                    verts_np = toNP(verts)
         | 
| 720 | 
            +
                else:
         | 
| 721 | 
            +
                    device = "cpu"
         | 
| 722 | 
            +
                    dtype = torch.float32
         | 
| 723 | 
            +
                    verts_np = verts.copy()
         | 
| 724 | 
            +
                    verts = torch.from_numpy(verts).float()
         | 
| 725 | 
            +
                if type(faces) == torch.Tensor:
         | 
| 726 | 
            +
                    faces_np = toNP(faces)
         | 
| 727 | 
            +
                else:
         | 
| 728 | 
            +
                    faces_np = faces.copy()
         | 
| 729 | 
            +
                    faces = torch.from_numpy(faces).to(device, dtype=torch.int64)
         | 
| 730 | 
            +
                is_cloud = faces.numel() == 0
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                if (np.isnan(verts_np).any()):
         | 
| 733 | 
            +
                    raise RuntimeError("tried to construct operators from NaN verts")
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                # Check the cache directory
         | 
| 736 | 
            +
                # Note 1: Collisions here are exceptionally unlikely, so we could probably just use the hash...
         | 
| 737 | 
            +
                #         but for good measure we check values nonetheless.
         | 
| 738 | 
            +
                # Note 2: There is a small possibility for race conditions to lead to bucket gaps or duplicate
         | 
| 739 | 
            +
                #         entries in this cache. The good news is that that is totally fine, and at most slightly
         | 
| 740 | 
            +
                #         slows performance with rare extra cache misses.
         | 
| 741 | 
            +
                found = False
         | 
| 742 | 
            +
                
         | 
| 743 | 
            +
                if cache_path is not None:
         | 
| 744 | 
            +
                    op_cache_dir = os.path.dirname(cache_path)
         | 
| 745 | 
            +
                    ensure_dir_exists(op_cache_dir)
         | 
| 746 | 
            +
                    # print("Building operators for input with hash: " + hash_key_str)
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                    # Search through buckets with matching hashes.  When the loop exits, this
         | 
| 749 | 
            +
                    # is the bucket index of the file we should write to.
         | 
| 750 | 
            +
                    i_cache_search = 0
         | 
| 751 | 
            +
                    while True:
         | 
| 752 | 
            +
                        try:
         | 
| 753 | 
            +
                            # print('loading path: ' + str(search_path))
         | 
| 754 | 
            +
                            npzfile = np.load(cache_path, allow_pickle=True)
         | 
| 755 | 
            +
                            cache_verts = npzfile["verts"]
         | 
| 756 | 
            +
                            cache_faces = npzfile["faces"]
         | 
| 757 | 
            +
                            cache_k_eig = npzfile["k_eig"].item()
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                            # If the cache doesn't match, keep looking
         | 
| 760 | 
            +
                            if (not np.array_equal(verts, cache_verts)) or (not np.array_equal(faces, cache_faces)):
         | 
| 761 | 
            +
                                i_cache_search += 1
         | 
| 762 | 
            +
                                print("hash collision! searching next.")
         | 
| 763 | 
            +
                                continue
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                            # print("  cache hit!")
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                            # If we're overwriting, or there aren't enough eigenvalues, just delete it; we'll create a new
         | 
| 768 | 
            +
                            # entry below more eigenvalues
         | 
| 769 | 
            +
                            if overwrite_cache:
         | 
| 770 | 
            +
                                print("  overwriting cache by request")
         | 
| 771 | 
            +
                                os.remove(cache_path)
         | 
| 772 | 
            +
                                break
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                            if cache_k_eig < k_eig:
         | 
| 775 | 
            +
                                print("  overwriting cache --- not enough eigenvalues")
         | 
| 776 | 
            +
                                os.remove(cache_path)
         | 
| 777 | 
            +
                                break
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                            if "L_data" not in npzfile:
         | 
| 780 | 
            +
                                print("  overwriting cache --- entries are absent")
         | 
| 781 | 
            +
                                os.remove(cache_path)
         | 
| 782 | 
            +
                                break
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                            def read_sp_mat(prefix):
         | 
| 785 | 
            +
                                data = npzfile[prefix + "_data"]
         | 
| 786 | 
            +
                                indices = npzfile[prefix + "_indices"]
         | 
| 787 | 
            +
                                indptr = npzfile[prefix + "_indptr"]
         | 
| 788 | 
            +
                                shape = npzfile[prefix + "_shape"]
         | 
| 789 | 
            +
                                mat = scipy.sparse.csc_matrix((data, indices, indptr), shape=shape)
         | 
| 790 | 
            +
                                return mat
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                            # This entry matches! Return it.
         | 
| 793 | 
            +
                            mass = npzfile["mass"]
         | 
| 794 | 
            +
                            L = read_sp_mat("L")
         | 
| 795 | 
            +
                            evals = npzfile["evals"][:k_eig]
         | 
| 796 | 
            +
                            evecs = npzfile["evecs"][:, :k_eig]
         | 
| 797 | 
            +
                            mass = torch.from_numpy(mass).to(device=device, dtype=dtype)
         | 
| 798 | 
            +
                            L = sparse_np_to_torch(L).to(device=device, dtype=dtype)
         | 
| 799 | 
            +
                            evals = torch.from_numpy(evals).to(device=device, dtype=dtype)
         | 
| 800 | 
            +
                            evecs = torch.from_numpy(evecs).to(device=device, dtype=dtype)
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                            found = True
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                            break
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                        except FileNotFoundError:
         | 
| 807 | 
            +
                            print(cache_path)
         | 
| 808 | 
            +
                            print("  cache miss -- constructing operators")
         | 
| 809 | 
            +
                            break
         | 
| 810 | 
            +
             | 
| 811 | 
            +
                        except Exception as E:
         | 
| 812 | 
            +
                            print("unexpected error loading file: " + str(E))
         | 
| 813 | 
            +
                            print("-- constructing operators")
         | 
| 814 | 
            +
                            break
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                if not found:
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    # No matching entry found; recompute.
         | 
| 819 | 
            +
                    mass, L, evals, evecs, = compute_operators_small(verts, faces, k_eig)
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                    dtype_np = np.float32
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                    # Store it in the cache
         | 
| 824 | 
            +
                    if op_cache_dir is not None:
         | 
| 825 | 
            +
             | 
| 826 | 
            +
                        L_np = sparse_torch_to_np(L).astype(dtype_np)
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                        np.savez(
         | 
| 829 | 
            +
                            cache_path,
         | 
| 830 | 
            +
                            verts=verts_np.astype(dtype_np),
         | 
| 831 | 
            +
                            faces=faces_np,
         | 
| 832 | 
            +
                            k_eig=k_eig,
         | 
| 833 | 
            +
                            mass=toNP(mass).astype(dtype_np),
         | 
| 834 | 
            +
                            L_data=L_np.data.astype(dtype_np),
         | 
| 835 | 
            +
                            L_indices=L_np.indices,
         | 
| 836 | 
            +
                            L_indptr=L_np.indptr,
         | 
| 837 | 
            +
                            L_shape=L_np.shape,
         | 
| 838 | 
            +
                            evals=toNP(evals).astype(dtype_np),
         | 
| 839 | 
            +
                            evecs=toNP(evecs).astype(dtype_np),
         | 
| 840 | 
            +
                        )
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                return mass, L, evals, evecs
         | 
| 843 | 
            +
             | 
| 844 | 
            +
             | 
| 845 | 
            +
            def to_basis(values, basis, massvec):
         | 
| 846 | 
            +
                """
         | 
| 847 | 
            +
                Transform data in to an orthonormal basis (where orthonormal is wrt to massvec)
         | 
| 848 | 
            +
                Inputs:
         | 
| 849 | 
            +
                  - values: (B,V,D)
         | 
| 850 | 
            +
                  - basis: (B,V,K)
         | 
| 851 | 
            +
                  - massvec: (B,V)
         | 
| 852 | 
            +
                Outputs:
         | 
| 853 | 
            +
                  - (B,K,D) transformed values
         | 
| 854 | 
            +
                """
         | 
| 855 | 
            +
                basisT = basis.transpose(-2, -1)
         | 
| 856 | 
            +
                return torch.matmul(basisT, values * massvec.unsqueeze(-1))
         | 
| 857 | 
            +
             | 
| 858 | 
            +
             | 
| 859 | 
            +
             | 
| 860 | 
            +
            def normalize_positions(pos, faces=None, method='mean', scale_method='max_rad'):
         | 
| 861 | 
            +
                # center and unit-scale positions
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                if method == 'mean':
         | 
| 864 | 
            +
                    # center using the average point position
         | 
| 865 | 
            +
                    pos = (pos - torch.mean(pos, dim=-2, keepdim=True))
         | 
| 866 | 
            +
                elif method == 'bbox':
         | 
| 867 | 
            +
                    # center via the middle of the axis-aligned bounding box
         | 
| 868 | 
            +
                    bbox_min = torch.min(pos, dim=-2).values
         | 
| 869 | 
            +
                    bbox_max = torch.max(pos, dim=-2).values
         | 
| 870 | 
            +
                    center = (bbox_max + bbox_min) / 2.
         | 
| 871 | 
            +
                    pos -= center.unsqueeze(-2)
         | 
| 872 | 
            +
                else:
         | 
| 873 | 
            +
                    raise ValueError("unrecognized method")
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                if scale_method == 'max_rad':
         | 
| 876 | 
            +
                    scale = torch.max(norm(pos), dim=-1, keepdim=True).values.unsqueeze(-1)
         | 
| 877 | 
            +
                    pos = pos / scale
         | 
| 878 | 
            +
                elif scale_method == 'area':
         | 
| 879 | 
            +
                    if faces is None:
         | 
| 880 | 
            +
                        raise ValueError("must pass faces for area normalization")
         | 
| 881 | 
            +
                    coords = pos[faces]
         | 
| 882 | 
            +
                    vec_A = coords[:, 1, :] - coords[:, 0, :]
         | 
| 883 | 
            +
                    vec_B = coords[:, 2, :] - coords[:, 0, :]
         | 
| 884 | 
            +
                    face_areas = torch.norm(torch.cross(vec_A, vec_B, dim=-1), dim=1) * 0.5
         | 
| 885 | 
            +
                    total_area = torch.sum(face_areas)
         | 
| 886 | 
            +
                    scale = (1. / torch.sqrt(total_area))
         | 
| 887 | 
            +
                    pos = pos * scale
         | 
| 888 | 
            +
                else:
         | 
| 889 | 
            +
                    raise ValueError("unrecognized scale method")
         | 
| 890 | 
            +
                return pos
         | 
| 891 | 
            +
             | 
| 892 | 
            +
             | 
| 893 | 
            +
            # Finds the k nearest neighbors of source on target.
         | 
| 894 | 
            +
            # Return is two tensors (distances, indices). Returned points will be sorted in increasing order of distance.
         | 
| 895 | 
            +
            def find_knn(points_source, points_target, k, largest=False, omit_diagonal=False, method='brute'):
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                if omit_diagonal and points_source.shape[0] != points_target.shape[0]:
         | 
| 898 | 
            +
                    raise ValueError("omit_diagonal can only be used when source and target are same shape")
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                if method != 'cpu_kd' and points_source.shape[0] * points_target.shape[0] > 1e8:
         | 
| 901 | 
            +
                    method = 'cpu_kd'
         | 
| 902 | 
            +
                    print("switching to cpu_kd knn")
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                if method == 'brute':
         | 
| 905 | 
            +
             | 
| 906 | 
            +
                    # Expand so both are NxMx3 tensor
         | 
| 907 | 
            +
                    points_source_expand = points_source.unsqueeze(1)
         | 
| 908 | 
            +
                    points_source_expand = points_source_expand.expand(-1, points_target.shape[0], -1)
         | 
| 909 | 
            +
                    points_target_expand = points_target.unsqueeze(0)
         | 
| 910 | 
            +
                    points_target_expand = points_target_expand.expand(points_source.shape[0], -1, -1)
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    diff_mat = points_source_expand - points_target_expand
         | 
| 913 | 
            +
                    dist_mat = norm(diff_mat)
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                    if omit_diagonal:
         | 
| 916 | 
            +
                        torch.diagonal(dist_mat)[:] = float('inf')
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                    result = torch.topk(dist_mat, k=k, largest=largest, sorted=True)
         | 
| 919 | 
            +
                    return result
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                elif method == 'cpu_kd':
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                    if largest:
         | 
| 924 | 
            +
                        raise ValueError("can't do largest with cpu_kd")
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                    points_source_np = toNP(points_source)
         | 
| 927 | 
            +
                    points_target_np = toNP(points_target)
         | 
| 928 | 
            +
             | 
| 929 | 
            +
                    # Build the tree
         | 
| 930 | 
            +
                    kd_tree = sklearn.neighbors.KDTree(points_target_np)
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                    k_search = k + 1 if omit_diagonal else k
         | 
| 933 | 
            +
                    _, neighbors = kd_tree.query(points_source_np, k=k_search)
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                    if omit_diagonal:
         | 
| 936 | 
            +
                        # Mask out self element
         | 
| 937 | 
            +
                        mask = neighbors != np.arange(neighbors.shape[0])[:, np.newaxis]
         | 
| 938 | 
            +
             | 
| 939 | 
            +
                        # make sure we mask out exactly one element in each row, in rare case of many duplicate points
         | 
| 940 | 
            +
                        mask[np.sum(mask, axis=1) == mask.shape[1], -1] = False
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                        neighbors = neighbors[mask].reshape((neighbors.shape[0], neighbors.shape[1] - 1))
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                    inds = torch.tensor(neighbors, device=points_source.device, dtype=torch.int64)
         | 
| 945 | 
            +
                    dists = norm(points_source.unsqueeze(1).expand(-1, k, -1) - points_target[inds])
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                    return dists, inds
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                else:
         | 
| 950 | 
            +
                    raise ValueError("unrecognized method")
         | 
| 951 | 
            +
             | 
    	
        utils/io.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pathlib import Path
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import os.path as osp
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            import shutil
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def is_number(s):
         | 
| 10 | 
            +
                try:
         | 
| 11 | 
            +
                    float(s)
         | 
| 12 | 
            +
                    return True
         | 
| 13 | 
            +
                except ValueError:
         | 
| 14 | 
            +
                    return False
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def may_create_folder(folder_path):
         | 
| 18 | 
            +
                if not osp.exists(folder_path):
         | 
| 19 | 
            +
                    oldmask = os.umask(000)
         | 
| 20 | 
            +
                    os.makedirs(folder_path, mode=0o777)
         | 
| 21 | 
            +
                    os.umask(oldmask)
         | 
| 22 | 
            +
                    return True
         | 
| 23 | 
            +
                return False
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def make_clean_folder(folder_path):
         | 
| 27 | 
            +
                success = may_create_folder(folder_path)
         | 
| 28 | 
            +
                if not success:
         | 
| 29 | 
            +
                    shutil.rmtree(folder_path)
         | 
| 30 | 
            +
                    may_create_folder(folder_path)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def sorted_alphanum(file_list_ordered):
         | 
| 34 | 
            +
                convert = lambda text: int(text) if text.isdigit() else text
         | 
| 35 | 
            +
                alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) if len(c) > 0]
         | 
| 36 | 
            +
                return sorted(file_list_ordered, key=alphanum_key)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def list_files(folder_path, name_filter, alphanum_sort=False):
         | 
| 40 | 
            +
                file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
         | 
| 41 | 
            +
                if alphanum_sort:
         | 
| 42 | 
            +
                    return sorted_alphanum(file_list)
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    return sorted(file_list)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def list_folders(folder_path, name_filter=None, alphanum_sort=False):
         | 
| 48 | 
            +
                folders = list()
         | 
| 49 | 
            +
                for subfolder in Path(folder_path).iterdir():
         | 
| 50 | 
            +
                    if subfolder.is_dir() and not subfolder.name.startswith('.'):
         | 
| 51 | 
            +
                        folder_name = subfolder.name
         | 
| 52 | 
            +
                        if name_filter is not None:
         | 
| 53 | 
            +
                            if name_filter in folder_name:
         | 
| 54 | 
            +
                                folders.append(folder_name)
         | 
| 55 | 
            +
                        else:
         | 
| 56 | 
            +
                            folders.append(folder_name)
         | 
| 57 | 
            +
                if alphanum_sort:
         | 
| 58 | 
            +
                    return sorted_alphanum(folders)
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    return sorted(folders)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def read_lines(file_path):
         | 
| 64 | 
            +
                with open(file_path, 'r') as fin:
         | 
| 65 | 
            +
                    lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
         | 
| 66 | 
            +
                return lines
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def read_strings(file_path):
         | 
| 70 | 
            +
                with open(file_path, 'r') as fin:
         | 
| 71 | 
            +
                    ret = fin.readlines()
         | 
| 72 | 
            +
                return ''.join(ret)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def read_json(filepath):
         | 
| 76 | 
            +
                with open(filepath, 'r') as fh:
         | 
| 77 | 
            +
                    ret = json.load(fh)
         | 
| 78 | 
            +
                return ret
         | 
    	
        utils/layers.py
    ADDED
    
    | @@ -0,0 +1,430 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import os.path
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import scipy
         | 
| 7 | 
            +
            import scipy.sparse.linalg as sla
         | 
| 8 | 
            +
            # ^^^ we NEED to import scipy before torch, or it crashes :(
         | 
| 9 | 
            +
            # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../')
         | 
| 16 | 
            +
            if ROOT_DIR not in sys.path:
         | 
| 17 | 
            +
                sys.path.append(ROOT_DIR)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from diffusion_net.utils import toNP
         | 
| 20 | 
            +
            from diffusion_net.geometry import to_basis, from_basis
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class LearnedTimeDiffusion(nn.Module):
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                Applies diffusion with learned per-channel t.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                In the spectral domain this becomes 
         | 
| 28 | 
            +
                    f_out = e ^ (lambda_i t) f_in
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Inputs:
         | 
| 31 | 
            +
                  - values: (V,C) in the spectral domain
         | 
| 32 | 
            +
                  - L: (V,V) sparse laplacian
         | 
| 33 | 
            +
                  - evals: (K) eigenvalues
         | 
| 34 | 
            +
                  - mass: (V) mass matrix diagonal
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                  (note: L/evals may be omitted as None depending on method)
         | 
| 37 | 
            +
                Outputs:
         | 
| 38 | 
            +
                  - (V,C) diffused values 
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(self, C_inout, method='spectral'):
         | 
| 42 | 
            +
                    super(LearnedTimeDiffusion, self).__init__()
         | 
| 43 | 
            +
                    self.C_inout = C_inout
         | 
| 44 | 
            +
                    self.diffusion_time = nn.Parameter(torch.Tensor(C_inout))  # (C)
         | 
| 45 | 
            +
                    self.method = method  # one of ['spectral', 'implicit_dense']
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    nn.init.constant_(self.diffusion_time, 0.0)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, x, L, mass, evals, evecs):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # project times to the positive halfspace
         | 
| 52 | 
            +
                    # (and away from 0 in the incredibly rare chance that they get stuck)
         | 
| 53 | 
            +
                    with torch.no_grad():
         | 
| 54 | 
            +
                        self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    if x.shape[-1] != self.C_inout:
         | 
| 57 | 
            +
                        raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
         | 
| 58 | 
            +
                            x.shape, self.C_inout))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if self.method == 'spectral':
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                        # Transform to spectral
         | 
| 63 | 
            +
                        x_spec = to_basis(x, evecs, mass)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                        # Diffuse
         | 
| 66 | 
            +
                        time = self.diffusion_time
         | 
| 67 | 
            +
                        diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
         | 
| 68 | 
            +
                        x_diffuse_spec = diffusion_coefs * x_spec
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        # Transform back to per-vertex
         | 
| 71 | 
            +
                        x_diffuse = from_basis(x_diffuse_spec, evecs)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    elif self.method == 'implicit_dense':
         | 
| 74 | 
            +
                        V = x.shape[-2]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                        # Form the dense matrices (M + tL) with dims (B,C,V,V)
         | 
| 77 | 
            +
                        mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
         | 
| 78 | 
            +
                        mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
         | 
| 79 | 
            +
                        mat_dense += torch.diag_embed(mass).unsqueeze(1)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        # Factor the system
         | 
| 82 | 
            +
                        cholesky_factors = torch.linalg.cholesky(mat_dense)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        # Solve the system
         | 
| 85 | 
            +
                        rhs = x * mass.unsqueeze(-1)
         | 
| 86 | 
            +
                        rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
         | 
| 87 | 
            +
                        sols = torch.cholesky_solve(rhsT, cholesky_factors)
         | 
| 88 | 
            +
                        x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        raise ValueError("unrecognized method")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    return x_diffuse
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            class SpatialGradientFeatures(nn.Module):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                Input:
         | 
| 101 | 
            +
                    - vectors: (V,C,2)
         | 
| 102 | 
            +
                Output:
         | 
| 103 | 
            +
                    - dots: (V,C) dots 
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def __init__(self, C_inout, with_gradient_rotations=True):
         | 
| 107 | 
            +
                    super(SpatialGradientFeatures, self).__init__()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.C_inout = C_inout
         | 
| 110 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if (self.with_gradient_rotations):
         | 
| 113 | 
            +
                        self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 114 | 
            +
                        self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # self.norm = nn.InstanceNorm1d(C_inout)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def forward(self, vectors):
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    vectorsA = vectors  # (V,C)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if self.with_gradient_rotations:
         | 
| 125 | 
            +
                        vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
         | 
| 126 | 
            +
                        vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        vectorsBreal = self.A(vectors[..., 0])
         | 
| 129 | 
            +
                        vectorsBimag = self.A(vectors[..., 1])
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    return torch.tanh(dots)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            class MiniMLP(nn.Sequential):
         | 
| 137 | 
            +
                '''
         | 
| 138 | 
            +
                A simple MLP with configurable hidden layer sizes.
         | 
| 139 | 
            +
                '''
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
         | 
| 142 | 
            +
                    super(MiniMLP, self).__init__()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    for i in range(len(layer_sizes) - 1):
         | 
| 145 | 
            +
                        is_last = (i + 2 == len(layer_sizes))
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        if dropout and i > 0:
         | 
| 148 | 
            +
                            self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        # Affine map
         | 
| 151 | 
            +
                        self.add_module(
         | 
| 152 | 
            +
                            name + "_mlp_layer_{:03d}".format(i),
         | 
| 153 | 
            +
                            nn.Linear(
         | 
| 154 | 
            +
                                layer_sizes[i],
         | 
| 155 | 
            +
                                layer_sizes[i + 1],
         | 
| 156 | 
            +
                            ),
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        # Nonlinearity
         | 
| 160 | 
            +
                        # (but not on the last layer)
         | 
| 161 | 
            +
                        if not is_last:
         | 
| 162 | 
            +
                            self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class DiffusionNetBlock(nn.Module):
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                Inputs and outputs are defined at vertices
         | 
| 168 | 
            +
                """
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def __init__(self,
         | 
| 171 | 
            +
                             C_width,
         | 
| 172 | 
            +
                             mlp_hidden_dims,
         | 
| 173 | 
            +
                             dropout=True,
         | 
| 174 | 
            +
                             diffusion_method='spectral',
         | 
| 175 | 
            +
                             with_gradient_features=True,
         | 
| 176 | 
            +
                             with_gradient_rotations=True):
         | 
| 177 | 
            +
                    super(DiffusionNetBlock, self).__init__()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Specified dimensions
         | 
| 180 | 
            +
                    self.C_width = C_width
         | 
| 181 | 
            +
                    self.mlp_hidden_dims = mlp_hidden_dims
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    self.dropout = dropout
         | 
| 184 | 
            +
                    self.with_gradient_features = with_gradient_features
         | 
| 185 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # Diffusion block
         | 
| 188 | 
            +
                    self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    self.MLP_C = 2 * self.C_width
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    if self.with_gradient_features:
         | 
| 193 | 
            +
                        self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
         | 
| 194 | 
            +
                        self.MLP_C += self.C_width
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # MLPs
         | 
| 197 | 
            +
                    self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # Manage dimensions
         | 
| 202 | 
            +
                    B = x_in.shape[0]  # batch dimension
         | 
| 203 | 
            +
                    if x_in.shape[-1] != self.C_width:
         | 
| 204 | 
            +
                        raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
         | 
| 205 | 
            +
                            x_in.shape, self.C_width))
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # Diffusion block
         | 
| 208 | 
            +
                    x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # Compute gradient features, if using
         | 
| 211 | 
            +
                    if self.with_gradient_features:
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        # Compute gradients
         | 
| 214 | 
            +
                        x_grads = [
         | 
| 215 | 
            +
                        ]  # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
         | 
| 216 | 
            +
                        for b in range(B):
         | 
| 217 | 
            +
                            # gradient after diffusion
         | 
| 218 | 
            +
                            x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
         | 
| 219 | 
            +
                            x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                            x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
         | 
| 222 | 
            +
                        x_grad = torch.stack(x_grads, dim=0)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                        # Evaluate gradient features
         | 
| 225 | 
            +
                        x_grad_features = self.gradient_features(x_grad)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        # Stack inputs to mlp
         | 
| 228 | 
            +
                        feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        # Stack inputs to mlp
         | 
| 231 | 
            +
                        feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Apply the mlp
         | 
| 234 | 
            +
                    x0_out = self.mlp(feature_combined)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # Skip connection
         | 
| 237 | 
            +
                    x0_out = x0_out + x_in
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    return x0_out
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            class DiffusionNet(nn.Module):
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def __init__(self,
         | 
| 245 | 
            +
                             C_in,
         | 
| 246 | 
            +
                             C_out,
         | 
| 247 | 
            +
                             C_width=128,
         | 
| 248 | 
            +
                             N_block=4,
         | 
| 249 | 
            +
                             last_activation=None,
         | 
| 250 | 
            +
                             outputs_at='vertices',
         | 
| 251 | 
            +
                             mlp_hidden_dims=None,
         | 
| 252 | 
            +
                             dropout=True,
         | 
| 253 | 
            +
                             with_gradient_features=True,
         | 
| 254 | 
            +
                             with_gradient_rotations=True,
         | 
| 255 | 
            +
                             diffusion_method='spectral',
         | 
| 256 | 
            +
                             num_eigenbasis=128):
         | 
| 257 | 
            +
                    """
         | 
| 258 | 
            +
                    Construct a DiffusionNet.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    Parameters:
         | 
| 261 | 
            +
                        C_in (int):                     input dimension 
         | 
| 262 | 
            +
                        C_out (int):                    output dimension 
         | 
| 263 | 
            +
                        last_activation (func)          a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
         | 
| 264 | 
            +
                        outputs_at (string)             produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
         | 
| 265 | 
            +
                        C_width (int):                  dimension of internal DiffusionNet blocks (default: 128)
         | 
| 266 | 
            +
                        N_block (int):                  number of DiffusionNet blocks (default: 4)
         | 
| 267 | 
            +
                        mlp_hidden_dims (list of int):  a list of hidden layer sizes for MLPs (default: [C_width, C_width])
         | 
| 268 | 
            +
                        dropout (bool):                 if True, internal MLPs use dropout (default: True)
         | 
| 269 | 
            +
                        diffusion_method (string):      how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
         | 
| 270 | 
            +
                        with_gradient_features (bool):  if True, use gradient features (default: True)
         | 
| 271 | 
            +
                        with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
         | 
| 272 | 
            +
                        num_eigenbasis (int):           for trunking the eigenvalues eigenvectors
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    super(DiffusionNet, self).__init__()
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    ## Store parameters
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # Basic parameters
         | 
| 280 | 
            +
                    self.C_in = C_in
         | 
| 281 | 
            +
                    self.C_out = C_out
         | 
| 282 | 
            +
                    self.C_width = C_width
         | 
| 283 | 
            +
                    self.N_block = N_block
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    # Outputs
         | 
| 286 | 
            +
                    self.last_activation = last_activation
         | 
| 287 | 
            +
                    self.outputs_at = outputs_at
         | 
| 288 | 
            +
                    if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
         | 
| 289 | 
            +
                        raise ValueError("invalid setting for outputs_at")
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # MLP options
         | 
| 292 | 
            +
                    if mlp_hidden_dims == None:
         | 
| 293 | 
            +
                        mlp_hidden_dims = [C_width, C_width]
         | 
| 294 | 
            +
                    self.mlp_hidden_dims = mlp_hidden_dims
         | 
| 295 | 
            +
                    self.dropout = dropout
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    # Diffusion
         | 
| 298 | 
            +
                    self.diffusion_method = diffusion_method
         | 
| 299 | 
            +
                    if diffusion_method not in ['spectral', 'implicit_dense']:
         | 
| 300 | 
            +
                        raise ValueError("invalid setting for diffusion_method")
         | 
| 301 | 
            +
                    self.num_eigenbasis = num_eigenbasis
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # Gradient features
         | 
| 304 | 
            +
                    self.with_gradient_features = with_gradient_features
         | 
| 305 | 
            +
                    self.with_gradient_rotations = with_gradient_rotations
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    ## Set up the network
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # First and last affine layers
         | 
| 310 | 
            +
                    self.first_lin = nn.Linear(C_in, C_width)
         | 
| 311 | 
            +
                    self.last_lin = nn.Linear(C_width, C_out)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # DiffusionNet blocks
         | 
| 314 | 
            +
                    self.blocks = []
         | 
| 315 | 
            +
                    for i_block in range(self.N_block):
         | 
| 316 | 
            +
                        block = DiffusionNetBlock(C_width=C_width,
         | 
| 317 | 
            +
                                                  mlp_hidden_dims=mlp_hidden_dims,
         | 
| 318 | 
            +
                                                  dropout=dropout,
         | 
| 319 | 
            +
                                                  diffusion_method=diffusion_method,
         | 
| 320 | 
            +
                                                  with_gradient_features=with_gradient_features,
         | 
| 321 | 
            +
                                                  with_gradient_rotations=with_gradient_rotations)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                        self.blocks.append(block)
         | 
| 324 | 
            +
                        self.add_module("block_" + str(i_block), self.blocks[-1])
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
         | 
| 327 | 
            +
                    """
         | 
| 328 | 
            +
                    A forward pass on the DiffusionNet.
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    In the notation below, dimension are:
         | 
| 331 | 
            +
                        - C is the input channel dimension (C_in on construction)
         | 
| 332 | 
            +
                        - C_OUT is the output channel dimension (C_out on construction)
         | 
| 333 | 
            +
                        - N is the number of vertices/points, which CAN be different for each forward pass
         | 
| 334 | 
            +
                        - B is an OPTIONAL batch dimension
         | 
| 335 | 
            +
                        - K_EIG is the number of eigenvalues used for spectral acceleration
         | 
| 336 | 
            +
                    Generally, our data layout it is [N,C] or [B,N,C].
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    Parameters:
         | 
| 341 | 
            +
                        x_in (tensor):      Input features, dimension [N,C] or [B,N,C]
         | 
| 342 | 
            +
                        mass (tensor):      Mass vector, dimension [N] or [B,N]
         | 
| 343 | 
            +
                        L (tensor):         Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
         | 
| 344 | 
            +
                        evals (tensor):     Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
         | 
| 345 | 
            +
                        evecs (tensor):     Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
         | 
| 346 | 
            +
                        gradX (tensor):     Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
         | 
| 347 | 
            +
                        gradY (tensor):     Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    Returns:
         | 
| 350 | 
            +
                        x_out (tensor):    Output with dimension [N,C_out] or [B,N,C_out]
         | 
| 351 | 
            +
                    """
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    ## Check dimensions, and append batch dimension if not given
         | 
| 354 | 
            +
                    if x_in.shape[-1] != self.C_in:
         | 
| 355 | 
            +
                        raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
         | 
| 356 | 
            +
                            self.C_in, x_in.shape[-1]))
         | 
| 357 | 
            +
                    N = x_in.shape[-2]
         | 
| 358 | 
            +
                    if len(x_in.shape) == 2:
         | 
| 359 | 
            +
                        appended_batch_dim = True
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        # add a batch dim to all inputs
         | 
| 362 | 
            +
                        x_in = x_in.unsqueeze(0)
         | 
| 363 | 
            +
                        mass = mass.unsqueeze(0)
         | 
| 364 | 
            +
                        if L != None:
         | 
| 365 | 
            +
                            L = L.unsqueeze(0)
         | 
| 366 | 
            +
                        if evals != None:
         | 
| 367 | 
            +
                            evals = evals.unsqueeze(0)
         | 
| 368 | 
            +
                        if evecs != None:
         | 
| 369 | 
            +
                            evecs = evecs.unsqueeze(0)
         | 
| 370 | 
            +
                        if gradX != None:
         | 
| 371 | 
            +
                            gradX = gradX.unsqueeze(0)
         | 
| 372 | 
            +
                        if gradY != None:
         | 
| 373 | 
            +
                            gradY = gradY.unsqueeze(0)
         | 
| 374 | 
            +
                        if edges != None:
         | 
| 375 | 
            +
                            edges = edges.unsqueeze(0)
         | 
| 376 | 
            +
                        if faces != None:
         | 
| 377 | 
            +
                            faces = faces.unsqueeze(0)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    elif len(x_in.shape) == 3:
         | 
| 380 | 
            +
                        appended_batch_dim = False
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    else:
         | 
| 383 | 
            +
                        raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    evals = evals[..., :self.num_eigenbasis]
         | 
| 386 | 
            +
                    evecs = evecs[..., :self.num_eigenbasis]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    # Apply the first linear layer
         | 
| 389 | 
            +
                    x = self.first_lin(x_in)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    # Apply each of the blocks
         | 
| 392 | 
            +
                    for b in self.blocks:
         | 
| 393 | 
            +
                        x = b(x, mass, L, evals, evecs, gradX, gradY)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    # Apply the last linear layer
         | 
| 396 | 
            +
                    x = self.last_lin(x)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    # Remap output to faces/edges if requested
         | 
| 399 | 
            +
                    if self.outputs_at == 'vertices':
         | 
| 400 | 
            +
                        x_out = x
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    elif self.outputs_at == 'edges':
         | 
| 403 | 
            +
                        # Remap to edges
         | 
| 404 | 
            +
                        x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
         | 
| 405 | 
            +
                        edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
         | 
| 406 | 
            +
                        xe = torch.gather(x_gather, 1, edges_gather)
         | 
| 407 | 
            +
                        x_out = torch.mean(xe, dim=-1)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    elif self.outputs_at == 'faces':
         | 
| 410 | 
            +
                        # Remap to faces
         | 
| 411 | 
            +
                        x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
         | 
| 412 | 
            +
                        faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
         | 
| 413 | 
            +
                        xf = torch.gather(x_gather, 1, faces_gather)
         | 
| 414 | 
            +
                        x_out = torch.mean(xf, dim=-1)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    elif self.outputs_at == 'global_mean':
         | 
| 417 | 
            +
                        # Produce a single global mean ouput.
         | 
| 418 | 
            +
                        # Using a weighted mean according to the point mass/area is discretization-invariant.
         | 
| 419 | 
            +
                        # (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
         | 
| 420 | 
            +
                        x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    # Apply last nonlinearity if specified
         | 
| 423 | 
            +
                    if self.last_activation != None:
         | 
| 424 | 
            +
                        x_out = self.last_activation(x_out)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # Remove batch dim if we added it
         | 
| 427 | 
            +
                    if appended_batch_dim:
         | 
| 428 | 
            +
                        x_out = x_out.squeeze(0)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return x_out
         | 
    	
        utils/mesh.py
    ADDED
    
    | @@ -0,0 +1,214 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
            import os 
         | 
| 4 | 
            +
            from tqdm import tqdm
         | 
| 5 | 
            +
            import potpourri3d as pp3d
         | 
| 6 | 
            +
            import open3d as o3d
         | 
| 7 | 
            +
            import scipy.io as sio
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import re
         | 
| 10 | 
            +
            import sys
         | 
| 11 | 
            +
            sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 12 | 
            +
            from shape_data import get_data_dirs
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # List of file extensions to consider as "mesh" files.
         | 
| 15 | 
            +
            # Kudos to chatgpt!
         | 
| 16 | 
            +
            # Add or remove extensions here as needed.
         | 
| 17 | 
            +
            MESH_EXTENSIONS = {".ply", ".obj", ".off", ".stl", ".fbx", ".gltf", ".glb"}
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def sorted_alphanum(file_list_ordered):
         | 
| 20 | 
            +
                def convert(text):
         | 
| 21 | 
            +
                    return int(text) if text.isdigit() else text
         | 
| 22 | 
            +
                def alphanum_key(key):
         | 
| 23 | 
            +
                    return [convert(c) for c in re.split('([0-9]+)', str(key)) if len(c) > 0]
         | 
| 24 | 
            +
                return sorted(file_list_ordered, key=alphanum_key)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def list_files(folder_path, name_filter, alphanum_sort=False):
         | 
| 28 | 
            +
                file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
         | 
| 29 | 
            +
                if alphanum_sort:
         | 
| 30 | 
            +
                    return sorted_alphanum(file_list)
         | 
| 31 | 
            +
                else:
         | 
| 32 | 
            +
                    return sorted(file_list)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def find_mesh_files(directory: Path, extensions: set=MESH_EXTENSIONS, alphanum_sort=False) -> list[Path]:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Recursively find all files in 'directory' whose suffix (lowercased) is in 'extensions'.
         | 
| 37 | 
            +
                Returns a list of Path objects.
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                matches = []
         | 
| 40 | 
            +
                for path in directory.rglob("*"):
         | 
| 41 | 
            +
                    if path.is_file() and path.suffix.lower() in extensions:
         | 
| 42 | 
            +
                        matches.append(path)
         | 
| 43 | 
            +
                if alphanum_sort:
         | 
| 44 | 
            +
                    return sorted_alphanum(matches)
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    return sorted(matches)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def save_ply(file_name, V, F, Rho=None, color=None):
         | 
| 49 | 
            +
                """Save mesh information either as an ASCII ply file.
         | 
| 50 | 
            +
                https://github.com/emmanuel-hartman/BaRe-ESA/blob/main/utils/input_output.py
         | 
| 51 | 
            +
                Input:
         | 
| 52 | 
            +
                    - file_name: specified path for saving mesh [string]
         | 
| 53 | 
            +
                    - V: vertices of the triangulated surface [nVx3 numpy ndarray]
         | 
| 54 | 
            +
                    - F: faces of the triangulated surface [nFx3 numpy ndarray]
         | 
| 55 | 
            +
                    - Rho: weights defined on the vertices of the triangulated surface [nVx1 numpy ndarray, default=None]
         | 
| 56 | 
            +
                    - color: colormap [nVx3 numpy ndarray of RGB triples]
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                Output:
         | 
| 59 | 
            +
                    - file_name.mat or file_name.ply file containing mesh information
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # Save as .ply file
         | 
| 63 | 
            +
                nV = V.shape[0]
         | 
| 64 | 
            +
                nF = F.shape[0]   
         | 
| 65 | 
            +
                if not ".ply" in file_name:
         | 
| 66 | 
            +
                    file_name += ".ply"
         | 
| 67 | 
            +
                file = open(file_name, "w")
         | 
| 68 | 
            +
                lines = ("ply","\n","format ascii 1.0","\n", "element vertex {}".format(nV),"\n","property float x","\n","property float y","\n","property float z","\n")
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                if color is not None:
         | 
| 71 | 
            +
                    lines += ("property uchar red","\n","property uchar green","\n","property uchar blue","\n")
         | 
| 72 | 
            +
                    if Rho is not None:
         | 
| 73 | 
            +
                        lines += ("property uchar alpha","\n")
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                lines += ("element face {}".format(nF),"\n","property list uchar int vertex_index","\n","end_header","\n")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                file.writelines(lines)
         | 
| 78 | 
            +
                lines = []
         | 
| 79 | 
            +
                for i in range(0,nV):
         | 
| 80 | 
            +
                    for j in range(0,3):
         | 
| 81 | 
            +
                        lines.append(str(V[i][j]))
         | 
| 82 | 
            +
                        lines.append(" ")
         | 
| 83 | 
            +
                    if color is not None:
         | 
| 84 | 
            +
                        for j in range(0,3):
         | 
| 85 | 
            +
                            lines.append(str(color[i][j]))
         | 
| 86 | 
            +
                            lines.append(" ")
         | 
| 87 | 
            +
                        if Rho is not None:
         | 
| 88 | 
            +
                            lines.append(str(Rho[i]))
         | 
| 89 | 
            +
                            lines.append(" ")
         | 
| 90 | 
            +
                                
         | 
| 91 | 
            +
                    lines.append("\n")
         | 
| 92 | 
            +
                for i in range(0,nF):
         | 
| 93 | 
            +
                    l = len(F[i,:])
         | 
| 94 | 
            +
                    lines.append(str(l))
         | 
| 95 | 
            +
                    lines.append(" ")
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    for j in range(0,l):
         | 
| 98 | 
            +
                        lines.append(str(F[i,j]))
         | 
| 99 | 
            +
                        lines.append(" ")
         | 
| 100 | 
            +
                    lines.append("\n")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                file.writelines(lines)
         | 
| 103 | 
            +
                file.close()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            def numpy_to_open3d_mesh(V, F):
         | 
| 106 | 
            +
                # Create an empty TriangleMesh object
         | 
| 107 | 
            +
                mesh = o3d.geometry.TriangleMesh()
         | 
| 108 | 
            +
                # Set vertices
         | 
| 109 | 
            +
                mesh.vertices = o3d.utility.Vector3dVector(V)
         | 
| 110 | 
            +
                # Set triangles
         | 
| 111 | 
            +
                mesh.triangles = o3d.utility.Vector3iVector(F)
         | 
| 112 | 
            +
                return mesh
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def load_mesh(filepath, scale=True, return_vnormals=False):
         | 
| 117 | 
            +
                V, F = pp3d.read_mesh(filepath)
         | 
| 118 | 
            +
                mesh = numpy_to_open3d_mesh(V, F)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                tmat = np.identity(4, dtype=np.float32)
         | 
| 121 | 
            +
                center = mesh.get_center()
         | 
| 122 | 
            +
                tmat[:3, 3] = -center
         | 
| 123 | 
            +
                area = mesh.get_surface_area()
         | 
| 124 | 
            +
                if scale:
         | 
| 125 | 
            +
                    smat = np.identity(4, dtype=np.float32)
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    smat[:3, :3] = np.identity(3, dtype=np.float32) / np.sqrt(area)
         | 
| 128 | 
            +
                    tmat = smat @ tmat
         | 
| 129 | 
            +
                mesh.transform(tmat)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                vertices = np.asarray(mesh.vertices, dtype=np.float32)
         | 
| 132 | 
            +
                faces = np.asarray(mesh.triangles, dtype=np.int32)
         | 
| 133 | 
            +
                if return_vnormals:
         | 
| 134 | 
            +
                    mesh.compute_vertex_normals()
         | 
| 135 | 
            +
                    vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
         | 
| 136 | 
            +
                    if scale:
         | 
| 137 | 
            +
                        return vertices, faces, vnormals, area, center
         | 
| 138 | 
            +
                    return vertices, faces, vnormals, area, center
         | 
| 139 | 
            +
                else:
         | 
| 140 | 
            +
                    return vertices, faces
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def mesh_geod_matrix(vertices, faces, do_tqdm=False, verbose=False):
         | 
| 144 | 
            +
                if verbose:
         | 
| 145 | 
            +
                    print("Setting Geodesic matrix bw vertices")
         | 
| 146 | 
            +
                n_vertices = vertices.shape[0]
         | 
| 147 | 
            +
                distmat = np.zeros((n_vertices, n_vertices))
         | 
| 148 | 
            +
                solver = pp3d.MeshHeatMethodDistanceSolver(vertices, faces)
         | 
| 149 | 
            +
                if do_tqdm:
         | 
| 150 | 
            +
                    iterable = tqdm(range(n_vertices))
         | 
| 151 | 
            +
                else:
         | 
| 152 | 
            +
                    iterable = range(n_vertices)
         | 
| 153 | 
            +
                for vertind in iterable:
         | 
| 154 | 
            +
                    distmat[vertind] = np.maximum(solver.compute_distance(vertind), 0)
         | 
| 155 | 
            +
                geod_mat = distmat
         | 
| 156 | 
            +
                return geod_mat
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def prepare_geod_mats(shapes_folder, out, basename=None):
         | 
| 160 | 
            +
                if basename is None:
         | 
| 161 | 
            +
                    basename = os.path.basename(os.path.dirname(shapes_folder)) #+ "_" + os.path.basename(shapes_folder)
         | 
| 162 | 
            +
                case = basename
         | 
| 163 | 
            +
                case_folder_out = os.path.join(out, case)
         | 
| 164 | 
            +
                os.makedirs(case_folder_out, exist_ok=True)
         | 
| 165 | 
            +
                all_shapes = [f for f in os.listdir(shapes_folder) if (".ply" in f) or (".off" in f) or ('.obj' in f)]
         | 
| 166 | 
            +
                for shape in tqdm(all_shapes, "Processing " + os.path.basename(shapes_folder)):
         | 
| 167 | 
            +
                    vertices, faces = pp3d.read_mesh(os.path.join(shapes_folder, shape))
         | 
| 168 | 
            +
                    areas = pp3d.face_areas(vertices, faces)
         | 
| 169 | 
            +
                    mat = mesh_geod_matrix(vertices, faces, verbose=False)
         | 
| 170 | 
            +
                    dict_save = {
         | 
| 171 | 
            +
                        'geod_dist': mat,
         | 
| 172 | 
            +
                        'areas_f': areas
         | 
| 173 | 
            +
                    }
         | 
| 174 | 
            +
                    sio.savemat(os.path.join(case_folder_out, shape[:-4]+'.mat'), dict_save)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            if __name__ == "__main__":
         | 
| 178 | 
            +
                parser = argparse.ArgumentParser(description="What to do.")
         | 
| 179 | 
            +
                parser.add_argument('--make_geods', required=True, type=int, default=0, help='launch computation of geod matrices')
         | 
| 180 | 
            +
                parser.add_argument('--data', required=False, type=str, default=None)
         | 
| 181 | 
            +
                parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
         | 
| 182 | 
            +
                parser.add_argument('--basename', required=False, type=str, default=None)
         | 
| 183 | 
            +
                args = parser.parse_args()
         | 
| 184 | 
            +
                if args.make_geods:
         | 
| 185 | 
            +
                    # from config import get_geod_path, get_dataset_path, get_template_path
         | 
| 186 | 
            +
                    # output = get_geod_path()
         | 
| 187 | 
            +
                    output = os.path.join(args.datadir, "geomats")
         | 
| 188 | 
            +
                    if args.data == "humans":
         | 
| 189 | 
            +
                        all_folders = [get_data_dirs(args.datadir, "faust", 'test')[0], get_data_dirs(args.datadir, "scape", 'test')[0], get_data_dirs(args.datadir, "shrec19", 'test')[0]]
         | 
| 190 | 
            +
                        for folder in all_folders:
         | 
| 191 | 
            +
                            prepare_geod_mats(folder, output)
         | 
| 192 | 
            +
                    elif args.data == "dt4d":
         | 
| 193 | 
            +
                        data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
         | 
| 194 | 
            +
                        all_folders = sorted([f for f in os.listdir(data_dir) if "cross" not in f])
         | 
| 195 | 
            +
                        for folder in all_folders:
         | 
| 196 | 
            +
                            prepare_geod_mats(os.path.join(data_dir, folder), os.path.join(output, "DT4D"), basename=folder)
         | 
| 197 | 
            +
                    elif args.data is not None:
         | 
| 198 | 
            +
                        data_dir, _, corr_dir = get_data_dirs(args.datadir, args.data, 'test')
         | 
| 199 | 
            +
                        prepare_geod_mats(data_dir, output, args.basename)
         | 
| 200 | 
            +
                # parser = argparse.ArgumentParser(description="Find all mesh files in a folder and list their paths.")
         | 
| 201 | 
            +
                # parser.add_argument("--folder", type=Path, help="Path to the folder to search (will search recursively).")
         | 
| 202 | 
            +
                # args = parser.parse_args()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                # search_folder = args.folder
         | 
| 205 | 
            +
                # if not search_folder.is_dir():
         | 
| 206 | 
            +
                #     print(f"Error: '{search_folder}' is not a valid directory.")
         | 
| 207 | 
            +
                # else:
         | 
| 208 | 
            +
                #     # Find all matching files
         | 
| 209 | 
            +
                #     mesh_files = find_mesh_files(search_folder)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                #     # Sort the results for consistency
         | 
| 212 | 
            +
                #     mesh_files.sort()
         | 
| 213 | 
            +
                #     for p in mesh_files:
         | 
| 214 | 
            +
                #         print(p.resolve())
         | 
    	
        utils/meshplot.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import potpourri3d as pp3d 
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import meshplot as mp
         | 
| 5 | 
            +
            from ipygany import PolyMesh, Scene
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            colors = np.array([[255, 119, 0], #orange
         | 
| 10 | 
            +
                                       [163, 51, 219], #violet
         | 
| 11 | 
            +
                                       [0, 246, 205], #bleu clair
         | 
| 12 | 
            +
                                       [0, 131, 246], #bleu floncé 
         | 
| 13 | 
            +
                                       [246, 234, 0], #jaune
         | 
| 14 | 
            +
                                       [143, 188, 143], #rouge
         | 
| 15 | 
            +
                                [255, 0, 0]]) 
         | 
| 16 | 
            +
                
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def double_plot_surf(surf_1, surf_2,cmap1=None,cmap2=None):
         | 
| 19 | 
            +
                d = mp.subplot(surf_1.vertices, surf_1.faces, c=cmap1, s=[2, 2, 0])
         | 
| 20 | 
            +
                mp.subplot(surf_2.vertices, surf_2.faces, c=cmap2, s=[2, 2, 1], data=d)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def visu_pts(surf, colors=colors, idx_points=None, n_kpts=5):
         | 
| 23 | 
            +
                areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
         | 
| 24 | 
            +
                area = np.sqrt(areas.sum()/2)
         | 
| 25 | 
            +
                if idx_points is None:
         | 
| 26 | 
            +
                    center = (surf.centers*(areas)).sum(axis=0)/areas.sum()
         | 
| 27 | 
            +
                    surf.updateVertices((surf.vertices - center)/area)
         | 
| 28 | 
            +
                    surf.cotanLaplacian()
         | 
| 29 | 
            +
                    idx_points = surf.get_keypoints(n_points=n_kpts)
         | 
| 30 | 
            +
                solver = pp3d.MeshHeatMethodDistanceSolver(surf.vertices, surf.faces)
         | 
| 31 | 
            +
                norm_center = solver.compute_distance(idx_points[-1])
         | 
| 32 | 
            +
                color_array = np.zeros(surf.vertices.shape)
         | 
| 33 | 
            +
                for i in range(len(idx_points)):
         | 
| 34 | 
            +
                    coeff = 0.1 + 1.5*norm_center[idx_points[i]]
         | 
| 35 | 
            +
                    i_v = idx_points[i]
         | 
| 36 | 
            +
                    dist = solver.compute_distance(i_v)*2
         | 
| 37 | 
            +
                    #color_array += np.clip(1-dist, 0, np.inf)[:, None]*colors[i][None, :]
         | 
| 38 | 
            +
                    color_array += np.exp(-dist**2/coeff)[:, None]*colors[i][None, :]
         | 
| 39 | 
            +
                color_array = np.clip(color_array, 0, 255.)
         | 
| 40 | 
            +
                return color_array
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def toNP(tens):
         | 
| 43 | 
            +
                if isinstance(tens, torch.Tensor):
         | 
| 44 | 
            +
                    return tens.detach().squeeze().cpu().numpy()
         | 
| 45 | 
            +
                return tens
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def overlay_surf(shape_vertices, shape_faces, target_vertices, target_faces, colors=["tomato", "darksalmon"]):
         | 
| 49 | 
            +
                shape_vertices = toNP(shape_vertices)
         | 
| 50 | 
            +
                target_vertices = toNP(target_vertices)
         | 
| 51 | 
            +
                mesh_1 = PolyMesh(
         | 
| 52 | 
            +
                    vertices=shape_vertices,
         | 
| 53 | 
            +
                    triangle_indices=shape_faces
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                mesh_1.default_color = colors[0]
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                mesh_2 = PolyMesh(
         | 
| 59 | 
            +
                    vertices=target_vertices,
         | 
| 60 | 
            +
                    triangle_indices=target_faces
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
                mesh_2.default_color = colors[1]
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
                scene = Scene([mesh_1, mesh_2])
         | 
| 66 | 
            +
                #scene = Scene([mesh_5])
         | 
| 67 | 
            +
                return scene, [mesh_1, mesh_2]
         | 
    	
        utils/misc.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os.path as osp
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import yaml
         | 
| 6 | 
            +
            import omegaconf
         | 
| 7 | 
            +
            from omegaconf import OmegaConf
         | 
| 8 | 
            +
            from scipy.spatial import cKDTree
         | 
| 9 | 
            +
            from pathlib import Path
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class KNNSearch(object):
         | 
| 13 | 
            +
                DTYPE = np.float32
         | 
| 14 | 
            +
                NJOBS = 4
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, data):
         | 
| 17 | 
            +
                    self.data = np.asarray(data, dtype=self.DTYPE)
         | 
| 18 | 
            +
                    self.kdtree = cKDTree(self.data)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def query(self, kpts, k, return_dists=False):
         | 
| 21 | 
            +
                    kpts = np.asarray(kpts, dtype=self.DTYPE)
         | 
| 22 | 
            +
                    nndists, nnindices = self.kdtree.query(kpts, k=k, workers=self.NJOBS)
         | 
| 23 | 
            +
                    if return_dists:
         | 
| 24 | 
            +
                        return nnindices, nndists
         | 
| 25 | 
            +
                    else:
         | 
| 26 | 
            +
                        return nnindices
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def query_ball(self, kpt, radius):
         | 
| 29 | 
            +
                    kpt = np.asarray(kpt, dtype=self.DTYPE)
         | 
| 30 | 
            +
                    assert kpt.ndim == 1
         | 
| 31 | 
            +
                    nnindices = self.kdtree.query_ball_point(kpt, radius, n_jobs=self.NJOBS)
         | 
| 32 | 
            +
                    return nnindices
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def validate_str(x):
         | 
| 36 | 
            +
                return x is not None and x != ''
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def hashing(arr, M):
         | 
| 40 | 
            +
                assert isinstance(arr, np.ndarray) and arr.ndim == 2
         | 
| 41 | 
            +
                N, D = arr.shape
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                hash_vec = np.zeros(N, dtype=np.int64)
         | 
| 44 | 
            +
                for d in range(D):
         | 
| 45 | 
            +
                    hash_vec += arr[:, d] * M**d
         | 
| 46 | 
            +
                return hash_vec
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def omegaconf_to_dotdict(hparams):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def _to_dot_dict(cfg):
         | 
| 52 | 
            +
                    res = {}
         | 
| 53 | 
            +
                    for k, v in cfg.items():
         | 
| 54 | 
            +
                        if v is None:
         | 
| 55 | 
            +
                            res[k] = v
         | 
| 56 | 
            +
                        elif isinstance(v, omegaconf.DictConfig):
         | 
| 57 | 
            +
                            res.update({k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()})
         | 
| 58 | 
            +
                        elif isinstance(v, (str, int, float, bool)):
         | 
| 59 | 
            +
                            res[k] = v
         | 
| 60 | 
            +
                        elif isinstance(v, omegaconf.ListConfig):
         | 
| 61 | 
            +
                            res[k] = omegaconf.OmegaConf.to_container(v, resolve=True)
         | 
| 62 | 
            +
                        else:
         | 
| 63 | 
            +
                            raise RuntimeError('The type of {} is not supported.'.format(type(v)))
         | 
| 64 | 
            +
                    return res
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return _to_dot_dict(hparams)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def incrange(start, end, step):
         | 
| 70 | 
            +
                assert step > 0
         | 
| 71 | 
            +
                res = [start]
         | 
| 72 | 
            +
                if start <= end:
         | 
| 73 | 
            +
                    while res[-1] + step <= end:
         | 
| 74 | 
            +
                        res.append(res[-1] + step)
         | 
| 75 | 
            +
                else:
         | 
| 76 | 
            +
                    while res[-1] - step >= end:
         | 
| 77 | 
            +
                        res.append(res[-1] - step)
         | 
| 78 | 
            +
                return res
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def seeding(seed=0):
         | 
| 82 | 
            +
                torch.manual_seed(seed)
         | 
| 83 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 84 | 
            +
                np.random.seed(seed)
         | 
| 85 | 
            +
                random.seed(seed)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                torch.backends.cudnn.enabled = True
         | 
| 88 | 
            +
                torch.backends.cudnn.benchmark = True
         | 
| 89 | 
            +
                torch.backends.cudnn.deterministic = True
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def run_trainer(trainer_cls):
         | 
| 93 | 
            +
                cfg_cli = OmegaConf.from_cli()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                assert cfg_cli.run_mode is not None
         | 
| 96 | 
            +
                if cfg_cli.run_mode == 'train':
         | 
| 97 | 
            +
                    assert cfg_cli.run_cfg is not None
         | 
| 98 | 
            +
                    cfg = OmegaConf.merge(
         | 
| 99 | 
            +
                        OmegaConf.load(cfg_cli.run_cfg),
         | 
| 100 | 
            +
                        cfg_cli,
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    OmegaConf.resolve(cfg)
         | 
| 103 | 
            +
                    cfg = omegaconf_to_dotdict(cfg)
         | 
| 104 | 
            +
                    seeding(cfg['seed'])
         | 
| 105 | 
            +
                    trainer = trainer_cls(cfg)
         | 
| 106 | 
            +
                    trainer.train()
         | 
| 107 | 
            +
                    trainer.test()
         | 
| 108 | 
            +
                elif cfg_cli.run_mode == 'test':
         | 
| 109 | 
            +
                    assert cfg_cli.run_ckpt is not None
         | 
| 110 | 
            +
                    log_dir = str(Path(cfg_cli.run_ckpt).parent)
         | 
| 111 | 
            +
                    cfg = OmegaConf.merge(
         | 
| 112 | 
            +
                        OmegaConf.load(osp.join(log_dir, 'config.yml')),
         | 
| 113 | 
            +
                        cfg_cli,
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                    OmegaConf.resolve(cfg)
         | 
| 116 | 
            +
                    cfg = omegaconf_to_dotdict(cfg)
         | 
| 117 | 
            +
                    cfg['test_ckpt'] = cfg_cli.run_ckpt
         | 
| 118 | 
            +
                    seeding(cfg['seed'])
         | 
| 119 | 
            +
                    trainer = trainer_cls(cfg)
         | 
| 120 | 
            +
                    trainer.test()
         | 
| 121 | 
            +
                else:
         | 
| 122 | 
            +
                    raise RuntimeError(f'Mode {cfg_cli.run_mode} is not supported.')
         | 
    	
        utils/pickle_stuff.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pickle
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            import importlib
         | 
| 4 | 
            +
            import sys 
         | 
| 5 | 
            +
            import os 
         | 
| 6 | 
            +
            sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Mapping of old module names to new module names
         | 
| 10 | 
            +
            MODULE_RENAME_MAP = {
         | 
| 11 | 
            +
                'module': 'diffu_models',
         | 
| 12 | 
            +
                'module.model': 'diffu_models.precond',
         | 
| 13 | 
            +
                'module.dit_models': 'diffu_models.dit_models',
         | 
| 14 | 
            +
                'module.model': 'diffu_models.precond',
         | 
| 15 | 
            +
                # add more as needed
         | 
| 16 | 
            +
            }
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            class RenameUnpickler(pickle.Unpickler):
         | 
| 19 | 
            +
                def find_class(self, module, name):
         | 
| 20 | 
            +
                    if module in MODULE_RENAME_MAP:
         | 
| 21 | 
            +
                        module = MODULE_RENAME_MAP[module]
         | 
| 22 | 
            +
                    try:
         | 
| 23 | 
            +
                        return super().find_class(module, name)
         | 
| 24 | 
            +
                    except ModuleNotFoundError as e:
         | 
| 25 | 
            +
                        raise ModuleNotFoundError(f"Could not find module '{module}'. You may need to update MODULE_RENAME_MAP.") from e
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Usage
         | 
| 28 | 
            +
            def load_renamed_pickle(file_path):
         | 
| 29 | 
            +
                with open(file_path, 'rb') as f:
         | 
| 30 | 
            +
                    return RenameUnpickler(f).load()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def safe_load_with_fallback(file_path):
         | 
| 33 | 
            +
                try:
         | 
| 34 | 
            +
                    with open(file_path, 'rb') as f:
         | 
| 35 | 
            +
                        return pickle.load(f)
         | 
| 36 | 
            +
                except ModuleNotFoundError:
         | 
| 37 | 
            +
                    with open(file_path, 'rb') as f:
         | 
| 38 | 
            +
                        return RenameUnpickler(f).load()
         | 
    	
        utils/surfaces.py
    ADDED
    
    | @@ -0,0 +1,1377 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import scipy as sp
         | 
| 3 | 
            +
            import scipy.linalg as spLA
         | 
| 4 | 
            +
            import scipy.sparse.linalg as sla
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
            import copy
         | 
| 8 | 
            +
            import potpourri3d as pp3d
         | 
| 9 | 
            +
            from sklearn.cluster import KMeans#, DBSCAN, SpectralClustering
         | 
| 10 | 
            +
            from scipy.spatial import cKDTree
         | 
| 11 | 
            +
            import trimesh as tm
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            try:
         | 
| 15 | 
            +
                from vtk import *
         | 
| 16 | 
            +
                import vtk.util.numpy_support as v2n
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                gotVTK = True
         | 
| 19 | 
            +
            except ImportError:
         | 
| 20 | 
            +
                print('could not import VTK functions')
         | 
| 21 | 
            +
                gotVTK = False
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            # import kernelFunctions as kfun
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            class vtkFields:
         | 
| 27 | 
            +
                def __init__(self):
         | 
| 28 | 
            +
                    self.scalars = []
         | 
| 29 | 
            +
                    self.vectors = []
         | 
| 30 | 
            +
                    self.normals = []
         | 
| 31 | 
            +
                    self.tensors = []
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # General surface class, fibers possible. The fiber is a vector field of norm 1, defined on each vertex.
         | 
| 35 | 
            +
            # It must be an array of the same size (number of vertices)x3.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Surface:
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # Fibers: a list of vectors, with i-th element corresponding to the value of the vector field at vertex i
         | 
| 40 | 
            +
                # Contains as object:
         | 
| 41 | 
            +
                # vertices : all the vertices
         | 
| 42 | 
            +
                # centers: the centers of each face
         | 
| 43 | 
            +
                # faces: faces along with the id of the faces
         | 
| 44 | 
            +
                # surfel: surface element of each face (area*normal)
         | 
| 45 | 
            +
                # List of methods:
         | 
| 46 | 
            +
                # read : from filename type, call readFILENAME and set all surface attributes
         | 
| 47 | 
            +
                # updateVertices: update the whole surface after a modification of the vertices
         | 
| 48 | 
            +
                # computeVertexArea and Normals
         | 
| 49 | 
            +
                # getEdges
         | 
| 50 | 
            +
                # LocalSignedDistance distance function in the neighborhood of a shape
         | 
| 51 | 
            +
                # toPolyData: convert the surface to a polydata vtk object
         | 
| 52 | 
            +
                # fromPolyDate: guess what
         | 
| 53 | 
            +
                # Simplify : simplify the meshes
         | 
| 54 | 
            +
                # flipfaces: invert the indices of faces from [a, b, c] to [a, c, b]
         | 
| 55 | 
            +
                # smooth: get a smoother surface
         | 
| 56 | 
            +
                # Isosurface: compute isosurface
         | 
| 57 | 
            +
                # edgeRecove: ensure that orientation is correct
         | 
| 58 | 
            +
                # remove isolated: if isolated vertice, remove it
         | 
| 59 | 
            +
                # laplacianMatrix: compute Laplacian Matrix of the surface graph
         | 
| 60 | 
            +
                # graphLaplacianMatrix: ???
         | 
| 61 | 
            +
                # laplacianSegmentation: segment the surface using laplacian properties
         | 
| 62 | 
            +
                # surfVolume: compute volume inscribed in the surface (+ inscribed infinitesimal volume for each face)
         | 
| 63 | 
            +
                # surfCenter: compute surface Center
         | 
| 64 | 
            +
                # surfMoments: compute surface second order moments
         | 
| 65 | 
            +
                # surfU: compute the informations for surface rigid alignement
         | 
| 66 | 
            +
                # surfEllipsoid: compute ellipsoid representing the surface
         | 
| 67 | 
            +
                # savebyu of vitk or vtk2: save the surface in a file
         | 
| 68 | 
            +
                # concatenate: concatenate to another surface
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __init__(self, surf=None, filename=None, FV=None):
         | 
| 71 | 
            +
                    if surf == None:
         | 
| 72 | 
            +
                        if FV == None:
         | 
| 73 | 
            +
                            if filename == None:
         | 
| 74 | 
            +
                                self.vertices = np.empty(0)
         | 
| 75 | 
            +
                                self.centers = np.empty(0)
         | 
| 76 | 
            +
                                self.faces = np.empty(0)
         | 
| 77 | 
            +
                                self.surfel = np.empty(0)
         | 
| 78 | 
            +
                            else:
         | 
| 79 | 
            +
                                if type(filename) is list:
         | 
| 80 | 
            +
                                    fvl = []
         | 
| 81 | 
            +
                                    for name in filename:
         | 
| 82 | 
            +
                                        fvl.append(Surface(filename=name))
         | 
| 83 | 
            +
                                    self.concatenate(fvl)
         | 
| 84 | 
            +
                                else:
         | 
| 85 | 
            +
                                    self.read(filename)
         | 
| 86 | 
            +
                        else:
         | 
| 87 | 
            +
                            self.vertices = np.copy(FV[1])
         | 
| 88 | 
            +
                            self.faces = np.int_(FV[0])
         | 
| 89 | 
            +
                            self.computeCentersAreas()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        self.vertices = np.copy(surf.vertices)
         | 
| 93 | 
            +
                        self.faces = np.copy(surf.faces)
         | 
| 94 | 
            +
                        self.surfel = np.copy(surf.surfel)
         | 
| 95 | 
            +
                        self.centers = np.copy(surf.centers)
         | 
| 96 | 
            +
                        self.computeCentersAreas()
         | 
| 97 | 
            +
                    self.volume, self.vols = self.surfVolume()
         | 
| 98 | 
            +
                    self.center = self.surfCenter()
         | 
| 99 | 
            +
                    self.cotanLaplacian()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def read(self, filename):
         | 
| 102 | 
            +
                    (mainPart, ext) = os.path.splitext(filename)
         | 
| 103 | 
            +
                    if ext == '.byu':
         | 
| 104 | 
            +
                        self.readbyu(filename)
         | 
| 105 | 
            +
                    elif ext == '.off':
         | 
| 106 | 
            +
                        self.readOFF(filename)
         | 
| 107 | 
            +
                    elif ext == '.vtk':
         | 
| 108 | 
            +
                        self.readVTK(filename)
         | 
| 109 | 
            +
                    elif ext == '.obj':
         | 
| 110 | 
            +
                        self.readOBJ(filename)
         | 
| 111 | 
            +
                    elif ext == '.ply':
         | 
| 112 | 
            +
                        self.readPLY(filename)
         | 
| 113 | 
            +
                    elif ext == '.tri' or ext == ".ntri":
         | 
| 114 | 
            +
                        self.readTRI(filename)
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        print('Unknown Surface Extension:', ext)
         | 
| 117 | 
            +
                        self.vertices = np.empty(0)
         | 
| 118 | 
            +
                        self.centers = np.empty(0)
         | 
| 119 | 
            +
                        self.faces = np.empty(0)
         | 
| 120 | 
            +
                        self.surfel = np.empty(0)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # face centers and area weighted normal
         | 
| 123 | 
            +
                def computeCentersAreas(self):
         | 
| 124 | 
            +
                    xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 125 | 
            +
                    xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 126 | 
            +
                    xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 127 | 
            +
                    self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 128 | 
            +
                    self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                # modify vertices without toplogical change
         | 
| 131 | 
            +
                def updateVertices(self, x0):
         | 
| 132 | 
            +
                    self.vertices = np.copy(x0)
         | 
| 133 | 
            +
                    xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 134 | 
            +
                    xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 135 | 
            +
                    xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 136 | 
            +
                    self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 137 | 
            +
                    self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 138 | 
            +
                    self.volume, self.vols = self.surfVolume()
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def computeVertexArea(self):
         | 
| 141 | 
            +
                    # compute areas of faces and vertices
         | 
| 142 | 
            +
                    V = self.vertices
         | 
| 143 | 
            +
                    F = self.faces
         | 
| 144 | 
            +
                    nv = V.shape[0]
         | 
| 145 | 
            +
                    nf = F.shape[0]
         | 
| 146 | 
            +
                    AF = np.zeros([nf, 1])
         | 
| 147 | 
            +
                    AV = np.zeros([nv, 1])
         | 
| 148 | 
            +
                    for k in range(nf):
         | 
| 149 | 
            +
                        # determining if face is obtuse
         | 
| 150 | 
            +
                        x12 = V[F[k, 1], :] - V[F[k, 0], :]
         | 
| 151 | 
            +
                        x13 = V[F[k, 2], :] - V[F[k, 0], :]
         | 
| 152 | 
            +
                        n12 = np.sqrt((x12 ** 2).sum())
         | 
| 153 | 
            +
                        n13 = np.sqrt((x13 ** 2).sum())
         | 
| 154 | 
            +
                        c1 = (x12 * x13).sum() / (n12 * n13)
         | 
| 155 | 
            +
                        x23 = V[F[k, 2], :] - V[F[k, 1], :]
         | 
| 156 | 
            +
                        n23 = np.sqrt((x23 ** 2).sum())
         | 
| 157 | 
            +
                        # n23 = norm(x23) ;
         | 
| 158 | 
            +
                        c2 = -(x12 * x23).sum() / (n12 * n23)
         | 
| 159 | 
            +
                        c3 = (x13 * x23).sum() / (n13 * n23)
         | 
| 160 | 
            +
                        AF[k] = np.sqrt((np.cross(x12, x13) ** 2).sum()) / 2
         | 
| 161 | 
            +
                        if (c1 < 0):
         | 
| 162 | 
            +
                            # face obtuse at vertex 1
         | 
| 163 | 
            +
                            AV[F[k, 0]] += AF[k] / 2
         | 
| 164 | 
            +
                            AV[F[k, 1]] += AF[k] / 4
         | 
| 165 | 
            +
                            AV[F[k, 2]] += AF[k] / 4
         | 
| 166 | 
            +
                        elif (c2 < 0):
         | 
| 167 | 
            +
                            # face obuse at vertex 2
         | 
| 168 | 
            +
                            AV[F[k, 0]] += AF[k] / 4
         | 
| 169 | 
            +
                            AV[F[k, 1]] += AF[k] / 2
         | 
| 170 | 
            +
                            AV[F[k, 2]] += AF[k] / 4
         | 
| 171 | 
            +
                        elif (c3 < 0):
         | 
| 172 | 
            +
                            # face obtuse at vertex 3
         | 
| 173 | 
            +
                            AV[F[k, 0]] += AF[k] / 4
         | 
| 174 | 
            +
                            AV[F[k, 1]] += AF[k] / 4
         | 
| 175 | 
            +
                            AV[F[k, 2]] += AF[k] / 2
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            # non obtuse face
         | 
| 178 | 
            +
                            cot1 = c1 / np.sqrt(1 - c1 ** 2)
         | 
| 179 | 
            +
                            cot2 = c2 / np.sqrt(1 - c2 ** 2)
         | 
| 180 | 
            +
                            cot3 = c3 / np.sqrt(1 - c3 ** 2)
         | 
| 181 | 
            +
                            AV[F[k, 0]] += ((x12 ** 2).sum() * cot3 + (x13 ** 2).sum() * cot2) / 8
         | 
| 182 | 
            +
                            AV[F[k, 1]] += ((x12 ** 2).sum() * cot3 + (x23 ** 2).sum() * cot1) / 8
         | 
| 183 | 
            +
                            AV[F[k, 2]] += ((x13 ** 2).sum() * cot2 + (x23 ** 2).sum() * cot1) / 8
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    for k in range(nv):
         | 
| 186 | 
            +
                        if (np.fabs(AV[k]) < 1e-10):
         | 
| 187 | 
            +
                            print('Warning: vertex ', k, 'has no face; use removeIsolated')
         | 
| 188 | 
            +
                    # print('sum check area:', AF.sum(), AV.sum()
         | 
| 189 | 
            +
                    return AV, AF
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def computeVertexNormals(self):
         | 
| 192 | 
            +
                    self.computeCentersAreas()
         | 
| 193 | 
            +
                    normals = np.zeros(self.vertices.shape)
         | 
| 194 | 
            +
                    F = self.faces
         | 
| 195 | 
            +
                    for k in range(F.shape[0]):
         | 
| 196 | 
            +
                        normals[F[k, 0]] += self.surfel[k]
         | 
| 197 | 
            +
                        normals[F[k, 1]] += self.surfel[k]
         | 
| 198 | 
            +
                        normals[F[k, 2]] += self.surfel[k]
         | 
| 199 | 
            +
                    af = np.sqrt((normals ** 2).sum(axis=1))
         | 
| 200 | 
            +
                    # logging.info('min area = %.4f'%(af.min()))
         | 
| 201 | 
            +
                    normals /= af.reshape([self.vertices.shape[0], 1])
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return normals
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                # Computes edges from vertices/faces
         | 
| 206 | 
            +
                def getEdges(self):
         | 
| 207 | 
            +
                    self.edges = []
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    for k in range(self.faces.shape[0]):
         | 
| 210 | 
            +
                        for kj in (0, 1, 2):
         | 
| 211 | 
            +
                            u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                            if (u not in self.edges) & (u.reverse() not in self.edges):
         | 
| 214 | 
            +
                                self.edges.append(u)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.edgeFaces = []
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    for u in self.edges:
         | 
| 219 | 
            +
                        self.edgeFaces.append([])
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    for k in range(self.faces.shape[0]):
         | 
| 222 | 
            +
                        for kj in (0, 1, 2):
         | 
| 223 | 
            +
                            u = [self.faces[k, kj], self.faces[k, (kj + 1) % 3]]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                            if u in self.edges:
         | 
| 226 | 
            +
                                kk = self.edges.index(u)
         | 
| 227 | 
            +
                            else:
         | 
| 228 | 
            +
                                u.reverse()
         | 
| 229 | 
            +
                                kk = self.edges.index(u)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                            self.edgeFaces[kk].append(k)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    self.edges = np.int_(np.array(self.edges))
         | 
| 234 | 
            +
                    self.bdry = np.int_(np.zeros(self.edges.shape[0]))
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    for k in range(self.edges.shape[0]):
         | 
| 237 | 
            +
                        if len(self.edgeFaces[k]) < 2:
         | 
| 238 | 
            +
                            self.bdry[k] = 1
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                # computes the signed distance function in a small neighborhood of a shape
         | 
| 241 | 
            +
                def LocalSignedDistance(self, data, value):
         | 
| 242 | 
            +
                    d2 = 2 * np.array(data >= value) - 1
         | 
| 243 | 
            +
                    c2 = np.cumsum(d2, axis=0)
         | 
| 244 | 
            +
                    for j in range(2):
         | 
| 245 | 
            +
                        c2 = np.cumsum(c2, axis=j + 1)
         | 
| 246 | 
            +
                    (n0, n1, n2) = c2.shape
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    rad = 3
         | 
| 249 | 
            +
                    diam = 2 * rad + 1
         | 
| 250 | 
            +
                    (x, y, z) = np.mgrid[-rad:rad + 1, -rad:rad + 1, -rad:rad + 1]
         | 
| 251 | 
            +
                    cube = (x ** 2 + y ** 2 + z ** 2)
         | 
| 252 | 
            +
                    maxval = (diam) ** 3
         | 
| 253 | 
            +
                    s = 3.0 * rad ** 2
         | 
| 254 | 
            +
                    res = d2 * s
         | 
| 255 | 
            +
                    u = maxval * np.ones(c2.shape)
         | 
| 256 | 
            +
                    u[rad + 1:n0 - rad, rad + 1:n1 - rad, rad + 1:n2 - rad] = (c2[diam:n0, diam:n1, diam:n2]
         | 
| 257 | 
            +
                                                                               - c2[0:n0 - diam, diam:n1, diam:n2] - c2[diam:n0,
         | 
| 258 | 
            +
                                                                                                                     0:n1 - diam,
         | 
| 259 | 
            +
                                                                                                                     diam:n2] - c2[
         | 
| 260 | 
            +
                                                                                                                                diam:n0,
         | 
| 261 | 
            +
                                                                                                                                diam:n1,
         | 
| 262 | 
            +
                                                                                                                                0:n2 - diam]
         | 
| 263 | 
            +
                                                                               + c2[0:n0 - diam, 0:n1 - diam, diam:n2] + c2[diam:n0,
         | 
| 264 | 
            +
                                                                                                                         0:n1 - diam,
         | 
| 265 | 
            +
                                                                                                                         0:n2 - diam] + c2[
         | 
| 266 | 
            +
                                                                                                                                        0:n0 - diam,
         | 
| 267 | 
            +
                                                                                                                                        diam:n1,
         | 
| 268 | 
            +
                                                                                                                                        0:n2 - diam]
         | 
| 269 | 
            +
                                                                               - c2[0:n0 - diam, 0:n1 - diam, 0:n2 - diam])
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    I = np.nonzero(np.fabs(u) < maxval)
         | 
| 272 | 
            +
                    # print(len(I[0]))
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    for k in range(len(I[0])):
         | 
| 275 | 
            +
                        p = np.array((I[0][k], I[1][k], I[2][k]))
         | 
| 276 | 
            +
                        bmin = p - rad
         | 
| 277 | 
            +
                        bmax = p + rad + 1
         | 
| 278 | 
            +
                        # print(p, bmin, bmax)
         | 
| 279 | 
            +
                        if (d2[p[0], p[1], p[2]] > 0):
         | 
| 280 | 
            +
                            # print(u[p[0],p[1], p[2]])
         | 
| 281 | 
            +
                            # print(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]].sum())
         | 
| 282 | 
            +
                            res[p[0], p[1], p[2]] = min(
         | 
| 283 | 
            +
                                cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] < 0)]) - .25
         | 
| 284 | 
            +
                        else:
         | 
| 285 | 
            +
                            res[p[0], p[1], p[2]] = - min(
         | 
| 286 | 
            +
                                cube[np.nonzero(d2[bmin[0]:bmax[0], bmin[1]:bmax[1], bmin[2]:bmax[2]] > 0)]) - .25
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    return res
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                def toPolyData(self):
         | 
| 291 | 
            +
                    if gotVTK:
         | 
| 292 | 
            +
                        points = vtkPoints()
         | 
| 293 | 
            +
                        for k in range(self.vertices.shape[0]):
         | 
| 294 | 
            +
                            points.InsertNextPoint(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
         | 
| 295 | 
            +
                        polys = vtkCellArray()
         | 
| 296 | 
            +
                        for k in range(self.faces.shape[0]):
         | 
| 297 | 
            +
                            polys.InsertNextCell(3)
         | 
| 298 | 
            +
                            for kk in range(3):
         | 
| 299 | 
            +
                                polys.InsertCellPoint(self.faces[k, kk])
         | 
| 300 | 
            +
                        polydata = vtkPolyData()
         | 
| 301 | 
            +
                        polydata.SetPoints(points)
         | 
| 302 | 
            +
                        polydata.SetPolys(polys)
         | 
| 303 | 
            +
                        return polydata
         | 
| 304 | 
            +
                    else:
         | 
| 305 | 
            +
                        raise Exception('Cannot run toPolyData without VTK')
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                def fromPolyData(self, g, scales=[1., 1., 1.]):
         | 
| 308 | 
            +
                    npoints = int(g.GetNumberOfPoints())
         | 
| 309 | 
            +
                    nfaces = int(g.GetNumberOfPolys())
         | 
| 310 | 
            +
                    logging.info('Dimensions: %d %d %d' % (npoints, nfaces, g.GetNumberOfCells()))
         | 
| 311 | 
            +
                    V = np.zeros([npoints, 3])
         | 
| 312 | 
            +
                    for kk in range(npoints):
         | 
| 313 | 
            +
                        V[kk, :] = np.array(g.GetPoint(kk))
         | 
| 314 | 
            +
                        # print(kk, V[kk])
         | 
| 315 | 
            +
                        # print(kk, np.array(g.GetPoint(kk)))
         | 
| 316 | 
            +
                    F = np.zeros([nfaces, 3])
         | 
| 317 | 
            +
                    gf = 0
         | 
| 318 | 
            +
                    for kk in range(g.GetNumberOfCells()):
         | 
| 319 | 
            +
                        c = g.GetCell(kk)
         | 
| 320 | 
            +
                        if (c.GetNumberOfPoints() == 3):
         | 
| 321 | 
            +
                            for ll in range(3):
         | 
| 322 | 
            +
                                F[gf, ll] = c.GetPointId(ll)
         | 
| 323 | 
            +
                                # print(kk, gf, F[gf])
         | 
| 324 | 
            +
                            gf += 1
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                            # self.vertices = np.multiply(data.shape-V-1, scales)
         | 
| 327 | 
            +
                    self.vertices = np.multiply(V, scales)
         | 
| 328 | 
            +
                    self.faces = np.int_(F[0:gf, :])
         | 
| 329 | 
            +
                    self.computeCentersAreas()
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                def Simplify(self, target=1000.0):
         | 
| 332 | 
            +
                    if gotVTK:
         | 
| 333 | 
            +
                        polydata = self.toPolyData()
         | 
| 334 | 
            +
                        dc = vtkQuadricDecimation()
         | 
| 335 | 
            +
                        red = 1 - min(np.float(target) / polydata.GetNumberOfPoints(), 1)
         | 
| 336 | 
            +
                        dc.SetTargetReduction(red)
         | 
| 337 | 
            +
                        dc.SetInput(polydata)
         | 
| 338 | 
            +
                        dc.Update()
         | 
| 339 | 
            +
                        g = dc.GetOutput()
         | 
| 340 | 
            +
                        self.fromPolyData(g)
         | 
| 341 | 
            +
                        z = self.surfVolume()
         | 
| 342 | 
            +
                        if (z > 0):
         | 
| 343 | 
            +
                            self.flipFaces()
         | 
| 344 | 
            +
                            print('flipping volume', z, self.surfVolume())
         | 
| 345 | 
            +
                    else:
         | 
| 346 | 
            +
                        raise Exception('Cannot run Simplify without VTK')
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def flipFaces(self):
         | 
| 349 | 
            +
                    self.faces = self.faces[:, [0, 2, 1]]
         | 
| 350 | 
            +
                    self.computeCentersAreas()
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def smooth(self, n=30, smooth=0.1):
         | 
| 353 | 
            +
                    if gotVTK:
         | 
| 354 | 
            +
                        g = self.toPolyData()
         | 
| 355 | 
            +
                        print(g)
         | 
| 356 | 
            +
                        smoother = vtkWindowedSincPolyDataFilter()
         | 
| 357 | 
            +
                        smoother.SetInput(g)
         | 
| 358 | 
            +
                        smoother.SetNumberOfIterations(n)
         | 
| 359 | 
            +
                        smoother.SetPassBand(smooth)
         | 
| 360 | 
            +
                        smoother.NonManifoldSmoothingOn()
         | 
| 361 | 
            +
                        smoother.NormalizeCoordinatesOn()
         | 
| 362 | 
            +
                        smoother.GenerateErrorScalarsOn()
         | 
| 363 | 
            +
                        # smoother.GenerateErrorVectorsOn()
         | 
| 364 | 
            +
                        smoother.Update()
         | 
| 365 | 
            +
                        g = smoother.GetOutput()
         | 
| 366 | 
            +
                        self.fromPolyData(g)
         | 
| 367 | 
            +
                    else:
         | 
| 368 | 
            +
                        raise Exception('Cannot run smooth without VTK')
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                # Computes isosurfaces using vtk
         | 
| 371 | 
            +
                def Isosurface(self, data, value=0.5, target=1000.0, scales=[1., 1., 1.], smooth=0.1, fill_holes=1.):
         | 
| 372 | 
            +
                    if gotVTK:
         | 
| 373 | 
            +
                        # data = self.LocalSignedDistance(data0, value)
         | 
| 374 | 
            +
                        if isinstance(data, vtkImageData):
         | 
| 375 | 
            +
                            img = data
         | 
| 376 | 
            +
                        else:
         | 
| 377 | 
            +
                            img = vtkImageData()
         | 
| 378 | 
            +
                            img.SetDimensions(data.shape)
         | 
| 379 | 
            +
                            img.SetOrigin(0, 0, 0)
         | 
| 380 | 
            +
                            if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 381 | 
            +
                                img.AllocateScalars(VTK_FLOAT, 1)
         | 
| 382 | 
            +
                            else:
         | 
| 383 | 
            +
                                img.SetNumberOfScalarComponents(1)
         | 
| 384 | 
            +
                            v = vtkDoubleArray()
         | 
| 385 | 
            +
                            v.SetNumberOfValues(data.size)
         | 
| 386 | 
            +
                            v.SetNumberOfComponents(1)
         | 
| 387 | 
            +
                            for ii, tmp in enumerate(np.ravel(data, order='F')):
         | 
| 388 | 
            +
                                v.SetValue(ii, tmp)
         | 
| 389 | 
            +
                                img.GetPointData().SetScalars(v)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                        cf = vtkContourFilter()
         | 
| 392 | 
            +
                        if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 393 | 
            +
                            cf.SetInputData(img)
         | 
| 394 | 
            +
                        else:
         | 
| 395 | 
            +
                            cf.SetInput(img)
         | 
| 396 | 
            +
                        cf.SetValue(0, value)
         | 
| 397 | 
            +
                        cf.SetNumberOfContours(1)
         | 
| 398 | 
            +
                        cf.Update()
         | 
| 399 | 
            +
                        # print(cf
         | 
| 400 | 
            +
                        connectivity = vtkPolyDataConnectivityFilter()
         | 
| 401 | 
            +
                        connectivity.ScalarConnectivityOff()
         | 
| 402 | 
            +
                        connectivity.SetExtractionModeToLargestRegion()
         | 
| 403 | 
            +
                        if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 404 | 
            +
                            connectivity.SetInputData(cf.GetOutput())
         | 
| 405 | 
            +
                        else:
         | 
| 406 | 
            +
                            connectivity.SetInput(cf.GetOutput())
         | 
| 407 | 
            +
                        connectivity.Update()
         | 
| 408 | 
            +
                        g = connectivity.GetOutput()
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                        if smooth > 0:
         | 
| 411 | 
            +
                            smoother = vtkWindowedSincPolyDataFilter()
         | 
| 412 | 
            +
                            if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 413 | 
            +
                                smoother.SetInputData(g)
         | 
| 414 | 
            +
                            else:
         | 
| 415 | 
            +
                                smoother.SetInput(g)
         | 
| 416 | 
            +
                            #     else:
         | 
| 417 | 
            +
                            # smoother.SetInputConnection(contour.GetOutputPort())
         | 
| 418 | 
            +
                            smoother.SetNumberOfIterations(30)
         | 
| 419 | 
            +
                            # this has little effect on the error!
         | 
| 420 | 
            +
                            # smoother.BoundarySmoothingOff()
         | 
| 421 | 
            +
                            # smoother.FeatureEdgeSmoothingOff()
         | 
| 422 | 
            +
                            # smoother.SetFeatureAngle(120.0)
         | 
| 423 | 
            +
                            smoother.SetPassBand(smooth)  # this increases the error a lot!
         | 
| 424 | 
            +
                            smoother.NonManifoldSmoothingOn()
         | 
| 425 | 
            +
                            # smoother.NormalizeCoordinatesOn()
         | 
| 426 | 
            +
                            # smoother.GenerateErrorScalarsOn()
         | 
| 427 | 
            +
                            # smoother.GenerateErrorVectorsOn()
         | 
| 428 | 
            +
                            smoother.Update()
         | 
| 429 | 
            +
                            g = smoother.GetOutput()
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                        # dc = vtkDecimatePro()
         | 
| 432 | 
            +
                        red = 1 - min(np.float(target) / g.GetNumberOfPoints(), 1)
         | 
| 433 | 
            +
                        # print('Reduction: ', red)
         | 
| 434 | 
            +
                        dc = vtkQuadricDecimation()
         | 
| 435 | 
            +
                        dc.SetTargetReduction(red)
         | 
| 436 | 
            +
                        # dc.AttributeErrorMetricOn()
         | 
| 437 | 
            +
                        # dc.SetDegree(10)
         | 
| 438 | 
            +
                        # dc.SetSplitting(0)
         | 
| 439 | 
            +
                        if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 440 | 
            +
                            dc.SetInputData(g)
         | 
| 441 | 
            +
                        else:
         | 
| 442 | 
            +
                            dc.SetInput(g)
         | 
| 443 | 
            +
                            # dc.SetInput(g)
         | 
| 444 | 
            +
                        # print(dc)
         | 
| 445 | 
            +
                        dc.Update()
         | 
| 446 | 
            +
                        g = dc.GetOutput()
         | 
| 447 | 
            +
                        # print('points:', g.GetNumberOfPoints())
         | 
| 448 | 
            +
                        cp = vtkCleanPolyData()
         | 
| 449 | 
            +
                        if vtkVersion.GetVTKMajorVersion() >= 6:
         | 
| 450 | 
            +
                            cp.SetInputData(dc.GetOutput())
         | 
| 451 | 
            +
                        else:
         | 
| 452 | 
            +
                            cp.SetInput(dc.GetOutput())
         | 
| 453 | 
            +
                            #        cp.SetInput(dc.GetOutput())
         | 
| 454 | 
            +
                        # cp.SetPointMerging(1)
         | 
| 455 | 
            +
                        cp.ConvertPolysToLinesOn()
         | 
| 456 | 
            +
                        cp.SetAbsoluteTolerance(1e-5)
         | 
| 457 | 
            +
                        cp.Update()
         | 
| 458 | 
            +
                        g = cp.GetOutput()
         | 
| 459 | 
            +
                        self.fromPolyData(g, scales)
         | 
| 460 | 
            +
                        z = self.surfVolume()
         | 
| 461 | 
            +
                        if (z > 0):
         | 
| 462 | 
            +
                            self.flipFaces()
         | 
| 463 | 
            +
                            # print('flipping volume', z, self.surfVolume())
         | 
| 464 | 
            +
                            logging.info('flipping volume %.2f %.2f' % (z, self.surfVolume()))
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                        # print(g)
         | 
| 467 | 
            +
                        # npoints = int(g.GetNumberOfPoints())
         | 
| 468 | 
            +
                        # nfaces = int(g.GetNumberOfPolys())
         | 
| 469 | 
            +
                        # print('Dimensions:', npoints, nfaces, g.GetNumberOfCells())
         | 
| 470 | 
            +
                        # V = np.zeros([npoints, 3])
         | 
| 471 | 
            +
                        # for kk in range(npoints):
         | 
| 472 | 
            +
                        #     V[kk, :] = np.array(g.GetPoint(kk))
         | 
| 473 | 
            +
                        #     #print(kk, V[kk])
         | 
| 474 | 
            +
                        #     #print(kk, np.array(g.GetPoint(kk)))
         | 
| 475 | 
            +
                        # F = np.zeros([nfaces, 3])
         | 
| 476 | 
            +
                        # gf = 0
         | 
| 477 | 
            +
                        # for kk in range(g.GetNumberOfCells()):
         | 
| 478 | 
            +
                        #     c = g.GetCell(kk)
         | 
| 479 | 
            +
                        #     if(c.GetNumberOfPoints() == 3):
         | 
| 480 | 
            +
                        #         for ll in range(3):
         | 
| 481 | 
            +
                        #             F[gf,ll] = c.GetPointId(ll)
         | 
| 482 | 
            +
                        #             #print(kk, gf, F[gf])
         | 
| 483 | 
            +
                        #         gf += 1
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        #         #self.vertices = np.multiply(data.shape-V-1, scales)
         | 
| 486 | 
            +
                        # self.vertices = np.multiply(V, scales)
         | 
| 487 | 
            +
                        # self.faces = np.int_(F[0:gf, :])
         | 
| 488 | 
            +
                        # self.computeCentersAreas()
         | 
| 489 | 
            +
                    else:
         | 
| 490 | 
            +
                        raise Exception('Cannot run Isosurface without VTK')
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                # Ensures that orientation is correct
         | 
| 493 | 
            +
                def edgeRecover(self):
         | 
| 494 | 
            +
                    v = self.vertices
         | 
| 495 | 
            +
                    f = self.faces
         | 
| 496 | 
            +
                    nv = v.shape[0]
         | 
| 497 | 
            +
                    nf = f.shape[0]
         | 
| 498 | 
            +
                    # faces containing each oriented edge
         | 
| 499 | 
            +
                    edg0 = np.int_(np.zeros((nv, nv)))
         | 
| 500 | 
            +
                    # number of edges between each vertex
         | 
| 501 | 
            +
                    edg = np.int_(np.zeros((nv, nv)))
         | 
| 502 | 
            +
                    # contiguous faces
         | 
| 503 | 
            +
                    edgF = np.int_(np.zeros((nf, nf)))
         | 
| 504 | 
            +
                    for (kf, c) in enumerate(f):
         | 
| 505 | 
            +
                        if (edg0[c[0], c[1]] > 0):
         | 
| 506 | 
            +
                            edg0[c[1], c[0]] = kf + 1
         | 
| 507 | 
            +
                        else:
         | 
| 508 | 
            +
                            edg0[c[0], c[1]] = kf + 1
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                        if (edg0[c[1], c[2]] > 0):
         | 
| 511 | 
            +
                            edg0[c[2], c[1]] = kf + 1
         | 
| 512 | 
            +
                        else:
         | 
| 513 | 
            +
                            edg0[c[1], c[2]] = kf + 1
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                        if (edg0[c[2], c[0]] > 0):
         | 
| 516 | 
            +
                            edg0[c[0], c[2]] = kf + 1
         | 
| 517 | 
            +
                        else:
         | 
| 518 | 
            +
                            edg0[c[2], c[0]] = kf + 1
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                        edg[c[0], c[1]] += 1
         | 
| 521 | 
            +
                        edg[c[1], c[2]] += 1
         | 
| 522 | 
            +
                        edg[c[2], c[0]] += 1
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    for kv in range(nv):
         | 
| 525 | 
            +
                        I2 = np.nonzero(edg0[kv, :])
         | 
| 526 | 
            +
                        for kkv in I2[0].tolist():
         | 
| 527 | 
            +
                            edgF[edg0[kkv, kv] - 1, edg0[kv, kkv] - 1] = kv + 1
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    isOriented = np.int_(np.zeros(f.shape[0]))
         | 
| 530 | 
            +
                    isActive = np.int_(np.zeros(f.shape[0]))
         | 
| 531 | 
            +
                    I = np.nonzero(np.squeeze(edgF[0, :]))
         | 
| 532 | 
            +
                    # list of faces to be oriented
         | 
| 533 | 
            +
                    activeList = [0] + I[0].tolist()
         | 
| 534 | 
            +
                    lastOriented = 0
         | 
| 535 | 
            +
                    isOriented[0] = True
         | 
| 536 | 
            +
                    for k in activeList:
         | 
| 537 | 
            +
                        isActive[k] = True
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    while lastOriented < len(activeList) - 1:
         | 
| 540 | 
            +
                        i = activeList[lastOriented]
         | 
| 541 | 
            +
                        j = activeList[lastOriented + 1]
         | 
| 542 | 
            +
                        I = np.nonzero(edgF[j, :])
         | 
| 543 | 
            +
                        foundOne = False
         | 
| 544 | 
            +
                        for kk in I[0].tolist():
         | 
| 545 | 
            +
                            if (foundOne == False) & (isOriented[kk]):
         | 
| 546 | 
            +
                                foundOne = True
         | 
| 547 | 
            +
                                u1 = edgF[j, kk] - 1
         | 
| 548 | 
            +
                                u2 = edgF[kk, j] - 1
         | 
| 549 | 
            +
                                if not ((edg[u1, u2] == 1) & (edg[u2, u1] == 1)):
         | 
| 550 | 
            +
                                    # reorient face j
         | 
| 551 | 
            +
                                    edg[f[j, 0], f[j, 1]] -= 1
         | 
| 552 | 
            +
                                    edg[f[j, 1], f[j, 2]] -= 1
         | 
| 553 | 
            +
                                    edg[f[j, 2], f[j, 0]] -= 1
         | 
| 554 | 
            +
                                    a = f[j, 1]
         | 
| 555 | 
            +
                                    f[j, 1] = f[j, 2]
         | 
| 556 | 
            +
                                    f[j, 2] = a
         | 
| 557 | 
            +
                                    edg[f[j, 0], f[j, 1]] += 1
         | 
| 558 | 
            +
                                    edg[f[j, 1], f[j, 2]] += 1
         | 
| 559 | 
            +
                                    edg[f[j, 2], f[j, 0]] += 1
         | 
| 560 | 
            +
                            elif (not isActive[kk]):
         | 
| 561 | 
            +
                                activeList.append(kk)
         | 
| 562 | 
            +
                                isActive[kk] = True
         | 
| 563 | 
            +
                        if foundOne:
         | 
| 564 | 
            +
                            lastOriented = lastOriented + 1
         | 
| 565 | 
            +
                            isOriented[j] = True
         | 
| 566 | 
            +
                            # print('oriented face', j, lastOriented,  'out of',  nf,  ';  total active', len(activeList))
         | 
| 567 | 
            +
                        else:
         | 
| 568 | 
            +
                            print('Unable to orient face', j)
         | 
| 569 | 
            +
                            return
         | 
| 570 | 
            +
                    self.vertices = v;
         | 
| 571 | 
            +
                    self.faces = f;
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    z, _ = self.surfVolume()
         | 
| 574 | 
            +
                    if (z > 0):
         | 
| 575 | 
            +
                        self.flipFaces()
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                def removeIsolated(self):
         | 
| 578 | 
            +
                    N = self.vertices.shape[0]
         | 
| 579 | 
            +
                    inFace = np.int_(np.zeros(N))
         | 
| 580 | 
            +
                    for k in range(3):
         | 
| 581 | 
            +
                        inFace[self.faces[:, k]] = 1
         | 
| 582 | 
            +
                    J = np.nonzero(inFace)
         | 
| 583 | 
            +
                    self.vertices = self.vertices[J[0], :]
         | 
| 584 | 
            +
                    logging.info('Found %d isolated vertices' % (J[0].shape[0]))
         | 
| 585 | 
            +
                    Q = -np.ones(N)
         | 
| 586 | 
            +
                    for k, j in enumerate(J[0]):
         | 
| 587 | 
            +
                        Q[j] = k
         | 
| 588 | 
            +
                    self.faces = np.int_(Q[self.faces])
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                def laplacianMatrix(self):
         | 
| 591 | 
            +
                    F = self.faces
         | 
| 592 | 
            +
                    V = self.vertices;
         | 
| 593 | 
            +
                    nf = F.shape[0]
         | 
| 594 | 
            +
                    nv = V.shape[0]
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    AV, AF = self.computeVertexArea()
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    # compute edges and detect boundary
         | 
| 599 | 
            +
                    # edm = sp.lil_matrix((nv,nv))
         | 
| 600 | 
            +
                    edm = -np.ones([nv, nv]).astype(np.int32)
         | 
| 601 | 
            +
                    E = np.zeros([3 * nf, 2]).astype(np.int32)
         | 
| 602 | 
            +
                    j = 0
         | 
| 603 | 
            +
                    for k in range(nf):
         | 
| 604 | 
            +
                        if (edm[F[k, 0], F[k, 1]] == -1):
         | 
| 605 | 
            +
                            edm[F[k, 0], F[k, 1]] = j
         | 
| 606 | 
            +
                            edm[F[k, 1], F[k, 0]] = j
         | 
| 607 | 
            +
                            E[j, :] = [F[k, 0], F[k, 1]]
         | 
| 608 | 
            +
                            j = j + 1
         | 
| 609 | 
            +
                        if (edm[F[k, 1], F[k, 2]] == -1):
         | 
| 610 | 
            +
                            edm[F[k, 1], F[k, 2]] = j
         | 
| 611 | 
            +
                            edm[F[k, 2], F[k, 1]] = j
         | 
| 612 | 
            +
                            E[j, :] = [F[k, 1], F[k, 2]]
         | 
| 613 | 
            +
                            j = j + 1
         | 
| 614 | 
            +
                        if (edm[F[k, 0], F[k, 2]] == -1):
         | 
| 615 | 
            +
                            edm[F[k, 2], F[k, 0]] = j
         | 
| 616 | 
            +
                            edm[F[k, 0], F[k, 2]] = j
         | 
| 617 | 
            +
                            E[j, :] = [F[k, 2], F[k, 0]]
         | 
| 618 | 
            +
                            j = j + 1
         | 
| 619 | 
            +
                    E = E[0:j, :]
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    edgeFace = np.zeros([j, nf])
         | 
| 622 | 
            +
                    ne = j
         | 
| 623 | 
            +
                    # print(E)
         | 
| 624 | 
            +
                    for k in range(nf):
         | 
| 625 | 
            +
                        edgeFace[edm[F[k, 0], F[k, 1]], k] = 1
         | 
| 626 | 
            +
                        edgeFace[edm[F[k, 1], F[k, 2]], k] = 1
         | 
| 627 | 
            +
                        edgeFace[edm[F[k, 2], F[k, 0]], k] = 1
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    bEdge = np.zeros([ne, 1])
         | 
| 630 | 
            +
                    bVert = np.zeros([nv, 1])
         | 
| 631 | 
            +
                    edgeAngles = np.zeros([ne, 2])
         | 
| 632 | 
            +
                    for k in range(ne):
         | 
| 633 | 
            +
                        I = np.flatnonzero(edgeFace[k, :])
         | 
| 634 | 
            +
                        # print('I=', I, F[I, :], E.shape)
         | 
| 635 | 
            +
                        # print('E[k, :]=', k, E[k, :])
         | 
| 636 | 
            +
                        # print(k, edgeFace[k, :])
         | 
| 637 | 
            +
                        for u in range(len(I)):
         | 
| 638 | 
            +
                            f = I[u]
         | 
| 639 | 
            +
                            i1l = np.flatnonzero(F[f, :] == E[k, 0])
         | 
| 640 | 
            +
                            i2l = np.flatnonzero(F[f, :] == E[k, 1])
         | 
| 641 | 
            +
                            # print(f, F[f, :])
         | 
| 642 | 
            +
                            # print(i1l, i2l)
         | 
| 643 | 
            +
                            i1 = i1l[0]
         | 
| 644 | 
            +
                            i2 = i2l[0]
         | 
| 645 | 
            +
                            s = i1 + i2
         | 
| 646 | 
            +
                            if s == 1:
         | 
| 647 | 
            +
                                i3 = 2
         | 
| 648 | 
            +
                            elif s == 2:
         | 
| 649 | 
            +
                                i3 = 1
         | 
| 650 | 
            +
                            elif s == 3:
         | 
| 651 | 
            +
                                i3 = 0
         | 
| 652 | 
            +
                            x1 = V[F[f, i1], :] - V[F[f, i3], :]
         | 
| 653 | 
            +
                            x2 = V[F[f, i2], :] - V[F[f, i3], :]
         | 
| 654 | 
            +
                            a = (np.cross(x1, x2) * np.cross(V[F[f, 1], :] - V[F[f, 0], :], V[F[f, 2], :] - V[F[f, 0], :])).sum()
         | 
| 655 | 
            +
                            b = (x1 * x2).sum()
         | 
| 656 | 
            +
                            if (a > 0):
         | 
| 657 | 
            +
                                edgeAngles[k, u] = b / np.sqrt(a)
         | 
| 658 | 
            +
                            else:
         | 
| 659 | 
            +
                                edgeAngles[k, u] = b / np.sqrt(-a)
         | 
| 660 | 
            +
                        if (len(I) == 1):
         | 
| 661 | 
            +
                            # boundary edge
         | 
| 662 | 
            +
                            bEdge[k] = 1
         | 
| 663 | 
            +
                            bVert[E[k, 0]] = 1
         | 
| 664 | 
            +
                            bVert[E[k, 1]] = 1
         | 
| 665 | 
            +
                            edgeAngles[k, 1] = 0
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                            # Compute Laplacian matrix
         | 
| 668 | 
            +
                    L = np.zeros([nv, nv])
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    for k in range(ne):
         | 
| 671 | 
            +
                        L[E[k, 0], E[k, 1]] = (edgeAngles[k, 0] + edgeAngles[k, 1]) / 2
         | 
| 672 | 
            +
                        L[E[k, 1], E[k, 0]] = L[E[k, 0], E[k, 1]]
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    for k in range(nv):
         | 
| 675 | 
            +
                        L[k, k] = - L[k, :].sum()
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                    A = np.zeros([nv, nv])
         | 
| 678 | 
            +
                    for k in range(nv):
         | 
| 679 | 
            +
                        A[k, k] = AV[k]
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    return L, A
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                def graphLaplacianMatrix(self):
         | 
| 684 | 
            +
                    F = self.faces
         | 
| 685 | 
            +
                    V = self.vertices
         | 
| 686 | 
            +
                    nf = F.shape[0]
         | 
| 687 | 
            +
                    nv = V.shape[0]
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    # compute edges and detect boundary
         | 
| 690 | 
            +
                    # edm = sp.lil_matrix((nv,nv))
         | 
| 691 | 
            +
                    edm = -np.ones([nv, nv])
         | 
| 692 | 
            +
                    E = np.zeros([3 * nf, 2])
         | 
| 693 | 
            +
                    j = 0
         | 
| 694 | 
            +
                    for k in range(nf):
         | 
| 695 | 
            +
                        if (edm[F[k, 0], F[k, 1]] == -1):
         | 
| 696 | 
            +
                            edm[F[k, 0], F[k, 1]] = j
         | 
| 697 | 
            +
                            edm[F[k, 1], F[k, 0]] = j
         | 
| 698 | 
            +
                            E[j, :] = [F[k, 0], F[k, 1]]
         | 
| 699 | 
            +
                            j = j + 1
         | 
| 700 | 
            +
                        if (edm[F[k, 1], F[k, 2]] == -1):
         | 
| 701 | 
            +
                            edm[F[k, 1], F[k, 2]] = j
         | 
| 702 | 
            +
                            edm[F[k, 2], F[k, 1]] = j
         | 
| 703 | 
            +
                            E[j, :] = [F[k, 1], F[k, 2]]
         | 
| 704 | 
            +
                            j = j + 1
         | 
| 705 | 
            +
                        if (edm[F[k, 0], F[k, 2]] == -1):
         | 
| 706 | 
            +
                            edm[F[k, 2], F[k, 0]] = j
         | 
| 707 | 
            +
                            edm[F[k, 0], F[k, 2]] = j
         | 
| 708 | 
            +
                            E[j, :] = [F[k, 2], F[k, 0]]
         | 
| 709 | 
            +
                            j = j + 1
         | 
| 710 | 
            +
                    E = E[0:j, :]
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    edgeFace = np.zeros([j, nf])
         | 
| 713 | 
            +
                    ne = j
         | 
| 714 | 
            +
                    # print(E)
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    # Compute Laplacian matrix
         | 
| 717 | 
            +
                    L = np.zeros([nv, nv])
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    for k in range(ne):
         | 
| 720 | 
            +
                        L[E[k, 0], E[k, 1]] = 1
         | 
| 721 | 
            +
                        L[E[k, 1], E[k, 0]] = 1
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    for k in range(nv):
         | 
| 724 | 
            +
                        L[k, k] = - L[k, :].sum()
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    return L
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                def cotanLaplacian(self, eps=1e-8):
         | 
| 729 | 
            +
                    L = pp3d.cotan_laplacian(self.vertices, self.faces, denom_eps=1e-10)
         | 
| 730 | 
            +
                    massvec_np = pp3d.vertex_areas(self.vertices, self.faces)
         | 
| 731 | 
            +
                    massvec_np += eps * np.mean(massvec_np)
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                    if (np.isnan(L.data).any()):
         | 
| 734 | 
            +
                        raise RuntimeError("NaN Laplace matrix")
         | 
| 735 | 
            +
                    if (np.isnan(massvec_np).any()):
         | 
| 736 | 
            +
                        raise RuntimeError("NaN mass matrix")
         | 
| 737 | 
            +
                    self.L = L 
         | 
| 738 | 
            +
                    self.massvec_np = massvec_np
         | 
| 739 | 
            +
                    return L, massvec_np
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                def lapEigen(self, n_eig=10, eps=1e-8):
         | 
| 742 | 
            +
                    L_eigsh = (self.L + sp.sparse.identity(self.L.shape[0]) * eps).tocsc()
         | 
| 743 | 
            +
                    massvec_eigsh = self.massvec_np
         | 
| 744 | 
            +
                    Mmat = sp.sparse.diags(massvec_eigsh)
         | 
| 745 | 
            +
                    eigs_sigma = eps#-0.01
         | 
| 746 | 
            +
                    evals, evecs = sla.eigsh(L_eigsh, k=n_eig, M=Mmat, sigma=eigs_sigma)
         | 
| 747 | 
            +
                    return evals, evecs
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                def laplacianSegmentation(self, k, verbose=False):
         | 
| 750 | 
            +
                    # (L, AA) = self.laplacianMatrix()
         | 
| 751 | 
            +
                    # # print((L.shape[0]-k-1, L.shape[0]-2))
         | 
| 752 | 
            +
                    # (D, y) = spLA.eigh(L, AA, eigvals=(L.shape[0] - k, L.shape[0] - 1))
         | 
| 753 | 
            +
                    evals, evecs = self.lapEigen(k*2)
         | 
| 754 | 
            +
                    y = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) #**2
         | 
| 755 | 
            +
                    # N = y.shape[0]
         | 
| 756 | 
            +
                    # d = y.shape[1]
         | 
| 757 | 
            +
                    # I = np.argsort(y.sum(axis=1))
         | 
| 758 | 
            +
                    # I0 = np.floor((N - 1) * sp.linspace(0, 1, num=k)).astype(int)
         | 
| 759 | 
            +
                    # # print(y.shape, L.shape, N, k, d)
         | 
| 760 | 
            +
                    # C = y[I0, :].copy()
         | 
| 761 | 
            +
                    #
         | 
| 762 | 
            +
                    # eps = 1e-20
         | 
| 763 | 
            +
                    # Cold = C.copy()
         | 
| 764 | 
            +
                    # u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
         | 
| 765 | 
            +
                    # T = u.min(axis=0).sum() / (N)
         | 
| 766 | 
            +
                    # # print(T)
         | 
| 767 | 
            +
                    # j = 0
         | 
| 768 | 
            +
                    # while j < 5000:
         | 
| 769 | 
            +
                    #     u0 = u - u.min(axis=0).reshape([1, N])
         | 
| 770 | 
            +
                    #     w = np.exp(-u0 / T);
         | 
| 771 | 
            +
                    #     w = w / (eps + w.sum(axis=0).reshape([1, N]))
         | 
| 772 | 
            +
                    #     # print(w.min(), w.max())
         | 
| 773 | 
            +
                    #     cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
         | 
| 774 | 
            +
                    #     C = np.dot(w, y) / (eps + w.sum(axis=1).reshape([k, 1]))
         | 
| 775 | 
            +
                    #     # print(j, 'cost0 ', cost)
         | 
| 776 | 
            +
                    #
         | 
| 777 | 
            +
                    #     u = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
         | 
| 778 | 
            +
                    #     cost = (u * w).sum() + T * (w * np.log(w + eps)).sum()
         | 
| 779 | 
            +
                    #     err = np.sqrt(((C - Cold) ** 2).sum(axis=1)).sum()
         | 
| 780 | 
            +
                    #     # print(j, 'cost ', cost, err, T)
         | 
| 781 | 
            +
                    #     if (j > 100) & (err < 1e-4):
         | 
| 782 | 
            +
                    #         break
         | 
| 783 | 
            +
                    #     j = j + 1
         | 
| 784 | 
            +
                    #     Cold = C.copy()
         | 
| 785 | 
            +
                    #     T = T * 0.99
         | 
| 786 | 
            +
                    #
         | 
| 787 | 
            +
                    #     # print(k, d, C.shape)
         | 
| 788 | 
            +
                    # dst = ((C.reshape([k, 1, d]) - y.reshape([1, N, d])) ** 2).sum(axis=2)
         | 
| 789 | 
            +
                    # md = dst.min(axis=0)
         | 
| 790 | 
            +
                    # idx = np.zeros(N).astype(int)
         | 
| 791 | 
            +
                    # for j in range(N):
         | 
| 792 | 
            +
                    #     I = np.flatnonzero(dst[:, j] < md[j] + 1e-10)
         | 
| 793 | 
            +
                    #     idx[j] = I[0]
         | 
| 794 | 
            +
                    # I = -np.ones(k).astype(int)
         | 
| 795 | 
            +
                    # kk = 0
         | 
| 796 | 
            +
                    # for j in range(k):
         | 
| 797 | 
            +
                    #     if True in (idx == j):
         | 
| 798 | 
            +
                    #         I[j] = kk
         | 
| 799 | 
            +
                    #         kk += 1
         | 
| 800 | 
            +
                    # idx = I[idx]
         | 
| 801 | 
            +
                    # if idx.max() < (k - 1):
         | 
| 802 | 
            +
                    #     print('Warning: kmeans convergence with %d clusters instead of %d' % (idx.max(), k))
         | 
| 803 | 
            +
                    #     # ml = w.sum(axis=1)/N
         | 
| 804 | 
            +
                    kmeans = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(y)
         | 
| 805 | 
            +
                    # kmeans = DBSCAN().fit(y)
         | 
| 806 | 
            +
                    idx = kmeans.labels_
         | 
| 807 | 
            +
                    nc = idx.max() + 1
         | 
| 808 | 
            +
                    C = np.zeros([nc, self.vertices.shape[1]])
         | 
| 809 | 
            +
                    a, foo = self.computeVertexArea()
         | 
| 810 | 
            +
                    for k in range(nc):
         | 
| 811 | 
            +
                        I = np.flatnonzero(idx == k)
         | 
| 812 | 
            +
                        nI = len(I)
         | 
| 813 | 
            +
                        # print(a.shape, nI)
         | 
| 814 | 
            +
                        aI = a[I]
         | 
| 815 | 
            +
                        ak = aI.sum()
         | 
| 816 | 
            +
                        C[k, :] = (self.vertices[I, :] * aI).sum(axis=0) / ak;
         | 
| 817 | 
            +
                    mean_eigen = (y*a).sum(axis=0)/a
         | 
| 818 | 
            +
                    tree = cKDTree(y)
         | 
| 819 | 
            +
                    _, indices = tree.query(mean_eigen, k=1)
         | 
| 820 | 
            +
                    index_center = indices[0]
         | 
| 821 | 
            +
                    _, indices = tree.query(C, k=1)
         | 
| 822 | 
            +
                    index_C = indices[0]
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                    mesh = tm.Trimesh(vertices=self.vertices, faces=self.faces)
         | 
| 825 | 
            +
                    nv = self.computeVertexNormals()
         | 
| 826 | 
            +
                    rori = self.vertices[index] + 0.001 * (-nv[index, :])
         | 
| 827 | 
            +
                    rdir = -nv[index, :]
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                    locations, index_ray, index_tri = mesh.ray.intersects_location(ray_origins=rori[None, :], ray_directions=rdir[None, :])
         | 
| 830 | 
            +
                    if verbose:
         | 
| 831 | 
            +
                        print(locations, index_ray, index_tri)
         | 
| 832 | 
            +
                        print(index_center, index_C)
         | 
| 833 | 
            +
                    return idx, C, index_center, locations[0]
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                def get_keypoints(self, n_eig=30, n_points=5, sym=False):
         | 
| 836 | 
            +
                    evals, evecs = self.lapEigen(n_eig+1)
         | 
| 837 | 
            +
                    pts_evecs = evecs[:, 1:] * np.exp(-2*0.1 * evals[1:]) # approximate geod distance
         | 
| 838 | 
            +
                    tree = cKDTree(pts_evecs)
         | 
| 839 | 
            +
                    _, center_index = tree.query(np.zeros(pts_evecs.shape[-1]), k=1)
         | 
| 840 | 
            +
                    indices = fps_gpt_geod(self.vertices, self.faces, 10, center_index)[1:]
         | 
| 841 | 
            +
                    areas = np.linalg.norm(self.surfel, axis=-1, keepdims=True)
         | 
| 842 | 
            +
                    area = np.sqrt(areas.sum()/2)
         | 
| 843 | 
            +
                    print(area)
         | 
| 844 | 
            +
                    ## Looking for center + provide distances to the center
         | 
| 845 | 
            +
                    
         | 
| 846 | 
            +
                    print(center_index)
         | 
| 847 | 
            +
                    solver = pp3d.MeshHeatMethodDistanceSolver(self.vertices, self.faces)
         | 
| 848 | 
            +
                    norm_center = solver.compute_distance(center_index)
         | 
| 849 | 
            +
                
         | 
| 850 | 
            +
                    ## Dist matrix just between selected points
         | 
| 851 | 
            +
                    dist_mat_indices = np.zeros((len(indices), len(indices)))
         | 
| 852 | 
            +
                    all_ii = np.arange(len(indices)) # just to easy code
         | 
| 853 | 
            +
                    for ii, index in enumerate(indices):
         | 
| 854 | 
            +
                        dist_mat_indices[ii, ii!=all_ii] = solver.compute_distance(index)[indices[indices != index]]
         | 
| 855 | 
            +
                
         | 
| 856 | 
            +
                    ## Select points which : farthest from the center, delete their neighbors with dist < 0.5*area
         | 
| 857 | 
            +
                    keep = []
         | 
| 858 | 
            +
                    max_norm = norm_center.max()
         | 
| 859 | 
            +
                    for ii, index in enumerate(indices[:n_points]):
         | 
| 860 | 
            +
                        # distances_compar = (dist_mat_indices[ii, ii!=all_ii] + norm_center[ii]) / norm_center[indices[ii!=all_ii]]
         | 
| 861 | 
            +
                        # min_dist = np.amin(np.abs(distances_compar - 1))
         | 
| 862 | 
            +
                        # print(min_dist/norm_center[ii], min_dist)
         | 
| 863 | 
            +
                        # if min_dist > 0.3:
         | 
| 864 | 
            +
                        #     #print(ii, min_dist/norm_center[ii], norm_center[ii], dist_mat_indices[ii, ii!=all_ii], norm_center[indices[indices != index]], indices[7])
         | 
| 865 | 
            +
                        keep.append(index)
         | 
| 866 | 
            +
                    ## Add the index of the center
         | 
| 867 | 
            +
                    ## Add the index of the center 
         | 
| 868 | 
            +
                    if sym:
         | 
| 869 | 
            +
                        mesh = tm.base.Trimesh(self.vertices, self.faces, process=False)  # Load your mesh
         | 
| 870 | 
            +
                        # Given data
         | 
| 871 | 
            +
                        index = 123  # Your starting point index
         | 
| 872 | 
            +
                        direction = -mesh.vertex_normals[center_index]
         | 
| 873 | 
            +
                        
         | 
| 874 | 
            +
                        # Get the starting vertex position
         | 
| 875 | 
            +
                        start_point = mesh.vertices[center_index]
         | 
| 876 | 
            +
                        
         | 
| 877 | 
            +
                        # Perform ray intersection
         | 
| 878 | 
            +
                        locations, index_ray, index_tri = mesh.ray.intersects_location(
         | 
| 879 | 
            +
                            ray_origins=start_point[None, :],  # Starting point
         | 
| 880 | 
            +
                            ray_directions=direction[None, :]  # Direction of travel
         | 
| 881 | 
            +
                        )
         | 
| 882 | 
            +
                        
         | 
| 883 | 
            +
                        # If intersections exist
         | 
| 884 | 
            +
                        if len(locations) > 0:
         | 
| 885 | 
            +
                            # Sort intersections by distance from start_point
         | 
| 886 | 
            +
                            distances = np.linalg.norm(locations - start_point, axis=1)
         | 
| 887 | 
            +
                            sorted_indices = np.argsort(distances)
         | 
| 888 | 
            +
                        
         | 
| 889 | 
            +
                            # Get the first valid intersection point beyond the start point
         | 
| 890 | 
            +
                            #next_intersection = locations[sorted_indices]
         | 
| 891 | 
            +
                            #print(f"Next intersection at: {next_intersection}")
         | 
| 892 | 
            +
                            intersected_triangle = index_tri[sorted_indices[1]]
         | 
| 893 | 
            +
                            print(f"Intersected Triangle Index: {mesh.faces[intersected_triangle]}")
         | 
| 894 | 
            +
                            center_index = mesh.faces[intersected_triangle][0]
         | 
| 895 | 
            +
                        else:
         | 
| 896 | 
            +
                            print("No intersection found in the given direction.")
         | 
| 897 | 
            +
                    keep.append(center_index)
         | 
| 898 | 
            +
                    return keep
         | 
| 899 | 
            +
             | 
| 900 | 
            +
             | 
| 901 | 
            +
                # Computes surface volume
         | 
| 902 | 
            +
                def surfVolume(self):
         | 
| 903 | 
            +
                    f = self.faces
         | 
| 904 | 
            +
                    v = self.vertices
         | 
| 905 | 
            +
                    t = v[f, :]
         | 
| 906 | 
            +
                    vols = np.linalg.det(t) / 6
         | 
| 907 | 
            +
                    return vols.sum(), vols
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                def surfCenter(self):
         | 
| 910 | 
            +
                    f = self.faces
         | 
| 911 | 
            +
                    v = self.vertices
         | 
| 912 | 
            +
                    center_infs = (v[f, :].sum(axis=1) / 4) * self.vols[:, np.newaxis]
         | 
| 913 | 
            +
                    center = center_infs.sum(axis=0)
         | 
| 914 | 
            +
                    return center
         | 
| 915 | 
            +
             | 
| 916 | 
            +
                def surfMoments(self):
         | 
| 917 | 
            +
                    f = self.faces
         | 
| 918 | 
            +
                    v = self.vertices
         | 
| 919 | 
            +
                    vec_0 = v[f[:, 0], :] + v[f[:, 1], :]
         | 
| 920 | 
            +
                    s_0 = vec_0[:, :, np.newaxis] * vec_0[:, np.newaxis, :]
         | 
| 921 | 
            +
                    vec_1 = v[f[:, 0], :] + v[f[:, 2], :]
         | 
| 922 | 
            +
                    s_1 = vec_1[:, :, np.newaxis] * vec_1[:, np.newaxis, :]
         | 
| 923 | 
            +
                    vec_2 = v[f[:, 1], :] + v[f[:, 2], :]
         | 
| 924 | 
            +
                    s_2 = vec_2[:, :, np.newaxis] * vec_2[:, np.newaxis, :]
         | 
| 925 | 
            +
                    moments_inf = self.vols[:, np.newaxis, np.newaxis] * (1. / 20) * (s_0 + s_1 + s_2)
         | 
| 926 | 
            +
                    return moments_inf.sum(axis=0)
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                def surfF(self):
         | 
| 929 | 
            +
                    f = self.faces
         | 
| 930 | 
            +
                    v = self.vertices
         | 
| 931 | 
            +
                    cent = (v[f, :].sum(axis=1)) / 4.
         | 
| 932 | 
            +
                    F = self.vols[:, np.newaxis] * np.sign(cent) * (cent ** 2)
         | 
| 933 | 
            +
                    return F.sum(axis=0)
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                def surfU(self):
         | 
| 936 | 
            +
                    vol = self.volume
         | 
| 937 | 
            +
                    vertices = self.vertices / pow(vol, 1. / 3)
         | 
| 938 | 
            +
                    vertices -= self.center
         | 
| 939 | 
            +
                    self.updateVertices(vertices)
         | 
| 940 | 
            +
                    moments = self.surfMoments()
         | 
| 941 | 
            +
                    u, s, vh = np.linalg.svd(moments)
         | 
| 942 | 
            +
                    F = self.surfF()
         | 
| 943 | 
            +
                    return np.diag(np.sign(F)) @ u, s
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                def surfEllipsoid(self, u, s, moments):
         | 
| 946 | 
            +
                    coeff = pow(4 * np.pi / 15, 1. / 5) * pow(np.linalg.det(moments), -1. / 10)
         | 
| 947 | 
            +
                    A = coeff * ((u * np.sqrt(s)) @ u.T)
         | 
| 948 | 
            +
                    return u, A
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                # Reads from .off file
         | 
| 951 | 
            +
                def readOFF(self, offfile):
         | 
| 952 | 
            +
                    with open(offfile, 'r') as f:
         | 
| 953 | 
            +
                        all_lines = f.readlines()
         | 
| 954 | 
            +
                
         | 
| 955 | 
            +
                
         | 
| 956 | 
            +
                    n_vertices = int(all_lines[1].split()[0])
         | 
| 957 | 
            +
                    n_faces = int(all_lines[1].split()[1])
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                    vertices_list = []
         | 
| 960 | 
            +
                    for i in range(n_vertices):
         | 
| 961 | 
            +
                        vertices_list.append([float(x) for x in all_lines[2+i].split()[:3]])
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                    faces_list = []
         | 
| 964 | 
            +
                    for i in range(n_faces):
         | 
| 965 | 
            +
                        # Be careful to convert to int. Otherwise, you can use np.array(faces_list).astype(np.int32)
         | 
| 966 | 
            +
                        faces_list.append([int(x) for x in all_lines[2+i+n_vertices].split()[1:4]])
         | 
| 967 | 
            +
                    self.faces = np.array(faces_list)
         | 
| 968 | 
            +
                    self.vertices = np.array(vertices_list)
         | 
| 969 | 
            +
                    self.computeCentersAreas()    # Reads from .byu file
         | 
| 970 | 
            +
                def readbyu(self, byufile):
         | 
| 971 | 
            +
                    with open(byufile, 'r') as fbyu:
         | 
| 972 | 
            +
                        ln0 = fbyu.readline()
         | 
| 973 | 
            +
                        ln = ln0.split()
         | 
| 974 | 
            +
                        # read header
         | 
| 975 | 
            +
                        ncomponents = int(ln[0])  # number of components
         | 
| 976 | 
            +
                        npoints = int(ln[1])  # number of vertices
         | 
| 977 | 
            +
                        nfaces = int(ln[2])  # number of faces
         | 
| 978 | 
            +
                        # fscanf(fbyu,'%d',1);		% number of edges
         | 
| 979 | 
            +
                        # %ntest = fscanf(fbyu,'%d',1);		% number of edges
         | 
| 980 | 
            +
                        for k in range(ncomponents):
         | 
| 981 | 
            +
                            fbyu.readline()  # components (ignored)
         | 
| 982 | 
            +
                        # read data
         | 
| 983 | 
            +
                        self.vertices = np.empty([npoints, 3])
         | 
| 984 | 
            +
                        k = -1
         | 
| 985 | 
            +
                        while k < npoints - 1:
         | 
| 986 | 
            +
                            ln = fbyu.readline().split()
         | 
| 987 | 
            +
                            k = k + 1;
         | 
| 988 | 
            +
                            self.vertices[k, 0] = float(ln[0])
         | 
| 989 | 
            +
                            self.vertices[k, 1] = float(ln[1])
         | 
| 990 | 
            +
                            self.vertices[k, 2] = float(ln[2])
         | 
| 991 | 
            +
                            if len(ln) > 3:
         | 
| 992 | 
            +
                                k = k + 1;
         | 
| 993 | 
            +
                                self.vertices[k, 0] = float(ln[3])
         | 
| 994 | 
            +
                                self.vertices[k, 1] = float(ln[4])
         | 
| 995 | 
            +
                                self.vertices[k, 2] = float(ln[5])
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                        self.faces = np.empty([nfaces, 3])
         | 
| 998 | 
            +
                        ln = fbyu.readline().split()
         | 
| 999 | 
            +
                        kf = 0
         | 
| 1000 | 
            +
                        j = 0
         | 
| 1001 | 
            +
                        while ln:
         | 
| 1002 | 
            +
                            if kf >= nfaces:
         | 
| 1003 | 
            +
                                break
         | 
| 1004 | 
            +
                                # print(nfaces, kf, ln)
         | 
| 1005 | 
            +
                            for s in ln:
         | 
| 1006 | 
            +
                                self.faces[kf, j] = int(sp.fabs(int(s)))
         | 
| 1007 | 
            +
                                j = j + 1
         | 
| 1008 | 
            +
                                if j == 3:
         | 
| 1009 | 
            +
                                    kf = kf + 1
         | 
| 1010 | 
            +
                                    j = 0
         | 
| 1011 | 
            +
                            ln = fbyu.readline().split()
         | 
| 1012 | 
            +
                    self.faces = np.int_(self.faces) - 1
         | 
| 1013 | 
            +
                    xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 1014 | 
            +
                    xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 1015 | 
            +
                    xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 1016 | 
            +
                    self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 1017 | 
            +
                    self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                # Saves in .byu format
         | 
| 1020 | 
            +
                def savebyu(self, byufile):
         | 
| 1021 | 
            +
                    # FV = readbyu(byufile)
         | 
| 1022 | 
            +
                    # reads from a .byu file into matlab's face vertex structure FV
         | 
| 1023 | 
            +
             | 
| 1024 | 
            +
                    with open(byufile, 'w') as fbyu:
         | 
| 1025 | 
            +
                        # copy header
         | 
| 1026 | 
            +
                        ncomponents = 1  # number of components
         | 
| 1027 | 
            +
                        npoints = self.vertices.shape[0]  # number of vertices
         | 
| 1028 | 
            +
                        nfaces = self.faces.shape[0]  # number of faces
         | 
| 1029 | 
            +
                        nedges = 3 * nfaces  # number of edges
         | 
| 1030 | 
            +
             | 
| 1031 | 
            +
                        str = '{0: d} {1: d} {2: d} {3: d} 0\n'.format(ncomponents, npoints, nfaces, nedges)
         | 
| 1032 | 
            +
                        fbyu.write(str)
         | 
| 1033 | 
            +
                        str = '1 {0: d}\n'.format(nfaces)
         | 
| 1034 | 
            +
                        fbyu.write(str)
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
                        k = -1
         | 
| 1037 | 
            +
                        while k < (npoints - 1):
         | 
| 1038 | 
            +
                            k = k + 1
         | 
| 1039 | 
            +
                            str = '{0: f} {1: f} {2: f} '.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
         | 
| 1040 | 
            +
                            fbyu.write(str)
         | 
| 1041 | 
            +
                            if k < (npoints - 1):
         | 
| 1042 | 
            +
                                k = k + 1
         | 
| 1043 | 
            +
                                str = '{0: f} {1: f} {2: f}\n'.format(self.vertices[k, 0], self.vertices[k, 1], self.vertices[k, 2])
         | 
| 1044 | 
            +
                                fbyu.write(str)
         | 
| 1045 | 
            +
                            else:
         | 
| 1046 | 
            +
                                fbyu.write('\n')
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                        j = 0
         | 
| 1049 | 
            +
                        for k in range(nfaces):
         | 
| 1050 | 
            +
                            for kk in (0, 1):
         | 
| 1051 | 
            +
                                fbyu.write('{0: d} '.format(self.faces[k, kk] + 1))
         | 
| 1052 | 
            +
                                j = j + 1
         | 
| 1053 | 
            +
                                if j == 16:
         | 
| 1054 | 
            +
                                    fbyu.write('\n')
         | 
| 1055 | 
            +
                                    j = 0
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                            fbyu.write('{0: d} '.format(-self.faces[k, 2] - 1))
         | 
| 1058 | 
            +
                            j = j + 1
         | 
| 1059 | 
            +
                            if j == 16:
         | 
| 1060 | 
            +
                                fbyu.write('\n')
         | 
| 1061 | 
            +
                                j = 0
         | 
| 1062 | 
            +
             | 
| 1063 | 
            +
                def saveVTK(self, fileName, scalars=None, normals=None, tensors=None, scal_name='scalars', vectors=None,
         | 
| 1064 | 
            +
                            vect_name='vectors'):
         | 
| 1065 | 
            +
                    vf = vtkFields()
         | 
| 1066 | 
            +
                    # print(scalars)
         | 
| 1067 | 
            +
                    if not (scalars == None):
         | 
| 1068 | 
            +
                        vf.scalars.append(scal_name)
         | 
| 1069 | 
            +
                        vf.scalars.append(scalars)
         | 
| 1070 | 
            +
                    if not (vectors == None):
         | 
| 1071 | 
            +
                        vf.vectors.append(vect_name)
         | 
| 1072 | 
            +
                        vf.vectors.append(vectors)
         | 
| 1073 | 
            +
                    if not (normals == None):
         | 
| 1074 | 
            +
                        vf.normals.append('normals')
         | 
| 1075 | 
            +
                        vf.normals.append(normals)
         | 
| 1076 | 
            +
                    if not (tensors == None):
         | 
| 1077 | 
            +
                        vf.tensors.append('tensors')
         | 
| 1078 | 
            +
                        vf.tensors.append(tensors)
         | 
| 1079 | 
            +
                    self.saveVTK2(fileName, vf)
         | 
| 1080 | 
            +
             | 
| 1081 | 
            +
                # Saves in .vtk format
         | 
| 1082 | 
            +
                def saveVTK2(self, fileName, vtkFields=None):
         | 
| 1083 | 
            +
                    F = self.faces;
         | 
| 1084 | 
            +
                    V = self.vertices;
         | 
| 1085 | 
            +
             | 
| 1086 | 
            +
                    with open(fileName, 'w') as fvtkout:
         | 
| 1087 | 
            +
                        fvtkout.write('# vtk DataFile Version 3.0\nSurface Data\nASCII\nDATASET POLYDATA\n')
         | 
| 1088 | 
            +
                        fvtkout.write('\nPOINTS {0: d} float'.format(V.shape[0]))
         | 
| 1089 | 
            +
                        for ll in range(V.shape[0]):
         | 
| 1090 | 
            +
                            fvtkout.write('\n{0: f} {1: f} {2: f}'.format(V[ll, 0], V[ll, 1], V[ll, 2]))
         | 
| 1091 | 
            +
                        fvtkout.write('\nPOLYGONS {0:d} {1:d}'.format(F.shape[0], 4 * F.shape[0]))
         | 
| 1092 | 
            +
                        for ll in range(F.shape[0]):
         | 
| 1093 | 
            +
                            fvtkout.write('\n3 {0: d} {1: d} {2: d}'.format(F[ll, 0], F[ll, 1], F[ll, 2]))
         | 
| 1094 | 
            +
                        if not (vtkFields == None):
         | 
| 1095 | 
            +
                            wrote_pd_hdr = False
         | 
| 1096 | 
            +
                            if len(vtkFields.scalars) > 0:
         | 
| 1097 | 
            +
                                if not wrote_pd_hdr:
         | 
| 1098 | 
            +
                                    fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
         | 
| 1099 | 
            +
                                    wrote_pd_hdr = True
         | 
| 1100 | 
            +
                                nf = len(vtkFields.scalars) / 2
         | 
| 1101 | 
            +
                                for k in range(nf):
         | 
| 1102 | 
            +
                                    fvtkout.write('\nSCALARS ' + vtkFields.scalars[2 * k] + ' float 1\nLOOKUP_TABLE default')
         | 
| 1103 | 
            +
                                    for ll in range(V.shape[0]):
         | 
| 1104 | 
            +
                                        # print(scalars[ll])
         | 
| 1105 | 
            +
                                        fvtkout.write('\n {0: .5f}'.format(vtkFields.scalars[2 * k + 1][ll]))
         | 
| 1106 | 
            +
                            if len(vtkFields.vectors) > 0:
         | 
| 1107 | 
            +
                                if not wrote_pd_hdr:
         | 
| 1108 | 
            +
                                    fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
         | 
| 1109 | 
            +
                                    wrote_pd_hdr = True
         | 
| 1110 | 
            +
                                nf = len(vtkFields.vectors) / 2
         | 
| 1111 | 
            +
                                for k in range(nf):
         | 
| 1112 | 
            +
                                    fvtkout.write('\nVECTORS ' + vtkFields.vectors[2 * k] + ' float')
         | 
| 1113 | 
            +
                                    vectors = vtkFields.vectors[2 * k + 1]
         | 
| 1114 | 
            +
                                    for ll in range(V.shape[0]):
         | 
| 1115 | 
            +
                                        fvtkout.write(
         | 
| 1116 | 
            +
                                            '\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
         | 
| 1117 | 
            +
                            if len(vtkFields.normals) > 0:
         | 
| 1118 | 
            +
                                if not wrote_pd_hdr:
         | 
| 1119 | 
            +
                                    fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
         | 
| 1120 | 
            +
                                    wrote_pd_hdr = True
         | 
| 1121 | 
            +
                                nf = len(vtkFields.normals) / 2
         | 
| 1122 | 
            +
                                for k in range(nf):
         | 
| 1123 | 
            +
                                    fvtkout.write('\nNORMALS ' + vtkFields.normals[2 * k] + ' float')
         | 
| 1124 | 
            +
                                    vectors = vtkFields.normals[2 * k + 1]
         | 
| 1125 | 
            +
                                    for ll in range(V.shape[0]):
         | 
| 1126 | 
            +
                                        fvtkout.write(
         | 
| 1127 | 
            +
                                            '\n {0: .5f} {1: .5f} {2: .5f}'.format(vectors[ll, 0], vectors[ll, 1], vectors[ll, 2]))
         | 
| 1128 | 
            +
                            if len(vtkFields.tensors) > 0:
         | 
| 1129 | 
            +
                                if not wrote_pd_hdr:
         | 
| 1130 | 
            +
                                    fvtkout.write(('\nPOINT_DATA {0: d}').format(V.shape[0]))
         | 
| 1131 | 
            +
                                    wrote_pd_hdr = True
         | 
| 1132 | 
            +
                                nf = len(vtkFields.tensors) / 2
         | 
| 1133 | 
            +
                                for k in range(nf):
         | 
| 1134 | 
            +
                                    fvtkout.write('\nTENSORS ' + vtkFields.tensors[2 * k] + ' float')
         | 
| 1135 | 
            +
                                    tensors = vtkFields.tensors[2 * k + 1]
         | 
| 1136 | 
            +
                                    for ll in range(V.shape[0]):
         | 
| 1137 | 
            +
                                        for kk in range(2):
         | 
| 1138 | 
            +
                                            fvtkout.write(
         | 
| 1139 | 
            +
                                                '\n {0: .5f} {1: .5f} {2: .5f}'.format(tensors[ll, kk, 0], tensors[ll, kk, 1],
         | 
| 1140 | 
            +
                                                                                       tensors[ll, kk, 2]))
         | 
| 1141 | 
            +
                            fvtkout.write('\n')
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                # Reads .vtk file
         | 
| 1144 | 
            +
                def readVTK(self, fileName):
         | 
| 1145 | 
            +
                    if gotVTK:
         | 
| 1146 | 
            +
                        u = vtkPolyDataReader()
         | 
| 1147 | 
            +
                        u.SetFileName(fileName)
         | 
| 1148 | 
            +
                        u.Update()
         | 
| 1149 | 
            +
                        v = u.GetOutput()
         | 
| 1150 | 
            +
                        # print(v)
         | 
| 1151 | 
            +
                        npoints = int(v.GetNumberOfPoints())
         | 
| 1152 | 
            +
                        nfaces = int(v.GetNumberOfPolys())
         | 
| 1153 | 
            +
                        V = np.zeros([npoints, 3])
         | 
| 1154 | 
            +
                        for kk in range(npoints):
         | 
| 1155 | 
            +
                            V[kk, :] = np.array(v.GetPoint(kk))
         | 
| 1156 | 
            +
             | 
| 1157 | 
            +
                        F = np.zeros([nfaces, 3])
         | 
| 1158 | 
            +
                        for kk in range(nfaces):
         | 
| 1159 | 
            +
                            c = v.GetCell(kk)
         | 
| 1160 | 
            +
                            for ll in range(3):
         | 
| 1161 | 
            +
                                F[kk, ll] = c.GetPointId(ll)
         | 
| 1162 | 
            +
             | 
| 1163 | 
            +
                        self.vertices = V
         | 
| 1164 | 
            +
                        self.faces = np.int_(F)
         | 
| 1165 | 
            +
                        xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 1166 | 
            +
                        xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 1167 | 
            +
                        xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 1168 | 
            +
                        self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 1169 | 
            +
                        self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 1170 | 
            +
                    else:
         | 
| 1171 | 
            +
                        raise Exception('Cannot run readVTK without VTK')
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                # Reads .obj file
         | 
| 1174 | 
            +
                def readOBJ(self, fileName):
         | 
| 1175 | 
            +
                    if gotVTK:
         | 
| 1176 | 
            +
                        u = vtkOBJReader()
         | 
| 1177 | 
            +
                        u.SetFileName(fileName)
         | 
| 1178 | 
            +
                        u.Update()
         | 
| 1179 | 
            +
                        v = u.GetOutput()
         | 
| 1180 | 
            +
                        # print(v)
         | 
| 1181 | 
            +
                        npoints = int(v.GetNumberOfPoints())
         | 
| 1182 | 
            +
                        nfaces = int(v.GetNumberOfPolys())
         | 
| 1183 | 
            +
                        V = np.zeros([npoints, 3])
         | 
| 1184 | 
            +
                        for kk in range(npoints):
         | 
| 1185 | 
            +
                            V[kk, :] = np.array(v.GetPoint(kk))
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                        F = np.zeros([nfaces, 3])
         | 
| 1188 | 
            +
                        for kk in range(nfaces):
         | 
| 1189 | 
            +
                            c = v.GetCell(kk)
         | 
| 1190 | 
            +
                            for ll in range(3):
         | 
| 1191 | 
            +
                                F[kk, ll] = c.GetPointId(ll)
         | 
| 1192 | 
            +
             | 
| 1193 | 
            +
                        self.vertices = V
         | 
| 1194 | 
            +
                        self.faces = np.int_(F)
         | 
| 1195 | 
            +
                        xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 1196 | 
            +
                        xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 1197 | 
            +
                        xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 1198 | 
            +
                        self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 1199 | 
            +
                        self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 1200 | 
            +
                    else:
         | 
| 1201 | 
            +
                        raise Exception('Cannot run readOBJ without VTK')
         | 
| 1202 | 
            +
             | 
| 1203 | 
            +
                def readPLY(self, fileName):
         | 
| 1204 | 
            +
                    if gotVTK:
         | 
| 1205 | 
            +
                        u = vtkPLYReader()
         | 
| 1206 | 
            +
                        u.SetFileName(fileName)
         | 
| 1207 | 
            +
                        u.Update()
         | 
| 1208 | 
            +
                        v = u.GetOutput()
         | 
| 1209 | 
            +
                        # print(v)
         | 
| 1210 | 
            +
                        npoints = int(v.GetNumberOfPoints())
         | 
| 1211 | 
            +
                        nfaces = int(v.GetNumberOfPolys())
         | 
| 1212 | 
            +
                        V = np.zeros([npoints, 3])
         | 
| 1213 | 
            +
                        for kk in range(npoints):
         | 
| 1214 | 
            +
                            V[kk, :] = np.array(v.GetPoint(kk))
         | 
| 1215 | 
            +
             | 
| 1216 | 
            +
                        F = np.zeros([nfaces, 3])
         | 
| 1217 | 
            +
                        for kk in range(nfaces):
         | 
| 1218 | 
            +
                            c = v.GetCell(kk)
         | 
| 1219 | 
            +
                            for ll in range(3):
         | 
| 1220 | 
            +
                                F[kk, ll] = c.GetPointId(ll)
         | 
| 1221 | 
            +
             | 
| 1222 | 
            +
                        self.vertices = V
         | 
| 1223 | 
            +
                        self.faces = np.int_(F)
         | 
| 1224 | 
            +
                        xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 1225 | 
            +
                        xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 1226 | 
            +
                        xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 1227 | 
            +
                        self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 1228 | 
            +
                        self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 1229 | 
            +
                    else:
         | 
| 1230 | 
            +
                        raise Exception('Cannot run readPLY without VTK')
         | 
| 1231 | 
            +
             | 
| 1232 | 
            +
                def readTRI(self, fileName):
         | 
| 1233 | 
            +
                    vertices = []
         | 
| 1234 | 
            +
                    faces = []
         | 
| 1235 | 
            +
                    with open(fileName, "r") as f:
         | 
| 1236 | 
            +
                        pithc = f.readlines()
         | 
| 1237 | 
            +
                        line_1 = pithc[0]
         | 
| 1238 | 
            +
                        n_pts, n_tri = line_1.split(' ')
         | 
| 1239 | 
            +
                        for i, line in enumerate(pithc[1:]):
         | 
| 1240 | 
            +
                            pouetage = line.split(' ')
         | 
| 1241 | 
            +
                            if i <= int(n_pts):
         | 
| 1242 | 
            +
                                vertices.append([float(poupout) for poupout in pouetage[:3]])
         | 
| 1243 | 
            +
                                # vertices.append([float(poupout) for poupout in pouetage[4:7]])
         | 
| 1244 | 
            +
                            else:
         | 
| 1245 | 
            +
                                faces.append([int(poupout.split("\n")[0]) for poupout in pouetage[1:]])
         | 
| 1246 | 
            +
                    self.vertices = np.array(vertices)
         | 
| 1247 | 
            +
                    self.faces = np.array(faces)
         | 
| 1248 | 
            +
                    xDef1 = self.vertices[self.faces[:, 0], :]
         | 
| 1249 | 
            +
                    xDef2 = self.vertices[self.faces[:, 1], :]
         | 
| 1250 | 
            +
                    xDef3 = self.vertices[self.faces[:, 2], :]
         | 
| 1251 | 
            +
                    self.centers = (xDef1 + xDef2 + xDef3) / 3
         | 
| 1252 | 
            +
                    self.surfel = np.cross(xDef2 - xDef1, xDef3 - xDef1)
         | 
| 1253 | 
            +
             | 
| 1254 | 
            +
                def concatenate(self, fvl):
         | 
| 1255 | 
            +
                    nv = 0
         | 
| 1256 | 
            +
                    nf = 0
         | 
| 1257 | 
            +
                    for fv in fvl:
         | 
| 1258 | 
            +
                        nv += fv.vertices.shape[0]
         | 
| 1259 | 
            +
                        nf += fv.faces.shape[0]
         | 
| 1260 | 
            +
                    self.vertices = np.zeros([nv, 3])
         | 
| 1261 | 
            +
                    self.faces = np.zeros([nf, 3], dtype='int')
         | 
| 1262 | 
            +
             | 
| 1263 | 
            +
                    nv0 = 0
         | 
| 1264 | 
            +
                    nf0 = 0
         | 
| 1265 | 
            +
                    for fv in fvl:
         | 
| 1266 | 
            +
                        nv = nv0 + fv.vertices.shape[0]
         | 
| 1267 | 
            +
                        nf = nf0 + fv.faces.shape[0]
         | 
| 1268 | 
            +
                        self.vertices[nv0:nv, :] = fv.vertices
         | 
| 1269 | 
            +
                        self.faces[nf0:nf, :] = fv.faces + nv0
         | 
| 1270 | 
            +
                        nv0 = nv
         | 
| 1271 | 
            +
                        nf0 = nf
         | 
| 1272 | 
            +
                    self.computeCentersAreas()
         | 
| 1273 | 
            +
             | 
| 1274 | 
            +
            def get_surf_area(surf):
         | 
| 1275 | 
            +
                areas = np.linalg.norm(surf.surfel, axis=-1)
         | 
| 1276 | 
            +
                return areas.sum()/2
         | 
| 1277 | 
            +
             | 
| 1278 | 
            +
            def centroid(surf):
         | 
| 1279 | 
            +
                areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True)
         | 
| 1280 | 
            +
                center = (surf.centers * areas).sum(axis=0) / areas.sum()
         | 
| 1281 | 
            +
                return center, np.sqrt(areas.sum()/2)
         | 
| 1282 | 
            +
             | 
| 1283 | 
            +
             | 
| 1284 | 
            +
            def do_bbox_vertices(vertices):
         | 
| 1285 | 
            +
                new_verts = np.zeros(vertices.shape)
         | 
| 1286 | 
            +
                new_verts[:, 0] = (vertices[:, 0] - np.amin(vertices[:, 0]))/(np.amax(vertices[:, 0]) - np.amin(vertices[:, 0]))
         | 
| 1287 | 
            +
                new_verts[:, 1] = (vertices[:, 1] - np.amin(vertices[:, 1]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 1]))
         | 
| 1288 | 
            +
                new_verts[:, 2] = (vertices[:, 2] - np.amin(vertices[:, 2]))/(np.amax(vertices[:, 1]) - np.amin(vertices[:, 2]))*0.5
         | 
| 1289 | 
            +
                return new_verts
         | 
| 1290 | 
            +
             | 
| 1291 | 
            +
            def opt_rot_surf(surf_1, surf_2, areas_c):
         | 
| 1292 | 
            +
                center_1, _ = centroid(surf_1)
         | 
| 1293 | 
            +
                pts_c1 = surf_1.centers - center_1
         | 
| 1294 | 
            +
                center_2, _ = centroid(surf_2)
         | 
| 1295 | 
            +
                pts_c2 = surf_2.centers - center_2 
         | 
| 1296 | 
            +
                to_sum = pts_c1[:, :, np.newaxis] * pts_c2[:, np.newaxis, :]
         | 
| 1297 | 
            +
                A = (to_sum * areas_c[:, :, np.newaxis]).sum(axis=0)
         | 
| 1298 | 
            +
                #A = to_sum.sum(axis=0)
         | 
| 1299 | 
            +
                u, _, v = np.linalg.svd(A)
         | 
| 1300 | 
            +
                a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, np.sign(np.linalg.det(A))]])
         | 
| 1301 | 
            +
                O = u @ a @ v
         | 
| 1302 | 
            +
                return O.T
         | 
| 1303 | 
            +
             | 
| 1304 | 
            +
            def srnf(surf):
         | 
| 1305 | 
            +
                areas = np.linalg.norm(surf.surfel, axis=-1, keepdims=True) + 1e-8
         | 
| 1306 | 
            +
                return surf.surfel/np.sqrt(areas)
         | 
| 1307 | 
            +
             | 
| 1308 | 
            +
            def fps_gpt(points, num_samples):
         | 
| 1309 | 
            +
                """
         | 
| 1310 | 
            +
                Selects a subset of points using the Farthest Point Sampling algorithm.
         | 
| 1311 | 
            +
             | 
| 1312 | 
            +
                :param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
         | 
| 1313 | 
            +
                :param num_samples: number of points to select
         | 
| 1314 | 
            +
                :return: numpy array of indices of the selected points
         | 
| 1315 | 
            +
                """
         | 
| 1316 | 
            +
                if num_samples > len(points):
         | 
| 1317 | 
            +
                    raise ValueError("num_samples must be less than or equal to the number of points")
         | 
| 1318 | 
            +
             | 
| 1319 | 
            +
                # Initialize an array to hold the indices of the selected points
         | 
| 1320 | 
            +
                selected_indices = np.zeros(num_samples, dtype=int)
         | 
| 1321 | 
            +
             | 
| 1322 | 
            +
                center = np.mean(points, axis=0)
         | 
| 1323 | 
            +
             | 
| 1324 | 
            +
                # Array to store the shortest distance of each point to a selected poin
         | 
| 1325 | 
            +
                shortest_distances = np.full(len(points), np.inf)
         | 
| 1326 | 
            +
                last_selected_point = center
         | 
| 1327 | 
            +
                for i in range(1, num_samples):
         | 
| 1328 | 
            +
                    
         | 
| 1329 | 
            +
             | 
| 1330 | 
            +
                    # Update the shortest distance of each point to the selected points
         | 
| 1331 | 
            +
                    for j, point in enumerate(points):
         | 
| 1332 | 
            +
                        distance = np.linalg.norm(point - last_selected_point)
         | 
| 1333 | 
            +
                        shortest_distances[j] = min(shortest_distances[j], distance)
         | 
| 1334 | 
            +
             | 
| 1335 | 
            +
                    # Select the point that is farthest from all selected points
         | 
| 1336 | 
            +
                    selected_indices[i] = np.argmax(shortest_distances)
         | 
| 1337 | 
            +
                    last_selected_point = points[selected_indices[i]]
         | 
| 1338 | 
            +
                return selected_indices
         | 
| 1339 | 
            +
             | 
| 1340 | 
            +
             | 
| 1341 | 
            +
            def fps_gpt_geod(points, faces, num_samples, index_start):
         | 
| 1342 | 
            +
                """
         | 
| 1343 | 
            +
                Selects a subset of points using the Farthest Point Sampling algorithm.
         | 
| 1344 | 
            +
             | 
| 1345 | 
            +
                :param points: numpy array of shape (N, D), where N is the number of points and D is the dimension of each point
         | 
| 1346 | 
            +
                :param num_samples: number of points to select
         | 
| 1347 | 
            +
                :return: numpy array of indices of the selected points
         | 
| 1348 | 
            +
                """
         | 
| 1349 | 
            +
                if num_samples > len(points):
         | 
| 1350 | 
            +
                    raise ValueError("num_samples must be less than or equal to the number of points")
         | 
| 1351 | 
            +
             | 
| 1352 | 
            +
                # Initialize an array to hold the indices of the selected points
         | 
| 1353 | 
            +
                selected_indices = np.zeros(num_samples, dtype=int)
         | 
| 1354 | 
            +
             | 
| 1355 | 
            +
                # Array to store the shortest distance of each point to a selected poin
         | 
| 1356 | 
            +
                shortest_distances = np.full(len(points), np.inf)
         | 
| 1357 | 
            +
                selected_indices[0] = index_start
         | 
| 1358 | 
            +
                distances = np.zeros((num_samples, len(points)))
         | 
| 1359 | 
            +
                solver = pp3d.MeshHeatMethodDistanceSolver(points, faces)
         | 
| 1360 | 
            +
                for i in range(1, num_samples):
         | 
| 1361 | 
            +
                    distances[i] = (solver.compute_distance(selected_indices[i-1]))
         | 
| 1362 | 
            +
             | 
| 1363 | 
            +
                    # Update the shortest distance of each point to the selected points
         | 
| 1364 | 
            +
                    for j, point in enumerate(points):
         | 
| 1365 | 
            +
                        distance = distances[i][j]
         | 
| 1366 | 
            +
                        shortest_distances[j] = min(shortest_distances[j], distance)
         | 
| 1367 | 
            +
             | 
| 1368 | 
            +
                    # Select the point that is farthest from all selected points
         | 
| 1369 | 
            +
                    selected_indices[i] = np.argmax(shortest_distances)
         | 
| 1370 | 
            +
                return selected_indices
         | 
| 1371 | 
            +
             | 
| 1372 | 
            +
            # class MultiSurface:
         | 
| 1373 | 
            +
            #     def __init__(self, pattern):
         | 
| 1374 | 
            +
            #         self.surf = []
         | 
| 1375 | 
            +
            #         files = glob.glob(pattern)
         | 
| 1376 | 
            +
            #         for f in files:
         | 
| 1377 | 
            +
            #             self.surf.append(Surface(filename=f))
         | 
    	
        utils/torch_fmap.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            from pykeops.torch import LazyTensor
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def euclidean_dist(x, y):
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                  x: pytorch Variable, with shape [m, d]
         | 
| 10 | 
            +
                  y: pytorch Variable, with shape [n, d]
         | 
| 11 | 
            +
                Returns:
         | 
| 12 | 
            +
                  dist: pytorch Variable, with shape [m, n]
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                #bs, m, n = x.size(0), x.size(1), y.size(1)
         | 
| 15 | 
            +
                xx = torch.pow(x.squeeze(), 2).sum(1, keepdim=True)
         | 
| 16 | 
            +
                yy = torch.pow(y.squeeze(), 2).sum(1, keepdim=True).t()
         | 
| 17 | 
            +
                dist = xx + yy - 2 * torch.inner(x.squeeze(), y.squeeze())
         | 
| 18 | 
            +
                dist = dist.clamp(min=1e-12).sqrt() 
         | 
| 19 | 
            +
                return dist
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def knnsearch(x, y, alpha=1./0.07, prod=False):
         | 
| 22 | 
            +
                if prod:
         | 
| 23 | 
            +
                    prods = torch.inner(x.squeeze(), y.squeeze())#/( torch.norm(x.squeeze(), dim=-1)[:, None]*torch.norm(y.squeeze(), dim=-1)[None, :])
         | 
| 24 | 
            +
                    output = F.softmax(alpha*prods, dim=1)
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    distance = euclidean_dist(x, y[None,:])
         | 
| 27 | 
            +
                    output = F.softmax(-alpha*distance, dim=1)
         | 
| 28 | 
            +
                return output.squeeze()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def extract_p2p_torch(reps_shape, reps_template):
         | 
| 31 | 
            +
                n_ev = reps_shape.shape[-1]
         | 
| 32 | 
            +
                with torch.no_grad():
         | 
| 33 | 
            +
                    # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
         | 
| 34 | 
            +
                    # print(evecs1_dzo.shape)
         | 
| 35 | 
            +
                    reps_shape_torch = torch.from_numpy(reps_shape).float().cuda()
         | 
| 36 | 
            +
                    G_i = LazyTensor(reps_shape_torch[:, None, :].contiguous())  # (M**2, 1, 2)
         | 
| 37 | 
            +
                    reps_template_torch = torch.from_numpy(reps_template).float().cuda()
         | 
| 38 | 
            +
                    X_j = LazyTensor(reps_template_torch[None, :, :n_ev].contiguous())  # (1, N, 2)
         | 
| 39 | 
            +
                    D_ij = ((G_i - X_j) ** 2).sum(-1)  # (M**2, N) symbolic matrix of squared distances
         | 
| 40 | 
            +
                    indKNN = D_ij.argKmin(1, dim=0).squeeze()  # Grid <-> Samples, (M**2, K) integer tensor
         | 
| 41 | 
            +
                    # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
         | 
| 42 | 
            +
                    # print(indKNN[:10], pmap10_ref[:10])
         | 
| 43 | 
            +
                    indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
         | 
| 44 | 
            +
                return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def extract_p2p_torch_fmap(fmap_shape_template, evecs_shape, evecs_template):
         | 
| 47 | 
            +
                n_ev = fmap_shape_template.shape[-1]
         | 
| 48 | 
            +
                with torch.no_grad():
         | 
| 49 | 
            +
                    # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
         | 
| 50 | 
            +
                    # print(evecs1_dzo.shape)
         | 
| 51 | 
            +
                    G_i = LazyTensor((evecs_shape[:, :n_ev] @ fmap_shape_template.squeeze().T)[:, None, :].contiguous())  # (M**2, 1, 2)
         | 
| 52 | 
            +
                    X_j = LazyTensor(evecs_template[None, :, :n_ev].contiguous())  # (1, N, 2)
         | 
| 53 | 
            +
                    D_ij = ((G_i - X_j) ** 2).sum(-1)  # (M**2, N) symbolic matrix of squared distances
         | 
| 54 | 
            +
                    indKNN = D_ij.argKmin(1, dim=0).squeeze()  # Grid <-> Samples, (M**2, K) integer tensor
         | 
| 55 | 
            +
                    # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
         | 
| 56 | 
            +
                    # print(indKNN[:10], pmap10_ref[:10])
         | 
| 57 | 
            +
                    indKNN_2 = D_ij.argKmin(1, dim=1).squeeze()
         | 
| 58 | 
            +
                return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            def wlstsq(A, B, w):
         | 
| 61 | 
            +
                if w is None:
         | 
| 62 | 
            +
                    return torch.linalg.lstsq(A, B).solution
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    assert w.dim() + 1 == A.dim() and w.shape[-1] == A.shape[-2]
         | 
| 65 | 
            +
                    W = torch.diag_embed(w)
         | 
| 66 | 
            +
                    return torch.linalg.lstsq(W @ A, W @ B).solution
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def torch_zoomout(evecs0, evecs1, evecs_1_trans, fmap01, target_size, step=1):
         | 
| 69 | 
            +
                assert fmap01.shape[-2] == fmap01.shape[-1], f"square fmap needed, got {fmap01.shape[-2]} and {fmap01.shape[-1]}"
         | 
| 70 | 
            +
                fs = fmap01.shape[0]
         | 
| 71 | 
            +
                for i in range(fs, target_size+1, step):
         | 
| 72 | 
            +
                    indKNN, _ = extract_p2p_torch_fmap(fmap01, evecs0, evecs1)
         | 
| 73 | 
            +
                    #fmap01 = wlstsq(evecs1[..., :i], evecs0[indKNN, :i], None)
         | 
| 74 | 
            +
                    fmap01 = evecs_1_trans[:i, :] @ evecs0[indKNN, :i]
         | 
| 75 | 
            +
                if fmap01.shape[0] < target_size:
         | 
| 76 | 
            +
                    fmap01 = evecs_1_trans[:target_size, :] @ evecs0[indKNN, :target_size]
         | 
| 77 | 
            +
                return fmap01
         | 
    	
        utils/utils_func.py
    ADDED
    
    | @@ -0,0 +1,123 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import os.path as osp
         | 
| 3 | 
            +
            import scipy.sparse as sp
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            import requests
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def ensure_pretrained_file(hf_url: str, save_dir: str = "pretrained", filename: str = "pretrained.pkl", token: str = None):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Ensure that a pretrained file exists locally.
         | 
| 15 | 
            +
                If the folder is empty, download from Hugging Face.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Args:
         | 
| 18 | 
            +
                    hf_url (str): Hugging Face file URL (resolve/main/...).
         | 
| 19 | 
            +
                    save_dir (str): Directory to store pretrained file.
         | 
| 20 | 
            +
                    filename (str): Name of file to save.
         | 
| 21 | 
            +
                    token (str): Optional Hugging Face token (for gated/private repos).
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                os.makedirs(save_dir, exist_ok=True)
         | 
| 24 | 
            +
                save_path = os.path.join(save_dir, filename)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                if os.path.exists(save_path):
         | 
| 27 | 
            +
                    print(f"✅ Found pretrained file: {save_path}")
         | 
| 28 | 
            +
                    return save_path
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                headers = {"Authorization": f"Bearer {token}"} if token else {}
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                print(f"⬇️ Downloading pretrained file from {hf_url} to {save_path} ...")
         | 
| 33 | 
            +
                response = requests.get(hf_url, headers=headers, stream=True)
         | 
| 34 | 
            +
                response.raise_for_status()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                with open(save_path, "wb") as f:
         | 
| 37 | 
            +
                    for chunk in response.iter_content(chunk_size=8192):
         | 
| 38 | 
            +
                        f.write(chunk)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                print("✅ Download complete.")
         | 
| 41 | 
            +
                return save_path
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def may_create_folder(folder_path):
         | 
| 45 | 
            +
                if not osp.exists(folder_path):
         | 
| 46 | 
            +
                    oldmask = os.umask(000)
         | 
| 47 | 
            +
                    os.makedirs(folder_path, mode=0o777)
         | 
| 48 | 
            +
                    os.umask(oldmask)
         | 
| 49 | 
            +
                    return True
         | 
| 50 | 
            +
                return False
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def make_clean_folder(folder_path):
         | 
| 54 | 
            +
                success = may_create_folder(folder_path)
         | 
| 55 | 
            +
                if not success:
         | 
| 56 | 
            +
                    shutil.rmtree(folder_path)
         | 
| 57 | 
            +
                    may_create_folder(folder_path)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            def str_delta(delta):
         | 
| 60 | 
            +
                s = delta.total_seconds()
         | 
| 61 | 
            +
                hours, remainder = divmod(s, 3600)
         | 
| 62 | 
            +
                minutes, seconds = divmod(remainder, 60)
         | 
| 63 | 
            +
                return '{:02}:{:02}:{:02}'.format(int(hours), int(minutes), int(seconds))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def desc_from_config(config, world_size):
         | 
| 66 | 
            +
                cond_str = 'uncond' #'cond' if c.dataset_kwargs.use_labels else 'uncond'
         | 
| 67 | 
            +
                dtype_str = 'fp16' if config["perfs"]["fp16"] else 'fp32'
         | 
| 68 | 
            +
                name_data = config["data"]["name"]
         | 
| 69 | 
            +
                if "abs" in config["data"]:
         | 
| 70 | 
            +
                    name_data += "abs" if config["data"]["abs"] else ""
         | 
| 71 | 
            +
                desc = (f'{name_data:s}-{cond_str:s}-{config["architecture"]["model"]:s}-'
         | 
| 72 | 
            +
                        f'gpus{world_size:d}-batch{config["hyper_params"]["batch_size"]:d}-{dtype_str:s}')
         | 
| 73 | 
            +
                if config["misc"]["precond"]:
         | 
| 74 | 
            +
                    desc += f'-{config["misc"]["precond"]:s}'
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if "desc" in config["misc"]:
         | 
| 77 | 
            +
                    if config["misc"]["desc"] is not None:
         | 
| 78 | 
            +
                        desc += f'-{config["misc"]["desc"]}'
         | 
| 79 | 
            +
                return desc
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            def get_dataset_path(config):
         | 
| 82 | 
            +
                name_exp = config["data"]["name"]
         | 
| 83 | 
            +
                if "add_name" in config:
         | 
| 84 | 
            +
                    if config["add_name"]["do"]:
         | 
| 85 | 
            +
                        name_exp += "_" + config["add_name"]["name"]
         | 
| 86 | 
            +
                dataset_path = os.path.join(config["data"]["root_dir"], name_exp)
         | 
| 87 | 
            +
                return name_exp, dataset_path
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            def sparse_np_to_torch(A):
         | 
| 90 | 
            +
                Acoo = A.tocoo()
         | 
| 91 | 
            +
                values = Acoo.data
         | 
| 92 | 
            +
                indices = np.vstack((Acoo.row, Acoo.col))
         | 
| 93 | 
            +
                shape = Acoo.shape
         | 
| 94 | 
            +
                return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            def convert_dict(np_dict, device="cpu"):
         | 
| 97 | 
            +
                torch_dict = {}
         | 
| 98 | 
            +
                for k, value in np_dict.items():
         | 
| 99 | 
            +
                    if sp.issparse(value):
         | 
| 100 | 
            +
                        torch_dict[k] = sparse_np_to_torch(value).to(device)
         | 
| 101 | 
            +
                        if torch_dict[k].dtype == torch.int32:
         | 
| 102 | 
            +
                            torch_dict[k] = torch_dict[k].long().to(device)
         | 
| 103 | 
            +
                    elif isinstance(value, np.ndarray):
         | 
| 104 | 
            +
                        torch_dict[k] = torch.from_numpy(value).to(device)
         | 
| 105 | 
            +
                        if torch_dict[k].dtype == torch.int32:
         | 
| 106 | 
            +
                            torch_dict[k] = torch_dict[k].squeeze().long().to(device)
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        torch_dict[k] = value
         | 
| 109 | 
            +
                return torch_dict
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def convert_dict_torch(in_dict, device="cpu"):
         | 
| 113 | 
            +
                torch_dict = {}
         | 
| 114 | 
            +
                for k, value in in_dict.items():
         | 
| 115 | 
            +
                    if isinstance(value, torch.Tensor):
         | 
| 116 | 
            +
                        torch_dict[k] = in_dict[k].to(device)
         | 
| 117 | 
            +
                return torch_dict
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            def batchify_dict(torch_dict):
         | 
| 120 | 
            +
                for k, value in torch_dict.items():
         | 
| 121 | 
            +
                    if isinstance(value, torch.Tensor):
         | 
| 122 | 
            +
                        if torch_dict[k].dtype != torch.int64:
         | 
| 123 | 
            +
                            torch_dict[k] = torch_dict[k].unsqueeze(0)
         | 
    	
        utils/utils_legacy.py
    ADDED
    
    | @@ -0,0 +1,130 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import hashlib
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import scipy
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def read_lines(file_path):
         | 
| 11 | 
            +
                with open(file_path, 'r') as fin:
         | 
| 12 | 
            +
                    lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
         | 
| 13 | 
            +
                return lines
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # == Pytorch things
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def toNP(x):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Really, definitely convert a torch tensor to a numpy array
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                return x.detach().to(torch.device('cpu')).numpy()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def label_smoothing_log_loss(pred, labels, smoothing=0.0):
         | 
| 26 | 
            +
                n_class = pred.shape[-1]
         | 
| 27 | 
            +
                one_hot = torch.zeros_like(pred)
         | 
| 28 | 
            +
                one_hot[labels] = 1.
         | 
| 29 | 
            +
                one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
         | 
| 30 | 
            +
                loss = -(one_hot * pred).sum(dim=-1).mean()
         | 
| 31 | 
            +
                return loss
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Randomly rotate points.
         | 
| 35 | 
            +
            # Torch in, torch out
         | 
| 36 | 
            +
            # Note fornow, builds rotation matrix on CPU.
         | 
| 37 | 
            +
            def random_rotate_points(pts, randgen=None):
         | 
| 38 | 
            +
                R = random_rotation_matrix(randgen)
         | 
| 39 | 
            +
                R = torch.from_numpy(R).to(device=pts.device, dtype=pts.dtype)
         | 
| 40 | 
            +
                return torch.matmul(pts, R)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def random_rotate_points_y(pts):
         | 
| 44 | 
            +
                angles = torch.rand(1, device=pts.device, dtype=pts.dtype) * (2. * np.pi)
         | 
| 45 | 
            +
                rot_mats = torch.zeros(3, 3, device=pts.device, dtype=pts.dtype)
         | 
| 46 | 
            +
                rot_mats[0, 0] = torch.cos(angles)
         | 
| 47 | 
            +
                rot_mats[0, 2] = torch.sin(angles)
         | 
| 48 | 
            +
                rot_mats[2, 0] = -torch.sin(angles)
         | 
| 49 | 
            +
                rot_mats[2, 2] = torch.cos(angles)
         | 
| 50 | 
            +
                rot_mats[1, 1] = 1.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                pts = torch.matmul(pts, rot_mats)
         | 
| 53 | 
            +
                return pts
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            # Numpy things
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            # Numpy sparse matrix to pytorch
         | 
| 60 | 
            +
            def sparse_np_to_torch(A):
         | 
| 61 | 
            +
                Acoo = A.tocoo()
         | 
| 62 | 
            +
                values = Acoo.data
         | 
| 63 | 
            +
                indices = np.vstack((Acoo.row, Acoo.col))
         | 
| 64 | 
            +
                shape = Acoo.shape
         | 
| 65 | 
            +
                return torch.sparse_coo_tensor(torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape), dtype=torch.float32).coalesce()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Pytorch sparse to numpy csc matrix
         | 
| 69 | 
            +
            def sparse_torch_to_np(A):
         | 
| 70 | 
            +
                if len(A.shape) != 2:
         | 
| 71 | 
            +
                    raise RuntimeError("should be a matrix-shaped type; dim is : " + str(A.shape))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                indices = toNP(A.indices())
         | 
| 74 | 
            +
                values = toNP(A.values())
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                mat = scipy.sparse.coo_matrix((values, indices), shape=A.shape).tocsc()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                return mat
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            # Hash a list of numpy arrays
         | 
| 82 | 
            +
            def hash_arrays(arrs):
         | 
| 83 | 
            +
                running_hash = hashlib.sha1()
         | 
| 84 | 
            +
                for arr in arrs:
         | 
| 85 | 
            +
                    binarr = arr.view(np.uint8)
         | 
| 86 | 
            +
                    running_hash.update(binarr)
         | 
| 87 | 
            +
                return running_hash.hexdigest()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def random_rotation_matrix(randgen=None):
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                Creates a random rotation matrix.
         | 
| 93 | 
            +
                randgen: if given, a np.random.RandomState instance used for random numbers (for reproducibility)
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                # adapted from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                if randgen is None:
         | 
| 98 | 
            +
                    randgen = np.random.RandomState()
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                theta, phi, z = tuple(randgen.rand(3).tolist())
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                theta = theta * 2.0 * np.pi  # Rotation about the pole (Z).
         | 
| 103 | 
            +
                phi = phi * 2.0 * np.pi  # For direction of pole deflection.
         | 
| 104 | 
            +
                z = z * 2.0  # For magnitude of pole deflection.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # Compute a vector V used for distributing points over the sphere
         | 
| 107 | 
            +
                # via the reflection I - V Transpose(V).  This formulation of V
         | 
| 108 | 
            +
                # will guarantee that if x[1] and x[2] are uniformly distributed,
         | 
| 109 | 
            +
                # the reflected points will be uniform on the sphere.  Note that V
         | 
| 110 | 
            +
                # has length sqrt(2) to eliminate the 2 in the Householder matrix.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                r = np.sqrt(z)
         | 
| 113 | 
            +
                Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z))
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                st = np.sin(theta)
         | 
| 116 | 
            +
                ct = np.cos(theta)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1)))
         | 
| 119 | 
            +
                # Construct the rotation matrix  ( V Transpose(V) - I ) R.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                M = (np.outer(V, V) - np.eye(3)).dot(R)
         | 
| 122 | 
            +
                return M
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            # Python string/file utilities
         | 
| 126 | 
            +
            def ensure_dir_exists(d):
         | 
| 127 | 
            +
                if not os.path.exists(d):
         | 
| 128 | 
            +
                    os.makedirs(d)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
    	
        zero_shot.py
    ADDED
    
    | @@ -0,0 +1,402 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import scipy
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            from datetime import datetime
         | 
| 11 | 
            +
            import importlib
         | 
| 12 | 
            +
            import json
         | 
| 13 | 
            +
            import argparse
         | 
| 14 | 
            +
            from omegaconf import OmegaConf
         | 
| 15 | 
            +
            from snk.loss import PrismRegularizationLoss
         | 
| 16 | 
            +
            from snk.prism_decoder import PrismDecoder
         | 
| 17 | 
            +
            from shape_models.fmap import DFMNet
         | 
| 18 | 
            +
            from shape_models.encoder import Encoder
         | 
| 19 | 
            +
            from diffu_models.losses import VELoss, VPLoss, EDMLoss
         | 
| 20 | 
            +
            from diffu_models.sds import guidance_grad
         | 
| 21 | 
            +
            from utils.torch_fmap import torch_zoomout, knnsearch, extract_p2p_torch_fmap
         | 
| 22 | 
            +
            from utils.utils_func import convert_dict, str_delta, ensure_pretrained_file
         | 
| 23 | 
            +
            from utils.eval import accuracy
         | 
| 24 | 
            +
            from utils.mesh import save_ply, load_mesh
         | 
| 25 | 
            +
            from shape_data import get_data_dirs
         | 
| 26 | 
            +
            from utils.pickle_stuff import safe_load_with_fallback
         | 
| 27 | 
            +
            from utils.geometry import compute_operators, load_operators
         | 
| 28 | 
            +
            from utils.surfaces import Surface
         | 
| 29 | 
            +
            import sys
         | 
| 30 | 
            +
            try:
         | 
| 31 | 
            +
                import google.colab
         | 
| 32 | 
            +
                print("Running Colab")
         | 
| 33 | 
            +
                from tqdm import tqdm
         | 
| 34 | 
            +
            except ImportError:
         | 
| 35 | 
            +
                print("Running local")
         | 
| 36 | 
            +
                from tqdm.auto import tqdm
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def seed_everything(seed=42):
         | 
| 40 | 
            +
              random.seed(seed)
         | 
| 41 | 
            +
              os.environ['PYTHONHASHSEED'] = str(seed)
         | 
| 42 | 
            +
              np.random.seed(seed)
         | 
| 43 | 
            +
              torch.manual_seed(seed)
         | 
| 44 | 
            +
              torch.backends.cudnn.deterministic = True
         | 
| 45 | 
            +
              torch.backends.cudnn.benchmark = False
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            seed_everything()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            class Tee:
         | 
| 50 | 
            +
                def __init__(self, *outputs):
         | 
| 51 | 
            +
                    self.outputs = outputs
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def write(self, message):
         | 
| 54 | 
            +
                    for output in self.outputs:
         | 
| 55 | 
            +
                        output.write(message)
         | 
| 56 | 
            +
                        output.flush()  # ensure it's written immediately
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def flush(self):
         | 
| 59 | 
            +
                    for output in self.outputs:
         | 
| 60 | 
            +
                        output.flush()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            class DiffModel:
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def __init__(self, cfg, device="cuda:0"):
         | 
| 65 | 
            +
                    if cfg["train_dir"] == "pretrained":
         | 
| 66 | 
            +
                        url = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/network-snapshot-041216.pkl"
         | 
| 67 | 
            +
                        network_pkl = ensure_pretrained_file(url, "pretrained")
         | 
| 68 | 
            +
                        url_json = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/training_options.json"
         | 
| 69 | 
            +
                        json_filename = ensure_pretrained_file(url_json, "pretrained", filename="training_options.json")
         | 
| 70 | 
            +
                        train_cfg = json.load(open(json_filename))
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        num_exp = cfg["diff_num_exp"]
         | 
| 73 | 
            +
                        files = os.listdir(cfg["train_dir"])
         | 
| 74 | 
            +
                        for file in files:
         | 
| 75 | 
            +
                            if file[:5] == f"{num_exp:05d}":
         | 
| 76 | 
            +
                                netdir = os.path.join(cfg["train_dir"], file)
         | 
| 77 | 
            +
                        train_cfg = json.load(open(os.path.join(netdir, "training_options.json")))
         | 
| 78 | 
            +
                        pkls = [f for f in os.listdir(netdir) if ".pkl" in f]
         | 
| 79 | 
            +
                        nice_pkls = sorted(pkls, key=lambda x: int(x.split(".")[0].split("-")[-1]))
         | 
| 80 | 
            +
                        chosen_pkl = nice_pkls[-1]
         | 
| 81 | 
            +
                        network_pkl = os.path.join(netdir, chosen_pkl)
         | 
| 82 | 
            +
                    print(f'Loading network from "{network_pkl}"...')
         | 
| 83 | 
            +
                    self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    print('Done!')
         | 
| 86 | 
            +
                    loss_name = train_cfg['hyper_params']['loss_name']
         | 
| 87 | 
            +
                    self.loss_sde = None
         | 
| 88 | 
            +
                    if loss_name == "EDMLoss":
         | 
| 89 | 
            +
                        self.loss_sde = EDMLoss()
         | 
| 90 | 
            +
                    elif loss_name == "VPLoss":
         | 
| 91 | 
            +
                        self.loss_sde = VPLoss()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class Matcher(object):
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(self, cfg):
         | 
| 97 | 
            +
                    self.cfg = cfg
         | 
| 98 | 
            +
                    self.device = torch.device(f'cuda:{cfg["gpu"]}' if torch.cuda.is_available() else 'cpu')
         | 
| 99 | 
            +
                    self.diffusion_model = None
         | 
| 100 | 
            +
                    if self.cfg.get("sds", False):
         | 
| 101 | 
            +
                        self.diffusion_model = DiffModel(cfg["sds_conf"])
         | 
| 102 | 
            +
                    self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
         | 
| 103 | 
            +
                    self.n_loop = 0
         | 
| 104 | 
            +
                    if self.cfg.get("optimize", False):
         | 
| 105 | 
            +
                        self.n_loop = self.cfg.opt.get("n_loop", 0)
         | 
| 106 | 
            +
                    self.snk = self.cfg.get("snk", False)
         | 
| 107 | 
            +
                    self.fmap_cfg = self.cfg.deepfeat_conf.fmap
         | 
| 108 | 
            +
                    self.dataloaders = dict()
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def reconf(self, cfg):
         | 
| 111 | 
            +
                    self.cfg = cfg
         | 
| 112 | 
            +
                    self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
         | 
| 113 | 
            +
                    self.n_loop = 0
         | 
| 114 | 
            +
                    if self.cfg.get("optimize", False):
         | 
| 115 | 
            +
                        self.n_loop = self.cfg.opt.get("n_loop", 0)
         | 
| 116 | 
            +
                    self.fmap_cfg = self.cfg.deepfeat_conf.fmap
         | 
| 117 | 
            +
                    self.dataloaders = dict()
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                def _init(self):
         | 
| 120 | 
            +
                    cfg = self.cfg
         | 
| 121 | 
            +
                    self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
         | 
| 122 | 
            +
                    if self.snk:
         | 
| 123 | 
            +
                        self.encoder = Encoder().to(self.device)
         | 
| 124 | 
            +
                        self.decoder = PrismDecoder(dim_in=515).to(self.device)
         | 
| 125 | 
            +
                        self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
         | 
| 126 | 
            +
                        self.soft_p2p = True
         | 
| 127 | 
            +
                        params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(self.decoder.parameters())
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        params_to_opt = self.fmap_model.parameters()
         | 
| 130 | 
            +
                    self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
         | 
| 131 | 
            +
                    self.eye = torch.eye(self.n_fmap).float().to(self.device)
         | 
| 132 | 
            +
                    self.eye.requires_grad = False
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def fmap(self, shape_dict, target_dict):
         | 
| 135 | 
            +
                    if self.fmap_cfg.get("use_diff", False):
         | 
| 136 | 
            +
                        C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model, scale=self.fmap_cfg.diffusion.time)
         | 
| 137 | 
            +
                        C12_pred, C12_obj, mask_12 = C12_pred
         | 
| 138 | 
            +
                        C21_pred, C21_obj, mask_21 = C21_pred
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict})
         | 
| 141 | 
            +
                        C12_obj, C21_obj = C12_pred, C21_pred
         | 
| 142 | 
            +
                        mask_12, mask_21 = None, None
         | 
| 143 | 
            +
                    return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
         | 
| 144 | 
            +
                
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def zo_shot(self, shape_dict, target_dict):
         | 
| 147 | 
            +
                    self._init()
         | 
| 148 | 
            +
                    evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
         | 
| 149 | 
            +
                    _, C12_mask_init, _, _, _, _, _ , _, _, _ = self.fmap(shape_dict, target_dict)
         | 
| 150 | 
            +
                    evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
         | 
| 151 | 
            +
                    new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
         | 
| 152 | 
            +
                    indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
         | 
| 153 | 
            +
                    return new_FM, indKNN_new
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
                def optimize(self, shape_dict, target_dict, target_normals):
         | 
| 157 | 
            +
                    self._init()
         | 
| 158 | 
            +
                    evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
         | 
| 159 | 
            +
                    C12_pred_init, _, _, _ , _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
         | 
| 160 | 
            +
                    evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
         | 
| 161 | 
            +
                    evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
         | 
| 162 | 
            +
                    n_verts_target = target_dict["vertices"].shape[-2]
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                    loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [], "proper": []}
         | 
| 165 | 
            +
                    snk_rec = None
         | 
| 166 | 
            +
                    for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
         | 
| 167 | 
            +
                        C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
         | 
| 168 | 
            +
                        if self.cfg.opt.soft_p2p:
         | 
| 169 | 
            +
                            ### A la SNK
         | 
| 170 | 
            +
                            ## P2P 2 -> 1
         | 
| 171 | 
            +
                            soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap], prod=True)
         | 
| 172 | 
            +
                            C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
         | 
| 173 | 
            +
                            soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                            ## P2P 1 -> 2 
         | 
| 176 | 
            +
                            soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap], prod=True)
         | 
| 177 | 
            +
                            C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
         | 
| 178 | 
            +
                            soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                            l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
         | 
| 181 | 
            +
                        else:
         | 
| 182 | 
            +
                            C12_new, C21_new = C12_pred, C21_pred
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye)**2).mean() + ((C21_new.squeeze() @ C21_new.squeeze().T - self.eye)**2).mean()
         | 
| 185 | 
            +
                        l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye)**2).mean() + ((C21_new.squeeze() @ C12_new.squeeze() - self.eye)**2).mean()
         | 
| 186 | 
            +
                        l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(target_dict["evals"][:self.n_fmap]) @ C12_new)**2).mean()
         | 
| 187 | 
            +
                        l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(shape_dict["evals"][:self.n_fmap]) @ C21_new)**2).mean()
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
                        l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
         | 
| 191 | 
            +
                        if self.snk:
         | 
| 192 | 
            +
                            # Latent vector 
         | 
| 193 | 
            +
                            latents = self.encoder(shape_dict)
         | 
| 194 | 
            +
                            latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                            # Prism decoder
         | 
| 197 | 
            +
                            feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
         | 
| 198 | 
            +
                            snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
         | 
| 199 | 
            +
                            l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
         | 
| 200 | 
            +
                            l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec)**2).sum(dim=-1).mean()
         | 
| 201 | 
            +
                            l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
         | 
| 202 | 
            +
                        l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
         | 
| 203 | 
            +
                        if self.fmap_cfg.get("use_diff", False):
         | 
| 204 | 
            +
                            if self.fmap_cfg.diffusion.get("abs", False):
         | 
| 205 | 
            +
                                C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
         | 
| 206 | 
            +
                            else:
         | 
| 207 | 
            +
                                C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
         | 
| 208 | 
            +
                            grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds, 
         | 
| 209 | 
            +
                                                       scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
         | 
| 210 | 
            +
                            with torch.no_grad():
         | 
| 211 | 
            +
                                denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
         | 
| 212 | 
            +
                            targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)   
         | 
| 213 | 
            +
                                         
         | 
| 214 | 
            +
                            l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                            grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds, 
         | 
| 217 | 
            +
                                                       scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
         | 
| 218 | 
            +
                            #denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
         | 
| 219 | 
            +
                            with torch.no_grad():
         | 
| 220 | 
            +
                                denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21 
         | 
| 221 | 
            +
                            targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(), self.cfg.sds_conf.zoomout)#, step=10)
         | 
| 222 | 
            +
                            l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
         | 
| 223 | 
            +
                            l_proper = l_proper_12 + l_proper_21
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                            l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
         | 
| 226 | 
            +
                            l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
         | 
| 227 | 
            +
                        loss = torch.as_tensor(0.).float().to(self.device)
         | 
| 228 | 
            +
                        if self.cfg.loss.get("ortho", 0) > 0:
         | 
| 229 | 
            +
                            loss += self.cfg.loss.get("ortho", 0) *  l_ortho
         | 
| 230 | 
            +
                        if self.cfg.loss.get("bij", 0) > 0:
         | 
| 231 | 
            +
                            loss += self.cfg.loss.get("bij", 0) *  l_bij
         | 
| 232 | 
            +
                        if self.cfg.loss.get("lap", 0) > 0:
         | 
| 233 | 
            +
                            loss += self.cfg.loss.get("lap", 0) *  l_lap 
         | 
| 234 | 
            +
                        if self.cfg.loss.get("cycle", 0) > 0:
         | 
| 235 | 
            +
                            loss += self.cfg.loss.get("cycle", 0) *  l_cycle
         | 
| 236 | 
            +
                        if self.cfg.loss.get("mse_rec", 0) > 0:
         | 
| 237 | 
            +
                            loss += self.cfg.loss.get("mse_rec", 0) *  l_mse
         | 
| 238 | 
            +
                        if self.cfg.loss.get("prism_rec", 0) > 0:
         | 
| 239 | 
            +
                            loss += self.cfg.loss.get("prism_rec", 0) *  l_prism
         | 
| 240 | 
            +
                        if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
         | 
| 241 | 
            +
                            loss += self.cfg.loss.get("sds", 0) * l_sds
         | 
| 242 | 
            +
                        if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
         | 
| 243 | 
            +
                            loss += self.cfg.loss.get("proper", 0) * l_proper
         | 
| 244 | 
            +
                    
         | 
| 245 | 
            +
                        loss.backward()
         | 
| 246 | 
            +
                        self.optim.step()
         | 
| 247 | 
            +
                        self.optim.zero_grad()
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                        loss_save["cycle"].append(l_cycle.item())
         | 
| 250 | 
            +
                        loss_save["ortho"].append(l_ortho.item())
         | 
| 251 | 
            +
                        loss_save["bij"].append(l_bij.item())
         | 
| 252 | 
            +
                        loss_save["sds"].append(l_sds.item())
         | 
| 253 | 
            +
                        loss_save["proper"].append(l_proper.item())
         | 
| 254 | 
            +
                        loss_save["mse"].append(l_mse.item())
         | 
| 255 | 
            +
                        loss_save["prism"].append(l_prism.item())
         | 
| 256 | 
            +
                    indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
         | 
| 257 | 
            +
                    indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
         | 
| 258 | 
            +
                    return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
             | 
| 261 | 
            +
             | 
| 262 | 
            +
                def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
         | 
| 263 | 
            +
                    shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch 
         | 
| 264 | 
            +
                    shape_dict_device = convert_dict(shape_dict, self.device)
         | 
| 265 | 
            +
                    target_dict_device = convert_dict(target_dict, self.device)
         | 
| 266 | 
            +
                    print(shape_dict_device["vertices"].device)
         | 
| 267 | 
            +
                    os.makedirs(output_pair, exist_ok=True)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
                    if self.cfg["optimize"]:
         | 
| 271 | 
            +
                        C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device, target_normals.to(self.device))
         | 
| 272 | 
            +
                        np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
         | 
| 273 | 
            +
                        np.save(os.path.join(output_pair, "losses.npy"), loss_save)
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        C12_new, p2p = self.zo_shot(shape_dict_device, target_dict_device)
         | 
| 276 | 
            +
                        snk_rec, loss_save = None, None
         | 
| 277 | 
            +
                    np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
         | 
| 278 | 
            +
                    np.save(os.path.join(output_pair, "p2p.npy"), p2p)
         | 
| 279 | 
            +
                    if snk_rec is not None:
         | 
| 280 | 
            +
                        save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(), target_dict["faces"])
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if refine:
         | 
| 283 | 
            +
                        evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
         | 
| 284 | 
            +
                        evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
         | 
| 285 | 
            +
                        new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128)#, step=10)
         | 
| 286 | 
            +
                        p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
         | 
| 287 | 
            +
                        np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
         | 
| 288 | 
            +
                    if eval:
         | 
| 289 | 
            +
                        file_i, vts_1, vts_2 = mapinfo
         | 
| 290 | 
            +
                        mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
         | 
| 291 | 
            +
                        A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
         | 
| 292 | 
            +
                        _, dist = accuracy(p2p[vts_2], vts_1, A_geod,
         | 
| 293 | 
            +
                                                        sqrt_area=sqrt_area,
         | 
| 294 | 
            +
                                                        return_all=True)
         | 
| 295 | 
            +
                        if refine:
         | 
| 296 | 
            +
                            _, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
         | 
| 297 | 
            +
                                                        sqrt_area=sqrt_area,
         | 
| 298 | 
            +
                                                        return_all=True)
         | 
| 299 | 
            +
                            np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
         | 
| 300 | 
            +
                            return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
         | 
| 301 | 
            +
                        return p2p, loss_save, dist.mean()
         | 
| 302 | 
            +
                    return p2p, loss_save
         | 
| 303 | 
            +
                    
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
             | 
| 307 | 
            +
                def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
         | 
| 308 | 
            +
                    os.makedirs(save_dir, exist_ok=True)
         | 
| 309 | 
            +
                    # dloader = DataLoader(dataset, collate_fn=collate_default, batch_size=1)
         | 
| 310 | 
            +
                    num_pairs = len(dataset)
         | 
| 311 | 
            +
                    id_pair = 0
         | 
| 312 | 
            +
                    all_accs = []
         | 
| 313 | 
            +
                    all_accs_zo = []
         | 
| 314 | 
            +
                    t1 = datetime.now()
         | 
| 315 | 
            +
                    save_txt = os.path.join(save_dir, "log.txt")
         | 
| 316 | 
            +
                    # Open a file for writing
         | 
| 317 | 
            +
                    log_file = open(save_txt, 'w')
         | 
| 318 | 
            +
                    # Replace sys.stdout with Tee that writes to both console and file
         | 
| 319 | 
            +
                    sys.stdout = Tee(sys.__stdout__, log_file)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    for batch in dset:
         | 
| 322 | 
            +
                        shape_dict, _, target_dict, _, _, _ = batch
         | 
| 323 | 
            +
                        print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
         | 
| 324 | 
            +
                        name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
         | 
| 325 | 
            +
                        if self.cfg.get("refine", False):
         | 
| 326 | 
            +
                            _, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=True)
         | 
| 327 | 
            +
                        else:
         | 
| 328 | 
            +
                            _, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=False)
         | 
| 329 | 
            +
                        delta = datetime.now() - t1
         | 
| 330 | 
            +
                        fm_delta = str_delta(delta)
         | 
| 331 | 
            +
                        remains = ((delta/(id_pair+1))*num_pairs) - delta
         | 
| 332 | 
            +
                        fm_remains = str_delta(remains)
         | 
| 333 | 
            +
                        all_accs.append(dist)
         | 
| 334 | 
            +
                        accs_mean = np.mean(all_accs)
         | 
| 335 | 
            +
                        if self.cfg.get("refine", False):
         | 
| 336 | 
            +
                            all_accs_zo.append(dist_zo)
         | 
| 337 | 
            +
                            accs_zo = np.mean(all_accs_zo)
         | 
| 338 | 
            +
                            print(f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
         | 
| 339 | 
            +
                        else:
         | 
| 340 | 
            +
                            print(f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
         | 
| 341 | 
            +
                        id_pair += 1
         | 
| 342 | 
            +
                    if self.cfg.get("refine", False):
         | 
| 343 | 
            +
                        print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        print(f"mean error : {np.mean(all_accs)}")
         | 
| 346 | 
            +
                    sys.stdout = sys.__stdout__ 
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
         | 
| 349 | 
            +
                    name = os.path.basename(os.path.splitext(file)[0])
         | 
| 350 | 
            +
                    cache_file = "single_" + name + ".npz"
         | 
| 351 | 
            +
                    verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True)
         | 
| 352 | 
            +
                    cache_path = os.path.join(self.cfg.cache, cache_file)
         | 
| 353 | 
            +
                    print("Cache is: ", cache_path)
         | 
| 354 | 
            +
                    if not os.path.exists(cache_path) or make_cache:
         | 
| 355 | 
            +
                        print("Computing operators ...")
         | 
| 356 | 
            +
                        compute_operators(verts_shape, faces, vnormals, num_evecs, cache_path, force_save=make_cache)
         | 
| 357 | 
            +
                    data_dict = load_operators(cache_path)
         | 
| 358 | 
            +
                    data_dict['name'] = name
         | 
| 359 | 
            +
                    data_dict_torch = convert_dict(data_dict, self.device)
         | 
| 360 | 
            +
                    #batchify_dict(data_dict_torch)
         | 
| 361 | 
            +
                    return data_dict_torch, area_shape
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def match_files(self, file_shape, file_target):
         | 
| 364 | 
            +
                    batch_shape, _ = self.load_data(file_shape)
         | 
| 365 | 
            +
                    batch_target, _ = self.load_data(file_target) 
         | 
| 366 | 
            +
                    target_surf = Surface(filename=file_target)
         | 
| 367 | 
            +
                    target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
         | 
| 368 | 
            +
                    batch = batch_shape, None, batch_target, target_normals, None, None
         | 
| 369 | 
            +
                    output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
         | 
| 370 | 
            +
                    p2p, _ = self.match(batch, output_folder, None)
         | 
| 371 | 
            +
                    return batch_shape, batch_target, p2p
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
             | 
| 375 | 
            +
             | 
| 376 | 
            +
            if __name__ == '__main__':
         | 
| 377 | 
            +
                parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
         | 
| 378 | 
            +
                parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
         | 
| 379 | 
            +
                parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')    
         | 
| 380 | 
            +
                parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
         | 
| 381 | 
            +
                parser.add_argument('--output', type=str, default="results", help="where to store experience results")
         | 
| 382 | 
            +
                args = parser.parse_args()
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                arg_cfg = OmegaConf.from_dotlist(
         | 
| 385 | 
            +
                    [f"{k}={v}" for k, v in vars(args).items() if v is not None]
         | 
| 386 | 
            +
                )
         | 
| 387 | 
            +
                yaml_cfg = OmegaConf.load(args.config)
         | 
| 388 | 
            +
                cfg = OmegaConf.merge(yaml_cfg, arg_cfg)
         | 
| 389 | 
            +
                dataset_name = args.dataset.lower()
         | 
| 390 | 
            +
                if cfg.get("oriented", False):
         | 
| 391 | 
            +
                    dataset_name += "_ori"
         | 
| 392 | 
            +
                shape_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapeDataset')
         | 
| 393 | 
            +
                pair_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapePairDataset')
         | 
| 394 | 
            +
                data_dir, name_data_geo, corr_dir = get_data_dirs(args.datadir, dataset_name, 'test')
         | 
| 395 | 
            +
                name_data_geo = "_".join(name_data_geo.split("_")[:2])
         | 
| 396 | 
            +
                dset_shape = shape_cls(data_dir, "cache/fmaps", "test", oriented=cfg.get("oriented", False))
         | 
| 397 | 
            +
                print("Preprocessing shapes done.")
         | 
| 398 | 
            +
                dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
         | 
| 399 | 
            +
                exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
         | 
| 400 | 
            +
                output_logs = os.path.join(args.output, name_data_geo, exp_time)
         | 
| 401 | 
            +
                matcher = Matcher(cfg)
         | 
| 402 | 
            +
                matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)
         |