diffumatch / utils /layers.py
daidedou
first_try
458efe2
raw
history blame
16.3 kB
import sys
import os
import os.path
import random
import scipy
import scipy.sparse.linalg as sla
# ^^^ we NEED to import scipy before torch, or it crashes :(
# (observed on Ubuntu 20.04 w/ torch 1.6.0 and scipy 1.5.2 installed via conda)
import numpy as np
import torch
import torch.nn as nn
ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
sys.path.append(ROOT_DIR)
from diffusion_net.utils import toNP
from diffusion_net.geometry import to_basis, from_basis
class LearnedTimeDiffusion(nn.Module):
"""
Applies diffusion with learned per-channel t.
In the spectral domain this becomes
f_out = e ^ (lambda_i t) f_in
Inputs:
- values: (V,C) in the spectral domain
- L: (V,V) sparse laplacian
- evals: (K) eigenvalues
- mass: (V) mass matrix diagonal
(note: L/evals may be omitted as None depending on method)
Outputs:
- (V,C) diffused values
"""
def __init__(self, C_inout, method='spectral'):
super(LearnedTimeDiffusion, self).__init__()
self.C_inout = C_inout
self.diffusion_time = nn.Parameter(torch.Tensor(C_inout)) # (C)
self.method = method # one of ['spectral', 'implicit_dense']
nn.init.constant_(self.diffusion_time, 0.0)
def forward(self, x, L, mass, evals, evecs):
# project times to the positive halfspace
# (and away from 0 in the incredibly rare chance that they get stuck)
with torch.no_grad():
self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
if x.shape[-1] != self.C_inout:
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
x.shape, self.C_inout))
if self.method == 'spectral':
# Transform to spectral
x_spec = to_basis(x, evecs, mass)
# Diffuse
time = self.diffusion_time
diffusion_coefs = torch.exp(-evals.unsqueeze(-1) * time.unsqueeze(0))
x_diffuse_spec = diffusion_coefs * x_spec
# Transform back to per-vertex
x_diffuse = from_basis(x_diffuse_spec, evecs)
elif self.method == 'implicit_dense':
V = x.shape[-2]
# Form the dense matrices (M + tL) with dims (B,C,V,V)
mat_dense = L.to_dense().unsqueeze(1).expand(-1, self.C_inout, V, V).clone()
mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
mat_dense += torch.diag_embed(mass).unsqueeze(1)
# Factor the system
cholesky_factors = torch.linalg.cholesky(mat_dense)
# Solve the system
rhs = x * mass.unsqueeze(-1)
rhsT = torch.transpose(rhs, 1, 2).unsqueeze(-1)
sols = torch.cholesky_solve(rhsT, cholesky_factors)
x_diffuse = torch.transpose(sols.squeeze(-1), 1, 2)
else:
raise ValueError("unrecognized method")
return x_diffuse
class SpatialGradientFeatures(nn.Module):
"""
Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
Input:
- vectors: (V,C,2)
Output:
- dots: (V,C) dots
"""
def __init__(self, C_inout, with_gradient_rotations=True):
super(SpatialGradientFeatures, self).__init__()
self.C_inout = C_inout
self.with_gradient_rotations = with_gradient_rotations
if (self.with_gradient_rotations):
self.A_re = nn.Linear(self.C_inout, self.C_inout, bias=False)
self.A_im = nn.Linear(self.C_inout, self.C_inout, bias=False)
else:
self.A = nn.Linear(self.C_inout, self.C_inout, bias=False)
# self.norm = nn.InstanceNorm1d(C_inout)
def forward(self, vectors):
vectorsA = vectors # (V,C)
if self.with_gradient_rotations:
vectorsBreal = self.A_re(vectors[..., 0]) - self.A_im(vectors[..., 1])
vectorsBimag = self.A_re(vectors[..., 1]) + self.A_im(vectors[..., 0])
else:
vectorsBreal = self.A(vectors[..., 0])
vectorsBimag = self.A(vectors[..., 1])
dots = vectorsA[..., 0] * vectorsBreal + vectorsA[..., 1] * vectorsBimag
return torch.tanh(dots)
class MiniMLP(nn.Sequential):
'''
A simple MLP with configurable hidden layer sizes.
'''
def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
super(MiniMLP, self).__init__()
for i in range(len(layer_sizes) - 1):
is_last = (i + 2 == len(layer_sizes))
if dropout and i > 0:
self.add_module(name + "_mlp_layer_dropout_{:03d}".format(i), nn.Dropout(p=.5))
# Affine map
self.add_module(
name + "_mlp_layer_{:03d}".format(i),
nn.Linear(
layer_sizes[i],
layer_sizes[i + 1],
),
)
# Nonlinearity
# (but not on the last layer)
if not is_last:
self.add_module(name + "_mlp_act_{:03d}".format(i), activation())
class DiffusionNetBlock(nn.Module):
"""
Inputs and outputs are defined at vertices
"""
def __init__(self,
C_width,
mlp_hidden_dims,
dropout=True,
diffusion_method='spectral',
with_gradient_features=True,
with_gradient_rotations=True):
super(DiffusionNetBlock, self).__init__()
# Specified dimensions
self.C_width = C_width
self.mlp_hidden_dims = mlp_hidden_dims
self.dropout = dropout
self.with_gradient_features = with_gradient_features
self.with_gradient_rotations = with_gradient_rotations
# Diffusion block
self.diffusion = LearnedTimeDiffusion(self.C_width, method=diffusion_method)
self.MLP_C = 2 * self.C_width
if self.with_gradient_features:
self.gradient_features = SpatialGradientFeatures(self.C_width, with_gradient_rotations=self.with_gradient_rotations)
self.MLP_C += self.C_width
# MLPs
self.mlp = MiniMLP([self.MLP_C] + self.mlp_hidden_dims + [self.C_width], dropout=self.dropout)
def forward(self, x_in, mass, L, evals, evecs, gradX, gradY):
# Manage dimensions
B = x_in.shape[0] # batch dimension
if x_in.shape[-1] != self.C_width:
raise ValueError("Tensor has wrong shape = {}. Last dim shape should have number of channels = {}".format(
x_in.shape, self.C_width))
# Diffusion block
x_diffuse = self.diffusion(x_in, L, mass, evals, evecs)
# Compute gradient features, if using
if self.with_gradient_features:
# Compute gradients
x_grads = [
] # Manually loop over the batch (if there is a batch dimension) since torch.mm() doesn't support batching
for b in range(B):
# gradient after diffusion
x_gradX = torch.mm(gradX[b, ...], x_diffuse[b, ...])
x_gradY = torch.mm(gradY[b, ...], x_diffuse[b, ...])
x_grads.append(torch.stack((x_gradX, x_gradY), dim=-1))
x_grad = torch.stack(x_grads, dim=0)
# Evaluate gradient features
x_grad_features = self.gradient_features(x_grad)
# Stack inputs to mlp
feature_combined = torch.cat((x_in, x_diffuse, x_grad_features), dim=-1)
else:
# Stack inputs to mlp
feature_combined = torch.cat((x_in, x_diffuse), dim=-1)
# Apply the mlp
x0_out = self.mlp(feature_combined)
# Skip connection
x0_out = x0_out + x_in
return x0_out
class DiffusionNet(nn.Module):
def __init__(self,
C_in,
C_out,
C_width=128,
N_block=4,
last_activation=None,
outputs_at='vertices',
mlp_hidden_dims=None,
dropout=True,
with_gradient_features=True,
with_gradient_rotations=True,
diffusion_method='spectral',
num_eigenbasis=128):
"""
Construct a DiffusionNet.
Parameters:
C_in (int): input dimension
C_out (int): output dimension
last_activation (func) a function to apply to the final outputs of the network, such as torch.nn.functional.log_softmax (default: None)
outputs_at (string) produce outputs at various mesh elements by averaging from vertices. One of ['vertices', 'edges', 'faces', 'global_mean']. (default 'vertices', aka points for a point cloud)
C_width (int): dimension of internal DiffusionNet blocks (default: 128)
N_block (int): number of DiffusionNet blocks (default: 4)
mlp_hidden_dims (list of int): a list of hidden layer sizes for MLPs (default: [C_width, C_width])
dropout (bool): if True, internal MLPs use dropout (default: True)
diffusion_method (string): how to evaluate diffusion, one of ['spectral', 'implicit_dense']. If implicit_dense is used, can set k_eig=0, saving precompute.
with_gradient_features (bool): if True, use gradient features (default: True)
with_gradient_rotations (bool): if True, use gradient also learn a rotation of each gradient. Set to True if your surface has consistently oriented normals, and False otherwise (default: True)
num_eigenbasis (int): for trunking the eigenvalues eigenvectors
"""
super(DiffusionNet, self).__init__()
## Store parameters
# Basic parameters
self.C_in = C_in
self.C_out = C_out
self.C_width = C_width
self.N_block = N_block
# Outputs
self.last_activation = last_activation
self.outputs_at = outputs_at
if outputs_at not in ['vertices', 'edges', 'faces', 'global_mean']:
raise ValueError("invalid setting for outputs_at")
# MLP options
if mlp_hidden_dims == None:
mlp_hidden_dims = [C_width, C_width]
self.mlp_hidden_dims = mlp_hidden_dims
self.dropout = dropout
# Diffusion
self.diffusion_method = diffusion_method
if diffusion_method not in ['spectral', 'implicit_dense']:
raise ValueError("invalid setting for diffusion_method")
self.num_eigenbasis = num_eigenbasis
# Gradient features
self.with_gradient_features = with_gradient_features
self.with_gradient_rotations = with_gradient_rotations
## Set up the network
# First and last affine layers
self.first_lin = nn.Linear(C_in, C_width)
self.last_lin = nn.Linear(C_width, C_out)
# DiffusionNet blocks
self.blocks = []
for i_block in range(self.N_block):
block = DiffusionNetBlock(C_width=C_width,
mlp_hidden_dims=mlp_hidden_dims,
dropout=dropout,
diffusion_method=diffusion_method,
with_gradient_features=with_gradient_features,
with_gradient_rotations=with_gradient_rotations)
self.blocks.append(block)
self.add_module("block_" + str(i_block), self.blocks[-1])
def forward(self, x_in, mass, L=None, evals=None, evecs=None, gradX=None, gradY=None, edges=None, faces=None):
"""
A forward pass on the DiffusionNet.
In the notation below, dimension are:
- C is the input channel dimension (C_in on construction)
- C_OUT is the output channel dimension (C_out on construction)
- N is the number of vertices/points, which CAN be different for each forward pass
- B is an OPTIONAL batch dimension
- K_EIG is the number of eigenvalues used for spectral acceleration
Generally, our data layout it is [N,C] or [B,N,C].
Call get_operators() to generate geometric quantities mass/L/evals/evecs/gradX/gradY. Note that depending on the options for the DiffusionNet, not all are strictly necessary.
Parameters:
x_in (tensor): Input features, dimension [N,C] or [B,N,C]
mass (tensor): Mass vector, dimension [N] or [B,N]
L (tensor): Laplace matrix, sparse tensor with dimension [N,N] or [B,N,N]
evals (tensor): Eigenvalues of Laplace matrix, dimension [K_EIG] or [B,K_EIG]
evecs (tensor): Eigenvectors of Laplace matrix, dimension [N,K_EIG] or [B,N,K_EIG]
gradX (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
gradY (tensor): Half of gradient matrix, sparse real tensor with dimension [N,N] or [B,N,N]
Returns:
x_out (tensor): Output with dimension [N,C_out] or [B,N,C_out]
"""
## Check dimensions, and append batch dimension if not given
if x_in.shape[-1] != self.C_in:
raise ValueError("DiffusionNet was constructed with C_in={}, but x_in has last dim={}".format(
self.C_in, x_in.shape[-1]))
N = x_in.shape[-2]
if len(x_in.shape) == 2:
appended_batch_dim = True
# add a batch dim to all inputs
x_in = x_in.unsqueeze(0)
mass = mass.unsqueeze(0)
if L != None:
L = L.unsqueeze(0)
if evals != None:
evals = evals.unsqueeze(0)
if evecs != None:
evecs = evecs.unsqueeze(0)
if gradX != None:
gradX = gradX.unsqueeze(0)
if gradY != None:
gradY = gradY.unsqueeze(0)
if edges != None:
edges = edges.unsqueeze(0)
if faces != None:
faces = faces.unsqueeze(0)
elif len(x_in.shape) == 3:
appended_batch_dim = False
else:
raise ValueError("x_in should be tensor with shape [N,C] or [B,N,C]")
evals = evals[..., :self.num_eigenbasis]
evecs = evecs[..., :self.num_eigenbasis]
# Apply the first linear layer
x = self.first_lin(x_in)
# Apply each of the blocks
for b in self.blocks:
x = b(x, mass, L, evals, evecs, gradX, gradY)
# Apply the last linear layer
x = self.last_lin(x)
# Remap output to faces/edges if requested
if self.outputs_at == 'vertices':
x_out = x
elif self.outputs_at == 'edges':
# Remap to edges
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 2)
edges_gather = edges.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
xe = torch.gather(x_gather, 1, edges_gather)
x_out = torch.mean(xe, dim=-1)
elif self.outputs_at == 'faces':
# Remap to faces
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
xf = torch.gather(x_gather, 1, faces_gather)
x_out = torch.mean(xf, dim=-1)
elif self.outputs_at == 'global_mean':
# Produce a single global mean ouput.
# Using a weighted mean according to the point mass/area is discretization-invariant.
# (A naive mean is not discretization-invariant; it could be affected by sampling a region more densely)
x_out = torch.sum(x * mass.unsqueeze(-1), dim=-2) / torch.sum(mass, dim=-1, keepdim=True)
# Apply last nonlinearity if specified
if self.last_activation != None:
x_out = self.last_activation(x_out)
# Remove batch dim if we added it
if appended_batch_dim:
x_out = x_out.squeeze(0)
return x_out