import sys import os import os.path import random import scipy import scipy.sparse.linalg as sla # ^^^ we NEED to import scipy before torch, or it crashes :( # (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda) import numpy as np import torch import torch.nn as nn ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../') if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) from diffusion_net.utils import toNP from diffusion_net.geometry import to_basis, from_basis class LearnedTimeDiffusion(nn.Module): """ Applies diffusion with learned per-channel t. In the spectral domain this becomes f_out = e ^ (lambda_i t) f_in Inputs: - values: (V,C) in the spectral domain - L: (V,V) sparse laplacian - evals: (K) eigenvalues - mass: (V) mass matrix diagonal (note: L/evals may be omitted as None depending on method) Outputs: - (V,C) diffused values """ def __init__(self, C_inout, method='spectral'): super(LearnedTimeDiffusion, self).__init__() self.C_inout = C_inout self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C) self.method = method # one of ['spectral', 'implicit_dense'] nn.init.constant_(self.diffusion_time, 0.0) def forward(self, x, L, mass, evals, evecs): # project times to the positive halfspace # (and away from 0 in the incredibly rare chance that they get stuck) with torch.no_grad(): self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8) if x.shape[-1] != self.C_inout: raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format( x.shape, self.C_inout)) if self.method == 'spectral': # Transform to spectral x_spec = to_basis(x, evecs, mass) # Diffuse time = self.diffusion_time diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0)) x_diffuse_spec = diffusion_coefs * x_spec # Transform back to per-vertex x_diffuse = from_basis(x_diffuse_spec, evecs) elif self.method == 'implicit_dense': V = x.shape[-2] # Form the dense matrices (M + tL) with dims (B,C,V,V) mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone() mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) mat_dense += torch.diag_embed(mass).unsqueeze(1) # Factor the system cholesky_factors = torch.linalg.cholesky(mat_dense) # Solve the system rhs = x * mass.unsqueeze(-1) rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1) sols = torch.cholesky_solve(rhsT, cholesky_factors) x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2) else: raise ValueError("unrecognized method") return x_diffuse class SpatialGradientFeatures(nn.Module): """ Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down. Input: - vectors: (V,C,2) Output: - dots: (V,C) dots """ def __init__(self, C_inout, with_gradient_rotations=True): super(SpatialGradientFeatures, self).__init__() self.C_inout = C_inout self.with_gradient_rotations = with_gradient_rotations if (self.with_gradient_rotations): self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False) self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False) else: self.A = nn.Linear(self.C_inout, self.C_inout, bias=False) # self.norm = nn.InstanceNorm1d(C_inout) def forward(self, vectors): vectorsA = vectors # (V,C) if self.with_gradient_rotations: vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1]) vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0]) else: vectorsBreal = self.A(vectors[..., 0]) vectorsBimag = self.A(vectors[..., 1]) dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag return torch.tanh(dots) class MiniMLP(nn.Sequential): ''' A simple MLP with configurable hidden layer sizes. ''' def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"): super(MiniMLP, self).__init__() for i in range(len(layer_sizes) - 1): is_last = (i + 2 == len(layer_sizes)) if dropout and i > 0: self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5)) # Affine map self.add_module( name + "_mlp_layer_{:03d}".format(i), nn.Linear( layer_sizes[i], layer_sizes[i + 1], ), ) # Nonlinearity # (but not on the last layer) if not is_last: self.add_module(name + "_mlp_act_{:03d}".format(i), activation()) class DiffusionNetBlock(nn.Module): """ Inputs and outputs are defined at vertices """ def __init__(self, C_width, mlp_hidden_dims, dropout=True, diffusion_method='spectral', with_gradient_features=True, with_gradient_rotations=True): super(DiffusionNetBlock, self).__init__() # Specified dimensions self.C_width = C_width self.mlp_hidden_dims = mlp_hidden_dims self.dropout = dropout self.with_gradient_features = with_gradient_features self.with_gradient_rotations = with_gradient_rotations # Diffusion block self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method) self.MLP_C = 2 * self.C_width if self.with_gradient_features: self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations) self.MLP_C += self.C_width # MLPs self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout) def forward(self, x_in, mass, L, evals, evecs, gradX, gradY): # Manage dimensions B = x_in.shape[0] # batch dimension if x_in.shape[-1] != self.C_width: raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format( x_in.shape, self.C_width)) # Diffusion block x_diffuse = self.diffusion(x_in, L, mass, evals, evecs) # Compute gradient features, if using if self.with_gradient_features: # Compute gradients x_grads = [ ] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching for b in range(B): # gradient after diffusion x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...]) x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...]) x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1)) x_grad = torch.stack(x_grads, dim=0) # Evaluate gradient features x_grad_features = self.gradient_features(x_grad) # Stack inputs to mlp feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1) else: # Stack inputs to mlp feature_combined = torch.cat((x_in, x_diffuse), dim=-1) # Apply the mlp x0_out = self.mlp(feature_combined) # Skip connection x0_out = x0_out + x_in return x0_out class DiffusionNet(nn.Module): def __init__(self, C_in, C_out, C_width=128, N_block=4, last_activation=None, outputs_at='vertices', mlp_hidden_dims=None, dropout=True, with_gradient_features=True, with_gradient_rotations=True, diffusion_method='spectral', num_eigenbasis=128): """ Construct a DiffusionNet. Parameters: C_in (int): input dimension C_out (int): output dimension last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None) outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud) C_width (int): dimension of internal DiffusionNet blocks (default: 128) N_block (int): number of DiffusionNet blocks (default: 4) mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width]) dropout (bool): if True, internal MLPs use dropout (default: True) diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute. with_gradient_features (bool): if True, use gradient features (default: True) with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True) num_eigenbasis (int): for trunking the eigenvalues eigenvectors """ super(DiffusionNet, self).__init__() ## Store parameters # Basic parameters self.C_in = C_in self.C_out = C_out self.C_width = C_width self.N_block = N_block # Outputs self.last_activation = last_activation self.outputs_at = outputs_at if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']: raise ValueError("invalid setting for outputs_at") # MLP options if mlp_hidden_dims == None: mlp_hidden_dims = [C_width, C_width] self.mlp_hidden_dims = mlp_hidden_dims self.dropout = dropout # Diffusion self.diffusion_method = diffusion_method if diffusion_method not in ['spectral', 'implicit_dense']: raise ValueError("invalid setting for diffusion_method") self.num_eigenbasis = num_eigenbasis # Gradient features self.with_gradient_features = with_gradient_features self.with_gradient_rotations = with_gradient_rotations ## Set up the network # First and last affine layers self.first_lin = nn.Linear(C_in, C_width) self.last_lin = nn.Linear(C_width, C_out) # DiffusionNet blocks self.blocks = [] for i_block in range(self.N_block): block = DiffusionNetBlock(C_width=C_width, mlp_hidden_dims=mlp_hidden_dims, dropout=dropout, diffusion_method=diffusion_method, with_gradient_features=with_gradient_features, with_gradient_rotations=with_gradient_rotations) self.blocks.append(block) self.add_module("block_" + str(i_block), self.blocks[-1]) def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None): """ A forward pass on the DiffusionNet. In the notation below, dimension are: - C is the input channel dimension (C_in on construction) - C_OUT is the output channel dimension (C_out on construction) - N is the number of vertices/points, which CAN be different for each forward pass - B is an OPTIONAL batch dimension - K_EIG is the number of eigenvalues used for spectral acceleration Generally, our data layout it is [N,C] or [B,N,C]. Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary. Parameters: x_in (tensor): Input features, dimension [N,C] or [B,N,C] mass (tensor): Mass vector, dimension [N] or [B,N] L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N] evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG] evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG] gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N] gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N] Returns: x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out] """ ## Check dimensions, and append batch dimension if not given if x_in.shape[-1] != self.C_in: raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format( self.C_in, x_in.shape[-1])) N = x_in.shape[-2] if len(x_in.shape) == 2: appended_batch_dim = True # add a batch dim to all inputs x_in = x_in.unsqueeze(0) mass = mass.unsqueeze(0) if L != None: L = L.unsqueeze(0) if evals != None: evals = evals.unsqueeze(0) if evecs != None: evecs = evecs.unsqueeze(0) if gradX != None: gradX = gradX.unsqueeze(0) if gradY != None: gradY = gradY.unsqueeze(0) if edges != None: edges = edges.unsqueeze(0) if faces != None: faces = faces.unsqueeze(0) elif len(x_in.shape) == 3: appended_batch_dim = False else: raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]") evals = evals[..., :self.num_eigenbasis] evecs = evecs[..., :self.num_eigenbasis] # Apply the first linear layer x = self.first_lin(x_in) # Apply each of the blocks for b in self.blocks: x = b(x, mass, L, evals, evecs, gradX, gradY) # Apply the last linear layer x = self.last_lin(x) # Remap output to faces/edges if requested if self.outputs_at == 'vertices': x_out = x elif self.outputs_at == 'edges': # Remap to edges x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2) edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1) xe = torch.gather(x_gather, 1, edges_gather) x_out = torch.mean(xe, dim=-1) elif self.outputs_at == 'faces': # Remap to faces x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3) faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1) xf = torch.gather(x_gather, 1, faces_gather) x_out = torch.mean(xf, dim=-1) elif self.outputs_at == 'global_mean': # Produce a single global mean ouput. # Using a weighted mean according to the point mass/area is discretization-invariant. # (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely) x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True) # Apply last nonlinearity if specified if self.last_activation != None: x_out = self.last_activation(x_out) # Remove batch dim if we added it if appended_batch_dim: x_out = x_out.squeeze(0) return x_out